PyTorchの量子化をかるく動かしてみる【Quantization】

PyTorch

はじめに

PyTorchでは、ディープラーニングのモデルを量子化する様々な方法が用意されています。今回はPyTorchでサポートされている量子化方法について、PyTorchのドキュメントに記載されているコードをベースに確認していこうと思います。

PyTorchでの量子化について

 深層学習の分野では、重みやバイアスといったパラメータの量子化bit数を下げる変換のことを量子化と呼びます。 PyTorchではモデルは浮動小数点(FP32)でパラメータを保存しており、量子化することでこれらを固定小数点(INT8)へ変換することができます。量子化によりモデルサイズは1/4となり、計算も2~4倍ほど高速になるそうです。

量子化の方法は大きく分けると、FP32で学習したモデルを学習後にINT8へ変換するPost Training Quantization(PTQ) と、学習時に疑似的な量子化を考慮することで量子化による精度低下を抑えるQuantization Aware Training(QAT) の2つがあります。PyTorchではこれらの手法がサポートされており、非常に簡単に量子化をすることができます。

Post Training Quantization

学習後にモデルを量子化するPTQには、重みのみを量子化する動的量子化(Dynamic Quantization)と、Activation(活性化関数をかけた値)も量子化される静的量子化(Static Quantization)の2種類があります。

Dynamic Quantization

torch.quantization.quantize_dynamicへ量子化前のモデルを渡し、量子化したいレイヤーを指定することでDynamic Quantizationができます。以下のコードでいうと、model_fp32はFP32で重みが保存されており、model_int8はそれをINT8へ変換したものになります。
乗算による計算量よりもパラメータの移動がボトルネックとなるようなケースではこの方法が良いようです。

import torch

class M(torch.nn.Module):
    def __init__(self):
        super(M, self).__init__()
        self.fc = torch.nn.Linear(4, 4)

    def forward(self, x):
        x = self.fc(x)
        return x

model_fp32 = M()

model_int8 = torch.quantization.quantize_dynamic(
    model_fp32,  # the original model
    {torch.nn.Linear},  # a set of layers to dynamically quantize
    dtype=torch.qint8)  # the target dtype for quantized weights

実際に各モデルの重みを見てみます。確かに重みがtorch.float32からtorch.qint8へ変換できています。

weight_fp32 = model_fp32.fc.weight
weight_int8 = model_int8.fc.weight() # ()つけないと重みが取得できない

print(f'量子化前:{weight_fp32.dtype}\n{weight_fp32}\n')
#量子化前:torch.float32
#Parameter containing:
#tensor([[-0.2993, -0.1329,  0.3400, -0.1602],
#        [-0.4522,  0.0911, -0.4214, -0.3618],
#        [ 0.2701,  0.1939, -0.3010, -0.0901],
#        [ 0.1691, -0.3158,  0.3606, -0.1663]], requires_grad=True)

print(f'量子化後:{weight_int8.dtype}\n{weight_int8}')
#量子化後:torch.qint8
#tensor([[-0.2979, -0.1312,  0.3405, -0.1596],
#        [-0.4540,  0.0922, -0.4221, -0.3618],
#        [ 0.2696,  0.1951, -0.3015, -0.0887],
#        [ 0.1702, -0.3157,  0.3618, -0.1667]], size=(4, 4), dtype=torch.qint8,
#       quantization_scheme=torch.per_tensor_affine, scale=0.0035468367859721184,
#       zero_point=0)

つづいて、Activationの値を確認します。Dynamic Quantizationではこちらはtorch.float32のままです。また、量子化による誤差によって、モデルの出力値が若干変化することも確認できます。

input_fp32 = torch.randn(1, 4)

output_fp32 = model_fp32(input_fp32)
output_int8 = model_int8(input_fp32)

print(f'入力値:{input_fp32.dtype}\n')
#入力値:torch.float32

print(f'量子化前:{output_fp32.dtype}\n{output_fp32[0,:]}\n')
#量子化前:torch.float32
#tensor([ 0.1221, -0.3486, -0.3761,  0.1072], grad_fn=<SliceBackward0>)

print(f'量子化後:{output_int8.dtype}\n{output_int8[0,:]}')
#量子化後:torch.float32
#tensor([ 0.1152, -0.3552, -0.3827,  0.1081])

Static Quantization

Static Quantizationの場合は、量子化したいモデルにtorch.quantization.QuantStub()とtorch.quantization.DeQuantStub()をforwardで行う計算の最初と最後へとそれぞれ追加します。ここではConvolution層とReLUによるモデルを例にしています。

