PyTorchでGlobal Average Poolingする際に用いられるtorch.nn.AdaptiveAvgPool2d

PyTorch
AdaptiveAvgPool2d — PyTorch 2.1 documentation

今回は、PyTorchでGlobal Average Poolingを行う方法についてまとめていきます。結論からいうと、torch.nn.AdaptiveAvgPool2dを出力サイズ(1,1)として用います。

torchvisionでのGlobal Average Poolingの実装

torchvisionのResNetでどのようにGlobal Average Poolingが実装されているのかを確認してみたところ、以下のように、torch.nn.AdaptiveAvgPool2dが用いられていました。

https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py より

この層によって値がどのように変化するか見てみる。試しに(4,3,256,256)のサイズの画像を入力してみると、self.avgpoolへの入力は(4,2048,8,8)で、self.avgpoolによって(4,2048,1,1)と形状が変化していることがわかります。ここで、2048チャンネルの特徴量がそれぞれ平均され、各チャンネルが(8,8)から(1,1)のサイズへプーリングされます。

class GAP(nn.Module):
    def __init__(self):
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))
        
    def forward(self, input):
        output = self.avgpool(input)
        print(f'{input.shape} -> {output.shape}') # Global Average Pooling内で形状がどう変化するか確認
        return output
    
model = torchvision.models.resnet50(weights=None)
model.avgpool = GAP() # 入出力サイズの確認のため、GAP層だけ置き換える

input = torch.rand(4,3,256,256)
output = model(input)
# torch.Size([4, 2048, 8, 8]) -> torch.Size([4, 2048, 1, 1])

torch.nn.AdaptiveAvgPool2dの使い方

torch.nn.AdaptiveAvgPool2dの使い方は非常にシンプルで、引数output_sizeとしてこの層が出力するサイズを指定するだけで、入力のサイズに応じてプーリングをしてくれます。

Global Average Poolingとして用いるためには、output_size=(1,1)としてやるだけです。

avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))

1チャンネル目が0~0.5、2チャンネル目が0.5~1となる値をGlobal Average Poolingさせてみます。それぞれのチャンネルの平均が取られ、約0.25, 0.75が出力されることが確認できました。

adaptiveavgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))

h = 8
w = h
c = 2

input = torch.linspace(0, 1, c*h*w, dtype=torch.float).reshape(1, c, h, w)
output = adaptiveavgpool(input)

plt.figure(figsize=(10,4))

plt.subplot(1,2,1)
im1 = plt.imshow(input[0,0,...])
plt.colorbar(im1)
im1.set_clim([0,1])
plt.title(f'mean:{float(output[0,0])}') # 1チャンネル目の平均値

plt.subplot(1,2,2)
im2 = plt.imshow(input[0,1,...])
plt.colorbar(im2)
im1.set_clim([0,1])
plt.title(f'mean:{float(output[0,1])}') # 2チャンネル目の平均値

plt.show()

AdaptiveAvgPool2dとAvgPool2dの違い

PyTorchには同じくAveragePoolingを行うモジュールに、torch.nn.AvgPool2dがあります。これらの違いについて確認していきます。

AdaptiveAvgPool2doutput_sizeを指定することで、プーリングを行うカーネルサイズ等を指定する必要がなく、入力サイズに応じてプーリングしてくれるというものでした。

これに対し、AvgPool2dではkernel_sizestrideを指定するため、出力サイズは入力サイズに応じて変動します。

# Global Average Pooling
adaptiveavgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))

# AvgPool2d ... 今回は(2,2)ずつ平均を取るプーリングとなる
avgpool = torch.nn.AvgPool2d(kernel_size=2, stride=2)

# AdaptiveAvgPoolの場合
output = adaptiveavgpool(input)
print(f'{input.shape} -> {output.shape}')
# torch.Size([1, 2, 8, 8]) -> torch.Size([1, 2, 1, 1])

# AvgPoolの場合
print(f'{input.shape} -> {output.shape}')
# torch.Size([1, 2, 8, 8]) -> torch.Size([1, 2, 4, 4])

コメント

タイトルとURLをコピーしました