PyTorchで勾配クリッピングをする方法

PyTorch

ニューラルネットワークの学習時に勾配爆発が発生しないようにする方法として、勾配クリッピング(Gradient Clipping)があります。PyTorchでこれを行うための方法についてまとめました。

勾配クリッピングとは

学習時に勾配が大きくなりすぎてしまい、学習が不安定になってしまうという問題のことを勾配爆発問題(Exploding gradients problem) といいます。これを防ぐための手法として、勾配クリッピングがあります。

勾配クリッピングは、勾配があらかじめ設定しておいた閾値以上となってしまった場合に、勾配の値を閾値にクリッピングするという手法です。下図では、実線が勾配クリッピングなしの場合の様子を表しており、破線が勾配クリッピングありを示しています。極端に大きな勾配になってしまった際に、パラメータが大きく更新されることを防ぐ効果があることがわかります。

Pascanu, Razvan, Tomas Mikolov, and Yoshua Bengio. “On the difficulty of training recurrent neural networks.” International conference on machine learning. PMLR, 2013. より引用

torch.nn.utils.clip_grad_norm_

PyTorchで勾配クリッピングを行う際には、torch.nn.utils.clip_grad_norm_を用いるのが便利です。勾配クリッピングを行うタイミングは、ロスから勾配を計算するのとパラメータを更新する間、つまりloss.backward()とoptimizer.step()の間で行います。

torch.nn.utils.clip_grad_norm_ — PyTorch 2.1 documentation

以下では、全結合層2層のDNNの勾配を実際に確認して、勾配クリッピングがどのようにされるのか確認しました。

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(4, 2)
        self.linear2 = nn.Linear(2, 2)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, inputs):
        x = self.sigmoid(self.linear1(inputs))
        x = self.sigmoid(self.linear2(x))
        return x
    
model = Model()

inputs = torch.rand([2, 4])
targets = 1.

outputs = model(inputs)

loss = torch.abs(outputs - targets).mean()
loss.backward()



# 勾配クリッピング前の値を確認
print(model.linear1.weight.grad)
print(model.linear1.bias.grad)
print(model.linear2.weight.grad)
print(model.linear2.bias.grad)
# tensor([[-0.0059, -0.0062, -0.0023, -0.0105],
#         [-0.0049, -0.0055, -0.0018, -0.0091]])
# tensor([-0.0113, -0.0098])
# tensor([[-0.0635, -0.0808],
#         [-0.0569, -0.0723]])
# tensor([-0.1204, -0.1078])

今回はパラメータ全体の勾配のL2ノルムが0.1と制限されるように勾配クリッピングを行いました。確認すると、確かにノルムが0.1となっていることがわかります。

norm_type = 2 # L2ノルム
grad_clip = 0.1 # ノルムの閾値を設定

# 勾配クリッピングを行う
grad_norm_before = nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=grad_clip, norm_type=norm_type)
# 戻り値が不要であればこう書いてよい
# nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=grad_clip, norm_type=norm_type) 



# 勾配クリッピング後の値を確認
print(model.linear1.weight.grad)
print(model.linear1.bias.grad)
print(model.linear2.weight.grad)
print(model.linear2.bias.grad)
# tensor([[-0.0028, -0.0029, -0.0011, -0.0049],
#         [-0.0023, -0.0026, -0.0009, -0.0043]])
# tensor([-0.0053, -0.0046])
# tensor([[-0.0297, -0.0378],
#         [-0.0266, -0.0338]])
# tensor([-0.0563, -0.0504])



# 勾配のノルムが指定した値になっているか確認
tmp1 = torch.cat([model.linear1.weight.grad, model.linear1.bias.grad.unsqueeze(1)], dim=1)
tmp2 = torch.cat([model.linear2.weight.grad, model.linear2.bias.grad.unsqueeze(1)], dim=1)
tmp = torch.cat([tmp1, tmp2], dim=1)
grad_norm_after = torch.norm(tmp, p=2)

print('Before:', grad_norm_before)
print('After :', grad_norm_after)
# Before: tensor(0.2138)
# After : tensor(0.1000)

torch.nn.utils.clip_grad_value_との違い

PyTorchには、clip_grad_norm_と非常に似た、torch.nn.utils.clip_grad_value_もあります。

torch.nn.utils.clip_grad_value_ — PyTorch 2.1 documentation

clip_grad_norm_では、さきほど確認したようにパラメータ全体の勾配のノルムを指定した値になるよう制限します。このため、勾配全体がスケーリングされます。一方、clip_grad_value_では閾値を超えた勾配のみがクリッピングされます。

さきほどと同様に確認します。閾値を0.02としたことで、1層目の勾配は変化せず、2層目の勾配はすべて0.02となってしまいました。

# 勾配クリッピング前の値を確認
print(model.linear1.weight.grad)
print(model.linear1.bias.grad)
print(model.linear2.weight.grad)
print(model.linear2.bias.grad)
# tensor([[-0.0046, -0.0017, -0.0033, -0.0028],
#         [-0.0016, -0.0006, -0.0011, -0.0009]])
# tensor([-0.0067, -0.0023])
# tensor([[-0.0594, -0.0518],
#         [-0.0564, -0.0492]])
# tensor([-0.1234, -0.1171])



grad_clip = 0.02 # 勾配の閾値

# 勾配クリッピングを行う
nn.utils.clip_grad_value_(parameters=model.parameters(), clip_value=grad_clip)



# 勾配クリッピング後の値を確認
print(model.linear1.weight.grad)
print(model.linear1.bias.grad)
print(model.linear2.weight.grad)
print(model.linear2.bias.grad)
# tensor([[-0.0046, -0.0017, -0.0033, -0.0028],
#         [-0.0016, -0.0006, -0.0011, -0.0009]])
# tensor([-0.0067, -0.0023])
# tensor([[-0.0200, -0.0200],
#         [-0.0200, -0.0200]])
# tensor([-0.0200, -0.0200])

コメント

タイトルとURLをコピーしました