蔣博文
(安徽理工大學(xué) 計(jì)算機(jī)科學(xué)與工程學(xué)院,安徽 淮南 232001)
圖像分類是機(jī)器視覺(jué)研究熱點(diǎn)之一。顧名思義,圖像分類即給定輸入圖像,卷積神經(jīng)網(wǎng)絡(luò)對(duì)輸入進(jìn)行圖像預(yù)處理、特征圖特征提取以及使用分類器進(jìn)行分類,最終輸出預(yù)測(cè)類別標(biāo)簽,其中特征圖的有效信息提取是至關(guān)重要的一步。傳統(tǒng)的圖像分類算法提取圖像的色彩、紋理和角點(diǎn)等特征信息,其在早期較為簡(jiǎn)單的圖像分類任務(wù)中具有較好得表現(xiàn),但在復(fù)雜場(chǎng)景下卻不能滿足要求 。
注意力機(jī)制作為捕捉特征圖顯著特征、提高卷積神經(jīng)網(wǎng)絡(luò)特征提取能力的新方法。隨著現(xiàn)代科技的發(fā)展,海量復(fù)雜的信息不斷地向人們襲來(lái),信息無(wú)處不在。然而人類接受信息的能力是有限的,研究發(fā)現(xiàn)在人類接受視覺(jué)數(shù)據(jù)的初始,人類的視覺(jué)處理系統(tǒng)會(huì)快速地將自己的大部分注意力集中在場(chǎng)景中相對(duì)重要的區(qū)域上,這種選擇處理機(jī)制可以極大地減少人類視覺(jué)系統(tǒng)需要處理的數(shù)據(jù)量,并在復(fù)雜信息環(huán)境中,抑制不重要的視覺(jué)刺激,從而將更多的精力分配給現(xiàn)實(shí)場(chǎng)景中更重要的部分,提取更重要的信息以便于大腦進(jìn)行更高層次的決策。接觸人類視覺(jué)研究,研究者們提出了注意力機(jī)制的思想。對(duì)于現(xiàn)實(shí)中的事物其所具有的特征是不同的,在卷積神經(jīng)網(wǎng)絡(luò)中反映為每張?zhí)卣鲌D的差異性。注意力機(jī)制就是通過(guò)一系列手段捕捉每張?zhí)卣鲌D顯著特征的像素或通道信息,具體反映在將重要的通道或者像素信息的權(quán)重增大并抑制不重要的信息權(quán)重,
本文以圖像分類任務(wù)為載體,通過(guò)結(jié)合現(xiàn)有的SE 通道注意力機(jī)制,將其嵌入到原始的ResNet 網(wǎng)絡(luò)中,提高網(wǎng)絡(luò)特征提取能力,捕捉特征圖中的顯著特征信息。通過(guò)在CIFAR-10 和CIFAR-100 數(shù)據(jù)集上使用基準(zhǔn)網(wǎng)絡(luò)進(jìn)行了實(shí)驗(yàn),驗(yàn)證了其有效性。
深度神經(jīng)網(wǎng)絡(luò)的深度對(duì)于網(wǎng)絡(luò)性能的提升是最直接的方法,但實(shí)踐證明網(wǎng)絡(luò)并不是越深越好,這是由于隨著網(wǎng)絡(luò)層數(shù)的增加,在網(wǎng)絡(luò)回歸的過(guò)程中梯度消失的現(xiàn)象就會(huì)越來(lái)越明顯,相應(yīng)的網(wǎng)絡(luò)訓(xùn)練的效果也會(huì)變差。為了解決加深網(wǎng)絡(luò)深度帶來(lái)的梯度消失和網(wǎng)絡(luò)退化問(wèn)題,ResNet網(wǎng)絡(luò)應(yīng)運(yùn)而生。
如圖1所示,34 層的ResNet 網(wǎng)絡(luò)由一系列的殘差模塊、全連接層和下采樣層組成。殘差模塊分為恒等殘差和非恒等殘差模塊兩種,分別對(duì)應(yīng)著圖中快捷連接(shortcut connections)的實(shí)線和虛線兩種。恒等殘差模塊中的實(shí)線表示對(duì)于本殘差模塊的輸入和輸出特征圖通道數(shù)是相同的,可以直接進(jìn)行相加。非恒等卷積殘差塊中的虛線表示輸入和輸出的特征圖通道數(shù)是不同的,需要先通過(guò)1×1 的卷積改變通道數(shù),然后再相加。
圖1 ResNet34
如表1所示,本文列出18 層、34 層、50 層、101 層和152 層五種深度的原始ResNe 網(wǎng)絡(luò)結(jié)構(gòu),其中conv代表普通卷積、stride 代表步長(zhǎng)、Global average pool 代表全局平均池化、fc 代表全連接層。五種深度的原始ResNet 網(wǎng)絡(luò)性能隨著層數(shù)的增加而增加,同時(shí)計(jì)算量也隨之增加。
表1 不同深度的ResNet 網(wǎng)絡(luò)結(jié)構(gòu)配置
損失函數(shù)用于在訓(xùn)練過(guò)程中的模型反向傳播時(shí)計(jì)算模型預(yù)測(cè)值和真實(shí)標(biāo)簽值之間的不一樣程度,以便進(jìn)行梯度更新。原始ResNet 網(wǎng)絡(luò)中使用交叉熵函數(shù)(Cross Entropy)作為最終的損失函數(shù)值,即將輸入網(wǎng)絡(luò)中的每個(gè)樣本的交叉熵進(jìn)行加權(quán)平均,具體計(jì)算公式為:
SE 通道注意力機(jī)制如圖2所示,其中為輸入特征圖,為高,為寬,為通道數(shù),GAP 為全局平均池化操作,fc 為全連接層。先通過(guò)Squeeze 操作壓縮特征,沿著×方向進(jìn)行壓縮特征圖,用1 個(gè)實(shí)數(shù)表示×特征平面,某種程度上該實(shí)數(shù)具有一定的全局感受野;然后通過(guò)全連接層,實(shí)現(xiàn)對(duì)1×1×特征圖進(jìn)行跨信道信息交互,充分融合不同信道之間的信息;通過(guò)Sigmoid 函數(shù)獲得每個(gè)通道權(quán)重信息的一維向量,其代表著每個(gè)通道的重要性;最后使用一維特征向量對(duì)原特征圖進(jìn)行縮放。
圖2 SE 模塊
SE 通道注意力機(jī)制的核心思想在于通過(guò)全連接層和下采樣層構(gòu)建壓縮和激勵(lì)模塊以便于獲取特征圖通道權(quán)重信息,讓網(wǎng)絡(luò)學(xué)習(xí)特征圖中更重要的地方,放大顯著特征的權(quán)重的同時(shí)縮小不重要特征權(quán)重,從而使訓(xùn)練模型達(dá)到更好的效果。SE 通道注意力機(jī)制作為一種軟注意力機(jī)制,屬于一個(gè)即插即用模塊,可以無(wú)縫嵌入多種CNN 網(wǎng)絡(luò)中并進(jìn)行端到端訓(xùn)練,在模型參數(shù)和計(jì)算復(fù)雜度少量增加的前提下,大幅提升網(wǎng)絡(luò)性能。
為了評(píng)估不同深度得ResNet 網(wǎng)絡(luò)嵌入SE 通道注意力機(jī)制之后的效果,本文在CIFAR-100 和CIFAR-10 圖像分類數(shù)據(jù)集上進(jìn)行了實(shí)驗(yàn),CIFAR-10 數(shù)據(jù)集包含10 個(gè)類別,共有60 000張彩色圖片,尺寸大小為32×32 像素,有50 000 張訓(xùn)練圖像和10 000 驗(yàn)證圖像,每個(gè)類別包含6 000 張圖像。CIFAR-100數(shù)據(jù)集包含100 個(gè)類別,共有60 000 張彩色圖片,尺寸大小為32×32 像素,有50 000 張訓(xùn)練圖像和10 000 驗(yàn)證圖像,每個(gè)類別包含600 張圖像。本文在CIFAR-100 和CIFAR-10 驗(yàn)證集上統(tǒng)計(jì)Top-1 Error、Top-5 Error 和Top-1 Acc 并作為評(píng)價(jià)標(biāo)準(zhǔn)。Top-1 Error 是指取概率向量里面最大的作為最終預(yù)測(cè)結(jié)果,且預(yù)測(cè)結(jié)果和真實(shí)標(biāo)簽不同,Top-1 Acc 則是預(yù)測(cè)結(jié)果和真實(shí)標(biāo)簽相同。Top-5 Error 是取概率向量里面最大的前五位作為最終預(yù)測(cè)結(jié)果,且預(yù)測(cè)結(jié)果和真實(shí)標(biāo)簽都不同。
操作系統(tǒng)及環(huán)境:Ubuntu18.04、Python3.7、CUDA11.0、PyTorch1.7.1。
框架:PyTorch。
GPU:NVIDIA GeForce RTX 2080 Ti。
具體實(shí)驗(yàn)設(shè)置:在訓(xùn)練的過(guò)程中,將SGD 作為優(yōu)化器,訓(xùn)練動(dòng)量設(shè)置為0.9,訓(xùn)練權(quán)重衰減設(shè)置為5e-4,使用單GPU 進(jìn)行訓(xùn)練,批量大小為128,學(xué)習(xí)率初始值設(shè)置為0.1。所有模型均設(shè)置200 個(gè)epoch 進(jìn)行訓(xùn)練,使用等間隔調(diào)整學(xué)習(xí)率,初始學(xué)習(xí)率在第60、120、160 個(gè)epoch 乘以0.2。
本文使用ResNet-50 和ResNet-101 網(wǎng)絡(luò)為基準(zhǔn),評(píng)估了改進(jìn)的ResNet 算法模型在CIFAR-100 和CIFAR-10 數(shù)據(jù)集上的表現(xiàn)。
分別比較嵌入SE 通道注意力機(jī)制的SE-ResNet 和原始ResNet 在CIFAR-100 數(shù)據(jù)集上的Top-1 Error、Top-5 Error、參數(shù)量,結(jié)果如表2所示??梢杂^察到,相較于原始的ResNet,添加SE 模塊的ResNet 模型在不同的網(wǎng)絡(luò)深度上都有明顯的提升,錯(cuò)誤率都降低了一個(gè)百分點(diǎn)左右,而模型參數(shù)只增加了極小。特別的是,SE-ResNet50 Top-1 Error 和Top-5 Error 分別為20.29%和4.89%,相對(duì)于ResNet50 降低了8.40%和14.36%,比更深層次的ResNet101 錯(cuò)誤率還要低。
表2 在CIFAR-100 數(shù)據(jù)集上SE-ResNet 與原始ResNet比較
CIFAR-10 數(shù)據(jù)集上,由于CIFAR-10 數(shù)據(jù)集僅有10 個(gè)類別,所以模型之間錯(cuò)誤率差別不大。本文僅選取SEResNet50 和ResNet50 進(jìn)行比較,如表3所示。ResNet50的Top-1 Error 為4.88,SE-ResNet50 僅為4.39,相較于原始ResNet50 在Top-1 Error 上降低了10.04%,錯(cuò)誤率有大幅降低。
表3 在CIFAR-10 數(shù)據(jù)集上SE-ResNet 與原始ResNet錯(cuò)誤率比較
對(duì)于原始ResNet 網(wǎng)絡(luò)提取特征能力的不足,SE-ResNet通過(guò)使用下采樣層和池化層構(gòu)建的壓縮和激勵(lì)模塊可以有效地捕捉特征圖通道或像素的顯著信息,提高網(wǎng)絡(luò)的特征提取能力,并讓網(wǎng)絡(luò)關(guān)注更加重要的地方。
為了解決原始ResNet 網(wǎng)絡(luò)特征提取能力不足的問(wèn)題,本文結(jié)合現(xiàn)有注意力機(jī)制SENet,提出一種基于改進(jìn)ResNet模型的圖像分類方法。本研究將SE 模塊嵌入到原始ResNet網(wǎng)絡(luò)每個(gè)殘差結(jié)構(gòu)的末端,通過(guò)壓縮和激勵(lì)模塊對(duì)原始特征圖進(jìn)行跨通道信息交互,增強(qiáng)網(wǎng)絡(luò)特征提取和通道信息融合。實(shí)驗(yàn)結(jié)果表明,與原始ResNet 算法相比,嵌入SE 通道注意力機(jī)制的SE-ResNet 在CIFAR-100 和CIFAR-10 數(shù)據(jù)集上以增加少量的模型參數(shù)為代價(jià)獲得了更高的識(shí)別準(zhǔn)確率。后續(xù)的工作可以從自注意力機(jī)制、軟注意力機(jī)制和硬注意力機(jī)制等其他注意力機(jī)制入手,進(jìn)一步提高ResNet 網(wǎng)絡(luò)的精度和泛化能力。