A-卷积网络压缩方法总结( 五 )

student 网络的输出结果 。
但是,直接使用 teacher 网络的 softmax 的输出结果 \(q\),可能不大合适 。因此,一个网络训练好之后,对于正确的答案会有一个很高的置信度 。例如,在 MNIST 数据中,对于某个 2 的输入,对于 2 的预测概率会很高,而对于 2 类似的数字,例如 3 和 7 的预测概率为 \(10^-6\) 和 \(10^-9\) 。这样的话,teacher 网络学到数据的相似信息(例如数字 2 和 3,7 很类似)很难传达给 student 网络,因为它们的概率值接近0 。因此,论文提出了 softmax-T(软标签计算公式),公式如下所示:
\[q_{i} = \frac{z_{i}/T}{\sum_{j}z_{j}/T}\]这里 \(q_i\) 是 \(student\) 网络学习的对象(soft targets),\(z_i\) 是 teacher 模型 softmax 前一层的输出 logit 。如果将 \(T\) 取 1,上述公式变成 softmax,根据 logit 输出各个类别的概率 。如果 \(T\) 接近于 0,则最大的值会越近 1,其它值会接近 0,近似于 onehot 编码 。
所以,可以知道 student 模型最终的损失函数由两部分组成:

  • 第一项是由小模型的预测结果与大模型的“软标签”所构成的交叉熵(cross entroy);
  • 第二项为预测结果与普通类别标签的交叉熵 。
这两个损失函数的重要程度可通过一定的权重进行调节,在实际应用中,T 的取值会影响最终的结果,一般而言,较大的 T 能够获得较高的准确度,T(蒸馏温度参数) 属于知识蒸馏模型训练超参数的一种 。T 是一个可调节的超参数、T 值越大、概率分布越软(论文中的描述),曲线便越平滑,相当于在迁移学习的过程中添加了扰动,从而使得学生网络在借鉴学习的时候更有效、泛化能力更强,这其实就是一种抑制过拟合的策略 。知识蒸馏的整个过程如下图:
A-卷积网络压缩方法总结

文章插图
student 模型的实际模型结构和小模型一样,但是损失函数包含了两部分,分类网络的知识蒸馏 mxnet 代码示例如下::
# -*-coding-*-: utf-8"""本程序没有给出具体的模型结构代码,主要给出了知识蒸馏 softmax 损失计算部分 。"""import mxnet as mxdef get_symbol(data, class_labels, resnet_layer_num,Temperature,mimic_weight,num_classes=2):backbone = StudentBackbone(data)# Backbone 为分类网络 backbone 类flatten = mx.symbol.Flatten(data=https://www.huyubaike.com/biancheng/conv1, name="flatten")fc_class_score_s = mx.symbol.FullyConnected(data=https://www.huyubaike.com/biancheng/flatten, num_hidden=num_classes, name='fc_class_score')softmax1 = mx.symbol.SoftmaxOutput(data=https://www.huyubaike.com/biancheng/fc_class_score_s, label=class_labels, name='softmax_hard')import symbol_resnet# Teacher modelfc_class_score_t = symbol_resnet.get_symbol(net_depth=resnet_layer_num, num_class=num_classes, data=https://www.huyubaike.com/biancheng/data)s_input_for_softmax=fc_class_score_s/Temperaturet_input_for_softmax=fc_class_score_t/Temperaturet_soft_labels=mx.symbol.softmax(t_input_for_softmax, name='teacher_soft_labels')softmax2 = mx.symbol.SoftmaxOutput(data=https://www.huyubaike.com/biancheng/s_input_for_softmax, label=t_soft_labels, name='softmax_soft',grad_scale=mimic_weight)group=mx.symbol.Group([softmax1,softmax2])group.save('group2-symbol.json')return grouptensorflow代码示例如下:
# 将类别标签进行one-hot编码one_hot = tf.one_hot(y, n_classes,1.0,0.0) # n_classes为类别总数, n为类别标签# one_hot = tf.cast(one_hot_int, tf.float32)teacher_tau = tf.scalar_mul(1.0/args.tau, teacher) # teacher为teacher模型直接输出张量, tau为温度系数Tstudent_tau = tf.scalar_mul(1.0/args.tau, student) # 将模型直接输出logits张量student处于温度系数Tobjective1 = tf.nn.sigmoid_cross_entropy_with_logits(student_tau, one_hot)objective2 = tf.scalar_mul(0.5, tf.square(student_tau-teacher_tau))"""student模型最终的损失函数由两部分组成:第一项是由小模型的预测结果与大模型的“软标签”所构成的交叉熵(cross entroy);第二项为预测结果与普通类别标签的交叉熵 。"""tf_loss = (args.lamda*tf.reduce_sum(objective1) + (1-args.lamda)*tf.reduce_sum(objective2))/batch_size

经验总结扩展阅读