Distilling the Knowledge in a Neural Network

这篇文章讲的是怎么提取训练好的模型的知识。考虑这样的场景,在很多应用场景中为了提升性能往往需要做集成,但是用集成的模型的话首先部署不够灵活,其次计算量会比较大;或者在深度学习里我们一个模型参数动则上百兆,要把这些模型部署到一些嵌入式设备也不太现实。这篇文章就是对复杂模型的输出【soft target】做一些调整,作为监督信息训练小模型。作者称之为”Knowledge Distillation”。

分类上我们最常用的目标函数就是softmax,在预测的时候一般就是取预测值最高的那个类别,然而在其他类别上的输出其实蕴含了更多的信息——它反映了当前样本在类别上的相似性。所以我们可以用复杂模型的概率输出——“soft targets”作为监督信息。如果soft targets的分布蕴含较多的信息【熵比较高】,那么soft targets能提供比hard targets更多的信息,梯度的variance也会减少,这样我们训练小模型的时候就可以用更少的样本,也能用更高的学习率。如果复杂模型是多个简单模型的集成,那么我们可以对多个简单模型的概率输出做算术或者几何平均,再作为soft targets。

然而在一些比较简单的任务上,如手写字体识别,我们现有的模型往往能对正确类别给予非常高的置信度,导致在其他类别上的概率输出非常得小,这些非常小的概率会对cross entropy loss有非常小的贡献。为解决这样的问题,在[1]中,作者用的是softmax的输入,即logits作为soft targets,然后最小化复杂模型与小模型的logits之间的均方误差。在这篇论文中,作者是在softmax函数里加入一个称为”temperature”的超参$T$,即
$$q_i = \frac{exp(z_i/T)}{\sum_jexp(z_j/T)}$$
调节$T$使得复杂模型产生一个合适的soft targets,然后用同样的$T$训练小模型;在测试的时候,$T$还是设为1。如果部分训练样本的标签已知,那么我们可以用多目标的联合训练:分别跟soft targets和正确标签算cross entropy loss。因为soft targets的梯度会除以$1/T^2$,所以联合训练的时候需要对soft targets的cross entropy loss乘上$1/T^2$。