基于中间层频域特征蒸馏的元学习算法

时间:2023-10-31 09:44:01 来源:网友投稿

张 灵, 郭林威

(广东工业大学 计算机学院, 广州 510000)

为解决大模型在有限资源设备上部署的难题,诸多学者开展相关研究,在保持DNNs表现不变或表现下降在可接受范围的情况下,通常采用缩小其规模的方式以实现在嵌入式系统中应用的目的.现阶段这方面的研究大致可分为4个方向:1)网络剪枝;2)网络量化;3)构建更有效的小型网络;4)知识蒸馏(DT).其中,知识蒸馏是将大型神经网络模型的信息迁移至小型神经网络模型的深度网络压缩方法[1],进而用更小模型获得更优的任务表现.根据当前知识蒸馏的发展,从知识类型上可将其划分为3种类型:1)基于模型输出的知识类型[2-4];2)基于模型中间层特征的知识类型[5-10];3)基于关系的知识类型[11-13].

目前,基于中间特征层知识的模型蒸馏方法较多是从特征图单个激活值层面上的知识迁移,并未考虑到特征图的全局特征.在学生模型训练的过程中,教师模型仅能传递固定的监督信号,因此无法根据学生模型当前的训练情况,对传递的知识做出适当调整.卷积神经网络(CNN)模型通常对图像的低频特征更为敏感性,且低频特征对视觉推理任务而言更具有信息性[14-15],因此,在CNN模型的知识蒸馏过程中,从教师模型提取更为敏感且信息性更高的低频特征作为知识传递给学生模型进行辅助训练会更为有效,故提出一种基于频域特征迁移的知识蒸馏方法,并结合元代理标签(MPL)进行模型训练,使教师模型能够在迁移中间层频域知识的同时,根据学生模型在验证集上的表现更新教师模型自身的参数,实现动态调整频域特征知识(DTM)的目的.

此外,基于频域特征的知识迁移仅局限在每个独立类别上的知识迁移,而忽视了类别之间的差异性信息.对此,本文使用Logistic模型对各个类别提取到的频域特征做线性二分类,再将分类边界正交向量上的各个元素作为学生模型对应频域特征值,使学生模型可根据输入图像的类别,有针对性地拟合该类别更具代表性的特征值,而在一定程度上忽略不重要的特征.

1.1 DCT频域特征提取

Yosinski等[16]指出:不同模型在浅层提取的特征差异较小,而模型深隐层提取的特征则具有更多的独特性.因此,对于处理同样任务的两个模型,更深层次的教师模型特征会为学生模型提供更多有用的信息.本文方法知识迁移的发起点是在模型最后一层的卷积输出层.在特征图提取频域特征的过程中,将两个模型最后卷积层输出的特征图看成特殊的二维信号矩阵,并用离散余弦变换(DCT)的变换核与信号矩阵相乘,得到代表特征图频域特征的DCT相关系数矩阵.最终通过缩小教师与学生模型特征图频率特征差异的方式完成知识迁移.该做法仅要求学生模型生成与教师模型数值分布相似的特征图即可,而不要求每个对应的激活值均完全相似.此外,由于离散余弦变换具有能量聚集的性质,故相对于其他特征图全局特征的迁移方法,DCT频域特征迁移能用更少的参数来表示更多的特征图知识,其计算表达式为

F=AfAT

(1)

(2)

(3)

式中:A为一维DCT变换的变换系数矩阵;f为模型的中间特征图;P为特征图维度大小;i、j分别为特征图维度的横、纵坐标编号,取值范围为0到P-1;c(i)为DCT变换核的补偿系数;F为离散余弦变换后得到的DCT相关系数矩阵,该矩阵中每个系数值表示频率分布与特征图数值分布的相似程度,相似程度越高,则对应的相关系数越大.

1.2 教师模型参数更新

在算法执行过程中,学生模型参数的更新表达式为

(4)

学生模型用于判别验证集的表达式为

(5)

式中,xval与yval分别为验证集输入与验证集标签.

