【PyTorch】GroupNormの動作を確認してみる

PyTorch

PyTorchでGroup Normalizationを行うtorch.nn.GroupNormについて調べた内容をまとめます。

GroupNorm — PyTorch 2.1 documentation

Group Normalizaitonとは

Group Normalizationは、2018年に登場したニューラルネットワークの標準化手法で、入力チャンネルをいくつかのグループに分割し、グループごとに標準化を行います。Batch Normalizationにはバッチサイズが小さい場合に性能が低下してしまうという問題点があり、バッチサイズによらない手法としてはLayer NormalizationとInstance Normalizationがあります。

Group Normalizationはそれら2手法の中間に位置するような手法となっており、バッチサイズによらず安定した学習が可能で、Layer NormalizationとInstance Normalizationよりも良い性能を示すとのことです。

Wu, Yuxin, and Kaiming He. “Group normalization.” Proceedings of the European conference on computer vision (ECCV). 2018.
APA より引用
Wu, Yuxin, and Kaiming He. “Group normalization.” Proceedings of the European conference on computer vision (ECCV). 2018.
APA より引用

torch.nn.GroupNormの使い方

GroupNormの引数には、num_groupsnum_channelsを指定する必要があります。num_channelsでは入力のチャンネル数を指定し、num_groupsではそれをいくつのグループに分割するかを指定します。ここでは6チャンネルの入力を考えているため、num_channels=6となります。

import torch
import torch.nn as nn

input = torch.rand(4,6,2,2)
m = nn.GroupNorm(3,6) # 3つのグループごとに標準化 
# m = nn.GroupNorm(6, 6) # 全チャンネルを独立に標準化 = Instance Normalizationと同じ
# m = nn.GroupNorm(1, 6) # 全チャンネルで標準化 = Layer Normalizationと同じ
output = m(input)

ちなみに、num_groups=num_channelsの場合、全チャンネルを独立に標準化するためInstance Normalizationと同一の処理となります。また、num_groups=1の場合、全チャンネルを1つのグループとしてまとめて標準化するため、Layer Normalizationと同一の処理となります。

参考

[1] Wu, Yuxin, and Kaiming He. “Group normalization.” Proceedings of the European conference on computer vision (ECCV). 2018.

[2] https://pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html

コメント

タイトルとURLをコピーしました