張仁斌,王 龍,周澤林,左藝聰,謝 昭
(1 合肥工業(yè)大學(xué) 計(jì)算機(jī)與信息學(xué)院,合肥 230601;2 合肥工業(yè)大學(xué) 大數(shù)據(jù)知識(shí)工程教育部重點(diǎn)實(shí)驗(yàn)室,合肥 230601;3 合肥工業(yè)大學(xué) 工業(yè)安全與應(yīng)急技術(shù)安徽省重點(diǎn)實(shí)驗(yàn)室,合肥 230601)
基于知識(shí)傳遞的知識(shí)蒸餾和參數(shù)遷移學(xué)習(xí)分別被廣泛使用于模型壓縮[1-2]和遷移學(xué)習(xí)[3]領(lǐng)域中,本文的目標(biāo)是基于知識(shí)傳遞實(shí)現(xiàn)網(wǎng)絡(luò)間的交流學(xué)習(xí),使網(wǎng)絡(luò)學(xué)習(xí)更加快速和充分。
盡管復(fù)雜龐大的網(wǎng)絡(luò)具有很高的性能,但是計(jì)算緩慢和網(wǎng)絡(luò)龐大不利于存儲(chǔ)的不足使其難以滿(mǎn)足在便攜設(shè)備上的應(yīng)用需求。模型壓縮是解決這個(gè)問(wèn)題的方法之一。Hinton 等 人[4]通過(guò)知識(shí)蒸餾(Knowledge Distillation,KD),首先利用大規(guī)模數(shù)據(jù)訓(xùn)練一個(gè)大模型作為教師網(wǎng)絡(luò),然后將小模型學(xué)生網(wǎng)絡(luò)向大模型學(xué)習(xí),知識(shí)從教師網(wǎng)絡(luò)傳遞到學(xué)生網(wǎng)絡(luò)上,以此得到的小網(wǎng)絡(luò)也具有大網(wǎng)絡(luò)相當(dāng)?shù)姆夯芰?,?shí)現(xiàn)模型壓縮。
在知識(shí)蒸餾的研究中,Zagoruyko 等人[5]提出將注意力作為知識(shí)從一個(gè)網(wǎng)絡(luò)轉(zhuǎn)移到另一個(gè)網(wǎng)絡(luò)中的學(xué)習(xí)方法,并且與將教師網(wǎng)絡(luò)的輸出作為學(xué)習(xí)對(duì)象的知識(shí)蒸餾方法進(jìn)行結(jié)合。Chen 等人[6]提出交叉樣本的相似性作為網(wǎng)絡(luò)間可轉(zhuǎn)移的知識(shí),并在多個(gè)圖像任務(wù)中進(jìn)行驗(yàn)證,轉(zhuǎn)移這種知識(shí)使行人識(shí)別任務(wù)相對(duì)基線(xiàn)取得明顯提升。Cho 等人[7]進(jìn)一步探索知識(shí)蒸餾的有效性,得出了教師網(wǎng)絡(luò)的效果越好并非意味著學(xué)生網(wǎng)絡(luò)效果就會(huì)越好的結(jié)論,這與Mirzadeh 等人[8]的實(shí)驗(yàn)結(jié)論相同。Heo 等人[9]將隱藏層特征作為知識(shí)進(jìn)行蒸餾,并在圖形分類(lèi)、檢測(cè)和分割三種任務(wù)上進(jìn)行實(shí)驗(yàn),驗(yàn)證了特征蒸餾的有效性。不同于分類(lèi)任務(wù),Saputra 等人[10]在回歸任務(wù)中成功應(yīng)用了知識(shí)蒸餾。Phuong 等人[11]從多個(gè)角度解釋了為什么知識(shí)蒸餾能夠成功地將知識(shí)在網(wǎng)絡(luò)間進(jìn)行轉(zhuǎn)移。近期,F(xiàn)acebook 團(tuán)隊(duì)提出的Deit[12]方法,探索了使用多種其他類(lèi)型的網(wǎng)絡(luò)來(lái)對(duì)圖像分類(lèi)網(wǎng)絡(luò)ViT[13]進(jìn)行注意力的教學(xué),達(dá)到了非常理想的效果。Deit 方法中,在訓(xùn)練時(shí)將基于Transformer[14]的ViT 作為學(xué)生網(wǎng)絡(luò),將其他類(lèi)型的網(wǎng)絡(luò),如以CNN 為基礎(chǔ)的ResNet[15]、EfficientNet[16]作為教師網(wǎng)絡(luò),借鑒知識(shí)蒸餾的方法,通過(guò)將學(xué)生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)的輸出計(jì)算損失值并進(jìn)行反向傳播,實(shí)現(xiàn)將知識(shí)從教師傳遞給學(xué)生,以此顯著提高作為學(xué)生的ViT 網(wǎng)絡(luò)的性能。實(shí)驗(yàn)結(jié)果表明,相對(duì)于需要在大量數(shù)據(jù)集上進(jìn)行預(yù)訓(xùn)練的ViT,Deit 不需要額外的數(shù)據(jù)做預(yù)訓(xùn)練,且用更少的計(jì)算資源生成更高性能的圖像分類(lèi)模型。Deit 通過(guò)將不同網(wǎng)絡(luò)的知識(shí)進(jìn)行傳遞,達(dá)到很好的學(xué)習(xí)效果。Lu 等人[17]分別在高分辨率和多分辨率模型中運(yùn)用知識(shí)蒸餾提煉知識(shí),通過(guò)交叉特征融合和多尺度訓(xùn)練等方式獲得了更優(yōu)的學(xué)生分辨率模型。Chen 等人[18]把神經(jīng)網(wǎng)絡(luò)實(shí)例的特征和節(jié)點(diǎn)的關(guān)系作為編碼知識(shí)從教師網(wǎng)絡(luò)傳遞給學(xué)生網(wǎng)絡(luò),在物體檢測(cè)的任務(wù)中取得了更好的模型效果。
把網(wǎng)絡(luò)的參數(shù)作為知識(shí)進(jìn)行轉(zhuǎn)移也有著非常經(jīng)典的應(yīng)用。Pan 等人[3]將源領(lǐng)域中模型的參數(shù)遷移到目標(biāo)領(lǐng)域的模型中的方法歸類(lèi)為參數(shù)遷移(Parameter-transfer)學(xué)習(xí)。Fan 等人[19]將少樣本檢測(cè)任務(wù)上學(xué)習(xí)到的知識(shí)遷移到檢測(cè)模型的最后一層,檢測(cè)效果相對(duì)基線(xiàn)得到了穩(wěn)定的提高。Jing 等人[20]通過(guò)知識(shí)參數(shù)遷移,把多個(gè)教師圖神經(jīng)網(wǎng)絡(luò)的知識(shí)傳遞給同一個(gè)學(xué)生圖神經(jīng)網(wǎng)絡(luò),以此得到的學(xué)生網(wǎng)絡(luò)在多個(gè)任務(wù)上取得了與教師網(wǎng)絡(luò)相當(dāng)?shù)男ЧT贛ean Teachers[21]中,教師網(wǎng)絡(luò)通過(guò)將學(xué)生網(wǎng)絡(luò)的參數(shù)進(jìn)行組合得到教師自身的網(wǎng)絡(luò)參數(shù),以此實(shí)現(xiàn)知識(shí)從學(xué)生網(wǎng)絡(luò)向教師網(wǎng)絡(luò)的傳遞。
Mean Teachers 作為半監(jiān)督的學(xué)習(xí)方法,同樣包括了教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)兩種結(jié)構(gòu)。其中,學(xué)生網(wǎng)絡(luò)的參數(shù)是通過(guò)梯度下降進(jìn)行更新,而教師網(wǎng)絡(luò)的參數(shù)則是僅僅通過(guò)組合學(xué)生網(wǎng)絡(luò)所學(xué)到的知識(shí)參數(shù)進(jìn)行更新,而不進(jìn)行梯度下降。在Mean Teachers中,知識(shí)通過(guò)網(wǎng)絡(luò)參數(shù)的形式從學(xué)生網(wǎng)絡(luò)流向教師網(wǎng)絡(luò)。進(jìn)一步地,教師網(wǎng)絡(luò)的輸出結(jié)果作為學(xué)生網(wǎng)絡(luò)的學(xué)習(xí)目標(biāo),進(jìn)行對(duì)學(xué)生網(wǎng)絡(luò)的教學(xué)。
深度互助學(xué)習(xí)[22](Deep Mutual Learning,DML)中,K個(gè)網(wǎng)絡(luò)中每一個(gè)網(wǎng)絡(luò)既有學(xué)生的身份,也有教師的身份。當(dāng)對(duì)其中某個(gè)網(wǎng)絡(luò)傳遞知識(shí)時(shí),其他所有K -1 個(gè)網(wǎng)絡(luò)都作為教師。在每一輪互助中,每個(gè)網(wǎng)絡(luò)都會(huì)接收到其他K -1 個(gè)網(wǎng)絡(luò)傳遞的知識(shí)。在知識(shí)蒸餾的方法中,小的學(xué)生模型通過(guò)將大的教師模型輸出作為學(xué)習(xí)的軟目標(biāo)計(jì)算交叉熵進(jìn)行梯度下降,進(jìn)而完成知識(shí)從大模型向小模型的傳遞。DML 不以模型壓縮為目的,而是通過(guò)將學(xué)生網(wǎng)絡(luò)與其他K -1 個(gè)教師網(wǎng)絡(luò)的輸出結(jié)果的KL散度(Kullback-Leibler divergence,KL)取均值,并作為損失的一部分進(jìn)行反向傳播,依托多個(gè)網(wǎng)絡(luò)輸出結(jié)果的互相借鑒,以此達(dá)到更高的魯棒性,實(shí)現(xiàn)共同進(jìn)步。
利用知識(shí)蒸餾可以加快小模型的訓(xùn)練速度和效果,但是具有一定的局限性。比如蒸餾的前提是擁有一個(gè)性能足夠好的教師網(wǎng)絡(luò),且蒸餾的主要目的在于更好地訓(xùn)練出一個(gè)小模型,并不能夠提升教師網(wǎng)絡(luò)自身的性能。DML 中每個(gè)網(wǎng)絡(luò)都會(huì)利用其他網(wǎng)絡(luò)的知識(shí)來(lái)提高自己,但是DML 中實(shí)現(xiàn)互助的方式是利用網(wǎng)絡(luò)之間的差異度作為損失值進(jìn)行梯度下降,模型性能受梯度下降方法局限性的影響,如梯度消失和梯度爆炸等導(dǎo)致互助失敗。
針對(duì)以上問(wèn)題,本文提出一種基于深度交流學(xué)習(xí)(Deep Communication Learning,DCL)的網(wǎng)絡(luò)訓(xùn)練模式。在DCL 中,多個(gè)神經(jīng)網(wǎng)絡(luò)在各自獨(dú)立學(xué)習(xí)的同時(shí)將網(wǎng)絡(luò)參數(shù)作為知識(shí)進(jìn)行交流,單個(gè)神經(jīng)網(wǎng)絡(luò)在訓(xùn)練中將自身所學(xué)到的知識(shí)分享給其他網(wǎng)絡(luò),同時(shí)從其他網(wǎng)絡(luò)上吸納一定比例的學(xué)習(xí)成果,獨(dú)自學(xué)習(xí)和在集體中的知識(shí)交流是交替進(jìn)行的。
DCL 和Mean Teachers 都將網(wǎng)絡(luò)所學(xué)到的參數(shù)作為知識(shí),并將這些知識(shí)進(jìn)行傳遞。不同的是,Mean Teachers 中教師網(wǎng)絡(luò)的目的在于對(duì)無(wú)標(biāo)簽數(shù)據(jù)進(jìn)行標(biāo)記,且最終學(xué)生網(wǎng)絡(luò)向教師網(wǎng)絡(luò)的學(xué)習(xí)方式同樣類(lèi)似于知識(shí)蒸餾,是通過(guò)計(jì)算學(xué)生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)輸出結(jié)果之間的差異度進(jìn)行反向傳播實(shí)現(xiàn)的。Mean Teachers 中教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)的主體是固定的,而DCL 中每個(gè)網(wǎng)絡(luò)既會(huì)作為知識(shí)的傳授方,也會(huì)作為知識(shí)的接收方,這些網(wǎng)絡(luò)的身份是等同的,并且DCL 各個(gè)網(wǎng)絡(luò)間互相學(xué)習(xí)的策略與Deit 和DML 完全不同。Deit 和DML 借鑒知識(shí)蒸餾,以教師模型的輸出結(jié)果為目標(biāo),讓學(xué)生向教師模仿和學(xué)習(xí),而DCL 則是將各個(gè)網(wǎng)絡(luò)所學(xué)到的網(wǎng)絡(luò)參數(shù)作為知識(shí)進(jìn)行吸納和融合,交流的過(guò)程不使用梯度下降,而是對(duì)所學(xué)知識(shí)的直接交流。
本文利用經(jīng)典、成熟的圖像分類(lèi)神經(jīng)網(wǎng)絡(luò)來(lái)驗(yàn)證所提出的學(xué)習(xí)模式,使用Inception[23],ResNet,WRN[24],DenseNet[25],MobileNet[26],ResNeXt[27]和EfficientNet 等7 種經(jīng)典網(wǎng)絡(luò)在Fashion-MNIST[28],CIFAR-10 和CIFAR-100[29]等多個(gè)數(shù)據(jù)集上進(jìn)行實(shí)驗(yàn)。結(jié)果表明,利用DCL,這些網(wǎng)絡(luò)獲得了學(xué)習(xí)效果最高3.44%的提升。
論文內(nèi)容安排如下:本文第1 節(jié)提出了一種基于知識(shí)交流的深度神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)方式-DCL,并對(duì)該方法進(jìn)行了詳細(xì)說(shuō)明;第2 節(jié)通過(guò)使用多種網(wǎng)絡(luò)和數(shù)據(jù)集對(duì)DCL 進(jìn)行了實(shí)驗(yàn),驗(yàn)證了DCL 學(xué)習(xí)模式的有效性;第3 節(jié)對(duì)全文進(jìn)行總結(jié)并展望未來(lái)工作。本文將代碼和模型進(jìn)行了開(kāi)源[30]。
深度交流學(xué)習(xí)模式如圖1 所示。深度交流學(xué)習(xí)是對(duì)人類(lèi)社會(huì)學(xué)習(xí)進(jìn)步的一個(gè)仿照。正如人類(lèi)在個(gè)體單獨(dú)學(xué)習(xí)后進(jìn)入集體進(jìn)行知識(shí)的交流,并經(jīng)獨(dú)自的學(xué)習(xí)把從集體獲得的知識(shí)進(jìn)行消化和吸收,利用集體的知識(shí)提高自己,同時(shí)在獨(dú)自學(xué)習(xí)中探索和獲取新的知識(shí),再于此后的交流中對(duì)其進(jìn)行分享。
圖1 深度交流學(xué)習(xí)模式Fig. 1 The process of Deep Communication Learning
網(wǎng)絡(luò)學(xué)習(xí)到的知識(shí)存在于網(wǎng)絡(luò)的參數(shù)之中,讓深度神經(jīng)網(wǎng)絡(luò)在學(xué)習(xí)的同時(shí)進(jìn)行知識(shí)的交流是深度交流學(xué)習(xí)的核心。DCL 的具體策略是,網(wǎng)絡(luò)在獨(dú)自學(xué)習(xí)一定的迭代輪次T后,各個(gè)網(wǎng)絡(luò)把自己學(xué)習(xí)到的知識(shí)貢獻(xiàn)到集體中,并以一定的比例βi收納來(lái)自集體的知識(shí)。隨后各個(gè)網(wǎng)絡(luò)再獨(dú)自學(xué)習(xí)一段時(shí)間T,以適應(yīng)和吸收集體的知識(shí)經(jīng)驗(yàn),再獨(dú)自探索新知識(shí)用于下次和其他網(wǎng)絡(luò)的交流。DCL 用這樣的方式讓所有網(wǎng)絡(luò)不斷地在互相交流中進(jìn)行學(xué)習(xí)和進(jìn)步。
即使是同一類(lèi)型的網(wǎng)絡(luò),不同的初始化參數(shù)也會(huì)使網(wǎng)絡(luò)變得互不相同。雖然數(shù)據(jù)集和網(wǎng)絡(luò)結(jié)構(gòu)都一樣,但是額外的知識(shí)存在于不同的初始化參數(shù)之中。針對(duì)于Independent 學(xué)習(xí)中網(wǎng)絡(luò)知識(shí)量有限的問(wèn)題,DCL 模式中采用知識(shí)交流的方法,支持每個(gè)網(wǎng)絡(luò)擁有額外的知識(shí)量。
設(shè)DCL 中的網(wǎng)絡(luò)初始化數(shù)量為K,此K個(gè)網(wǎng)絡(luò)表示為:θ1,θ2,……,θK。設(shè)具有N個(gè)樣本且分為M類(lèi)的數(shù)據(jù)集為:
初始化DCL 后,每個(gè)網(wǎng)絡(luò)進(jìn)行隨機(jī)采樣和反向傳播學(xué)習(xí)。在經(jīng)過(guò)T次獨(dú)自學(xué)習(xí)的迭代后,所有網(wǎng)絡(luò)進(jìn)行一次知識(shí)交流。交流中,每個(gè)網(wǎng)絡(luò)首先將自己所學(xué)的參數(shù)知識(shí)以αi的比例貢獻(xiàn)到集體中,并存儲(chǔ)于θwavg中。對(duì)此過(guò)程可用式(3)進(jìn)行描述:
其中,對(duì)于每個(gè)網(wǎng)絡(luò)所貢獻(xiàn)的比例,具有以下約束:
然后,每個(gè)網(wǎng)絡(luò)從集體的知識(shí)中吸納比例為β的知識(shí)量,實(shí)現(xiàn)總體網(wǎng)絡(luò)的知識(shí)向每個(gè)網(wǎng)絡(luò)的傳遞:
對(duì)于個(gè)體向集體貢獻(xiàn)的知識(shí)量比例αi數(shù)值的確定,本文借鑒正則化(regularization)的思想,即如果網(wǎng)絡(luò)參數(shù)的絕對(duì)值|θi |越小,則讓其對(duì)集體貢獻(xiàn)更大比例的知識(shí),從而讓這些網(wǎng)絡(luò)在表現(xiàn)效果相當(dāng)?shù)那闆r下,更多地向參數(shù)小的網(wǎng)絡(luò)學(xué)習(xí),以此增加自身的魯棒性。本文采用的策略是令貢獻(xiàn)的比例與自身參數(shù)的絕對(duì)值大小成反比,即:
根據(jù)式(3)~(5),可以得出每次交流中單個(gè)網(wǎng)絡(luò)的參數(shù)對(duì)總體的貢獻(xiàn)比例為:
其中,ε為一個(gè)極小數(shù),用來(lái)避免當(dāng)網(wǎng)絡(luò)某層的參數(shù)全為0 時(shí)出現(xiàn)分母非法的情況,在實(shí)際使用中,本文對(duì)ε的取值為1×10-18。
算法1 描述了DCL 的具體流程。
算法1Deep Communication Learning
代碼中,E為訓(xùn)練的總迭代次數(shù),E和學(xué)習(xí)率衰減策略的具體設(shè)置見(jiàn)本文的2.3 節(jié)。
在本算法中,T的設(shè)置是關(guān)鍵之一,因?yàn)閬?lái)自于其他網(wǎng)絡(luò)的知識(shí)參數(shù),未必會(huì)在吸納后立即就能很好地適應(yīng)自身的網(wǎng)絡(luò)參數(shù)。因此,如果讓網(wǎng)絡(luò)一直交流而不給予足夠的獨(dú)自學(xué)習(xí)和適應(yīng)時(shí)間,很容易會(huì)出現(xiàn)這些網(wǎng)絡(luò)由于無(wú)法適應(yīng)其他網(wǎng)絡(luò)的知識(shí)參數(shù),而一直處于欠擬合狀態(tài)。
單個(gè)網(wǎng)絡(luò)每次從集體中吸納的知識(shí)比例β是一個(gè)超參數(shù)。如果β的取值過(guò)小,會(huì)使網(wǎng)絡(luò)向集體學(xué)習(xí)的知識(shí)量很少,這種情況下一方面會(huì)相對(duì)保持網(wǎng)絡(luò)的獨(dú)特性,即網(wǎng)絡(luò)之間不會(huì)非常相像,另一方面會(huì)降低交流學(xué)習(xí)給網(wǎng)絡(luò)所帶來(lái)的收益。在β取極值為0 時(shí),網(wǎng)絡(luò)之間停止交流,個(gè)體不再向集體學(xué)習(xí)。同樣,如果β的取值過(guò)大,則會(huì)使網(wǎng)絡(luò)之間隨著交流次數(shù)增多而變得更加相像,在一定程度上喪失自身的獨(dú)特性。因此將β取一個(gè)適當(dāng)大小的值是非常重要的,本文第2 節(jié)實(shí)驗(yàn)中將0.1 作為β的取值。
本文使用多種網(wǎng)絡(luò)在多個(gè)數(shù)據(jù)集上進(jìn)行實(shí)驗(yàn),所有源代碼、模型和實(shí)驗(yàn)結(jié)果均已開(kāi)源[30]。
本文使用3 個(gè)數(shù)據(jù)集進(jìn)行實(shí)驗(yàn)。CIFAR-10 和CIFAR-100 數(shù)據(jù)集由大小為32×32 的RGB 圖像組成,分別包含10 個(gè)和100 個(gè)類(lèi)別的物體。兩者都被劃分為50 000 張圖像作為訓(xùn)練集和10 000 張圖像作為測(cè)試集。Fashion-MNIST 是一個(gè)包含10 種服飾類(lèi)別的圖像數(shù)據(jù)集,圖像大小為28×28,并且以60 000張圖片作為訓(xùn)練集,10 000 張圖片作為測(cè)試集。本文將圖像分類(lèi)的正確率作為這3 個(gè)數(shù)據(jù)集的評(píng)價(jià)指標(biāo)。
本文使用7 種具有不同原理和參數(shù)量大小的經(jīng)典神經(jīng)網(wǎng)絡(luò)進(jìn)行實(shí)驗(yàn)。包括經(jīng)典卷積網(wǎng)絡(luò)Inception-V1 以及深度殘差網(wǎng)絡(luò)ResNet-18,以及以殘差為基礎(chǔ)進(jìn)一步發(fā)展而來(lái)的WRN-16-4、DenseNet-121 和MobileNet-V2。作為ResNet 和Inception 的結(jié)合,ResNeXt-50 也被用在本文的實(shí)驗(yàn)中。兼顧速度與精度的EfficientNet-B3 在圖像分類(lèi)領(lǐng)域有著優(yōu)秀的表現(xiàn),本文也采用這個(gè)網(wǎng)絡(luò)作為實(shí)驗(yàn)對(duì)象之一。
本文使用PyTorch 實(shí)現(xiàn)了所有網(wǎng)絡(luò),并且以NVIDIA Tesla V100 GPU 作為加速進(jìn)行實(shí)驗(yàn)。實(shí)驗(yàn)采用Nesterov 動(dòng)量設(shè)置為0.9 的SGD 作為模型優(yōu)化器,batch size設(shè)置為128,并且設(shè)置了0.000 1的L2正則損失。在訓(xùn)練時(shí),對(duì)于在ImageNet 進(jìn)行過(guò)預(yù)訓(xùn)練的模型,學(xué)習(xí)率被初始化為0.001,而沒(méi)有預(yù)訓(xùn)練過(guò)的模型,學(xué)習(xí)率被初始化為0.1。學(xué)習(xí)率每迭代60 個(gè)epoch會(huì)衰減為原來(lái)的0.1,并且200 個(gè)epoch被作為訓(xùn)練的總迭代次數(shù)。實(shí)驗(yàn)中,數(shù)據(jù)增強(qiáng)方法包括對(duì)圖像的隨機(jī)翻轉(zhuǎn)和每邊填充4 個(gè)像素后進(jìn)行的隨機(jī)裁剪,裁剪后缺失的像素被填充為0。
表1~表3 分別比較在3 個(gè)數(shù)據(jù)集上多種網(wǎng)絡(luò)在K =2 時(shí),通過(guò)Independent 學(xué)習(xí)和DCL 學(xué)習(xí)達(dá)到的Top-1 正確率。結(jié)果分析表明:
表1 CIFAR-100 數(shù)據(jù)集K=2 的Top-1 正確率Tab.1 Top-1 accuracy for CIFAR-100 dataset when K=2 %
表2 CIFAR-10 數(shù)據(jù)集K=2 的Top-1 正確率Tab.2 Top-1 accuracy for CIFAR-10 dataset when K=2 %
表3 Fashion-MNIST 數(shù)據(jù)集K=2 的Top-1 正確率Tab.3 Top-1 accuracy for Fashion -MNIST dataset when K=2%
(1)相對(duì)于Independent 學(xué)習(xí),所有這些網(wǎng)絡(luò)都可以通過(guò)DCL 來(lái)提高自己的學(xué)習(xí)效果,這些提高體現(xiàn)在DCL-Ind 一列中的數(shù)據(jù)都是正數(shù)。
(2)沒(méi)有經(jīng)過(guò)預(yù)訓(xùn)練的網(wǎng)絡(luò)結(jié)構(gòu),通過(guò)DCL 則更顯著地提升了學(xué)習(xí)效果。3 個(gè)數(shù)據(jù)集上最大的提升都來(lái)自于WRN-16-4 和EfficientNet-B3 這2 個(gè)未進(jìn)行預(yù)訓(xùn)練的網(wǎng)絡(luò),分別是3.44%,2.79%和1.16%。
(3)相對(duì)于單通道且圖片尺寸更小的Fashion-MNIST,DCL 的學(xué)習(xí)方式在三通道且圖片尺寸更大的CIFAR 數(shù)據(jù)上對(duì)學(xué)習(xí)效果的提升更加明顯。
在DML 中,本文通過(guò)使用KL散度作為損失的一部分,讓不同網(wǎng)絡(luò)的輸出更集中(不離群),而讓網(wǎng)絡(luò)各自都取得更好的學(xué)習(xí)效果。DML 能夠生效的原因是這種方式使輸出結(jié)果更加集中,提高網(wǎng)絡(luò)的魯棒性。在蒸餾學(xué)習(xí)中,通過(guò)讓小網(wǎng)絡(luò)把大網(wǎng)絡(luò)的輸出做軟目標(biāo)進(jìn)行學(xué)習(xí),實(shí)現(xiàn)大網(wǎng)絡(luò)向小網(wǎng)絡(luò)的知識(shí)傳遞。在Deit 中,通過(guò)讓ViT 網(wǎng)絡(luò)在某些層的輸出向和卷積網(wǎng)絡(luò)或者混合學(xué)習(xí),實(shí)現(xiàn)把知識(shí)向ViT 的傳遞,因而讓ViT 在圖像分類(lèi)中取得了更加優(yōu)秀的表現(xiàn)。
本文DCL 模式中,不同網(wǎng)絡(luò)通過(guò)分享一部分權(quán)重來(lái)進(jìn)行知識(shí)的溝通,在一定程度上使得這些網(wǎng)絡(luò)的最終輸出更加集中、即不離群,也有利于網(wǎng)絡(luò)獲得更高的魯棒性。
圖2 和圖3 比較EfficientNet-B3 和ResNet-18使用不同數(shù)量的網(wǎng)絡(luò)進(jìn)行DCL 學(xué)習(xí)的結(jié)果。學(xué)習(xí)效果對(duì)比表明:
圖2 CIFAR-10 上EfficientNet-B3 使用不同數(shù)量網(wǎng)絡(luò)進(jìn)行DCL學(xué)習(xí)的Top-1 正確率結(jié)果Fig. 2 Top-1 accuracy of DCL learning using different number of networks for EfficientNet-B3 on CIFAR-10
圖3 CIFAR-10 上ResNet-18 不同數(shù)量網(wǎng)絡(luò)進(jìn)行DCL 學(xué)習(xí)的Top-1 正確率結(jié)果Fig. 3 Top-1 accuracy of DCL learning using different number of networks for ResNet-18 on CIFAR-10
(1)DCL 的正確率曲線(xiàn)總是高于Independent的正確率曲線(xiàn),表明在相同的迭代學(xué)習(xí)次數(shù)下,DCL比Independent 學(xué)習(xí)得更加充分,并且在訓(xùn)練結(jié)束后,DCL 達(dá)到了Independent 所沒(méi)有達(dá)到的學(xué)習(xí)效果。
(2)K的值越大,圖中的正確率曲線(xiàn)就處在更高的位置,表明增加進(jìn)行交流和溝通的網(wǎng)絡(luò)個(gè)數(shù)K,將提高整體的學(xué)習(xí)效果。在DCL 中,進(jìn)行交流的學(xué)習(xí)者越多,集體會(huì)傾向于取得更加優(yōu)秀的學(xué)習(xí)表現(xiàn)。
本文提出了一種讓深度神經(jīng)網(wǎng)絡(luò)在學(xué)習(xí)中進(jìn)行互相交流的訓(xùn)練模式,利用經(jīng)典、成熟的圖像分類(lèi)神經(jīng)網(wǎng)絡(luò)對(duì)所提出的學(xué)習(xí)模式的驗(yàn)證結(jié)果表明,該模式使多種深度神經(jīng)網(wǎng)絡(luò)的學(xué)習(xí)效果獲得了明顯提高。利用DCL,深度神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)的效果更好。實(shí)驗(yàn)結(jié)果證明了DCL 模式對(duì)多類(lèi)神經(jīng)網(wǎng)絡(luò)都有效,且增加交流的網(wǎng)絡(luò)個(gè)數(shù),能進(jìn)一步提高學(xué)習(xí)效果。未來(lái)的工作將對(duì)分布式訓(xùn)練的交流方式進(jìn)行探索,以提高多個(gè)網(wǎng)絡(luò)進(jìn)行交流訓(xùn)練的時(shí)間效率。