由式(5)可知,学生模型在验证集上的判别损失对于教师模型参数可导,因此,可通过最小化学生模型在验证集上的损失值来更新教师模型参数,使教师模型能够通过学生模型在验证集上的表现修改自身参数,进而调整传递给学生模型中间隐层的频域特征知识.教师模型参数更新表达式为

(6)

式中,ηT为教师模型学习率.

1.3 线性分类器获取类间差异性信息

在提取预训练教师模型最后残差块的特征图输出频域特征后,将特征向量与对应标签用Logistic模型进行二分类,得到分类边界的正交向量WLogistic.而正交向量上的每个元素在一定程度上表示了特征向量中每个对应特征值的重要程度,可以指导学生模型的频域特征向量的拟合.

以CIFAR-10数据集分类为例,将10个类别中1个类别的频域特征向量作为正类,而其他9个类别的频域特征向量作为负类.用二元线性分类器Logistic Regression对定义好的正负类进行分类,可得到分类边界的正交向量WLogistic.将正交向量中的每个数值按比例缩放至[0,2]范围内,然后以权值的形式加权到教师模型与学生模型频域特征的误差损失中.当权值小于1时,说明该权值在对应位置上的DCT频域特征重要程度较低,此时该位置上两模型匹配的损失值应乘以一个小于1的数以缩小拟合程度;而当权值大于1时,说明该权值在对应位置上的DCT频域特征需重视,则该位置上两模型匹配的损失值要乘以一个大于1的数以提高拟合程度.这样使学生模型在向教师模型学习过程中仍能对每个输入图像的类别有针对地拟合.

学生模型最后的训练损失函数为

(7)

式中:LTotal为总的损失函数;LOutput为学生模型的分类损失;α为分类损失与频域特征损失的平衡系数,取值范围为0~1;LDCT_Loss为模型频域特征损失.

实验采用数据集CIFAR-10、CIFAR-100及ImageNet 2012对基于频域特征迁移蒸馏方法的有效性进行验证.数据集CIFAR-10由10个类的60 000个32×32彩色图像组成,每类有6 000个图像,主要包含交通工具与动物两个大类的图像,如飞机、汽车、船、猫、鹿及青蛙等;数据集CIFAR-100有100个类别,每个类别包含600个图像,包含20个大类的图像,如哺乳动物类、水生动物类、花卉类及户外场景类等;数据集ImageNet 2012包含了1 000个类别,且不同图片的像素大小各不相同.由于设备条件限制,该实验无法进行全数据集运行,故从中抽取了200个类别作为实验数据集,在CIFAR-10与CIFAR-100数据集中存在较多的动物类别,因此在ImageNet 2012所抽取的类别集中选取交通工具、家具与球类等几个大类.实验中教师模型为ResNet-56,学生模型为ResNet-34.实验设备为单个GPU(GeForce GTX 1080 Ti),实验环境为python3.6和pytorch1.7.1.

此外,本文还进行了将类间差异性引入知识蒸馏的验证实验.通过使用各个类别分类的正交向量WLogistic相互计算距离,用来比较每个分类决策边界之间的相似度.在DCT变换能量聚集性质的实验中,进行了特征图匹配维度的研究.在实验中两个模型每个通道的DCT相关系数矩阵大小为8×8,本文除了进行8×8尺寸的相关系数矩阵迁移外,还进行了4×4、2×2以及1×1尺寸的相关系数矩阵知识迁移实验.

2.1 DCT相关系数矩阵尺寸的选择

由于DCT变换对图像特征有能量聚集的功能,因此进行了相关实验以进一步探究DCT特征图迁移尺寸的问题,以便能用更少的数值代表更多的特征知识.DCT相关系数矩阵迁移维度分别为8×8、4×4、2×2及1×1时学生模型的准确率如图1所示.

图1 不同尺寸系数矩阵的准确率

