張 靈, 郭林威
(廣東工業(yè)大學(xué) 計(jì)算機(jī)學(xué)院, 廣州 510000)
為解決大模型在有限資源設(shè)備上部署的難題,諸多學(xué)者開(kāi)展相關(guān)研究,在保持DNNs表現(xiàn)不變或表現(xiàn)下降在可接受范圍的情況下,通常采用縮小其規(guī)模的方式以實(shí)現(xiàn)在嵌入式系統(tǒng)中應(yīng)用的目的.現(xiàn)階段這方面的研究大致可分為4個(gè)方向:1)網(wǎng)絡(luò)剪枝;2)網(wǎng)絡(luò)量化;3)構(gòu)建更有效的小型網(wǎng)絡(luò);4)知識(shí)蒸餾(DT).其中,知識(shí)蒸餾是將大型神經(jīng)網(wǎng)絡(luò)模型的信息遷移至小型神經(jīng)網(wǎng)絡(luò)模型的深度網(wǎng)絡(luò)壓縮方法[1],進(jìn)而用更小模型獲得更優(yōu)的任務(wù)表現(xiàn).根據(jù)當(dāng)前知識(shí)蒸餾的發(fā)展,從知識(shí)類型上可將其劃分為3種類型:1)基于模型輸出的知識(shí)類型[2-4];2)基于模型中間層特征的知識(shí)類型[5-10];3)基于關(guān)系的知識(shí)類型[11-13].
目前,基于中間特征層知識(shí)的模型蒸餾方法較多是從特征圖單個(gè)激活值層面上的知識(shí)遷移,并未考慮到特征圖的全局特征.在學(xué)生模型訓(xùn)練的過(guò)程中,教師模型僅能傳遞固定的監(jiān)督信號(hào),因此無(wú)法根據(jù)學(xué)生模型當(dāng)前的訓(xùn)練情況,對(duì)傳遞的知識(shí)做出適當(dāng)調(diào)整.卷積神經(jīng)網(wǎng)絡(luò)(CNN)模型通常對(duì)圖像的低頻特征更為敏感性,且低頻特征對(duì)視覺(jué)推理任務(wù)而言更具有信息性[14-15],因此,在CNN模型的知識(shí)蒸餾過(guò)程中,從教師模型提取更為敏感且信息性更高的低頻特征作為知識(shí)傳遞給學(xué)生模型進(jìn)行輔助訓(xùn)練會(huì)更為有效,故提出一種基于頻域特征遷移的知識(shí)蒸餾方法,并結(jié)合元代理標(biāo)簽(MPL)進(jìn)行模型訓(xùn)練,使教師模型能夠在遷移中間層頻域知識(shí)的同時(shí),根據(jù)學(xué)生模型在驗(yàn)證集上的表現(xiàn)更新教師模型自身的參數(shù),實(shí)現(xiàn)動(dòng)態(tài)調(diào)整頻域特征知識(shí)(DTM)的目的.
此外,基于頻域特征的知識(shí)遷移僅局限在每個(gè)獨(dú)立類別上的知識(shí)遷移,而忽視了類別之間的差異性信息.對(duì)此,本文使用Logistic模型對(duì)各個(gè)類別提取到的頻域特征做線性二分類,再將分類邊界正交向量上的各個(gè)元素作為學(xué)生模型對(duì)應(yīng)頻域特征值,使學(xué)生模型可根據(jù)輸入圖像的類別,有針對(duì)性地?cái)M合該類別更具代表性的特征值,而在一定程度上忽略不重要的特征.
Yosinski等[16]指出:不同模型在淺層提取的特征差異較小,而模型深隱層提取的特征則具有更多的獨(dú)特性.因此,對(duì)于處理同樣任務(wù)的兩個(gè)模型,更深層次的教師模型特征會(huì)為學(xué)生模型提供更多有用的信息.本文方法知識(shí)遷移的發(fā)起點(diǎn)是在模型最后一層的卷積輸出層.在特征圖提取頻域特征的過(guò)程中,將兩個(gè)模型最后卷積層輸出的特征圖看成特殊的二維信號(hào)矩陣,并用離散余弦變換(DCT)的變換核與信號(hào)矩陣相乘,得到代表特征圖頻域特征的DCT相關(guān)系數(shù)矩陣.最終通過(guò)縮小教師與學(xué)生模型特征圖頻率特征差異的方式完成知識(shí)遷移.該做法僅要求學(xué)生模型生成與教師模型數(shù)值分布相似的特征圖即可,而不要求每個(gè)對(duì)應(yīng)的激活值均完全相似.此外,由于離散余弦變換具有能量聚集的性質(zhì),故相對(duì)于其他特征圖全局特征的遷移方法,DCT頻域特征遷移能用更少的參數(shù)來(lái)表示更多的特征圖知識(shí),其計(jì)算表達(dá)式為
F=AfAT
(1)
(2)
(3)
式中:A為一維DCT變換的變換系數(shù)矩陣;f為模型的中間特征圖;P為特征圖維度大小;i、j分別為特征圖維度的橫、縱坐標(biāo)編號(hào),取值范圍為0到P-1;c(i)為DCT變換核的補(bǔ)償系數(shù);F為離散余弦變換后得到的DCT相關(guān)系數(shù)矩陣,該矩陣中每個(gè)系數(shù)值表示頻率分布與特征圖數(shù)值分布的相似程度,相似程度越高,則對(duì)應(yīng)的相關(guān)系數(shù)越大.
在算法執(zhí)行過(guò)程中,學(xué)生模型參數(shù)的更新表達(dá)式為
(4)
學(xué)生模型用于判別驗(yàn)證集的表達(dá)式為
(5)
式中,xval與yval分別為驗(yàn)證集輸入與驗(yàn)證集標(biāo)簽.
由式(5)可知,學(xué)生模型在驗(yàn)證集上的判別損失對(duì)于教師模型參數(shù)可導(dǎo),因此,可通過(guò)最小化學(xué)生模型在驗(yàn)證集上的損失值來(lái)更新教師模型參數(shù),使教師模型能夠通過(guò)學(xué)生模型在驗(yàn)證集上的表現(xiàn)修改自身參數(shù),進(jìn)而調(diào)整傳遞給學(xué)生模型中間隱層的頻域特征知識(shí).教師模型參數(shù)更新表達(dá)式為
(6)
式中,ηT為教師模型學(xué)習(xí)率.
在提取預(yù)訓(xùn)練教師模型最后殘差塊的特征圖輸出頻域特征后,將特征向量與對(duì)應(yīng)標(biāo)簽用Logistic模型進(jìn)行二分類,得到分類邊界的正交向量WLogistic.而正交向量上的每個(gè)元素在一定程度上表示了特征向量中每個(gè)對(duì)應(yīng)特征值的重要程度,可以指導(dǎo)學(xué)生模型的頻域特征向量的擬合.
以CIFAR-10數(shù)據(jù)集分類為例,將10個(gè)類別中1個(gè)類別的頻域特征向量作為正類,而其他9個(gè)類別的頻域特征向量作為負(fù)類.用二元線性分類器Logistic Regression對(duì)定義好的正負(fù)類進(jìn)行分類,可得到分類邊界的正交向量WLogistic.將正交向量中的每個(gè)數(shù)值按比例縮放至[0,2]范圍內(nèi),然后以權(quán)值的形式加權(quán)到教師模型與學(xué)生模型頻域特征的誤差損失中.當(dāng)權(quán)值小于1時(shí),說(shuō)明該權(quán)值在對(duì)應(yīng)位置上的DCT頻域特征重要程度較低,此時(shí)該位置上兩模型匹配的損失值應(yīng)乘以一個(gè)小于1的數(shù)以縮小擬合程度;而當(dāng)權(quán)值大于1時(shí),說(shuō)明該權(quán)值在對(duì)應(yīng)位置上的DCT頻域特征需重視,則該位置上兩模型匹配的損失值要乘以一個(gè)大于1的數(shù)以提高擬合程度.這樣使學(xué)生模型在向教師模型學(xué)習(xí)過(guò)程中仍能對(duì)每個(gè)輸入圖像的類別有針對(duì)地?cái)M合.
學(xué)生模型最后的訓(xùn)練損失函數(shù)為
(7)
式中:LTotal為總的損失函數(shù);LOutput為學(xué)生模型的分類損失;α為分類損失與頻域特征損失的平衡系數(shù),取值范圍為0~1;LDCT_Loss為模型頻域特征損失.
實(shí)驗(yàn)采用數(shù)據(jù)集CIFAR-10、CIFAR-100及ImageNet 2012對(duì)基于頻域特征遷移蒸餾方法的有效性進(jìn)行驗(yàn)證.數(shù)據(jù)集CIFAR-10由10個(gè)類的60 000個(gè)32×32彩色圖像組成,每類有6 000個(gè)圖像,主要包含交通工具與動(dòng)物兩個(gè)大類的圖像,如飛機(jī)、汽車、船、貓、鹿及青蛙等;數(shù)據(jù)集CIFAR-100有100個(gè)類別,每個(gè)類別包含600個(gè)圖像,包含20個(gè)大類的圖像,如哺乳動(dòng)物類、水生動(dòng)物類、花卉類及戶外場(chǎng)景類等;數(shù)據(jù)集ImageNet 2012包含了1 000個(gè)類別,且不同圖片的像素大小各不相同.由于設(shè)備條件限制,該實(shí)驗(yàn)無(wú)法進(jìn)行全數(shù)據(jù)集運(yùn)行,故從中抽取了200個(gè)類別作為實(shí)驗(yàn)數(shù)據(jù)集,在CIFAR-10與CIFAR-100數(shù)據(jù)集中存在較多的動(dòng)物類別,因此在ImageNet 2012所抽取的類別集中選取交通工具、家具與球類等幾個(gè)大類.實(shí)驗(yàn)中教師模型為ResNet-56,學(xué)生模型為ResNet-34.實(shí)驗(yàn)設(shè)備為單個(gè)GPU(GeForce GTX 1080 Ti),實(shí)驗(yàn)環(huán)境為python3.6和pytorch1.7.1.
此外,本文還進(jìn)行了將類間差異性引入知識(shí)蒸餾的驗(yàn)證實(shí)驗(yàn).通過(guò)使用各個(gè)類別分類的正交向量WLogistic相互計(jì)算距離,用來(lái)比較每個(gè)分類決策邊界之間的相似度.在DCT變換能量聚集性質(zhì)的實(shí)驗(yàn)中,進(jìn)行了特征圖匹配維度的研究.在實(shí)驗(yàn)中兩個(gè)模型每個(gè)通道的DCT相關(guān)系數(shù)矩陣大小為8×8,本文除了進(jìn)行8×8尺寸的相關(guān)系數(shù)矩陣遷移外,還進(jìn)行了4×4、2×2以及1×1尺寸的相關(guān)系數(shù)矩陣知識(shí)遷移實(shí)驗(yàn).
由于DCT變換對(duì)圖像特征有能量聚集的功能,因此進(jìn)行了相關(guān)實(shí)驗(yàn)以進(jìn)一步探究DCT特征圖遷移尺寸的問(wèn)題,以便能用更少的數(shù)值代表更多的特征知識(shí).DCT相關(guān)系數(shù)矩陣遷移維度分別為8×8、4×4、2×2及1×1時(shí)學(xué)生模型的準(zhǔn)確率如圖1所示.
圖1 不同尺寸系數(shù)矩陣的準(zhǔn)確率
由圖1a可看出,在CIFAR-10數(shù)據(jù)集中,當(dāng)DCT相關(guān)系數(shù)矩陣取2×2時(shí),學(xué)生模型的準(zhǔn)確率達(dá)到最優(yōu).由圖1b則可以看出,在圖像分類類別較多的CIFAR-100與ImageNet 2012數(shù)據(jù)集上,當(dāng)DCT相關(guān)系數(shù)矩陣取1×1時(shí),學(xué)生模型的準(zhǔn)確率可達(dá)最優(yōu).這一結(jié)果說(shuō)明DCT對(duì)特征圖的頻域特征提取操作在一定程度上還具備噪聲過(guò)濾的作用,其可將表示圖像主要信息的低頻特征及表示噪聲的高頻特征獨(dú)立地表征出來(lái).而當(dāng)僅取低頻的特征作為知識(shí)進(jìn)行遷移時(shí),對(duì)模型表現(xiàn)的改進(jìn)則更為明顯.
根據(jù)上述實(shí)驗(yàn)規(guī)律可發(fā)現(xiàn),隨著數(shù)據(jù)集分類的增加,DCT對(duì)特征圖的特征提取與壓縮效果更加顯著,且對(duì)學(xué)生模型準(zhǔn)確率的提升效果更優(yōu).由此將上述實(shí)驗(yàn)得到的結(jié)論用于下文關(guān)于類間差異性信息及DT與DTM方法的實(shí)驗(yàn)中,將CIFAR-10數(shù)據(jù)集上原本8×8的DCT相關(guān)系數(shù)矩陣截取為左上角2×2的矩陣作為教師模型遷移的知識(shí);而在CIFAR-100與ImageNet 2012數(shù)據(jù)集中則截取左上角1×1的矩陣作為教師模型遷移的知識(shí).
在該項(xiàng)實(shí)驗(yàn)中,將從教師模型中獲得的DCT頻域特征作為L(zhǎng)ogistic分類器的輸入,得到每個(gè)類別正交向量WLogistic相互計(jì)算歐氏距離,并比較每?jī)蓚€(gè)類別間對(duì)應(yīng)的WLogistic.在CIFAR-10、CIFAR-100及ImageNet 2012數(shù)據(jù)集的實(shí)驗(yàn)中分別取部分具有代表性的分類樣本,利用上述方法計(jì)算得到的相似性結(jié)果如表1~3所示.
表1 CIFAR-10正交向量的相似性比較
表2 CIFAR-100正交向量的相似性比較
表3 ImageNet 2012正交向量的相似性比較
由表1可看出,貓這個(gè)類別的正交向量與同為四肢動(dòng)物的狗、馬等類別的歐氏距離會(huì)比汽車、飛機(jī)的類別更接近.同樣卡車這一類別的正交向量與汽車類別的歐氏距離會(huì)比馬、狗等四肢動(dòng)物更接近.而鳥(niǎo)這一類別與狗、汽車、馬以及飛機(jī)這4類的相似度均較低,所以鳥(niǎo)這一類別的正交向量與其他4類正交向量的歐氏距離均較遠(yuǎn).由此可得,每個(gè)類別的正交向量在一定程度上確實(shí)包含了與不同類別間差異性的信息.在下文實(shí)驗(yàn)中,DT及DTM方法得到的實(shí)驗(yàn)準(zhǔn)確率均為加入類間差異性這一信息后所得到的實(shí)驗(yàn)結(jié)果.
對(duì)比方法包括了基于最終輸出的知識(shí)蒸餾方法KD,基于特征層的知識(shí)提取但缺少全局統(tǒng)計(jì)特征的方法AT、FT、EKD及結(jié)合了元學(xué)習(xí)的知識(shí)蒸餾算法MPL.在所有實(shí)驗(yàn)中,除了AT與EKD為多個(gè)殘差塊做知識(shí)遷移操作之外,其余均為基于中間特征圖的知識(shí)遷移方法.對(duì)于所有的對(duì)比實(shí)驗(yàn),除模型結(jié)構(gòu)更換為教師模型ResNet-56與學(xué)生模型ResNet-34之外,其他實(shí)驗(yàn)步驟及相關(guān)參數(shù)基本沿用原文中的實(shí)驗(yàn)步驟和相關(guān)參數(shù).方法KD用于平滑教師輸出標(biāo)簽的蒸餾溫度參數(shù)T設(shè)為4,向教師模型軟標(biāo)簽學(xué)習(xí)的學(xué)習(xí)率設(shè)為0.9;AT中采用對(duì)所有通道對(duì)應(yīng)激活值2次方求和的方法以獲取注意力圖;FT方法使用了原文中表現(xiàn)效果最優(yōu)的通道壓縮比例k=0.5,用于提取特征的自編碼器為6層卷積層;EKD方法中在模型4個(gè)殘差塊的輸出都接入了引導(dǎo)模塊,引導(dǎo)模塊為3層卷積層,后面再接上單層的全連接層和一層softmax層;MPL中教師模型根據(jù)學(xué)生模型反饋更新參數(shù)的學(xué)習(xí)率為0.05,用于平滑教師輸出標(biāo)簽的蒸餾溫度參數(shù)T與KD一樣設(shè)為4.
DT與DTM中頻域特征損失項(xiàng)的權(quán)重參數(shù)α設(shè)置為500.DT與其余未加入元學(xué)習(xí)訓(xùn)練方法的知識(shí)蒸餾算法共訓(xùn)練了150個(gè)epoch;每個(gè)batch大小為128;學(xué)習(xí)率初始化為0.1,后續(xù)在80~120 epoch下降為0.01,120~150 epoch處下降為0.001.MPL與DTM兩個(gè)方法訓(xùn)練了150 000 epoch,學(xué)生模型每訓(xùn)練1 000 epoch,教師模型則根據(jù)學(xué)生模型在驗(yàn)證集的準(zhǔn)確率更新一次參數(shù);每個(gè)batch大小為128;學(xué)生模型學(xué)習(xí)率為0.05.
學(xué)生模型訓(xùn)練具體過(guò)程如圖2所示,將模型最后一層卷積層的特征圖輸出經(jīng)過(guò)DCT操作獲得對(duì)應(yīng)的DCT相關(guān)系數(shù)矩陣,截取左上部分低頻特征后,將兩個(gè)模型的DCT頻域特征匹配做差進(jìn)而完成知識(shí)遷移.其中,教師模型的頻域特征會(huì)再經(jīng)過(guò)Logistic模型分類器,來(lái)獲取每個(gè)類別分類邊界的正交向量,以加權(quán)的形式調(diào)整學(xué)生模型每個(gè)頻域特征值需要向教師模型擬合的程度,使學(xué)生模型可根據(jù)圖像類別有針對(duì)性地?cái)M合對(duì)該類別更有代表性的特征值,忽略不重要的特征,融入類別之間的差異性信息.而并非與現(xiàn)有大多數(shù)知識(shí)蒸餾方法類似,將所有提取到的特征值逐一完整地進(jìn)行匹配.
圖2 學(xué)生模型訓(xùn)練過(guò)程
圖3展示了實(shí)驗(yàn)中從模型輸入到最后用于遷移的知識(shí)提取過(guò)程.圖3中每個(gè)方格表示一個(gè)對(duì)應(yīng)的數(shù)值,顏色越亮表示該方格對(duì)應(yīng)的數(shù)值越大,且說(shuō)明特征圖的數(shù)值分布與該位置所表示的頻率分布越相似.矩陣截取值根據(jù)上述關(guān)于DCT相關(guān)系數(shù)矩陣尺寸選擇的實(shí)驗(yàn)結(jié)論中得出.
圖3 頻域特征提取過(guò)程
由圖3b與3c對(duì)比可看出,相對(duì)于圖像能量分散的特征圖3b而言,圖3c明顯使圖像能量更加集中于左上角的低頻區(qū)域.該區(qū)域所表示的低頻特征對(duì)視覺(jué)推理任務(wù)而言更具有信息性[14].相對(duì)于直接用圖3b進(jìn)行匹配的FitNets或其他基于空間域特征的方法而言,不僅縮小了知識(shí)表達(dá)所需的參數(shù),并且包含了整個(gè)特征圖的全局統(tǒng)計(jì)特征,而非僅局限于特征圖單個(gè)激活值層面上的知識(shí)表達(dá),不同方法在3個(gè)數(shù)據(jù)集的準(zhǔn)確率如表4所示.
表4 不同數(shù)據(jù)集的平均準(zhǔn)確率
由表4結(jié)果可看出,3個(gè)數(shù)據(jù)集上所有的知識(shí)蒸餾方法比原始的學(xué)生模型在準(zhǔn)確率上均具有提升;相較原始知識(shí)蒸餾算法KD,AT、FT、EKD對(duì)學(xué)生模型精確度有一定的提升.而在未加入元學(xué)習(xí)訓(xùn)練方法的知識(shí)蒸餾算法中,基于頻域知識(shí)遷移的DT準(zhǔn)確率會(huì)高于前者對(duì)比方法,而低于MPL.將文中方法結(jié)合元學(xué)習(xí)訓(xùn)練方法后,DTM準(zhǔn)確率在CIFAR-100數(shù)據(jù)集上比缺少了全局統(tǒng)計(jì)特征知識(shí)的MPL平均提高了約0.12%,在CIFAR-10以及ImageNet 2012數(shù)據(jù)集上平均提高了0.16%.
本文從教師模型中提取了頻率域信息作為一種新的知識(shí)傳遞給學(xué)生模型,并用線性分類器Logistic模型對(duì)一個(gè)類別與其他全部類別進(jìn)行分類,使其在知識(shí)遷移過(guò)程中兼顧了類別之間的差異性信息.最終結(jié)合MPL模型訓(xùn)練方法,使教師模型在對(duì)學(xué)生模型進(jìn)行知識(shí)遷移時(shí),可根據(jù)學(xué)生模型在驗(yàn)證集上的表現(xiàn)來(lái)修改教師模型自身的參數(shù),以達(dá)到動(dòng)態(tài)調(diào)整學(xué)生模型特征層知識(shí)的目的.CIFAR-10、CIFAR-100與ImageNet 2012圖像分類數(shù)據(jù)集的實(shí)驗(yàn)結(jié)果也驗(yàn)證了該方法的有效性.