麻永田 齊晶 張秋實 羅大為 方建軍
[摘 要]? 為了提高小樣本學習的準確率和抗干擾能力,提出了一種基于二階統(tǒng)計量的小樣本學習模型,以CNN最后一層卷積輸出的一階特征向量為輸入,通過計算協(xié)方差矩陣和二階池化獲得具有較高區(qū)分度的二階統(tǒng)計量,采用奇異值(SVD)分解將二階特征映射到低維仿射子空間并據(jù)此分類。本算法在Omniglot和minilmageNet數(shù)據(jù)集上進行了測試,實驗結(jié)果表明,在minilmageNet上的5-way 5-shot模型準確率達到了73.6%,比Prototypical Networks高出5.4%,在Omniglot上的20-way 1-shot模型準確率則獲得了2.4%的提升,本算法性能優(yōu)于Prototypical Networks等算法。在異常值測試中,本算法也展現(xiàn)出比Matching Networks和Prototypical Networks算法更強的魯棒性。
[關(guān)鍵詞] 小樣本學習;協(xié)方差矩陣;二階統(tǒng)計量;低維仿射;SVD分解
[中圖分類號] TP 391.1? [文獻標志碼] A? [文章編號] 1005-0310(2021)04-0073-06
Research on? Few-shot Learning Algorithm Based on
Second-order Statistics
MA? Yongtian1, QI? Jing2, ZHANG Qiushi 1, LUO Dawei 1, FANG? Jianjun
(1.College of Urban Rail Transit and Logistics, Beijing Union University, Beijing 100101, China;2.Tourism College,
Beijing Union University, Beijing 100101, China)
Abstract:? To improve the accuracy and anti-interference ability of few-shot learning, this paper proposes a?few-shot learning model based on second-order statistics. In the model, CNN is used to extract features and its output of the last convolutional layer is obtained to compute high-resolution second-order features by means of covariance matrix and second-order pooling operation. Meanwhile, the obtained second-order features are mapped to low-dimensional affine subspace by operating singular value decomposition (SVD) for classification. The proposed model is tested on Omniglot and minilmageNet datasets. The results reveal that the performance of the proposed model is better than other models including Prototypical Networks. The accuracy of the 5-way 5-shot model on minilmageNet dataset reaches up to 73.6%, which is 5.4% higher than Prototypical Networks. The 20-way 1-shot model on Omniglot dataset gets 2.4% accuracy improvement. As for outlier test, the proposed model also shows stronger robustness than those of Matching Networks and Prototypical Networks.
Keywords: Few-shot learning;Covariance matrix;Second-order statistics;Low-dimensional affine;Singular value decomposition
0 引言
機器學習是一種需要大量數(shù)據(jù)驅(qū)動的科學方法,其相關(guān)研究已取得了很大成功。但是,對于小數(shù)據(jù)集或者弱標注的應(yīng)用場景,例如缺陷檢測、故障檢測等,深度學習就顯得捉襟見肘。近年來,小樣本學習作為一種新的機器學習方法被提出來,成為機器學習研究領(lǐng)域的熱點問題之一[1]。
與一階統(tǒng)計量相比,二階統(tǒng)計量能夠獲得更加豐富的特征表達。文獻[2]證明了在大規(guī)模目標識別中,使用二階統(tǒng)計量所表現(xiàn)出的性能要優(yōu)于使用一階統(tǒng)計量。文獻[3]在動作識別中使用高階特征量獲得更豐富的動作特征及其高階相關(guān)性,更好地區(qū)分動作屬性,一階特征則作為噪聲而被忽略。文獻[4]將二階統(tǒng)計量拓展到注意力機制中,研究表明二階統(tǒng)計量可以獲得層間特征的內(nèi)在相關(guān)性,這使得網(wǎng)絡(luò)能夠?qū)W⒂诟嗟男畔⑻卣?,增強分類學習能力。文獻[5]在詞袋模型中分別對一階、二階和三階統(tǒng)計量的性能進行評估,證明高階統(tǒng)計量具有更豐富的特征表達能力。二階統(tǒng)計量在語義分割[6]、物體檢測[7]及動作識別[8]等計算機視覺領(lǐng)域的研究中都表現(xiàn)出顯著的效果。
與常見的一維向量特征相比,協(xié)方差二階矩陣擁有行和列兩個方向的數(shù)據(jù)關(guān)聯(lián)性,比只有一個方向的一維向量特征蘊含更豐富的信息。因此,本文提出在小樣本學習模型中采用二階特征矩陣作為分類依據(jù)。在小樣本學習的相似匹配部分,一些模型直接將各類別的均值作為它們的原型表示[9],這種策略容易受到異常值的干擾。為了降低噪聲干擾,本文采用低維仿射子空間的策略對分類器進行建模。
1 基于二階統(tǒng)計量的小樣本學習
圖1是本文設(shè)計的網(wǎng)絡(luò)結(jié)構(gòu),它由特征提取和相似匹配兩部分組成。以卷積神經(jīng)網(wǎng)絡(luò)(Convolutional Neural Network, CNN)為主干網(wǎng)絡(luò),將其最后一層卷積輸出的特征圖進行協(xié)方差計算,獲取二階矩陣特征。在相似匹配部分,將特征映射到低維子空間進行處理,以增強模型的魯棒性,學習同類圖像之間的關(guān)系,實現(xiàn)圖像分類。
1.1 特征提取網(wǎng)絡(luò)
CNN被廣泛應(yīng)用在計算機視覺研究任務(wù)中,并不斷取得突破。研究表明,基于CNN的特征提取網(wǎng)絡(luò)能夠較好地提取圖像特征,并進行端到端的分類。本文采用圖2所示的特征提取網(wǎng)絡(luò),它是一個4階段的CNN網(wǎng)絡(luò):將輸入圖像喂入CNN網(wǎng)絡(luò),經(jīng)過4個卷積塊(每個卷積塊由核數(shù)為64的3×3卷積和一個2×2的Max Pooling組成,每次卷積前都進行BatchNorm處理,采用ReLU激活函數(shù))的下采樣處理,輸出特征圖。
1.2 二階統(tǒng)計量
把CNN的最后一層卷積輸出特征展開成一維向量作為輸入,通過協(xié)方差池獲取二階特征分布,捕獲了比一階更高的特征統(tǒng)計量,這種二階特征包含層間特征分布及其相關(guān)性,具有較強的類別區(qū)分能力。基于二階統(tǒng)計量的特征提取示意圖如圖3所示。
令xn∈RD表示圖像中的數(shù)據(jù)點,RD表示圖像,D表示圖像的維度,則圖像經(jīng)CNN最后一層卷積層輸出的特征圖可表示為式(1)。
其中,f(xn)表示CNN特征提取,即RD→RK,K表示特征圖的維度。φn表示特征圖上的第n個特征向量。N表示輸出特征圖上特征向量的數(shù)量,且滿足式(2)。
1.3 基于低維仿射子空間的分類器
Softmax憑借其優(yōu)異的性能被廣泛應(yīng)用于機器學習中。本文擬采用Softmax作為小樣本學習的分類器,如式(5)所示。
式(5)中,c表示support集的樣本類別,q表示query集的樣本類別。
由于小樣本學習的訓練樣本數(shù)量有限,若用每類樣本的特征向量均值作為類原型,使用直接度量計算進行匹配,會對異常點和噪聲過于敏感,如圖4(a)所示。因此,本文將二階特征映射到一個低維仿射子空間,然后與原二階特征做歐氏距離計算來進行匹配,如圖4(b)所示。
其中,Wc表示c類樣本的線性子空間,主要是通過奇異值分解(Singular Value Decomposition, SVD)[10]將c類樣本的二階特征矩陣進行分解,左奇異矩陣是原特征矩陣的線性子空間正交基,因此本文將其視為原特征矩陣映射的低維子空間,并借此求得fΘ(q)。
1.4 算法流程
令S表示支撐集(support sets),X表示支撐集中的一個圖像樣本,c1表示類別1,C表示類別數(shù)量。M表示查詢集(query sets)中每類圖像的數(shù)量?;诙A統(tǒng)計量的小樣本學習的算法流程如圖5所示。
2 實驗
為了測試基于二階統(tǒng)計量的小樣本學習算法的準確性和魯棒性,本文在不同的公開圖像數(shù)據(jù)集上對算法進行了對比實驗。
2.1 實驗設(shè)置
2.1.1 實驗環(huán)境
本文所有實驗均在Ubuntu 16.04系統(tǒng)下進行,選擇Pytorch深度學習框架,采用Python 3.5語言編譯,CPU型號為英特爾i7-9700,GPU型號為GeForce RTX 2080 Ti。
2.1.2 實驗數(shù)據(jù)集
為驗證基于二階統(tǒng)計量的小樣本學習算法的性能,本文選擇Omniglot和minilmageNet兩個數(shù)據(jù)集進行實驗[11-12]。
Omniglot是一個手寫字符識別的數(shù)據(jù)集,是最常用的小樣本數(shù)據(jù)集之一,該數(shù)據(jù)集包含5 050個字母,共計16 231 623個手寫字符。實驗將Omniglot數(shù)據(jù)集中圖像的大小調(diào)整到28×28并以90度的倍數(shù)旋轉(zhuǎn)來增加字符類,訓練episode設(shè)置為60個類別,每個類別包括5個query查詢樣本。
minilmageNet是大型圖像數(shù)據(jù)庫lmageNet的簡化版,相比于Omniglot,它具有更豐富的圖像信息。minilmageNet數(shù)據(jù)集包含60 000張84×84大小的彩色圖像,分為100個類別,每個類別中有600張圖像。實驗將minilmageNet數(shù)據(jù)集的100個類別進行了拆分,選擇其中的64個類別數(shù)據(jù)作為訓練集,16個類別作為驗證集,20個類別作為測試集。
2.1.3 實驗樣本
小樣本學習訓練集中包含了很多的類別,每個
類別中有多個樣本。在訓練階段,從訓練集中隨機抽取C種類別,每個類別K個樣本(共C×K個)作為支撐集;再從這C種類別剩余的數(shù)據(jù)中抽取一批(batch)樣本作為查詢集。
2.2 實驗結(jié)果分析
2.2.1 模型準確率分析
基于二階統(tǒng)計量的小樣本學習算法與Matching Networks、Prototypical Networks算法在Omniglot數(shù)據(jù)集上的分類任務(wù)的對比實驗結(jié)果見表1。
從實驗結(jié)果可看出,基于二階統(tǒng)計量的網(wǎng)絡(luò)(Second-order Networks)通過協(xié)方差池獲取二階特征分布,捕獲了圖像更高維的特征理解,通過在低維仿射子空間進行匹配計算的方法,充分利用了圖像的高維特征來擴大類間差異的優(yōu)點,具有較高的準確率。相比于Prototypical Networks算法,本文算法在20-way 1-shot中的準確率達到了98.4%,獲得了2.4%的提升;20-way 5-shot的準確率達到了99.7%,提升了0.8%。但是5-way 1-shot的準確率僅提升了0.5%,5-way 5-shot幾乎沒有得到提升。據(jù)分析認為,Omniglot是一個手寫字符數(shù)據(jù)集,圖像相對簡單,Prototypical Networks等算法已經(jīng)達到了一個較高的識別率,因此提升不明顯。
為了充分證明Second-order Networks在復雜圖像上的分類效果,本文還在minilmageNet數(shù)據(jù)集上進行了對比實驗,實驗結(jié)果如表2所示?;诙A統(tǒng)計量的算法在5-way 1-shot和5-shot中的準確率分別達到了52.3%和72.1%,相比于Prototypical Networks算法,分別提升了2.9%和3.9%,說明二階統(tǒng)計量在復雜圖像的分類任務(wù)中仍然可以有效提升小樣本學習的準確率。
為了證明加入低維仿射子空間進行匹配計算的有效性,本文還在minilmageNet數(shù)據(jù)集上進行了直接距離度量和通過仿射子空間進行距離度量的對比實驗,如表3所示。從實驗結(jié)果可知,加入仿射子空間后,模型的準確率在1-shot和5-shot中
分別獲得0.9%和1.5%的提升。這表明,相比利用歐氏距離計算方法的直接距離度量進行匹配,本文通過SVD將圖像特征映射到一個子空間,然后求得圖像特征間的相關(guān)關(guān)系并據(jù)此進行圖像匹配,能
2.2.2 魯棒性測試
深度學習方法的有效性依賴于高質(zhì)量的訓練數(shù)據(jù)集,當訓練集呈現(xiàn)顯著復雜噪聲、異常點入侵及類別不均衡等問題時,其有效性往往無法得以保證。為評估本文算法對于異常值的魯棒性,本實驗從數(shù)據(jù)集外隨機選取幾張圖像作為異常值插入支
持集中,對異常值圖像的選取和處理須遵循以下兩條規(guī)則:異常值的圖像數(shù)量不得超過標記類別的樣本數(shù)量;異常值圖像不屬于支持集中任何類別,但在訓練時將其隨機標記為支持集的某一類別。
本文采用5-shot對不同異常值進行測試,并以異常值數(shù)量為橫軸、模型準確率為縱軸將測試結(jié)果可視化,如圖6所示。從圖6中可看出,隨著插入異常值數(shù)量的增加,3種算法的準確率均出現(xiàn)了不同程度的下降,這說明3種算法都不可避免會受到異常值的干擾。但從下降幅度可知,本文算法的下降幅度比Matching Networks和Prototypical Networks算法要小,這是由于二階統(tǒng)計量具有較強的類別區(qū)分能力,為分類計算能夠提供更多的匹配計算的維度。因此,本文算法對于異常值干擾的魯棒性方面要強于Matching Networks和Prototypical Networks算法。
3 結(jié)束語
本文提出在小樣本學習算法中引入二階統(tǒng)計量,基于此方法,可以在深度神經(jīng)網(wǎng)絡(luò)學習的表示空間中充分利用每一類支持集中圖像的高階深度特征表示類別,并通過迭代訓練,使其在少量樣本的情況下獲得更好的分類效果。本文提出的方法在Omniglot和minilmageNet數(shù)據(jù)集上進行測試,其準確率均比Matching Networks和Prototypical Networks等算法要高,在minilmageNet數(shù)據(jù)集測試中的5-way 5-shot模型準確率達到了73.6%,比Prototypical Networks高出5.4%,在Omniglot數(shù)據(jù)集測試中的20-way 1-shot模型準確率則獲得了2.4%的提升。實驗結(jié)果表明,通過低維仿射子空間處理方法進一步提高了模型準確率;同時,基于二階統(tǒng)計量的小樣本學習算法具有更好的分類效果,且應(yīng)對異常值等噪聲的能力更強。
[參考文獻]
[1] 汪榮貴,鄭巖,楊娟,等.代表特征網(wǎng)絡(luò)的小樣本學習方法[J].中國圖象圖形學報, 2019, 24(9):1514-1527.
[2] LI P H, XIE J T, WANG Q L, et al. Is second-order information helpful for large-scale visual recognition? [C]//Proceedings of the IEEE International Conference on Computer Vision (ICCV). Venice:IEEE, 2017: 2070-2078.
[3] CHERIAN A, KONIUSZ P, GOULD S. Higher-order pooling of CNN features via kernel linearization for action recognition[C]// 2017 IEEE Winter Conference on Applications of Computer Vision (WACV). Santa Rosa:IEEE, 2017: 130-138.
[4] DAI T, CAI J, ZHANG Y B, et al. Second-order attention network for single image super-resolution[C]// 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). Long Beach:IEEE, 2019: 11065-11074.
[5] KONIUSZ P, YAN F, GOSSELIN P, et al. Higher-order occurrence pooling for Bags-of-Words: visual concept detection[J]. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2017, 39(2): 313-326.
[6] BAO L C, WU B Y, LIU W, et al. CNN in MRF: video object segmentation via inference in a CNN-based higher-order spatio-temporal MRF[C] //2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition(CVPR). Salt Lake City:IEEE, 2018: 5977-5986.
[7] KIM T, JEONG M, KIM S, et al. Diversify and match: a domain adaptive representation learning paradigm for object detection[C]//2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). Long Beach:IEEE, 2019: 12456-12465.
[8] CHOUTAS V, WEINZAEPFEL P, REVAUD J, et al. PoTion: pose motion representation for action recognition[C]//2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition(CVPR).Salt Lake City:IEEE, 2018: 7024-7033.
[9] SNELL J, SWERSKY K, ZEMEL R S, et al. Prototypical networks for few-shot learning[C]//Proceedings of the 31st International Conference on Neural Information Processing Systems. Long Beach:NIPS, 2017: 4077-4087.
[10] DADKHAH S, MANAF A A, HORI Y, et al. An effective SVD-based image tampering detection and self-recovery using active watermarking[J]. Signal Processing:Image Communication, 2014, 29(10): 1197-1210.
[11] LAKE B M, SALAKHUTDINOV R, TENENBAUM J B, et al. The Omniglot challenge: a 3-year progress report[J]. Current Opinion in Behavioral Sciences, 2019,29: 97-104.
[12] QIAO S Y, LIU C X, SHEN W, et al. Few-shot image recognition by predicting parameters from activations[C]// 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition(CVPR).Salt Lake City:IEEE, 2018: 7229-7238.
(責任編輯 白麗媛)