由图1a可看出,在CIFAR-10数据集中,当DCT相关系数矩阵取2×2时,学生模型的准确率达到最优.由图1b则可以看出,在图像分类类别较多的CIFAR-100与ImageNet 2012数据集上,当DCT相关系数矩阵取1×1时,学生模型的准确率可达最优.这一结果说明DCT对特征图的频域特征提取操作在一定程度上还具备噪声过滤的作用,其可将表示图像主要信息的低频特征及表示噪声的高频特征独立地表征出来.而当仅取低频的特征作为知识进行迁移时,对模型表现的改进则更为明显.

根据上述实验规律可发现,随着数据集分类的增加,DCT对特征图的特征提取与压缩效果更加显著,且对学生模型准确率的提升效果更优.由此将上述实验得到的结论用于下文关于类间差异性信息及DT与DTM方法的实验中,将CIFAR-10数据集上原本8×8的DCT相关系数矩阵截取为左上角2×2的矩阵作为教师模型迁移的知识;而在CIFAR-100与ImageNet 2012数据集中则截取左上角1×1的矩阵作为教师模型迁移的知识.

2.2 类间差异性信息

在该项实验中,将从教师模型中获得的DCT频域特征作为Logistic分类器的输入,得到每个类别正交向量WLogistic相互计算欧氏距离,并比较每两个类别间对应的WLogistic.在CIFAR-10、CIFAR-100及ImageNet 2012数据集的实验中分别取部分具有代表性的分类样本,利用上述方法计算得到的相似性结果如表1~3所示.

表1 CIFAR-10正交向量的相似性比较

表2 CIFAR-100正交向量的相似性比较

表3 ImageNet 2012正交向量的相似性比较

由表1可看出,猫这个类别的正交向量与同为四肢动物的狗、马等类别的欧氏距离会比汽车、飞机的类别更接近.同样卡车这一类别的正交向量与汽车类别的欧氏距离会比马、狗等四肢动物更接近.而鸟这一类别与狗、汽车、马以及飞机这4类的相似度均较低,所以鸟这一类别的正交向量与其他4类正交向量的欧氏距离均较远.由此可得,每个类别的正交向量在一定程度上确实包含了与不同类别间差异性的信息.在下文实验中,DT及DTM方法得到的实验准确率均为加入类间差异性这一信息后所得到的实验结果.

2.3 对比实验

对比方法包括了基于最终输出的知识蒸馏方法KD,基于特征层的知识提取但缺少全局统计特征的方法AT、FT、EKD及结合了元学习的知识蒸馏算法MPL.在所有实验中,除了AT与EKD为多个残差块做知识迁移操作之外,其余均为基于中间特征图的知识迁移方法.对于所有的对比实验,除模型结构更换为教师模型ResNet-56与学生模型ResNet-34之外,其他实验步骤及相关参数基本沿用原文中的实验步骤和相关参数.方法KD用于平滑教师输出标签的蒸馏温度参数T设为4,向教师模型软标签学习的学习率设为0.9;AT中采用对所有通道对应激活值2次方求和的方法以获取注意力图;FT方法使用了原文中表现效果最优的通道压缩比例k=0.5,用于提取特征的自编码器为6层卷积层;EKD方法中在模型4个残差块的输出都接入了引导模块,引导模块为3层卷积层,后面再接上单层的全连接层和一层softmax层;MPL中教师模型根据学生模型反馈更新参数的学习率为0.05,用于平滑教师输出标签的蒸馏温度参数T与KD一样设为4.

DT与DTM中频域特征损失项的权重参数α设置为500.DT与其余未加入元学习训练方法的知识蒸馏算法共训练了150个epoch;每个batch大小为128;学习率初始化为0.1,后续在80~120 epoch下降为0.01,120~150 epoch处下降为0.001.MPL与DTM两个方法训练了150 000 epoch,学生模型每训练1 000 epoch,教师模型则根据学生模型在验证集的准确率更新一次参数;每个batch大小为128;学生模型学习率为0.05.