Static QuantizationではActivationも量子化するための量子化パラメータを決定する必要があり、その作業をキャリブレーションといいます。

また、この例では量子化の精度を向上させるため、Conv層とそれに続く活性化関数のReLUはConvReLU2dというモジュールへ合体されます。これらの間にバッチノーマライゼーションを追加したConvBnReLU2dなども用意されています。(参考ページ)

import torch

class M(torch.nn.Module):
    def __init__(self):
        super(M, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.relu = torch.nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.relu(x)
        x_dequant = self.dequant(x) # 型を比較したいため、dequant前後の値を出力する
        return x_dequant, x

model_fp32 = M()
model_fp32.eval() # 量子化の際は推論モードにする

model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')

# 畳み込み層と活性化関数を合体させる
model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [['conv', 'relu']])

model_fp32_prepared = torch.quantization.prepare(model_fp32_fused)

# Activationのための量子化パラメータを決定するため、キャリブレーションを行う
# ここでは乱数を用いているが、本当は実際のデータセットを使用する
input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)

# モデルをINT8へ変換する
model_int8 = torch.quantization.convert(model_fp32_prepared)

こちらも重みを確認してみます。重みが量子化後にINT8となっていることが確認できます。

weight_fp32 = model_fp32.conv.weight
weight_int8 = model_int8.conv.weight() # ()つけないと重みが取得できない

print(f'量子化前:{weight_fp32.dtype}\n{weight_fp32}\n')
#量子化前:torch.float32
#Parameter containing:
#tensor([[[[0.1745]]]], requires_grad=True)

print(f'量子化後:{weight_int8.dtype}\n{weight_int8}')
#量子化後:torch.qint8
#tensor([[[[0.1738]]]], size=(1, 1, 1, 1), dtype=torch.qint8,
#       quantization_scheme=torch.per_channel_affine,
#       scale=tensor([0.0014], dtype=torch.float64), zero_point=tensor([0]),
#       axis=0)

続いてActivationです。モデルが最後に出力する値は量子化後もFP32のままですが、dequantをする前はINT8となっており、Activationも量子化されていることがわかります。dequantでそれをFP32に戻しているようです。

input_fp32 = torch.randn(1,1,3,3)

output_fp32_dequant, output_fp32 = model_fp32(input_fp32)
output_int8_dequant, output_int8 = model_int8(input_fp32)

# model_fp32
print(f'dequant前 : {output_fp32.dtype}')
#dequant前 : torch.float32
print(f'dequant後 : {output_fp32_dequant.dtype}')
#dequant後 : torch.float32

# model_int8
print(f'dequant前 : {output_int8.dtype}')
#dequant前 : torch.quint8
print(f'dequant後 : {output_int8_dequant.dtype}')
#dequant後 : torch.float32

実行結果

Quantization Aware Training

Static Quantizationと同様に、モデルにquantとdequantを挿入しています。今回のモデルではConv層とReLUの間にバッチノーマライゼーションが追加されているため、それら3つを合体させています。また、QATでは学習を行う前にtorch.quantization.prepare_qatをする必要があります。

今回は実際に学習部分は動かしていませんが、training_loop部分で、量子化による誤差を考慮しながらモデルの学習を行うことができます。そのため、学習後に量子化を行うPTQよりも精度の良い量子化が可能となります。

import torch

class M(torch.nn.Module):
    def __init__(self):
        super(M, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.bn = torch.nn.BatchNorm2d(1)
        self.relu = torch.nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dequant(x)
        return x

model_fp32 = M()
model_fp32.train()

model_fp32.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')

model_fp32.eval() # サンプルコードにはないが、fuseする際に推論モードにしないとエラーとなる
model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [['conv', 'bn', 'relu']])
model_fp32_fused.train() # prepare_qatする際に学習モードでないとエラーとなる

model_fp32_prepared = torch.quantization.prepare_qat(model_fp32_fused)


# ここで学習を行う
# training_loop(model_fp32_prepared)


model_fp32_prepared.eval()
model_int8 = torch.quantization.convert(model_fp32_prepared)

最後に

今回はPyTorchでサポートされている3種類の量子化方法についてドキュメントのコードをベースに実際に動かして、量子化の手順を確認しました。

3つの方法の中で最も簡単に量子化を行うことができるのはDynamic Quantizationですが、精度良く量子化を行うためにはQuantization Aware Trainingが良さそうです。

今後、各量子化方法をより詳細に確認してみたいと思っています。

コメント

タイトルとURLをコピーしました