張 兵,陳海燕,侯夏曄,袁立罡,劉振亞
1(南京航空航天大學 計算機科學與技術(shù)學院,南京211106)
2(南京航空航天大學 民航學院,南京211106)
3(軟件新技術(shù)與產(chǎn)業(yè)化協(xié)同創(chuàng)新中心,南京210023)
E-mail:chenhaiyan@nuaa.edu.cn
度量學習的目的是根據(jù)數(shù)據(jù)自身特點,學習一種最優(yōu)的度量方式來衡量樣本之間的相似性.傳統(tǒng)度量學習算法的提出,極大地改善了基于距離的機器學習算法的性能.近年來,隨著深度學習在視覺識別領(lǐng)域取得巨大成功,神經(jīng)網(wǎng)絡的端到端訓練和語義特征提取的優(yōu)勢被應用到度量學習中,形成了一種新的度量學習模式——深度度量學習[1].深度度量學習通過訓練一個神經(jīng)網(wǎng)絡將數(shù)據(jù)映射到一個嵌入空間中,在這個空間中,樣本越相似則其嵌入向量就越接近,越不相似則其嵌入向量就越遠離.相比于傳統(tǒng)的度量學習方法,深度度量學習最大的優(yōu)勢在于神經(jīng)網(wǎng)絡可以聯(lián)合學習特征表示和語義嵌入,因此被廣泛地應用到計算機視覺領(lǐng)域,例如圖像檢索[2]、人臉識別[3]、視覺跟蹤[4]、行人再識別[5]等.
損失函數(shù)在深度度量學習中起著至關(guān)重要的作用,根據(jù)計算損失時使用策略的不同,現(xiàn)有的深度度量學習損失大致可以分為兩類:基于對的損失(Pair-based Loss)和基于代理的損失(Proxy-based Loss).
基于對的損失函數(shù)旨在用一組成對的距離來表示兩個樣本之間的關(guān)系,如最早被提出的對比損失[6],目標是在嵌入空間中最小化同類樣本間的距離,而異類樣本則彼此推開.最近被提出的排序列表損失(Ranked List Loss,RLL)[7]也是一種基于成對約束建立的損失函數(shù),給定一個查詢點,對所有數(shù)據(jù)點進行相似度排序,獲得一個排序列表.該損失旨在探索基于集合的相似結(jié)構(gòu),相比基于點的方法能夠包含更豐富的信息.但是,它的這種相似結(jié)構(gòu)是基于每個訓練批次中所有的數(shù)據(jù)來建立的,當批數(shù)據(jù)量過大時,訓練復雜度高,收斂速度慢.
基于代理的損失函數(shù)通過為每個類分配一個代理來解決訓練復雜度高的問題.代理損失將每個數(shù)據(jù)點視為一個錨點,并約束錨點樣本更靠近同類代理點而遠離異類代理點.最近新提出的代理錨損失(Proxy Anchor Loss,PAL)[8]兼顧了代理損失和基于對的損失的優(yōu)點,它將每個代理作為錨點,并將同一類的數(shù)據(jù)拉近代理而其他類的數(shù)據(jù)盡量遠離代理.然而,代理錨損失在對正樣本對進行挖掘時試圖將同一類的正樣本壓縮到特征空間中的某個代理錨點,沒有考慮類內(nèi)數(shù)據(jù)的分布情況,這很容易造成同類樣本的相似結(jié)構(gòu)的丟失.
本文在RLL和PAL的啟發(fā)下,結(jié)合兩者的優(yōu)勢,提出了一種新的基于代理錨的排序列表損失函數(shù).該損失函數(shù)根據(jù)數(shù)據(jù)到給定代理錨點的距離,對所有樣本對進行排序得到一個排序列表,使所有的正樣本都排在負樣本之前,同時只強制正樣本到其同類代理錨點的距離小于閾值.通過這種方式盡可能地保留每個類內(nèi)部的相似結(jié)構(gòu),解決了代理錨損失中忽略類內(nèi)數(shù)據(jù)分布的問題.此外,該損失也具有代理錨損失的優(yōu)勢,訓練復雜度較低,收斂速度較快.最后在兩個標準數(shù)據(jù)集上驗證了該損失函數(shù)的有效性.
RLL是一種新的基于對的損失函數(shù),該方法考慮了一批數(shù)據(jù)中的所有正樣本和負樣本來構(gòu)建一個基于集合的相似結(jié)構(gòu).具體來說,給定一個查詢點,RLL根據(jù)相似性對所有數(shù)據(jù)點進行排序來獲得一個排名列表,將所有正樣本都排在負樣本之前,并且通過只強制正樣本對的距離小于閾值,來為每個類學習一個超球體,閾值是每個類的超球體的直徑,如圖1所示.這樣,RLL可以有效地保留每個類內(nèi)部的相似結(jié)構(gòu).另外,RLL用間隔m來分隔正集和負集.給定一個樣本xi,RLL的目標是把它的負樣本推到邊界α以外,把它的正樣本拉到邊界α-m以內(nèi),具體形式見式(1):
圖1 排序列表損失;圓A表示一個錨點,它們不同的形狀表示不同的類別Fig.1 Ranked list loss;the circle A indicates an anchor,their different shapes represent distinct classes
Lm(xi,xj;f)=(1-yij)[α-dij]++yij[dij-(α-m)]+
(1)
(2)
dij=‖f(xi)-f(xj)‖2
(3)
f表示嵌入函數(shù),dij表示兩個樣本點之間的歐氏距離.
(4)
(5)
RLL通過公式(1)對具有非零損失的非平凡樣本對進行挖掘,即違反損失約束的樣本對,并通過公式(5)對非平凡的負樣本對進行不同程度的加權(quán),目的是利用所有的數(shù)據(jù)點學習一個基于集合的相似結(jié)構(gòu),使得正負樣本分離,查詢點與正樣本的距離要比負樣本更近,且兩者之間保持一個m的間隔.RLL最大的問題是,計算復雜度高,收斂速度慢,無法較好地應對數(shù)據(jù)量大的應用場景.
為了解決基于對的損失中訓練復雜度高的問題,基于代理的度量學習損失被提出.這類方法的思想是為訓練集中的每個類生成一個代理,來體現(xiàn)嵌入空間的全局結(jié)構(gòu),并在訓練過程中將每個數(shù)據(jù)點與代理相關(guān)聯(lián).由于代理的數(shù)量遠遠小于訓練數(shù)據(jù)的數(shù)量,有效降低了訓練的復雜度.代理NCA損失(Proxy-NCA Loss,PNL)[9]是第一個基于代理的損失函數(shù),它借鑒了近鄰成分分析(Neighbourhood Components Analysis,NCA)[10]的思想,希望錨點樣本與其同類代理點的距離更近而與其異類代理點的距離更遠.但這種方法也存在一個固有的局限性:由于每個數(shù)據(jù)點只與代理相關(guān)聯(lián),因此損失了基于對的損失函數(shù)中大量用到的數(shù)據(jù)關(guān)系.
PAL借鑒基于對的損失的思想,利用了數(shù)據(jù)之間的關(guān)系,克服了PNL固有的局限性,具體如圖2所示.PAL將每個代理作為一個錨點,并將其與批處理中的所有數(shù)據(jù)聯(lián)系起來,PAL損失函數(shù)見式(6):
圖2 代理錨損失;圓P表示代理,它們不同的形狀表示不同的類別Fig.2 Proxy anchor loss;the circle P indicates a proxy,their different shapes represent distinct classes
(6)
與PNL不同的是,PAL通過在損失中增加間隔而產(chǎn)生了類內(nèi)緊致性和類間可分離性,從而構(gòu)造了一個更有鑒別性的嵌入空間.但PAL也是將同一類正樣本壓縮到特征空間中的某個代理錨點附近,沒有考慮類內(nèi)數(shù)據(jù)的分布情況,這很容易造成同類樣本相似結(jié)構(gòu)信息的丟失.
本文結(jié)合了PAL和RLL的優(yōu)勢,提出了一種基于代理錨的排序列表損失(Ranked Proxy Anchor Loss,RPAL),同時兼顧類內(nèi)數(shù)據(jù)分布問題和訓練復雜度問題.和PAL一樣,RPAL從每個類中選擇一個代理作為錨點樣本,然后將整批數(shù)據(jù)與其關(guān)聯(lián)起來,對于正樣本對則約束其距離小于一定的閾值,以此來盡量保留類內(nèi)數(shù)據(jù)的相似性結(jié)構(gòu),具體如圖3所示.圖中每個形狀代表一個類別,中心的圓P表示該類的一個代理錨點,RPAL的目的是以代理錨點為中心,將批量中所有的正樣本拉到邊界α-m以內(nèi),并將所有的負樣本盡可能地推到邊界α以外,使正負集之間保持一定間隔m.
圖3 RPAL說明Fig.3 Illustration of RPAL
RPAL按照代理NCA中的標準代理分配設置為每一個類分配一個代理,本文參照了RLL中損失函數(shù)的形式,并將其引入到PAL中,得到RPAL的總損失函數(shù)見式(7):
(7)
Lm(x,p)=(1-y)[α-d(x,p)]++y[d(x,p)-(α-m)]+
(8)
(9)
其中,僅當x和p類別相同時,y=1,否則y=0.wij和RLL中的設置一樣,是負樣本對的權(quán)重,參數(shù)T控制著加權(quán)的程度.當T=0時,它平等地對待所有非平凡的反例,如果T=+∞,它將成為最難挖掘的反例.d(x,p)表示嵌入向量x與代理錨點p的余弦距離.在該損失中,通過公式(8)來對非平凡樣本對進行挖掘,即違反了式中的約束,具有非零損失的數(shù)據(jù)點.此外,通過公式(9)對每個批次中大量的非平凡負樣本根據(jù)其違反約束的程度來進行加權(quán).
在每次批訓練中,RPAL首先為每個類選擇一個代理作為錨點;然后,通過根據(jù)距離對批處理數(shù)據(jù)中所有其他數(shù)據(jù)點進行排序來獲得一個排名列表,優(yōu)化的目標是將所有的正例都排在負例之前.并且,為了保留每個類內(nèi)部的相似結(jié)構(gòu),在RPAL中將每個正樣本約束到邊界α-m以內(nèi),即為每個類以代理錨點為圓心,以α-m為半徑,學習一個超球體;對于負樣本,則將它們推到另一個邊界α以外,使得正集與負集之間間隔m.算法1描述了基于RPAL的深度度量學習算法RPAL-DMLA.
算法1.RPAL-DMLA
輸入:所有訓練圖像數(shù)據(jù);預訓練網(wǎng)絡參數(shù);損失函數(shù)的超參數(shù):α,m,T;
輸出:更新后的網(wǎng)絡參數(shù);
過程:
1.通過采樣器構(gòu)造小批量數(shù)據(jù),并將其輸入到網(wǎng)絡,得到一批嵌入向量X
2.為批中的每個類分配一個代理p,構(gòu)造代理集P,正代理集P+
3.計算所有嵌入向量x與代理p的距離d(x,p)
4.for eachx∈Xdo
5. 基于d(x,p)和公式(8)對正負樣本對采樣;
6. 根據(jù)公式(9)計算負樣本權(quán)重wij;
7. 根據(jù)公式(7)計算損失L(X)
8.endfor
9.梯度計算并反向傳播更新的網(wǎng)絡參數(shù)
10.結(jié)束
設M表示訓練樣本數(shù),C表示樣本的類別數(shù),該損失的復雜度是O(MC),因為它在批處理中將每個代理與所有正樣本和負樣本聯(lián)系起來進行比較.在公式(7)中,第1項求和公式約束的是代理錨與正樣本的距離,使得它小于閾值α-m,復雜度是O(MC);第2項求和公式旨在將負樣本與代理錨推開,使它們之間的距離大于閾值α,復雜度也是O(MC),因此RPAL總的計算復雜度是O(MC),與PAL相同,收斂速度得到了保證.
本文在兩個流行的標準圖像檢索數(shù)據(jù)集上對所提方法進行了實驗評估:
1)CUB-200-2011[11]數(shù)據(jù)集,擁有200種鳥類的11788張圖片,實驗中將前100個類的5864張圖片用于訓練,其他100個類的5924張圖片用于測試.
2)Cars-196[12]數(shù)據(jù)集,包含196個車型的16185張圖片,實驗中使用前98個類的8054張圖片進行訓練,其余98個類的8131張圖片用于測試.
為了與之前的工作進行公平的比較,本文采用了在ImageNet數(shù)據(jù)集上預訓練且進行了批標準化的GoogleNet V2(BN-Inception)[13]作為嵌入網(wǎng)絡.實驗中根據(jù)嵌入向量的維度,對最后一層全連接層的大小進行了修改,并用L2標準化對最后的輸出進行了歸一化處理.
在訓練過程中,對輸入圖像通過水平翻轉(zhuǎn)和隨機裁剪進行了數(shù)據(jù)增強,而在測試中只對輸入圖像進行中心裁剪,輸入圖像的默認大小設置為224×224.在所有實驗中,使用AdamW優(yōu)化器[14],權(quán)重衰減率設置為10-4,在Cars-196和CUB-200-2001數(shù)據(jù)集上進行了60代訓練,初始學習率設置為10-4,且在訓練時對于每一批次的輸入圖像進行隨機抽樣.
對于代理點的選擇,實驗中按照代理NCA[9]中的設置,為每個類指定一個代理,并使用正態(tài)分布初始化代理,以確保它們均勻分布在單位超球體上.通過超參數(shù)影響實驗找到超參數(shù)取值α=1.4,m=0.4,T=20.
在兩個標準數(shù)據(jù)集上將本文所提出的RPAL與以下方法進行了比較:Lifted Struct[15],N-pair-mc[16],Clustering[17],Proxy-NCA[9],MS[18],SoftTriple[19],HTL[20],RLL-H[7],Proxy-Anchor[8].使用Recall@K作為損失函數(shù)圖像檢索性能的評價指標,它是由K個最近鄰中至少存在一個正確的檢索樣本來確定的.兩個數(shù)據(jù)集上的對比結(jié)果如表1和表2所示.
表1 在CUB-200-2011數(shù)據(jù)集上Recall@K(%)的比較Table 1 Comparison of Recall@K(%)on the CUB 200-2011 datasets
表2 在Cars-196數(shù)據(jù)集上Recall@K(%)的比較Table 2 Comparison of Recall@K(%)on the Cars-196 datasets
表1和表2展示了本文的方法和其他方法在小數(shù)據(jù)集(CUB-200-2011和Cars-196)的比較結(jié)果.從這兩個數(shù)據(jù)集上的結(jié)果來看,本文的方法優(yōu)于被比較的其他方法,并且在Cars-196數(shù)據(jù)集上,Recall@1指標提高了1%,這驗證了本文所提出的損失函數(shù)的有效性.
為了進一步驗證RPAL的泛化性能,除了上述使用的BN-Inception網(wǎng)絡模型,在實驗中還采用了現(xiàn)有深度度量學習方法中較為流行的其他網(wǎng)絡架構(gòu)(GoogleNet[21],ResNet-50[22],ResNet-101[22])作為嵌入網(wǎng)絡在Cars-196數(shù)據(jù)集上進行了訓練.實驗結(jié)果如表3所示,在這3種不同類型的網(wǎng)絡架構(gòu)上RPAL的性能都要優(yōu)于PAL.
表3 在Cars-196數(shù)據(jù)集上不同嵌入網(wǎng)絡的比較Table 3 Comparison of different embedding networks on the Cars-196 datasets
4.4.1 樣本挖掘超參數(shù)的影響
為了研究超參數(shù)α的影響,實驗中將負樣本權(quán)重T和邊界m設置為:T=20,m=0.4,觀察α的不同取值對最終圖像分類結(jié)果的影響,在Cars-196數(shù)據(jù)集上進行的實驗結(jié)果如表4所示.
表4 α對圖像分類結(jié)果的影響Table 4 Impact of α on image classification results
從實驗結(jié)果可以看出,α對RPAL學習判別嵌入有較大的影響,因為α控制著負樣本被推開的程度.
為了分析m的影響,實驗中將設置:T=20,α=1.4,觀察m的不同取值對最終圖像分類結(jié)果的影響,在Cars-196數(shù)據(jù)集上進行的實驗結(jié)果如表5所示.
表5 間隔m對圖像分類結(jié)果的影響Table 5 Impact of margin m on image classification results
從表5中可以觀察到當m>0時,RPAL的性能表現(xiàn)要比m=0時提高了10%左右,這說明邊界m對于增強RPAL的泛化能力具有重要意義.
4.4.2 負樣本權(quán)重的影響
在3.1節(jié)提出的負樣本權(quán)重公式(9)中,T是控制對負樣本加權(quán)程度的參數(shù).通過在Cars-196數(shù)據(jù)集上進行實驗來評估不同的T值對圖像分類結(jié)果的影響,實驗中其他參數(shù)設置為:m=0.4,α=1.4,結(jié)果如表6所示.
表6 負樣本權(quán)重T對圖像分類結(jié)果的影響Table 6 Impact of negative sample weight T on image classification results
從表6中可以觀察到當T=0時,因為沒有對負例進行加權(quán),Recall@1的結(jié)果較差,但當T>0時,RPAL的性能相對穩(wěn)定,并在T=20時達到最佳.
4.4.3 批量大小的影響
批量大小決定了在每次迭代訓練時數(shù)據(jù)量的大小,這直接影響著挖掘非平凡例子的數(shù)量,因此批量大小在深度度量學習中是很重要的.為了研究批量大小對RPAL性能的影響,本節(jié)在3個標準數(shù)據(jù)集上觀察不同批量大小時Recall@1指標的變化.結(jié)果如表7所示,其中可以觀察到,隨著批量大小的增加,RPAL的性能逐漸提高,因為更大的批量有利于挖掘出更多的非平凡例子.當批量大小為180時,RPAL達到了最好的性能.
表7 批量大小對圖像分類結(jié)果的影響Table 7 Impact of batch size on image classification results
本文將RLL和PAL有效地結(jié)合在一起,提出一種新的基于代理錨的排序列表損失RPAL.它體現(xiàn)了這兩種損失方法的優(yōu)勢,既能夠像基于代理的損失一樣,實現(xiàn)快速可靠的收斂,且訓練的復雜度低,也能夠像基于對的排序列表損失一樣,考慮類內(nèi)數(shù)據(jù)分布,完整保留了類內(nèi)數(shù)據(jù)的相似結(jié)構(gòu).在兩個標準數(shù)據(jù)集的實驗結(jié)果表明,本文所提出的基于代理錨的排序列表損失的深度度量學習算法具有更好的圖像分類性能.