学生模型训练具体过程如图2所示,将模型最后一层卷积层的特征图输出经过DCT操作获得对应的DCT相关系数矩阵,截取左上部分低频特征后,将两个模型的DCT频域特征匹配做差进而完成知识迁移.其中,教师模型的频域特征会再经过Logistic模型分类器,来获取每个类别分类边界的正交向量,以加权的形式调整学生模型每个频域特征值需要向教师模型拟合的程度,使学生模型可根据图像类别有针对性地拟合对该类别更有代表性的特征值,忽略不重要的特征,融入类别之间的差异性信息.而并非与现有大多数知识蒸馏方法类似,将所有提取到的特征值逐一完整地进行匹配.

图2 学生模型训练过程

图3展示了实验中从模型输入到最后用于迁移的知识提取过程.图3中每个方格表示一个对应的数值,颜色越亮表示该方格对应的数值越大,且说明特征图的数值分布与该位置所表示的频率分布越相似.矩阵截取值根据上述关于DCT相关系数矩阵尺寸选择的实验结论中得出.

图3 频域特征提取过程

由图3b与3c对比可看出,相对于图像能量分散的特征图3b而言,图3c明显使图像能量更加集中于左上角的低频区域.该区域所表示的低频特征对视觉推理任务而言更具有信息性[14].相对于直接用图3b进行匹配的FitNets或其他基于空间域特征的方法而言,不仅缩小了知识表达所需的参数,并且包含了整个特征图的全局统计特征,而非仅局限于特征图单个激活值层面上的知识表达,不同方法在3个数据集的准确率如表4所示.

表4 不同数据集的平均准确率

由表4结果可看出,3个数据集上所有的知识蒸馏方法比原始的学生模型在准确率上均具有提升;相较原始知识蒸馏算法KD,AT、FT、EKD对学生模型精确度有一定的提升.而在未加入元学习训练方法的知识蒸馏算法中,基于频域知识迁移的DT准确率会高于前者对比方法,而低于MPL.将文中方法结合元学习训练方法后,DTM准确率在CIFAR-100数据集上比缺少了全局统计特征知识的MPL平均提高了约0.12%,在CIFAR-10以及ImageNet 2012数据集上平均提高了0.16%.

本文从教师模型中提取了频率域信息作为一种新的知识传递给学生模型,并用线性分类器Logistic模型对一个类别与其他全部类别进行分类,使其在知识迁移过程中兼顾了类别之间的差异性信息.最终结合MPL模型训练方法,使教师模型在对学生模型进行知识迁移时,可根据学生模型在验证集上的表现来修改教师模型自身的参数,以达到动态调整学生模型特征层知识的目的.CIFAR-10、CIFAR-100与ImageNet 2012图像分类数据集的实验结果也验证了该方法的有效性.

猜你喜欢频域类别准确率大型起重船在规则波中的频域响应分析舰船科学技术(2022年22期)2022-12-13乳腺超声检查诊断乳腺肿瘤的特异度及准确率分析健康之家(2021年19期)2021-05-23不同序列磁共振成像诊断脊柱损伤的临床准确率比较探讨医学食疗与健康(2021年27期)2021-05-132015—2017 年宁夏各天气预报参考产品质量检验分析农业科技与信息(2021年2期)2021-03-27高速公路车牌识别标识站准确率验证法中国交通信息化(2018年5期)2018-08-21频域稀疏毫米波人体安检成像处理和快速成像稀疏阵列设计雷达学报(2018年3期)2018-07-18基于改进Radon-Wigner变换的目标和拖曳式诱饵频域分离火控雷达技术(2016年1期)2016-02-06服务类别新校长(2016年8期)2016-01-10基于频域伸缩的改进DFT算法电测与仪表(2015年3期)2015-04-09论类别股东会商事法论集(2014年1期)2014-06-27

推荐访问:中间层 蒸馏 算法

版权所有:天海范文网 2010-2024 未经授权禁止复制或建立镜像[天海范文网]所有资源完全免费共享

Powered by 天海范文网 © All Rights Reserved.。鲁ICP备10209932号