祁生勇,臧月進,呂國云,杜 明
(1.西北工業(yè)大學電子信息學院,陜西西安 710072;2.上海機電工程研究所,上海 201109)
2014年,Goodfellow 等[1]提出無監(jiān)督生成對抗網(wǎng)絡(luò)(generative adversarial networks,GAN),得益于GAN 巧妙的構(gòu)思,其在圖像生成領(lǐng)域取得了巨大的成功。
2015年,Denton 等[2]提出了拉普拉斯金字塔生成對抗網(wǎng)絡(luò)(Laplacian pyramid GAN,LAPGAN),該算法將拉普拉斯金字塔和條件生成對抗網(wǎng)絡(luò)[3]相結(jié)合,提升了生成圖像的質(zhì)量。2015年,Radford[4]提出了DCGAN,該模型結(jié)合卷積神經(jīng)網(wǎng)絡(luò)對于圖像特征的提取能力和GAN 對數(shù)據(jù)建模的能力,進一步提升了生成圖像的質(zhì)量。2017年,Arjovsky[5]等提出了WGAN(Wasserstein GAN),從理論層面分析了原始GAN 訓(xùn)練不穩(wěn)定的問題。 WGAN 使用Wasserstein 距離衡量真實數(shù)據(jù)分布和生成數(shù)據(jù)分布的距離,能夠在任何情況下為生成器提供梯度信息以更新參數(shù)。
目前,生成對抗網(wǎng)絡(luò)在圖像生成、圖像轉(zhuǎn)換、圖像超分辨率等[6]領(lǐng)域取得了巨大的成功??罩心繕朔N類繁多,并且各種機型姿態(tài)各異,公開的數(shù)據(jù)集較少,因此針對空中目標圖像生成的難度較大。本文基于DCGAN架構(gòu),通過優(yōu)化判別器損失函數(shù),提高了模型訓(xùn)練穩(wěn)定性,同時提高了生成圖像的質(zhì)量。
DCGAN 首次將GAN 和卷積神經(jīng)網(wǎng)絡(luò)結(jié)合起來,同時設(shè)計了一些優(yōu)化訓(xùn)練的技巧防止模式崩塌,一定程度上解決了原始GAN 訓(xùn)練不穩(wěn)定的問題,在MNIST[7]和LSUN[8]數(shù)據(jù)集中取得了較好的結(jié)果。
DCGAN 訓(xùn)練過程如圖1所示,首先隨機向量輸入生成器得到生成圖像,判別器負責區(qū)分圖像為真的概率。判別器輸出概率越接近1 說明輸入圖像越真實,接近0 則說明與真實圖像差距較大。生成器和判別器不斷地對抗迭代優(yōu)化,理論上,最終判別器無法區(qū)分輸入的是真實還是生成的圖像,對任何輸入得到的輸出都為0.5,這時模型就達到了最優(yōu),生成器完整捕獲了數(shù)據(jù)的真實分布。
圖1 DCGAN訓(xùn)練過程Fig.1 DCGAN training process
圖2 展示了DCGAN 生成器網(wǎng)絡(luò)結(jié)構(gòu),首先輸入大小為100且服從均勻分布的隨機向量z,接著將其映射為1 024 個4×4 大小的特征圖,特征圖通過4 個不同的步幅卷積步驟后得到大小為64×64 的彩色圖像G(z)。
圖2 DCGAN生成器網(wǎng)絡(luò)結(jié)構(gòu)Fig.2 DCGAN generator network structure
DCGAN 判別器網(wǎng)絡(luò)結(jié)構(gòu)如圖3所示,判別器由4個卷積層和1 個全連接層構(gòu)成,輸入真實圖像或生成圖像,輸出圖像為真的概率,若為真實圖像則概率接近1,若為生成圖像則概率接近0。同生成器一樣,除了輸入層,其他所有層都進行批歸一化[9]處理。
圖3 DCGAN判別器網(wǎng)絡(luò)結(jié)構(gòu)Fig.3 DCGAN discriminator network structure
DCGAN 的損失函數(shù)和原始GAN 的損失函數(shù)一樣,都為交叉熵損失函數(shù),如式(1)所示。
式中:z和x分別表示隨機向量和真實圖像;pdata(x)和pz(z)分別表示真實圖像和隨機向量的概率分布;G和D分別表示生成器網(wǎng)絡(luò)和判別器網(wǎng)絡(luò);E表示數(shù)學期望;V(D,G)表示目標函數(shù)。判別器D的目標是區(qū)分真實圖像和生成圖像:對于式(1)右側(cè)的第1 項,輸入是真實圖像,判別器D希望輸出的概率接近1;對于式(1)右側(cè)的第2 項,輸入為生成圖像G(z),判別器希望輸出趨近于0,取反之后也是越大越好,這就是max(D)的含義。生成器訓(xùn)練時,D保持不變,為了“欺騙”判別器,希望D(G(z))接近1,這時生成的圖像會更接近真實圖像,整體越小代表生成效果越好,這就是min(G)的含義。
空中目標種類繁多、姿態(tài)各異,圖像特征復(fù)雜,DCGAN 使用交叉熵損失函數(shù)易導(dǎo)致模型梯度消失,陷入局部最優(yōu)解,不能完整地捕捉數(shù)據(jù)真實的分布。圖4 給出了生成器數(shù)據(jù)分布與真實數(shù)據(jù)分布示意圖,藍色實線表示生成器捕捉的數(shù)據(jù)分布,黑色虛線表示真實的數(shù)據(jù)分布:(a)表示訓(xùn)練剛開始,兩者距離較遠;(b)表示隨著模型不斷地訓(xùn)練,生成器學習到的數(shù)據(jù)分布向真實分布靠近;(c)表示理想狀況下生成器完全學習到了真實分布。
圖4 生成器數(shù)據(jù)分布與真實數(shù)據(jù)分布示意圖Fig.4 Schematic diagram of generator data distribution and real data distribution
DCGAN 損失函數(shù)為交叉熵損失,生成器很容易收斂到局部最優(yōu),參數(shù)無法更新,最終生成器的數(shù)據(jù)分布與真實數(shù)據(jù)分布如圖4(b)所示,這會導(dǎo)致生成圖像有偽影、圖像模糊等問題,因此需要對損失函數(shù)進行優(yōu)化。
基于第1 章的分析,DCGAN 的損失函數(shù)包含最大化判別器和最小化生成器。Goodfellow[1]等證明,當損失函數(shù)為交叉熵損失時,假設(shè)最優(yōu)的判別器固定,生成器G更新的目標如式(2)所示。
式中:pdata(x)和分別代表真實數(shù)據(jù)的分布和生成器的數(shù)據(jù)分布;x和分別表示真實圖像和生成圖像;LG表示生成器損失函數(shù)。對式(2)進一步化簡得到式(3)。
所以DCGAN 的生成器損失函數(shù)近似于最小化pdata(x)和之間的JS 散度(Jensen-Shannon divergence)。Arjovsky[5]指出,若pdata(x)和不重疊,則兩個分布之間的JS 散度為常數(shù)2 lg2,因此生成器的梯度將變?yōu)?,而WGAN 中使用Wasserstein 距離去判斷兩者之間的距離,每次更新判別器權(quán)重時都強制映射到一個區(qū)間,保證了反向傳播過程永遠有梯度信息,兩者之間的關(guān)系如圖5所示。
圖5 WGAN梯度示意圖Fig.5 Schematic diagram of WGAN gradient
針對原始GAN 存在的問題,WGAN 使用Wasserstein 距離反映兩個分布之間的距離,如式(4)所示。
式中:γ表示pdata(x)和聯(lián)合分布;Π表示所有聯(lián)合分布的集合;表示兩個數(shù)學分布的Wasserstein距離。
對于每個可能的聯(lián)合分布γ,從中采樣得到真實圖像x和生成圖像,計算該聯(lián)合分布下的期望值,在所有可能的期望值中取下界,就得到了Wasserstein 距離。相較于JS 散度,即使兩個分布不重疊,Wasserstein距離也能用來衡量兩者之間的距離關(guān)系。
雖然Wasserstein 距離更準確地度量了生成圖像和真實圖像的分布距離,但是式(4)無法直接求解,因此進一步化簡為式(5)。
式中,D∈1-Lipschitz 表示判別器D滿足1-Lipschitz連續(xù)性條件,所以每次判別器權(quán)重參數(shù)會被強制截斷到[-0.01,0.01]之間。
總體來說,原始GAN 中交叉熵損失函數(shù)不能很好地判別兩個分布的距離。所以,WGAN 提出使用Wasserstein 距離衡量兩個分布之間的距離,無論兩者距離多遠,都能提供有效的梯度信息以更新網(wǎng)絡(luò)參數(shù)。
為了更清楚地說明判別器的作用,將式(1)中判別器的部分改寫為式(6)。
式中,LD表示判別器損失。根據(jù)2.2節(jié)的分析,為了保證模型訓(xùn)練過程中梯度不為0,WGAN 在梯度更新時把判別器權(quán)重截斷在[-0.01,0.01]之間,使其滿足Lipschitz連續(xù)性條件。但是強制性截斷處理會使得判別器丟失一部分圖像信息,只能學習到一個簡單的分布,這時判別器對復(fù)雜的飛機圖像特征分辨能力較弱,因此本文對WGAN 判別器的權(quán)重參數(shù)使用梯度懲罰[10]替代強制性截斷,同時滿足Lipschitz 連續(xù)性條件,如式(7)所示。
式(7)中:y表示在x和的連線上隨機插值采樣得到的一個新樣本;ppenalty表示新樣本的分布;λ表示懲罰系數(shù)。進一步化簡可以得到式(8)。
梯度懲罰過程如下:首先隨機選擇一個真實樣本x~pdata(x)和一個生成樣本;然后在x和的連線上隨機插值采樣得到的一個新的樣本y~ppenalty。最后一項懲罰項表示:判別器D對采樣得到的樣本y求梯度,梯度大于1 的時候,懲罰項會使梯度信息接近1,這樣就會將梯度限制在一定范圍,pdata(x)和的距離也會越來越近,生成的樣本越來越符合真實場景。
最終加入懲罰項的判別器損失函數(shù),如式(9)所示。
1)實驗環(huán)境配置及數(shù)據(jù)集。實驗環(huán)境基于酷睿i7 處理器與英偉達GTX1080Ti GPU 環(huán)境以及Pytorch 1.4.0深度學習框架。
構(gòu)建空中機動目標數(shù)據(jù)集,該數(shù)據(jù)集包含6 800張各種類型的空中目標圖片,網(wǎng)絡(luò)模型的訓(xùn)練集與交叉驗證集包含4 760 張圖片,測試數(shù)據(jù)集包含2 040 張空中目標圖片。
2)訓(xùn)練方式設(shè)計。由于網(wǎng)絡(luò)參數(shù)較多,為了防止模型對于訓(xùn)練數(shù)據(jù)集過擬合,在模型訓(xùn)練時使用Dropout[11]技術(shù),隨機固定某些參數(shù)不更新。
生成器和判別器組成的DCGAN 網(wǎng)絡(luò)體現(xiàn)的是一種相互對抗學習的過程,如果判別器訓(xùn)練效果足夠好,生成器梯度會消失;判別器訓(xùn)練效果不好,生成器梯度又會不夠準確。為了權(quán)衡兩者的關(guān)系,在訓(xùn)練過程中,判別器參數(shù)更新多次,生成器參數(shù)更新一次。
3)實驗參數(shù)設(shè)置。本文所有的訓(xùn)練過程中模型的學習率均為0.000 2;batch size 為64;梯度懲罰λ=10;生成器和判別器參數(shù)更新次數(shù)比例為1∶5,即判別器參數(shù)更新5次,生成器參數(shù)更新1次。
為評價改進后算法的性能,從模型訓(xùn)練過程穩(wěn)定性以及生成圖像FID[12]和IS[13]得分兩方面進行分析比較。改進前后生成器損失函數(shù)對比如圖6所示。
圖6 改進前后生成器損失函數(shù)對比Fig.6 Comparison of generator loss function before and after improvement
訓(xùn)練穩(wěn)定性評估:為了評估改進DCGAN 訓(xùn)練過程的穩(wěn)定性,本文采用相同的數(shù)據(jù)集對改進前后的損失函數(shù)進行比較,由圖6可知,改進后模型的生成器損失函數(shù)整體波動更小,訓(xùn)練過程更穩(wěn)定。
圖像生成質(zhì)量評估:FID 分析于2017年被提出,該方法首先把真實圖像和生成圖像輸入訓(xùn)練好的分類模型中(如Inception Net-V3網(wǎng)絡(luò)),去除了最后的池化層得到一個高維特征向量,通過計算生成圖像和真實圖像高維特征向量的距離,就可以得到FID 分數(shù),F(xiàn)ID 越小,表示生成的圖像質(zhì)量越好、多樣性越好,數(shù)學表示如式(10)所示。
式中:μ為經(jīng)驗均值;Σ為經(jīng)驗協(xié)方差;Tr為矩陣的跡;x代表真實圖像,g代表生成圖像。
IS 是另一個通用的GAN 模型評價標準,IS 評價的思路也是使用一個訓(xùn)練好的網(wǎng)絡(luò)對生成圖像進行分類,如果分類的準確性越高,說明生成的圖像越真實;同時生成每個種類圖像的概率越平均,說明模型生成圖像的多樣性越高。綜合得分越高表明生成器生成圖像質(zhì)量越好,數(shù)學表示如式(11)所示。
式中:表示生成圖像;m表示標簽信息;是這兩個分布的KL 散度,對其求指數(shù)就得到了最終IS分數(shù)。
本文使用空中機動目標數(shù)據(jù)集,基于Pytorch深度學習框架分別對原始的DCGAN和改進的DCGAN網(wǎng)絡(luò)進行訓(xùn)練,測試了32×32 和64×64 兩種分辨率的圖像。
圖6 為訓(xùn)練損失函數(shù)曲線,由圖6 可知,改進后的生成器損失函數(shù)波動明顯下降,訓(xùn)練過程更加穩(wěn)定。圖7(a)和圖7(b)分別為改進前后32×32 分辨率的圖像生成結(jié)果,可以看出,改進后的生成圖像的邊緣細節(jié)更清楚,圖像噪點明顯減少并且目標主體和背景區(qū)分明顯。圖8(a)表示真實圖像,圖8(b)和圖8(c)分別表示改進前后64×64分辨率的圖像生成結(jié)果。與圖7相比,當分辨率增大時,更多的圖像細節(jié)顯示了出來,如:直升機機翼、戰(zhàn)斗機尾翼以及起落架都更加清晰明顯。從圖8(b)可以看出改進前生成的飛機圖像輪廓不明顯,并且容易與背景混合,會出現(xiàn)很多不真實的紋理;從圖8(c)可以看出改進后的飛機主體與背景更容易區(qū)分,當飛機顏色與背景相近時也不會出現(xiàn)兩者混在一起的情況,生成圖像更加接近真實場景。
圖7 改進前后模型生成結(jié)果對比(32×32)Fig.7 Comparison of generated results(32×32)before and after improvement
圖8 真實圖像與改進前后模型生成結(jié)果對比(64×64)Fig.8 Comparison of real images and generated results(64×64)before and after improvement
綜上所述,改進后模型生成的圖像能夠展示更多細節(jié),飛機主體與背景差異顯著增強。通過觀察生成結(jié)果還可以看出,改進后模型生成的空中目標圖像與真實圖像更接近,虛假紋理減少,圖像邊緣細節(jié)更加豐富,可為空中目標檢測識別任務(wù)提供更強的數(shù)據(jù)支持。表1給出了兩種算法的FID和IS得分情況。
從表1 中可以看出,改進后的模型在32×32 分辨率下FID 和IS 得分分別提高了9.4%和7.6%;64×64分辨率下FID 和IS 得分分別提高了5.9%和4.8%。由此可得,改進后的模型生成圖像的質(zhì)量更高,圖像細節(jié)更豐富,生成種類更加多樣化,沒有出現(xiàn)原始DCGAN模式崩潰的情況。
表1 兩種算法FID和IS得分比較Tab.1 Comparison of FID and IS scores between two algorithms
本文提出一種改進的DCGAN 圖像生成算法。模型訓(xùn)練過程中使用改進的Wasserstein 距離衡量生成數(shù)據(jù)分布和真實數(shù)據(jù)分布,優(yōu)化了原始DCGAN 的判別器損失函數(shù),能夠在任何情況下為生成器提供梯度信息。實驗結(jié)果表明,針對空中目標數(shù)據(jù)集,改進后的模型訓(xùn)練過程更加穩(wěn)定,生成圖像更清晰,并且能夠根據(jù)需求生成不同分辨率的圖像,可以有效擴充空中目標檢測任務(wù)的數(shù)據(jù)樣本。