摘要: 為了解決數(shù)據(jù)的長(zhǎng)尾分布容易造成網(wǎng)絡(luò)模型識(shí)別準(zhǔn)確度下降的問(wèn)題,提出了一種基于因果推斷的兩階段長(zhǎng)尾分類(lèi)模型。首先采用重加權(quán)的方法去除特征和標(biāo)簽之間可能存在的虛假關(guān)聯(lián);其次通過(guò)平衡微調(diào)進(jìn)一步提升模型在少樣本尾部類(lèi)別識(shí)別的準(zhǔn)確率。模型可分為兩個(gè)階段:第一階段設(shè)計(jì)了具有迭代優(yōu)化效果的去相關(guān)樣本重加權(quán)算法以去除虛假相關(guān),達(dá)到穩(wěn)定預(yù)測(cè)的效果;第二階段設(shè)計(jì)了基于CAM的類(lèi)平衡采樣算法進(jìn)行平衡微調(diào)訓(xùn)練,使來(lái)自不平衡數(shù)據(jù)集的學(xué)習(xí)特征在所有類(lèi)別之間轉(zhuǎn)移和重新平衡,以提高模型在尾部類(lèi)別的分類(lèi)性能。實(shí)驗(yàn)結(jié)果證明了模型具有較優(yōu)的性能,同時(shí),無(wú)論從理論層面還是數(shù)據(jù)層面都具有較好的可解釋性。
關(guān)鍵詞: 長(zhǎng)尾分布; 因果推斷; 去相關(guān); 類(lèi)平衡采樣; 可解釋性
中圖分類(lèi)號(hào): TP391
文獻(xiàn)標(biāo)志碼: A
文章編號(hào): 1671-6841(2024)05-0031-08
DOI: 10.13705/j.issn.1671-6841.2023122
A Study of Two-stage Long-tail Classification Based on Causal Inference
CAO Xiaomin, LIU Jinfeng
(College of Information Engineering, Ningxia University, Yinchuan 750021,China)
Abstract: In order to solve the problem caused by long-tail distribution of data, which might decrease network model recognition accuracy to decrease, a two-stage long-tail classification model based on causal inference was proposed. Firstly, a re-weighting approach in the model was used to remove possible spurious associations between features and labels, and secondly the recognition accuracy of the model in tail categories with fewer samples was improved by balancing fine-tuning. The model was divided into two stages. In the first stage, a de-correlated sample reweighting algorithm with iterative optimization effect was designed to remove spurious correlation and achieve stable prediction; in the second stage, a CAM-based class balancing sampling algorithm was designed for balancing fine-tuning training, so that the learned features from unbalanced datasets were transferred and rebalanced among all classes to improve the classification performance of the model in the tail category. The experiments proved that the model had superior performance. Meanwhile, compared with other model, this model had better interpretability from the theoretical level as well as the data level.
Key words: long-tail distribution; causal inference; removal related; class balance sampling; interpretable
0 引言
近年來(lái),隨著大規(guī)模圖像數(shù)據(jù)集在深度神經(jīng)網(wǎng)絡(luò)(deeps neural networks,DNN)上的廣泛應(yīng)用,使得計(jì)算機(jī)在識(shí)別、監(jiān)控和跟蹤目標(biāo)方面超越人類(lèi)成為可能[1]。在計(jì)算機(jī)視覺(jué)研究中,通常假設(shè)數(shù)據(jù)集的分布是均衡的,例如ImageNet-2012[2]、MS COCO[3]和Places Dataset[4]。而在實(shí)際應(yīng)用中,數(shù)據(jù)集通常呈長(zhǎng)尾分布,即少數(shù)類(lèi)別(又稱(chēng)頭類(lèi))包含大量樣本,而大多數(shù)類(lèi)別(又稱(chēng)尾類(lèi))只有非常少量的樣本。許多標(biāo)準(zhǔn)高效的DNN在這種分布下訓(xùn)練時(shí),呈現(xiàn)在頭類(lèi)中表現(xiàn)良好,而在尾類(lèi)中表現(xiàn)不佳,從而導(dǎo)致整體識(shí)別精度顯著下降。緩解此類(lèi)問(wèn)題的主要方法為非平衡學(xué)習(xí)策略,主要包含數(shù)據(jù)級(jí)策略和算法級(jí)策略?xún)深?lèi)。數(shù)據(jù)級(jí)策略主要包括各種類(lèi)型的重采樣方法;算法級(jí)策略著重調(diào)整各個(gè)類(lèi)別的權(quán)重,引導(dǎo)網(wǎng)絡(luò)對(duì)尾類(lèi)給予更多的關(guān)注。除此之外,將頭類(lèi)數(shù)據(jù)中學(xué)習(xí)到的知識(shí)轉(zhuǎn)移到尾類(lèi)中也是一種行之有效的方法。
因果推斷是用于解釋分析的強(qiáng)大建模工具,可以幫助恢復(fù)數(shù)據(jù)中的因果關(guān)聯(lián),實(shí)現(xiàn)可解釋的穩(wěn)定預(yù)測(cè),且因果關(guān)系也能為模型提供較強(qiáng)的可解釋性。因此,本文結(jié)合因果推斷理論緩解長(zhǎng)尾分布數(shù)據(jù)分類(lèi)問(wèn)題。因果推斷中的去相關(guān)樣本重加權(quán)方法能夠去除樣本標(biāo)簽和特征之間的虛假相關(guān),使模型更注重樣本標(biāo)簽與特征之間的真正聯(lián)系,避免混雜因素對(duì)模型的影響,這樣不僅提升了模型的識(shí)別準(zhǔn)確率,加強(qiáng)模型預(yù)測(cè)時(shí)的穩(wěn)定性,同時(shí)具有較好的可解釋性。
本文提出了一種基于因果推斷的兩階段長(zhǎng)尾分類(lèi)模型,該模型第一階段采用改進(jìn)后的去相關(guān)樣本加權(quán)的方法進(jìn)行不平衡訓(xùn)練,以去除樣本標(biāo)簽和特征之間的虛假相關(guān);第二階段針對(duì)第一階段不平衡訓(xùn)練在尾類(lèi)上識(shí)別精度較差的缺點(diǎn),采用重采樣方法進(jìn)行平衡微調(diào)訓(xùn)練。
1 相關(guān)工作
1.1 傳統(tǒng)方法
解決長(zhǎng)尾分布數(shù)據(jù)的傳統(tǒng)方法主要有重采樣和重加權(quán)兩類(lèi)。重采樣即重新采樣數(shù)據(jù)集以實(shí)現(xiàn)更均衡的數(shù)據(jù)分布,這類(lèi)方法包括對(duì)少數(shù)類(lèi)進(jìn)行過(guò)采樣[5](通過(guò)添加數(shù)據(jù)副本)、對(duì)多數(shù)類(lèi)進(jìn)行欠采樣[6](通過(guò)移除數(shù)據(jù)),以及基于每類(lèi)樣本數(shù)量的類(lèi)平衡抽樣[7-8]。有學(xué)者對(duì)重采樣的方法進(jìn)行改進(jìn)或與其他方法相結(jié)合,獲得了優(yōu)于單一重采樣方法的性能,比如Zhou等[9]提出了一個(gè)統(tǒng)一的雙邊分支網(wǎng)絡(luò)(bilateral branch network, BBN),該網(wǎng)絡(luò)同時(shí)負(fù)責(zé)表征學(xué)習(xí)(此分支利用原始數(shù)據(jù)學(xué)習(xí))和分類(lèi)學(xué)習(xí)(此分支利用重采樣學(xué)習(xí)),以全面提高長(zhǎng)尾任務(wù)的識(shí)別性能。Kang等[10]將實(shí)例平衡采樣與分類(lèi)器相結(jié)合,發(fā)現(xiàn)使用最簡(jiǎn)單的實(shí)例平衡采樣學(xué)習(xí)到的表示,可以通過(guò)調(diào)整分類(lèi)器來(lái)實(shí)現(xiàn)較強(qiáng)的長(zhǎng)尾識(shí)別能力。
重加權(quán)方法的主要思想是給不同類(lèi)別分配不同的權(quán)重,引導(dǎo)網(wǎng)絡(luò)對(duì)少數(shù)類(lèi)別給予更多的關(guān)注,實(shí)際上是調(diào)整了每個(gè)類(lèi)別的損失在總損失中的占比,緩解了因長(zhǎng)尾分布導(dǎo)致的梯度占比失衡。Gui等[11]設(shè)計(jì)了一個(gè)重新加權(quán)方案,其利用每個(gè)類(lèi)別的有效樣本數(shù)來(lái)重新平衡損失,從而產(chǎn)生類(lèi)別平衡損失。Cao等[12]設(shè)計(jì)了一種兩步訓(xùn)練方法,第一步只用基于理論原則的標(biāo)簽分布感知邊際損失(label-distribution-aware margin loss, LDAM)進(jìn)行訓(xùn)練,取代訓(xùn)練過(guò)程中標(biāo)準(zhǔn)的交叉熵?fù)p失;第二步加上了傳統(tǒng)重加權(quán)操作。這種方法將重新加權(quán)置于初始階段之后,允許模型學(xué)習(xí)初始表示,同時(shí)避免與重新加權(quán)、重新抽樣相關(guān)的一些并發(fā)問(wèn)題。
傳統(tǒng)方法有較優(yōu)的分類(lèi)性能,但其可解釋性較差,這限制了深度學(xué)習(xí)方法的應(yīng)用領(lǐng)域。
1.2 因果推斷方法
因果推斷是研究如何更加科學(xué)地識(shí)別變量之間的因果關(guān)系。因果推斷要求原因先于結(jié)果,原因與結(jié)果同時(shí)變化或者相關(guān),結(jié)果不存在其他可能的解釋?zhuān)瑥?qiáng)調(diào)原因的唯一性。Pearl等[13]提出了“因果之梯”的概念,自下而上將問(wèn)題劃分為關(guān)聯(lián)、干預(yù)和反事實(shí),分別對(duì)應(yīng)于觀察、行動(dòng)和想象。對(duì)于這三個(gè)層次,因果推斷的方法主要包括重加權(quán)方法、分層方法、基于匹配方法、基于樹(shù)方法、基于表示方法、基于多任務(wù)學(xué)習(xí)方法以及元學(xué)習(xí)方法[14]。
在平衡分布數(shù)據(jù)分類(lèi)任務(wù)中,基于因果推斷的方法展示了其優(yōu)勢(shì)。Kuang等[15]提出了一種去相關(guān)加權(quán)回歸DWR算法,該算法聯(lián)合了優(yōu)化變量去相關(guān)正則化模型和加權(quán)回歸模型。Shen等[16]提出了一種新的因果正則化邏輯回歸CRLR算法,全局混雜因子平衡有助于識(shí)別因果特征,在不同域之間,這些因果特征對(duì)結(jié)果的影響具有穩(wěn)定性,然后對(duì)這些因果特征進(jìn)行邏輯回歸,構(gòu)建一個(gè)針對(duì)不可知性的魯棒預(yù)測(cè)模型,其可解釋性可以通過(guò)特征可視化得到充分描述。Li等[17]將因果分類(lèi)用于一組個(gè)性化決策問(wèn)題,并將其與分類(lèi)進(jìn)行區(qū)分,討論了通過(guò)增強(qiáng)型因果異質(zhì)性建模方法解決因果分類(lèi)的條件,同時(shí)還提出了一個(gè)因果分類(lèi)的一般框架,使用現(xiàn)有的監(jiān)督方法進(jìn)行靈活運(yùn)用。
雖然在平衡分布數(shù)據(jù)分類(lèi)任務(wù)中,因果推斷方法優(yōu)勢(shì)明顯,但將其應(yīng)用于長(zhǎng)尾分布數(shù)據(jù)分類(lèi)任務(wù)中會(huì)存在尾部類(lèi)別分類(lèi)精度較差的問(wèn)題,從而影響整體分類(lèi)精度。
2 基于因果推斷的兩階段長(zhǎng)尾分類(lèi)模型
2.1 去相關(guān)樣本重加權(quán)算法及改進(jìn)
在實(shí)際應(yīng)用中,不能保證未知測(cè)試數(shù)據(jù)與訓(xùn)練數(shù)據(jù)具有相同的分布。如果利用訓(xùn)練數(shù)據(jù)中存在的特征之間的偏差關(guān)系來(lái)改進(jìn)預(yù)測(cè),就會(huì)導(dǎo)致參數(shù)估計(jì)的不準(zhǔn)確性以及與不同分布數(shù)據(jù)集之間預(yù)測(cè)的不穩(wěn)定性。因此導(dǎo)致模型精度下降的主要原因是不相關(guān)特征和類(lèi)別標(biāo)簽之間的虛假相關(guān)。去相關(guān)樣本重加權(quán)方法[15]的目標(biāo)是去除特征之間的虛假相關(guān),本質(zhì)是通過(guò)對(duì)樣本進(jìn)行全局加權(quán),直接對(duì)每個(gè)輸入樣本的所有特征進(jìn)行去相關(guān)以解決分布偏移問(wèn)題,去相關(guān)樣本加權(quán)方法首先利用卷積神經(jīng)網(wǎng)絡(luò)(convolutional neural network,CNN)進(jìn)行特征提取,然后開(kāi)始去相關(guān)的樣本重加權(quán),以此來(lái)消除特征之間的線(xiàn)性、非線(xiàn)性依賴(lài)關(guān)系,再利用最終損失對(duì)分類(lèi)網(wǎng)絡(luò)進(jìn)行優(yōu)化并進(jìn)行圖片分類(lèi)。所用公式為
wb=argminw∑pj=1‖E[XTj∑wX-j]-E[XTjw]E[XT-jw]‖22,(1)
其中:w為樣本權(quán)重;wb表示最終學(xué)習(xí)到的樣本權(quán)重;∑w=diag(w1,w2,…,wn)和∑ni=1wi=n是權(quán)重對(duì)應(yīng)的對(duì)角矩陣,n表示樣本量;X表示變量集合(為n維行矩陣),X-j=X\{Xj}表示通過(guò)刪除變量集合X中第j個(gè)變量所得到的所有剩余變量;p表示變量的位數(shù)。
通過(guò)樣本重加權(quán)使X中的變量互不相干,從而減少訓(xùn)練環(huán)境中協(xié)變量之間的相關(guān)性,從而提高參數(shù)估計(jì)的準(zhǔn)確性。當(dāng)∑ni=1wi=n時(shí),公式(1)中的損失可以表示為
Loss=∑pj=1‖XTj∑wX-j/n-(XTjw/n)·(XT-jw/n)‖22,(2)
其中w為wi。
由于在重加權(quán)過(guò)程中會(huì)產(chǎn)生大量的額外空間,為解決這一問(wèn)題,本文在上述方法的尾端采用了迭代優(yōu)化機(jī)制,只保存最優(yōu)權(quán)重參數(shù)。對(duì)于每個(gè)批次,用于優(yōu)化樣本權(quán)重的特征生成為
ZO=Concat(ZG1,ZG2,…,ZGk,ZL),
wO=Concat(wG1,wG2,…,wGk,wL)。(3)
其中:ZO和wO分別表示優(yōu)化樣本特征和權(quán)重;ZG1,ZG2,…,ZGk,wG1,wG2,…,wGk表示整個(gè)訓(xùn)練集的全局信息,在每個(gè)批次結(jié)束時(shí)更新;ZL和wL是當(dāng)前批次中的特征和權(quán)重。例如批量大小為x時(shí),ZO是大小為((k+1)x)×mZ的矩陣,wO是(k+1)x維向量。通過(guò)這種方式將儲(chǔ)存成本從O(N)降到了O(kx)。在對(duì)每一批進(jìn)行訓(xùn)練時(shí),保持wGi不變,只有wL在本批次進(jìn)行特征學(xué)習(xí),在每次訓(xùn)練迭代結(jié)束時(shí),將全局信息(ZGi,wGi)和局部信息(ZL,wL)融合,所用公式為
Z′Gi=αiZ+(1-αi)ZL,
w′Gi=αiwGi+(1-αi)wL。(4)
對(duì)于每組全局信息(ZGi,wGi),使用k個(gè)不同的平滑參數(shù)αi來(lái)約束全局信息的長(zhǎng)期記憶(αi較大)和短期記憶(αi較?。?,最后將(ZGi,wGi)替換為(Z′Gi,w′Gi)。
在訓(xùn)練過(guò)程中,引入Mixup[18]數(shù)據(jù)增強(qiáng)方法可進(jìn)一步提高模型性能。Mixup數(shù)據(jù)增強(qiáng)方法簡(jiǎn)單來(lái)說(shuō)就是構(gòu)造虛擬訓(xùn)練樣本執(zhí)行數(shù)據(jù)增強(qiáng),并且在數(shù)據(jù)處理過(guò)程中引入較少的參數(shù)量來(lái)節(jié)約計(jì)算資源。
本階段網(wǎng)絡(luò)模型采用Resnet_34作為主干網(wǎng)絡(luò),并將輸出的特征圖譜進(jìn)行去相關(guān)的樣本重加權(quán)操作,并利用最終損失對(duì)分類(lèi)網(wǎng)絡(luò)進(jìn)行迭代優(yōu)化,從而實(shí)現(xiàn)圖片分類(lèi)任務(wù),其主要流程如圖1所示。
2.2 基于CAM的類(lèi)平衡采樣
第一階段模型在不均衡數(shù)據(jù)集上訓(xùn)練,能夠?qū)W習(xí)到好的特征表示,但是尾部類(lèi)別中識(shí)別準(zhǔn)確率較差。為了得到更均衡的數(shù)據(jù)分布,第二階段用重采樣方法進(jìn)行平衡微調(diào),使獲取的不平衡數(shù)據(jù)集特征值在所有類(lèi)別之間實(shí)現(xiàn)特征共享與特征重平衡。最終本文選擇基于類(lèi)激活映射(class activation mapping,CAM)[19]的類(lèi)平衡采樣方法作為平衡微調(diào)實(shí)驗(yàn)的模型。
2.2.1 類(lèi)平衡采樣
對(duì)于不同的采樣方式,概率pj的公式為
pj=nqj/(∑Ci=1nqi),(5)
其中:q∈[0,1],對(duì)于不同的q值,會(huì)出現(xiàn)不同的采樣策略;C是類(lèi)的數(shù)量。
本文所用的采樣方法為類(lèi)平衡采樣[10],每個(gè)類(lèi)被選中的樣本概率相等。q=0時(shí),概率pCBj公式為
pCBj=1/C。(6)
2.2.2 類(lèi)激活映射(CAM)
為了產(chǎn)生鑒別性的信息,本文受類(lèi)激活映射的啟發(fā),將CAM與類(lèi)平衡采樣相結(jié)合構(gòu)成第二階段實(shí)驗(yàn),使模型從數(shù)據(jù)層面具有可解釋性。
類(lèi)激活映射(CAM)[19]是將輸出層的權(quán)重投射回卷積特征圖,以識(shí)別圖像關(guān)注區(qū)域的重要性技術(shù)。通過(guò)全局平均池化輸出卷積層中每個(gè)單元特征圖的空間平均值,這些值的加權(quán)和生成最終輸出。類(lèi)似地,通過(guò)計(jì)算最后一個(gè)卷積層的特征圖的加權(quán)和獲得類(lèi)激活圖,生成類(lèi)激活圖的過(guò)程如圖2所示。
對(duì)于給定的圖像,設(shè)fk(x,y)表示在空間位置(x,y)處最后卷積層中單元k的激活。然后對(duì)單元k執(zhí)行全局平均池化,F(xiàn)k=∑x,yfk(x,y),因此,對(duì)于給定的c類(lèi),softmax的輸入是Sc,Sc=∑kwckFk,其中wck是對(duì)應(yīng)單元k的c類(lèi)的權(quán)重。最后,c類(lèi)softmax輸出為公式(7),通過(guò)將Fk=∑x,yfk(x,y)代入到Sc中,得到公式(8)。
Pc=exp(Sc)∑cexp(Sc),(7)
Sc=∑kwck∑x,yfk(x,y)=∑x,y∑kwckfk(x,y)。(8)
Mc被定義為c類(lèi)的類(lèi)激活映射,其中每個(gè)元素空間的公式為
Mc(x,y)=∑kwckfk(x,y),(9)
因此,Sc=∑x,yMc(x,y),Mc直接指示了網(wǎng)絡(luò)空間(x,y)處激活的重要性,從而圖像分類(lèi)為c類(lèi)。
2.2.3 基于CAM的類(lèi)平衡采樣
第二階段微調(diào)過(guò)程如圖3所示,首先應(yīng)用重新采樣來(lái)獲得平衡的采樣圖像,通過(guò)第一訓(xùn)練階段的參數(shù)化模型得到特征圖,再通過(guò)全連接層得到圖像的類(lèi)別標(biāo)簽。對(duì)于每個(gè)采樣的圖像,基于標(biāo)簽c的特征圖和第一階段訓(xùn)練得到的權(quán)重生成CAM。前景和背景根據(jù)CAM的平均值分開(kāi),其中前景包含大于平均值的像素,背景包含其余的像素。最后,在背景保持不變的情況下對(duì)前景進(jìn)行預(yù)處理,包括水平翻轉(zhuǎn)、縮放、旋轉(zhuǎn)和平移變換,對(duì)每張圖片隨機(jī)選擇一個(gè)變換,最終生成有信息的采樣數(shù)據(jù),并將生成的采樣數(shù)據(jù)增加到數(shù)據(jù)集,使用第一訓(xùn)練階段的參數(shù)化模型進(jìn)行訓(xùn)練。
2.3 去相關(guān)樣本重加權(quán)算法和CAM的可解釋
2.3.1 去相關(guān)樣本重加權(quán)算法的可解釋
雖然許多深度學(xué)習(xí)模型在其目標(biāo)任務(wù)上能夠取得良好的性能,但深度學(xué)習(xí)模型一直以來(lái)都被認(rèn)為是“黑箱”模型。近年來(lái),有學(xué)者嘗試使用因果推斷的方法去探究深度學(xué)習(xí)網(wǎng)絡(luò)的可解釋性。Pearl等[13]闡述了因果關(guān)系階梯中不同層級(jí)的可解釋性,因果關(guān)系階梯大致可以分為以下三層。
1) 統(tǒng)計(jì)相關(guān)的解釋?zhuān)搶蛹?jí)旨在利用相關(guān)性來(lái)解釋人類(lèi)是如何進(jìn)行判斷的。
2) 因果干預(yù)的解釋?zhuān)搶蛹?jí)旨在對(duì)相關(guān)行動(dòng)進(jìn)行人為干預(yù),從而得到干預(yù)后的結(jié)果,并通過(guò)這些結(jié)果進(jìn)行解釋。
3) 基于反事實(shí)的解釋?zhuān)搶蛹?jí)是三個(gè)層級(jí)中最高的,旨在利用一些反事實(shí)來(lái)進(jìn)行想象,并基于這些想象進(jìn)行解釋。
當(dāng)前的機(jī)器學(xué)習(xí)主要利用數(shù)據(jù)中的統(tǒng)計(jì)相關(guān)性進(jìn)行建模,相關(guān)性的來(lái)源主要有因果、混淆以及樣本選擇偏差三種,分別對(duì)應(yīng)圖4中的三種結(jié)構(gòu)。圖4中T表示原因,Y表示結(jié)果,X表示混淆變量,S表示選擇偏差,實(shí)心箭頭表示因果關(guān)系,虛線(xiàn)箭頭表示假性相關(guān)關(guān)系?;煜侵复嬖谝粋€(gè)變量X,該變量構(gòu)成了T和Y的共同原因,如果忽略了X的影響,那么T和Y之間存在假性相關(guān)關(guān)系,即T并非產(chǎn)生Y的直接原因。樣本選擇偏差也會(huì)產(chǎn)生相關(guān)性,當(dāng)兩個(gè)相互獨(dú)立的變量T和Y產(chǎn)生了一個(gè)共同結(jié)果S,引入S則為T(mén)和Y之間打開(kāi)了一條通路,從而誤以為T(mén)和Y之間存在關(guān)聯(lián)關(guān)系。上述兩種相關(guān)通常被稱(chēng)為虛假相關(guān),只有由因果產(chǎn)生的相關(guān)是一種穩(wěn)定的機(jī)制,不會(huì)受非標(biāo)簽特征影響,也只有這種穩(wěn)定的結(jié)構(gòu)是可解釋的。
傳統(tǒng)的可解釋技術(shù)多數(shù)會(huì)依賴(lài)于特征和結(jié)果之間的相關(guān)性,有可能會(huì)檢測(cè)出一些相反甚至病態(tài)的解釋關(guān)系。同時(shí),這些技術(shù)難以回答“如果某個(gè)干預(yù)改變了,模型的決策或判斷是什么?”這樣的反事實(shí)相關(guān)的問(wèn)題。而屬于可解釋性技術(shù)的因果推斷技術(shù)是專(zhuān)門(mén)研究干預(yù)結(jié)果效應(yīng)的方法。因果關(guān)系與其他關(guān)系相比受到的干擾較少,由因果產(chǎn)生的相關(guān)是一種穩(wěn)定的機(jī)制,不會(huì)受非標(biāo)簽特征所影響。
當(dāng)進(jìn)行因果推斷時(shí),需要考慮可能存在的混淆因素,這些因素可能導(dǎo)致因果關(guān)系被低估或高估。為了得出準(zhǔn)確的因果推斷結(jié)果,可采用去相關(guān)樣本重加權(quán)的方法消除混淆因素。
去相關(guān)樣本重加權(quán)方法通過(guò)重新加權(quán)樣本來(lái)減少某些特征對(duì)研究結(jié)果的影響,從而更準(zhǔn)確地確定因果關(guān)系。
2.3.2 CAM的可解釋
CAM是一種用于深度學(xué)習(xí)模型可視化和解釋的方法,可以幫助我們理解模型對(duì)不同類(lèi)別的判斷基于哪些特征。CAM通過(guò)對(duì)CNN模型的最后一層卷積層進(jìn)行修改,使其能夠輸出給定輸入圖像在特定類(lèi)別上的激活熱力圖。CAM將CNN最后一層卷積層的特征圖和全局平均池化層的特征權(quán)重相乘,得到每個(gè)類(lèi)別的特征映射,這些特征映射會(huì)被送入一個(gè)可視化工具中,并將它們轉(zhuǎn)換為彩色的熱力圖,這些熱力圖可以讓人們更直觀地理解模型的判斷過(guò)程,識(shí)別出模型可能出現(xiàn)的錯(cuò)誤,還可以用于優(yōu)化模型的訓(xùn)練和設(shè)計(jì),通過(guò)觀察熱力圖發(fā)現(xiàn)哪些區(qū)域?qū)τ诜诸?lèi)有用,進(jìn)而調(diào)整模型參數(shù),以提高模型的準(zhǔn)確性和可解釋性。
3 實(shí)驗(yàn)結(jié)果與分析
3.1 數(shù)據(jù)集設(shè)置
本文使用的CIFAR-10/100_LT[11]是CIFAR-10/100的長(zhǎng)尾版本。CIFAR-10和CIFAR-100都包含60 000張圖像,50 000張用于訓(xùn)練,10 000張用于驗(yàn)證,類(lèi)別分別為10和100。
本次實(shí)驗(yàn)根據(jù)數(shù)據(jù)不平衡率設(shè)計(jì)了CIFAR-10/100的長(zhǎng)尾版本,數(shù)據(jù)不平衡率控制了訓(xùn)練集的分布。不平衡率被廣泛用作長(zhǎng)尾性的度量,也是本文主要使用的長(zhǎng)尾性度量標(biāo)準(zhǔn)。Cui等[11]將數(shù)據(jù)集的不平衡率μ定義為最大類(lèi)中的訓(xùn)練樣本數(shù)除以最小樣本數(shù),其中N是每個(gè)類(lèi)別中的樣本數(shù)量,則
μ=Nmax/Nmin。(10)
對(duì)于長(zhǎng)尾CIFAR-10數(shù)據(jù)集,不平衡率分別設(shè)置為10、20、50、100時(shí)圖像數(shù)量如表1所示。同時(shí),也對(duì)CIFAR-100數(shù)據(jù)集做了類(lèi)似的處理。
3.2 實(shí)驗(yàn)設(shè)置
本文模型的特征提取器選用Resnet_34,第二階段的實(shí)驗(yàn)采用了第一階段不平衡訓(xùn)練得到的最優(yōu)參數(shù)化模型。其中第一階段實(shí)驗(yàn)參數(shù)設(shè)置如下:次數(shù)epoch=200;學(xué)習(xí)率lr=0.01;動(dòng)量momentum=0.9;
批量大小batch_size=128;權(quán)重衰減wd=1e-4。第二階段的實(shí)驗(yàn)參數(shù)除epoch設(shè)置為40以外,其他與第一階段實(shí)驗(yàn)參數(shù)設(shè)置相同。在進(jìn)行采樣方式的對(duì)比實(shí)驗(yàn)以及消融實(shí)驗(yàn)時(shí),參數(shù)設(shè)置均與上述參數(shù)設(shè)置相同。
本文所涉及的實(shí)驗(yàn)均在Windows 11操作系統(tǒng)以及NVIDIA GeForce RTX 3050 4 GB GPU上實(shí)現(xiàn),本文采用的深度學(xué)習(xí)的開(kāi)源框架為Pytorchcuda 1.13.0。
3.3 實(shí)驗(yàn)結(jié)果
3.3.1 對(duì)比實(shí)驗(yàn)
將本文所提出的模型與CIFAR10/100_LT數(shù)據(jù)集上的其他方法進(jìn)行評(píng)估,不平衡率分別設(shè)置為10、20、50及100。同時(shí)為了去除不同實(shí)驗(yàn)環(huán)境帶來(lái)的數(shù)據(jù)差異,采用的所有對(duì)比方法均在本文模型相同的實(shí)驗(yàn)環(huán)境下進(jìn)行。分類(lèi)精度結(jié)果如表2所示,其中黑體數(shù)據(jù)為最優(yōu)結(jié)果。
CIFAR10_LT數(shù)據(jù)集:當(dāng)不平衡率分別設(shè)置為10、20、50及100時(shí),相比于其他方法,本文模型取得了最優(yōu)分類(lèi)精度,分別為91.22%、86.01%、82.76%和79.28%。
CIFAR100_LT數(shù)據(jù)集:當(dāng)不平衡率設(shè)置為10和20時(shí),本文模型取得了最優(yōu)精度,分別為62.41%和55.44%。當(dāng)不平衡率設(shè)置為50和100時(shí),最優(yōu)分類(lèi)精度為BKD模型的47.25%和44.21%,本文模型的分類(lèi)精度為47.43%和43.39%。
本文模型在CIFAR10_LT數(shù)據(jù)集上有最優(yōu)的表現(xiàn),在CIFAR100_LT數(shù)據(jù)集相關(guān)實(shí)驗(yàn)中與最優(yōu)模型BKD表現(xiàn)基本持平,精度相差不足1%。
不相關(guān)特征和類(lèi)別標(biāo)簽之間的虛假相關(guān)是導(dǎo)致模型預(yù)測(cè)準(zhǔn)確率下降的主要原因,同時(shí)會(huì)導(dǎo)致模型預(yù)測(cè)的不穩(wěn)定性。本文模型通過(guò)去除變量之間的虛假相關(guān),提高模型的預(yù)測(cè)穩(wěn)定性以及準(zhǔn)確率。其次,通過(guò)增加平衡微調(diào)實(shí)驗(yàn),解決了不平衡數(shù)據(jù)導(dǎo)致模型在樣本數(shù)量較少的尾部類(lèi)別中識(shí)別精度較差的問(wèn)題,進(jìn)一步提高了模型性能。
為了證明選擇類(lèi)平衡采樣方法的優(yōu)越性,對(duì)不同采樣進(jìn)行了對(duì)比實(shí)驗(yàn),不平衡率取100和50,分類(lèi)精度結(jié)果如表3所示。
從表3中可以看出,基于CAM的類(lèi)平衡采樣方法優(yōu)于其他基于CAM的采樣方法。
3.3.2 消融實(shí)驗(yàn)
為了判斷各方法的有效性,本文進(jìn)行了消融實(shí)驗(yàn)來(lái)評(píng)價(jià)本文所提出兩階段模型的性能。消融實(shí)驗(yàn)只在不平衡率為100的CIFAR10_LT上進(jìn)行,其主干網(wǎng)絡(luò)均采用Resnet_34。具體實(shí)驗(yàn)結(jié)果如表4所示。
在Resnet_34網(wǎng)絡(luò)的基礎(chǔ)上增加去相關(guān)重加權(quán)方法之后,分類(lèi)精度增長(zhǎng)8.69%,以此可以證明因果推斷原理在長(zhǎng)尾分類(lèi)任務(wù)中的有效性。在②的基礎(chǔ)上,增加Mixup數(shù)據(jù)增強(qiáng)方法之后,分類(lèi)精度提升1.88%。添加第二階段微調(diào)實(shí)驗(yàn)(類(lèi)平衡采樣)之后,分類(lèi)精度比③增長(zhǎng)2.69%,但將類(lèi)平衡采樣更改為基于CAM的類(lèi)平衡采樣方法之后,分類(lèi)精度比③提升6.34%??梢?jiàn)使用了CAM的方法進(jìn)行類(lèi)平衡采樣,不僅能夠提升模型性能,還能夠使模型從數(shù)據(jù)層面具有可解釋性。
4 結(jié)束語(yǔ)
本文的主要貢獻(xiàn)如下:1) 本文提出的基于因果推斷的兩階段長(zhǎng)尾分類(lèi)模型在CIFAR10/100_LT數(shù)據(jù)集上取得了不錯(cuò)的分類(lèi)效果,并且通過(guò)對(duì)比實(shí)驗(yàn)以及消融實(shí)驗(yàn),證明了該方法的有效性;2) 本文所提出的方法不僅在整體模型上具有可解釋性,并且在微調(diào)訓(xùn)練階段采用了基于CAM的類(lèi)平衡采樣方法,CAM方法能夠顯示出特征的具體位置,使模型在數(shù)據(jù)層面也具有可解釋性;3) 本文將因果推斷理論應(yīng)用于長(zhǎng)尾分類(lèi)任務(wù)中,再次證明因果推斷理論在長(zhǎng)尾分類(lèi)任務(wù)中的有效性。
解決長(zhǎng)尾分布問(wèn)題在計(jì)算機(jī)視覺(jué)領(lǐng)域不僅非常重要,而且也是一項(xiàng)巨大的挑戰(zhàn)。我們認(rèn)為因果推斷是一個(gè)很好的發(fā)展方向,在未來(lái)的研究中,將深入研究因果推斷理論在長(zhǎng)尾分類(lèi)任務(wù)中的應(yīng)用。
參考文獻(xiàn):
[1] 王陽(yáng), 袁國(guó)武, 瞿睿, 等. 基于改進(jìn)YOLOv3的機(jī)場(chǎng)停機(jī)坪目標(biāo)檢測(cè)方法[J]. 鄭州大學(xué)學(xué)報(bào)(理學(xué)版), 2022, 54(5): 22-28.
WANG Y, YUAN G W, QU R, et al. Target detection method of airport apron based on improved YOLOv3[J]. Journal of Zhengzhou university (natural science edition), 2022, 54(5): 22-28.
[2] DENG J, DONG W, SOCHER R, et al. ImageNet: a large-scale hierarchical image database[C]∥2009 IEEE Conference on Computer Vision and Pattern Recognition. Piscataway:IEEE Press, 2009: 248-255.
[3] LIN T Y, MAIRE M, BELONGIE S, et al. Microsoft COCO: common objects in context[M]. Cham: Springer International Publishing, 2014.
[4] ZHOU B L, LAPEDRIZA A, KHOSLA A, et al. Places: a 10 million image database for scene recognition[J]. IEEE transactions on pattern analysis and machine intelligence, 2018, 40(6): 1452-1464.
[5] JUNSOMBOON N, PHIENTHRAKUL T. Combining over-sampling and under-sampling techniques for imbalance dataset[C]∥Proceedings of the 9th International Conference on Machine Learning and Computing. New York: ACM Press, 2017: 243-247.
[6] MOHAMMED R, RAWASHDEH J, ABDULLAH M. Machine learning with oversampling and undersampling techniques: overview study and experimental results[C]∥2020 11th International Conference on Information and Communication Systems. Piscataway:IEEE Press, 2020: 243-248.
[7] SHEN L, LIN Z C, HUANG Q M. Relay backpropagation for effective learning of deep convolutional neural networks[M]. Cham: Springer International Publishing, 2016.
[8] MAHAJAN D, GIRSHICK R, RAMANATHAN V, et al. Exploring the limits of weakly supervised pretraining[C]∥Computer Vision-ECCV 2018: 15th European Conference. New York: ACM Press, 2018: 185-201.
[9] ZHOU B Y, CUI Q, WEI X S, et al. BBN: bilateral-branch network with cumulative learning for long-tailed visual recognition[C]∥2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition. Piscataway:IEEE Press, 2020: 9716-9725.
[10]KANG B Y, XIE S N, ROHRBACH M, et al. Decoupling representation and classifier for long-tailed recognition[EB/OL].(2019-10-21)[2023-02-21]. https:∥arxiv.org/abs/1910.09217.
[11]CUI Y, JIA M L, LIN T Y, et al. Class-balanced loss based on effective number of samples[C]∥2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition. Piscataway:IEEE Press, 2020: 9260-9269.
[12]CAO K D, WEI C, GAIDON A, et al. Learning imbalanced datasets with label-distribution-aware margin loss[EB/OL]. (2019-07-18)[2023-02-21].https:∥arxiv.org/abs/1906.07413.
[13]PEARL J, MACKENZIE D. The book of why: the new science of cause and effect[M].New York: Basic Books Publishing, 2018.
[14]YAO L Y, CHU Z X, LI S, et al. A survey on causal inference[J]. ACM transactions on knowledge discovery from data, 2021, 15(5): 1-46.
[15]KUANG K, XIONG R X, CUI P, et al. Stable prediction with model misspecification and agnostic distribution shift[J]. Proceedings of the AAAI conference on artificial intelligence, 2020, 34(4): 4485-4492.
[16]SHEN Z Y, CUI P, KUANG K, et al. Causally regularized learning with agnostic data selection bias[C]∥Proceedings of the 26th ACM International Conference on Multimedia. New York: ACM Press, 2018: 411-419.
[17]LI J Y, ZHANG W J, LIU L, et al. A general framework for causal classification[J]. International journal of data science and analytics, 2021, 11(2): 127-139.
[18]ZHANG H Y, CISSE M, DAUPHIN Y N, et al. Mixup: beyond empirical risk minimization[EB/OL].(2017-10-25)[2023-02-21].https:∥arxiv.org/abs/1710.09412.
[19]ZHOU B L, KHOSLA A, LAPEDRIZA A, et al. Learning deep features for discriminative localization[C]∥2016 IEEE Conference on Computer Vision and Pattern Recognition. Piscataway:IEEE Press, 2016: 2921-2929.
[20]YANG Y Z, XU Z. Rethinking the value of labels for improving class-imbalanced learning[EB/OL]. (2020-07-13)[2023-02-21]. https:∥arxiv.org/abs/2006.07529.
[21]CHOU H P, CHANG S C, PAN J Y, et al. Remix: rebalanced mixup[M]. Cham: Springer International Publishing, 2020.
[22]ZHANG S Y, CHEN C, HU X Y, et al. Balanced knowledge distillation for long-tailed learning[J]. Neurocomputing, 2023, 527: 36-46.