張輝宜, 張 進(jìn), 黃 俊
(安徽工業(yè)大學(xué) 計(jì)算機(jī)科學(xué)與技術(shù)學(xué)院,安徽 馬鞍山 243000)
傳統(tǒng)監(jiān)督學(xué)習(xí)中每個(gè)樣本只含有一個(gè)語(yǔ)義信息,但是現(xiàn)實(shí)世界的數(shù)據(jù)往往含有多個(gè)類別的語(yǔ)義信息,即單個(gè)樣本關(guān)聯(lián)著多個(gè)語(yǔ)義標(biāo)簽。例如,一幅天空的圖像可以同時(shí)標(biāo)注“藍(lán)天”、“白云”等語(yǔ)義標(biāo)簽;一段新聞文檔可以同時(shí)屬于“時(shí)事”、“政策”等多個(gè)類別。針對(duì)這些含有多個(gè)語(yǔ)義標(biāo)簽的多標(biāo)簽數(shù)據(jù),如果只考慮單一語(yǔ)義標(biāo)簽對(duì)其進(jìn)行學(xué)習(xí),就很難獲得很好的分類效果。多標(biāo)簽學(xué)習(xí)的應(yīng)用領(lǐng)域十分廣泛,包含了圖像分類[1]、文本分類[2]、音樂(lè)分類[3]以及生物學(xué)分類[4]等多個(gè)領(lǐng)域。隨著現(xiàn)實(shí)生活中多標(biāo)簽圖像數(shù)量越來(lái)越多、種類越來(lái)越復(fù)雜,多標(biāo)簽學(xué)習(xí)在圖像分類上的應(yīng)用也顯得更加重要。利用標(biāo)簽之間的相關(guān)性可以提升多標(biāo)簽?zāi)P偷姆诸愋阅躘5]。根據(jù)標(biāo)簽相關(guān)性挖掘的程度,可以將多標(biāo)簽分類模型分為3類:沒(méi)有利用到標(biāo)簽相關(guān)性的一階方法[6-7];挖掘標(biāo)簽對(duì)之間關(guān)系的二階方法[8-9];和利用所有標(biāo)簽或類別標(biāo)簽子集中標(biāo)簽關(guān)系的高階方法[10-11]。初期采用淺層分類模型[12-13]對(duì)多標(biāo)簽圖像進(jìn)行分類,在人工干預(yù)提取數(shù)據(jù)特征的情況下,淺層模型一般都能取得較好的分類結(jié)果。近十年來(lái)應(yīng)用深度學(xué)習(xí)理論[14]尤其是卷積神經(jīng)網(wǎng)絡(luò)(Convolutional Neural Network, CNN)構(gòu)建了一批經(jīng)典深度卷積神經(jīng)網(wǎng)絡(luò)模型,例如AlexNet[15]、VGG[16]、ResNet[17],這些模型可以對(duì)大量多標(biāo)簽圖像樣本進(jìn)行有效的深層特征學(xué)習(xí),但單獨(dú)利用卷積神經(jīng)網(wǎng)絡(luò)對(duì)多標(biāo)簽圖像進(jìn)行分類缺乏對(duì)標(biāo)簽相關(guān)性的利用,這會(huì)影響模型的分類性能,因此多標(biāo)簽分類的現(xiàn)有工作往往會(huì)利用標(biāo)簽相關(guān)性以提高性能。在標(biāo)簽相關(guān)性中標(biāo)簽的共現(xiàn)關(guān)系可以通過(guò)概率圖模型很好地表述,在以往的研究工作中,有很多基于這種數(shù)學(xué)理論的方法可以對(duì)標(biāo)簽關(guān)系進(jìn)行建模[18,19],但是概率圖模型的計(jì)算成本過(guò)高,為了解決這個(gè)問(wèn)題,使用遞歸網(wǎng)絡(luò)將標(biāo)簽編碼為嵌入向量,以實(shí)現(xiàn)標(biāo)簽間相關(guān)性建模的方法被提出[20],該方法也存在著遞歸神經(jīng)網(wǎng)絡(luò)模型依賴于預(yù)定義或?qū)W習(xí)的標(biāo)簽順序的不足,且無(wú)法很好地獲得標(biāo)簽全局依賴性。2019年Chen等提出了ML-GCN模型[21],ML-GCN利用訓(xùn)練數(shù)據(jù)集中所有標(biāo)簽類別的標(biāo)簽共現(xiàn)關(guān)系建立了整體的標(biāo)簽相關(guān)性,在最終分類階段使用圖卷積網(wǎng)絡(luò)(Graph Convolutional Network,GCN)[22]傳播標(biāo)簽共現(xiàn)嵌入并將標(biāo)簽共現(xiàn)嵌入與CNN特征合并,但是ML-GCN學(xué)習(xí)到的標(biāo)簽共現(xiàn)嵌入維度遠(yuǎn)遠(yuǎn)高于需要分類的標(biāo)簽類別數(shù),這會(huì)影響模型的分類性能。提出基于圖注意力網(wǎng)絡(luò)(Graph Attention Network,GAT)[23]的多標(biāo)簽圖像分類模型ML-GAT,ML-GAT采用降維[24]的方法對(duì)ML-GCN中標(biāo)簽共現(xiàn)嵌入維度過(guò)高的問(wèn)題進(jìn)行改進(jìn),同時(shí)采用圖注意力網(wǎng)絡(luò)對(duì)標(biāo)簽之間的關(guān)系進(jìn)行更加精確的建模。
針對(duì)通過(guò)圖卷積神經(jīng)網(wǎng)絡(luò)得到標(biāo)簽共現(xiàn)嵌入維度過(guò)高的問(wèn)題,ML-GAT采用詞嵌入降維模塊對(duì)高維雙向Transformer 的表征編碼器(Bidirectional Encoder Representation from Transformers,BERT)[25]標(biāo)簽語(yǔ)義嵌入表示矩陣進(jìn)行降維,得到低維標(biāo)簽語(yǔ)義嵌入表示矩陣。為了學(xué)習(xí)標(biāo)簽之間非對(duì)稱的關(guān)系特征,將低維標(biāo)簽語(yǔ)義嵌入表示矩陣和標(biāo)簽類別共現(xiàn)圖輸入GAT,獲取標(biāo)簽共現(xiàn)嵌入模塊,得到維度合適的低維標(biāo)簽共現(xiàn)嵌入。同時(shí)ML-GAT采用圖像特征提取模塊提取圖像特征。為了匹配低維標(biāo)簽共現(xiàn)嵌入維度,圖像特征需要經(jīng)過(guò)圖像特征降維模塊進(jìn)行降維,在降維的同時(shí)也減少了圖像特征中的冗余部分。最后,將標(biāo)簽共現(xiàn)嵌入與降維后的圖像特征通過(guò)圖像特征與標(biāo)簽共現(xiàn)嵌入融合模塊進(jìn)行融合,得到多標(biāo)簽預(yù)測(cè)評(píng)分。多標(biāo)簽圖注意力網(wǎng)絡(luò)模型結(jié)構(gòu)如圖1所示。
圖1 多標(biāo)簽圖注意力網(wǎng)絡(luò)模型結(jié)構(gòu)Fig. 1 Model structure of multi-label graph attention network
ML-GAT中圖像通用特征提取模塊使用101層ResNet,即ResNet-101模型。ResNet-101是目前主流的CNN之一,其優(yōu)點(diǎn)是易于調(diào)整,可以比較方便地利用在多標(biāo)簽圖像分類任務(wù)上,并且有較強(qiáng)的特征提取能力。因?yàn)镸L-GAT采用的是在ImageNet上預(yù)訓(xùn)練的ResNet-101,所以需要去除用來(lái)對(duì)ImageNet進(jìn)行分類的全連接層,為了控制圖像維度,需要同時(shí)去除ResNet-101的自適應(yīng)池化層,這樣可以得到多標(biāo)簽圖像特征提取器。將解析度為448×448的多標(biāo)簽圖像樣本I輸入多標(biāo)簽圖像特征提取器,可提取多標(biāo)簽圖像的特征圖F:
F=fResNet(I;θResNet)∈RW×H×D
其中,特征圖F的長(zhǎng)寬為W、H,通道數(shù)為D,fResNet表示圖像通用特征提取模塊,θResNet是ResNet-101模型參數(shù)。
因?yàn)樵趫D像通用特征與標(biāo)簽共現(xiàn)嵌入融合模塊中需要將圖像特征與標(biāo)簽共現(xiàn)嵌入維度進(jìn)行匹配,同時(shí)對(duì)圖像特征進(jìn)行降維,可以一定程度上提高圖像特征的判別力,故在ML-GAT中采取以下步驟對(duì)對(duì)特征圖F的長(zhǎng)寬W,H以及通道數(shù)D進(jìn)行降維,F(xiàn)首先通過(guò)卷積層conv1下采樣,得到F′∈RW′×H′×D,W′和H′代表降維后特征圖F′的長(zhǎng)與寬,再通過(guò)一層卷積層conv2對(duì)特征圖F′的通道數(shù)D進(jìn)行降維,得到F″∈RW′×H′×d″,d″為降維后F″的通道數(shù),最后經(jīng)過(guò)全局最大值池化層GMP,提取多標(biāo)簽圖像的特征紋理,去除無(wú)用特征。這樣可以為每一張圖像提取一個(gè)維度為Rd″的圖像特征向量x:
x=fGMP(fconv2(fconv1(F);θconv1);θconv2)∈Rd″
其中,fGMP為全局最大值池化運(yùn)算,fconv1和fconv2分別為卷積層conv1與conv2進(jìn)行的卷積運(yùn)算,θconv1,θconv2分別為卷積層conv1與卷積層conv2的模型參數(shù)。
GAT首先針對(duì)每一個(gè)標(biāo)簽節(jié)點(diǎn)i,計(jì)算標(biāo)簽節(jié)點(diǎn)i與包括自身節(jié)點(diǎn)自身在內(nèi)的所有鄰居節(jié)點(diǎn)j之間的相關(guān)系數(shù)eij:
其中,W∈RM×d′是共享參數(shù),對(duì)標(biāo)簽節(jié)點(diǎn)i的特征zi和標(biāo)簽節(jié)點(diǎn)j的特征zj進(jìn)行增維,增維后的維度為M,在ML-GAT中最后一層M=d″,[·‖·]表示對(duì)標(biāo)簽節(jié)點(diǎn)i,j的特征進(jìn)行拼接可以將兩個(gè)維度為RM的向量拼接為R2M的向量,j∈Ni是與標(biāo)簽節(jié)點(diǎn)i存在共現(xiàn)關(guān)系的一跳鄰居節(jié)點(diǎn)。a∈R2M運(yùn)算將拼接特征映射到一個(gè)實(shí)數(shù)上。
對(duì)相關(guān)系數(shù)采用歸一化運(yùn)算得到注意力系數(shù):
其中,LeakyReLU是激活函數(shù)。用注意力系數(shù)aij用來(lái)計(jì)算每個(gè)節(jié)點(diǎn)的最終輸出特征:
其中,σ為非線性激活函數(shù)。因?yàn)橥ㄟ^(guò)計(jì)算得到的標(biāo)簽i對(duì)標(biāo)簽j的注意力系數(shù)與標(biāo)簽j對(duì)標(biāo)簽i的注意力系數(shù)不同,所以GAT得到的標(biāo)簽節(jié)點(diǎn)特征可以一定程度上表示多標(biāo)簽學(xué)習(xí)中標(biāo)簽與標(biāo)簽之間的非對(duì)稱關(guān)系,例如"飛機(jī)”和“天空”這一對(duì)標(biāo)簽,有“天空”這一標(biāo)簽時(shí)“飛機(jī)”有著很小的概率同時(shí)出現(xiàn),而“飛機(jī)”標(biāo)簽出現(xiàn)時(shí)則會(huì)大概率伴隨著“天空”標(biāo)簽的出現(xiàn)。ML-GAT采用GAT可以單獨(dú)為每一對(duì)標(biāo)簽計(jì)算注意力系數(shù),得到能更加準(zhǔn)確表達(dá)標(biāo)簽之間關(guān)系的標(biāo)簽共現(xiàn)嵌入。
在GAT獲取標(biāo)簽共現(xiàn)嵌入模塊,經(jīng)過(guò)兩層GAT的計(jì)算,可以得到一個(gè)帶有類別標(biāo)簽間非對(duì)稱關(guān)系,維度為RC×d″的標(biāo)簽共現(xiàn)嵌入Zl+2。每一層的GAT計(jì)算為
Zl+1=fGAT(Zl,U)+Zl
其中,fGAT表示一層GAT計(jì)算,U∈RC×C表示標(biāo)簽節(jié)點(diǎn)從標(biāo)簽類別共現(xiàn)圖中獲得的相關(guān)矩陣建立方式與ML-GCN中相同,U中元素uij取值取決于類別標(biāo)簽i與類別標(biāo)簽j之間的共現(xiàn)次數(shù),為了能更好地將上一層的信息傳遞到下一層,因此在計(jì)算時(shí)將加上之前一層GAT的計(jì)算結(jié)果Zl。
對(duì)于一張多標(biāo)簽圖像樣本,本模型使用的多標(biāo)簽分類損失函數(shù)(Multi-label Soft Margin Loss):
實(shí)驗(yàn)所采用的軟硬件環(huán)境為Intel Pentium G4560 @ 3.50 GHz,NVIDIA GeForece GTX 1080Ti 11 GB顯卡,12 GB內(nèi)存,操作系統(tǒng)為Ubuntu 16.04,編程語(yǔ)言為Python,深度學(xué)習(xí)框架為Pytorch 1.5。
ML-GAT在兩種常用多標(biāo)簽圖像數(shù)據(jù)集上進(jìn)行對(duì)比實(shí)驗(yàn),分別是:Microsoft COCO 2014(MS-COCO 2014)[26]和PASCAL Visual Object Classes Challenge(VOC 2007)[27]。MS-COCO 2014擁有80個(gè)類別的多標(biāo)簽圖像,包含82 081張圖像組成的訓(xùn)練集和 40 504張圖像組成的驗(yàn)證集,平均每張圖像都擁有2.9個(gè)類別標(biāo)簽。VOC 2007數(shù)據(jù)集包含9 963張圖像組成的訓(xùn)練集、驗(yàn)證集和測(cè)試集,包含20個(gè)常見(jiàn)物體類別標(biāo)簽。
在ML-GAT中,將維度為RC×L的高維BERT標(biāo)簽語(yǔ)義嵌入矩陣Z0輸入到詞嵌入降維模塊,預(yù)訓(xùn)練BERT標(biāo)簽次嵌入矩陣維度L取值為1 024,經(jīng)過(guò)一層卷積核長(zhǎng)度4寬度為1的卷積層進(jìn)行下采樣,水平步長(zhǎng)為4垂直步長(zhǎng)為1,得到低維標(biāo)簽語(yǔ)義嵌入表示矩陣Zl∈RC×d′,d′此時(shí)為256,將Zl輸入GAT獲取標(biāo)簽共現(xiàn)嵌入模塊,經(jīng)過(guò)兩層GAT計(jì)算得到標(biāo)簽共現(xiàn)嵌入Zl+2∈RC×d″。為了將標(biāo)簽共現(xiàn)嵌入應(yīng)用在圖像特征上,將多標(biāo)簽圖像解析度設(shè)置為448×448,將其輸入圖像通用特征提取模塊,得到多標(biāo)簽圖像特征圖F∈RW×H×D,D為2 048,W、H均為14,針對(duì)VOC 2007數(shù)據(jù)集模型,采用的卷積層conv1不改變其W、H,使W′、H′與W、H相等。MS-COCO 2014數(shù)據(jù)集中W,H通過(guò)長(zhǎng)寬為5卷積核的conv1計(jì)算得到值均為10的W′、H′,在兩種數(shù)據(jù)集上均經(jīng)過(guò)長(zhǎng)寬為1的卷積核的卷積層conv2,對(duì)特征圖通道數(shù)D進(jìn)行降維,得到F″∈RW′×H′×d″,最后采用池化核大小為W′×H′的全局最大值池化層GMP得到維度為Rd″的圖像特征向量x,d″是圖像特征向量x的維度,同時(shí)也是標(biāo)簽共現(xiàn)嵌入的列維度,在VOC 2007數(shù)據(jù)集上的取值分別為{300,512,768},而在MS-COCO 2014數(shù)據(jù)集上設(shè)置d″為{1 024,1 280,1 536},d″參數(shù)設(shè)置由參數(shù)搜索和數(shù)據(jù)集中的標(biāo)簽類別標(biāo)簽個(gè)數(shù)共同決定,參數(shù)搜索策略為試錯(cuò)法,由于MS-COCO 2014所含有的類別標(biāo)簽數(shù)是VOC 2007中所含類別標(biāo)簽數(shù)的4倍,因此d″也同步增加。設(shè)置初始學(xué)習(xí)率為0.005,采用隨機(jī)梯度下降作為優(yōu)化器,權(quán)重衰減設(shè)置為10-4,動(dòng)量設(shè)置為0.9,總共訓(xùn)練100輪。
測(cè)試采用的評(píng)價(jià)指標(biāo)有:平均每類精度(CP)、平均每類召回率(CR)和平均每類F1(CF1)值。另外針對(duì)整體分類結(jié)果使用平均整體精度(OP),平均整體召回率(OR),平均整體F1(OF1)進(jìn)行評(píng)價(jià)。針對(duì)每個(gè)類別的分類準(zhǔn)確度,取平均值得到平均精度均值(mAP)[28],評(píng)價(jià)指標(biāo)定義如下:
在MS-COCO 2014數(shù)據(jù)集的實(shí)驗(yàn)中,因?yàn)閷?shí)驗(yàn)設(shè)備條件有限,且數(shù)據(jù)集中樣本相對(duì)較多,故進(jìn)行實(shí)驗(yàn)時(shí),采用隨機(jī)抽取部分訓(xùn)練樣本用作訓(xùn)練模型,再將訓(xùn)練出的模型在全部測(cè)試樣本上進(jìn)行測(cè)試的方法。對(duì)于ML-GCN和ResNet-101進(jìn)行同樣的采樣、訓(xùn)練、測(cè)試方法進(jìn)行實(shí)驗(yàn),在MS-COCO 2014訓(xùn)練樣本列表中采用Python的Random模塊,從82 081張訓(xùn)練樣本中隨機(jī)抽取4 000個(gè)樣本,采樣3次,訓(xùn)練出3個(gè)模型分別測(cè)試,對(duì)所有測(cè)試產(chǎn)生的評(píng)價(jià)指標(biāo),取3次測(cè)試的均值作為實(shí)驗(yàn)結(jié)果。VOC 2007數(shù)據(jù)集采用全部訓(xùn)練樣本和測(cè)試樣本進(jìn)行實(shí)驗(yàn)。實(shí)驗(yàn)結(jié)果中各評(píng)價(jià)指標(biāo)中最佳值均已加粗。
ML-GAT在VOC2007上的測(cè)試結(jié)果如表1所示,經(jīng)過(guò)與近幾年來(lái)的主流深度多標(biāo)簽圖像分類模型進(jìn)行對(duì)比(實(shí)驗(yàn)數(shù)據(jù)來(lái)源中除ResNet-101、ML-GAT,其他方法數(shù)據(jù)均來(lái)自各論文中給出的測(cè)試結(jié)果),在d″設(shè)置為512的情況下,ML-GAT在mAP這一指標(biāo)上達(dá)到了94.3,在14個(gè)類別的分類上為最佳值。在MS-COCO 2014數(shù)據(jù)集上ML-GAT的測(cè)試結(jié)果如表2所示,此時(shí)d″設(shè)置為1 280,在所有標(biāo)簽上的預(yù)測(cè)與前3個(gè)標(biāo)簽上的預(yù)測(cè)結(jié)果中,有7個(gè)主要分類指標(biāo)超過(guò)或持平ML-GCN,說(shuō)明ML-GAT模型可以在多個(gè)常用數(shù)據(jù)集上取得較好的分類結(jié)果。
表1 在VOC 2007上的實(shí)驗(yàn)結(jié)果Table 1 Experimental results on VOC 2007
表2 在MS-COCO 2014上的實(shí)驗(yàn)結(jié)果Table 2 Experimental results on MS-COCO 2014
為了比較不同數(shù)據(jù)集上標(biāo)簽共現(xiàn)嵌入列維度d″對(duì)分類性能的影響,分別對(duì)兩個(gè)數(shù)據(jù)集設(shè)置不同的d″進(jìn)行對(duì)比實(shí)驗(yàn),如圖2所示,在VOC 2007數(shù)據(jù)集中,d″取值為512時(shí),ML-GAT在mAP評(píng)價(jià)指標(biāo)上達(dá)到最佳,在MS-COCO 2014數(shù)據(jù)集上進(jìn)行一次采樣測(cè)試,d″大小為1 280時(shí)得到最佳mAP,這說(shuō)明MS-COCO 2014數(shù)據(jù)集中的標(biāo)簽類別更多,標(biāo)簽共現(xiàn)嵌入中冗余部分較少。而VOC 2007因?yàn)闃?biāo)簽類別較少,因此標(biāo)簽共現(xiàn)嵌入冗余部分較多。通過(guò)在這兩種數(shù)據(jù)集上進(jìn)行對(duì)比實(shí)驗(yàn),驗(yàn)證了ML-GAT在標(biāo)簽共現(xiàn)嵌入降維,與對(duì)標(biāo)簽之間非對(duì)稱關(guān)系的提取上采取的策略是有效的。
(a) VOC 2007
(b) MS-COCO 2014
圖卷積神經(jīng)網(wǎng)絡(luò)與CNN結(jié)合的深度多標(biāo)簽圖像分類模型ML-GCN在多標(biāo)簽圖像的分類上取得了很好的效果,但是ML-GCN中通過(guò)GCN獲取到的標(biāo)簽共現(xiàn)嵌入維度過(guò)高,標(biāo)簽共現(xiàn)嵌入沒(méi)有很好的反應(yīng)標(biāo)簽之間非對(duì)稱關(guān)系,針對(duì)ML-GCN存在的這兩點(diǎn)不足,提出一種基于圖注意力網(wǎng)絡(luò)的多標(biāo)簽圖像分類模型ML-GAT。ML-GAT通過(guò)對(duì)輸入GAT的高維標(biāo)簽語(yǔ)義嵌入表示矩陣進(jìn)行降維,解決了ML-GCN利用GCN獲取標(biāo)簽共現(xiàn)嵌入時(shí),冗余部分降低模型分類準(zhǔn)確度問(wèn)題,同時(shí)GAT可以對(duì)標(biāo)簽鄰居之間計(jì)算不同注意力系數(shù),學(xué)習(xí)標(biāo)簽之間非對(duì)稱關(guān)系特征,促進(jìn)模型分類。通過(guò)在主流數(shù)據(jù)集上與多標(biāo)簽深度學(xué)習(xí)經(jīng)典模型進(jìn)行對(duì)比實(shí)驗(yàn),ML-GAT模型在多標(biāo)簽圖像分類主要評(píng)價(jià)指標(biāo)上,相較經(jīng)典深度多標(biāo)簽圖像分類模型有一定的改進(jìn),實(shí)驗(yàn)證明了ML-GAT模型的有效性。