Pytorch模型量化( 四 )

就目前来说,截止目前为止,只有如下的op和顺序才可以 (这个mapping关系就定义在DEFAULT_OP_LIST_TO_FUSER_METHOD中):

  • Convolution, BatchNorm
  • Convolution, BatchNorm, ReLU
  • Convolution, ReLU
  • Linear, ReLU
  • BatchNorm, ReLU
  • ConvTranspose, BatchNorm
2、设置qconfigqconfig要设置到模型或者Module上 。
#如果要部署在x86 server上model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')#如果要部署在ARM上model_fp32.qconfig = torch.quantization.get_default_qconfig('qnnpack')x86和arm之外目前不支持 。
3、prepareprepare用来给每个子module插入Observer,用来收集和定标数据 。
以activation的observer为例,观察输入数据得到 四元组中的 min_val 和 max_val,至少观察个几百个迭代的数据吧,然后由这四元组得到 scale 和 zp 这两个参数的值 。
model_fp32_prepared= torch.quantization.prepare(model_fp32_fused)4、喂数据这一步不是训练 。是为了获取数据的分布特点,来更好的计算activation的 scale 和 zp。至少要喂上几百个迭代的数据 。
#至少观察个几百迭代for data in data_loader:    model_fp32_prepared(data)5、转换模型第四步完成后,各个op权重的四元组 (min_val,max_val,qmin, qmax) 中的 min_val , max_val 已经有了,各个op activation的四元组 (min_val,max_val,qmin, qmax) 中的 min_val , max_val 也已经观察出来了 。那么在这一步我们将调用convert API:
model_prepared_int8 = torch.quantization.convert(model_fp32_prepared)我们来吃一个完整的例子:
# -*- coding:utf-8 -*-# Author:凌逆战 | Never# Date: 2022/10/17"""权重和激活都会被量化"""import torchfrom torch import nn# 定义一个浮点模型,其中一些层可以被静态量化class F32Model(torch.nn.Module):    def __init__(self):        super(F32Model, self).__init__()        self.quant = torch.quantization.QuantStub()  # QuantStub: 转换张量从浮点到量化        self.conv = nn.Conv2d(1, 1, 1)        self.fc = nn.Linear(2, 2, bias=False)        self.relu = nn.ReLU()        self.dequant = torch.quantization.DeQuantStub()  # DeQuantStub: 将量化张量转换为浮点    def forward(self, x):        x = self.quant(x)  # 手动指定张量: 从浮点转换为量化        x = self.conv(x)        x = self.fc(x)        x = self.relu(x)        x = self.dequant(x)  # 手动指定张量: 从量化转换到浮点        return xmodel_fp32 = F32Model()model_fp32.eval()  # 模型必须设置为eval模式,静态量化逻辑才能工作# 1、如果要部署在ARM上;果要部署在x86 server上 ‘fbgemm’model_fp32.qconfig = torch.quantization.get_default_qconfig('qnnpack')# 2、在适用的情况下,将一些层进行融合,可以加速# 常见的融合包括在:DEFAULT_OP_LIST_TO_FUSER_METHODmodel_fp32_fused = torch.quantization.fuse_modules(model_fp32, [['fc', 'relu']])# 3、准备模型,插入observers,观察 activation 和 weightmodel_fp32_prepared = torch.quantization.prepare(model_fp32_fused)# 4、代表性数据集,获取数据的分布特点,来更好的计算activation的 scale 和 zpinput_fp32 = torch.randn(1, 1, 2, 2)  # (batch_size, channel, W, H)model_fp32_prepared(input_fp32)# 5、量化模型model_int8 = torch.quantization.convert(model_fp32_prepared)# 运行模型,相关计算将在int8中进行output_fp32 = model_fp32(input_fp32)output_int8 = model_int8(input_fp32)print(output_fp32)# tensor([[[[0.6315, 0.0000],#           [0.2466, 0.0000]]]], grad_fn=<ReluBackward0>)print(output_int8)# tensor([[[[0.3886, 0.0000],#           [0.2475, 0.0000]]]])

经验总结扩展阅读