Pytorch模型量化( 三 )

Pytorch模型量化

文章插图
Pytorch模型量化

文章插图
DemoModel(  (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))  (relu): ReLU()  (fc): Linear(in_features=2, out_features=2, bias=True))DemoModel(  (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))  (relu): ReLU()  (fc): DynamicQuantizedLinear(in_features=2, out_features=2, dtype=torch.qint8, qscheme=torch.per_tensor_affine))tensor([[[[-0.5361,  0.0741],          [-0.2033,  0.4149]]]], grad_fn=<AddBackward0>)tensor([[[[-0.5371,  0.0713],          [-0.2040,  0.4126]]]])Post Training Static Quantization (训练后静态量化)静态量化需要把模型的权重和激活都进行量化,静态量化需要把训练集或者和训练集分布类似的数据喂给模型(注意没有反向传播),然后通过每个op输入的分布 来计算activation的量化参数(scale和zp)——称之为Calibrate(定标),因为静态量化的前向推理过程自始至终都是int计算,activation需要确保一个op的输入符合下一个op的输入 。
PyTorch会使用以下5步来完成模型的静态量化:
1、fuse_model合并一些可以合并的layer 。这一步的目的是为了提高速度和准确度:
fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None)比如给fuse_modules传递下面的参数就会合并网络中的conv1、bn1、relu1:
torch.quantization.fuse_modules(F32Model, [['fc', 'relu']], inplace=True)一旦合并成功,那么原始网络中的fc就会被替换为新的合并后的module(因为其是list中的第一个元素),而relu(list中剩余的元素)会被替换为nn.Identity(),这个模块是个占位符,直接输出输入 。举个例子,对于下面的一个小网络:
import torchfrom torch import nnclass F32Model(nn.Module):    def __init__(self):        super(F32Model, self).__init__()        self.fc = nn.Linear(3, 2,bias=False)        self.relu = nn.ReLU(inplace=False)    def forward(self, x):        x = self.fc(x)        x = self.relu(x)        return xmodel_fp32 = F32Model()print(model_fp32)# F32Model(#   (fc): Linear(in_features=3, out_features=2, bias=False)#   (relu): ReLU()# )model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [['fc', 'relu']])print(model_fp32_fused)# F32Model(#   (fc): LinearReLU(#     (0): Linear(in_features=3, out_features=2, bias=False)#     (1): ReLU()#   )#   (relu): Identity()# )modules_to_fuse参数的list可以包含多个item list,或者是submodule的op list也可以,比如:[ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']] 。有的人会说了,我要fuse的module被Sequential封装起来了,如何传参?参考下面的代码:
torch.quantization.fuse_modules(a_sequential_module, ['0', '1', '2'], inplace=True)

经验总结扩展阅读