王宥翔
(鄭州中糧科研設(shè)計(jì)院 電氣所,河南 鄭州 450000)
超分辨率(Super Resolution)通過硬件或軟件提高原有圖像的分辨率。圖像超分辨率研究大體分為3類:基于插值、基于重建、基于學(xué)習(xí);在技術(shù)層面則分為超分辨率復(fù)原和超分辨率重建。超分辨率重建是通過一系列低分辨率的圖像生成一幅高分辨率的圖像過程。
超分辨率重建是用時間帶寬換取空間分辨率,實(shí)現(xiàn)時間分辨率轉(zhuǎn)換為空間分辨率。超分辨率重建各種算法的區(qū)別主要在于網(wǎng)絡(luò)構(gòu)建的思路不同,而相同思路建構(gòu)的網(wǎng)絡(luò)也存在細(xì)微的差別。超分辨率重建大部分使用單純的卷積神經(jīng)網(wǎng)絡(luò)(Convolutional Neural Networks,CNN)完成任務(wù),但是CNN網(wǎng)絡(luò)在池化層和平移不變性方面容易出現(xiàn)問題,文獻(xiàn)[1]揭示并分析了卷積神經(jīng)網(wǎng)絡(luò)在變換兩種空間表征(笛卡爾空間坐標(biāo)和像素空間坐標(biāo))時的常見缺陷。本文基于深度學(xué)習(xí)的方案,選擇更為優(yōu)秀的生成對抗網(wǎng)絡(luò)(Generative Adversarial Networks,GAN)進(jìn)行超分辨率重建。
生成模型泛指在給定一些隱含參數(shù)的條件下隨機(jī)生成觀測數(shù)據(jù)的模型,主要分為兩類:一是建立有確切數(shù)據(jù)的分布函數(shù)模型;二是在無需完全明確數(shù)據(jù)分布函數(shù)模型的條件下直接生成一個新樣本[2],如GAN(圖1)。GAN通過對抗的方式,同時訓(xùn)練生成器(generator)和判別器(discriminator),生成器用于生成假樣本,讓這個假樣本無限逼近真實(shí)樣本,判別器則需要盡量準(zhǔn)確地判斷輸入的是真實(shí)樣本還是由生成器自己生成的假樣本。
圖1 GAN結(jié)構(gòu)
GAN的主要結(jié)構(gòu)由一個生成模型G(generator)和一個判別模型D(discriminator)組成。輸入圖片之后,程序提取輸入的圖片,并采樣轉(zhuǎn)化成數(shù)據(jù)tensor,數(shù)據(jù)輸入到網(wǎng)絡(luò)中開始計(jì)算,然后生成器G和判別器D開始它們的零和最大最小博弈。簡單來說,通過生成器,低分辨率的圖像可以重建一張高分辨率的圖像,然后由判別器網(wǎng)絡(luò)判斷。當(dāng)生成器網(wǎng)絡(luò)的生成圖能夠很好地“騙”過判別器網(wǎng)絡(luò),使判別器認(rèn)為這個生成圖是原數(shù)據(jù)集中的圖像,這里超分辨率重構(gòu)的網(wǎng)絡(luò)的目標(biāo)就達(dá)成了。生成器與判別器的工作原理如圖2所示,數(shù)據(jù)傳遞如圖3所示。
圖2 生成器與判別器的工作原理
圖3 生成器與判別器的數(shù)據(jù)傳遞
總體來說,在GAN中二者互相博弈,生成器不斷生成并輸出假的數(shù)據(jù),并與訓(xùn)練集一同輸入判別器中進(jìn)行判斷,繼續(xù)優(yōu)化學(xué)習(xí)。在這個過程中,生成器和判別器反復(fù)博弈,共同進(jìn)化,最終達(dá)到超進(jìn)化,經(jīng)過有限次迭代之后輸出數(shù)據(jù)并轉(zhuǎn)化為新的圖像輸出[3]。圖4是SRGAN的網(wǎng)絡(luò)結(jié)構(gòu),比較直觀的描述了GAN在解決圖像超分辨率的網(wǎng)絡(luò)運(yùn)行思路。
圖4 SRGAN的網(wǎng)絡(luò)結(jié)構(gòu)
GAN模型本質(zhì)上是一個最大最小博弈。目標(biāo)函數(shù)為
minGmaxDV(G,D)=Ex~pr(x)[logD(x)]+Ez~pr(z)[log(1-D(G(z)))],
(1)
其中,E代表期望,x~pr(x)代表x服從pr(x)分布,z是隨機(jī)噪聲,服從z~pr(z)的分布。而如何得出這個結(jié)論,就要關(guān)系到生成器和判別器的網(wǎng)絡(luò)原理。
2.2.1 判別器
判別器是程序需要優(yōu)先訓(xùn)練的模型,使它能夠判別一個輸入數(shù)據(jù)是否真的來自真實(shí)數(shù)據(jù)集,如果返回值大于0.5就為真,小于0.5則為假??梢钥闯?使用最簡單的二分類就可實(shí)現(xiàn),這里使用交叉熵的方法[4]。
給定一個樣本(x,y),y∈{1,0},表示其來自生成器還是真實(shí)數(shù)據(jù)。對于輸入的x,判別器會返回一個y,y表示x屬于真實(shí)數(shù)據(jù)的概率,
P(y=1|x)=D(x),
(2)
反之,x屬于生成的圖像數(shù)據(jù)概率
P(y=0|x)=1-D(x)。
(3)
判別器的目的是最小化交叉熵,交叉熵的表達(dá)式是[5]
minD(-Ex~p(x)(ylogP(y=1|x))+(1-y)logP(y=0|x)),
(4)
帶入式(2)和式(3),得到
minD(-Ex~p(x)(yD(x)+(1-y)(1-D(x))))。
(5)
假設(shè)整個樣本數(shù)據(jù)里面真實(shí)圖像數(shù)據(jù)和生成器生成的圖像數(shù)據(jù)是等比例的,
(6)
得到
(7)
然后最小化最大化互換,同時把負(fù)號變?yōu)檎?
maxDEx~pr(x)(D(x))+Ex~pg(x)(1-D(x))。
(8)
如果x~pg(x),代表x是生成器生成的,而生成器又是滿足z~p(z)分布而生成的,再次替換可得
maxDEx~pr(x)(D(x))+Ez~p(z)(1-D(G(z))),
(9)
即所需求的目標(biāo)函數(shù)。
2.2.2 生成器
生成器是判別器訓(xùn)練完成后才開始訓(xùn)練的模型,作用是在給定輸入的情況下得到一定的輸出,然后繼續(xù)送給判別器判斷,之后返回給自身一個誤差值,從而繼續(xù)學(xué)習(xí)。
生成器的目標(biāo)剛好和判別器相反,即讓判別器把自己生成的樣本判別為真實(shí)樣本。因?yàn)镚AN網(wǎng)絡(luò)的本質(zhì)數(shù)學(xué)模型是一個最大最小博弈,通過判別器得到了目標(biāo)函數(shù),從而得到最大值max,所以生成器的目的就是得到最小值min[6]。目標(biāo)函數(shù)
maxDEx~pr(x)(D(x))+Ez~p(z)(1-D(G(z)))
(10)
由兩部分構(gòu)成,由后一部分可得生成器目標(biāo)
minGEz~p(z)(1-D(G(z)))。
(11)
將生成器與判別器的函數(shù)結(jié)合,即得到生成對抗網(wǎng)絡(luò)的模型,
minGmaxDV(G,D)=Ex~pr(x)(logD(x))+Ez~p(z)(log(1-D(G(z))))。
(12)
訓(xùn)練時的優(yōu)化需要引入生成對抗網(wǎng)絡(luò)的損失函數(shù),
LossG=log(1-D(G(z)))or-log(D(G(z))),LossD=-log(D(x))or-log(1-D(G(z))),
(13)
LossG=log(1-D(G(z)))or-log(D(G(z)))。
(14)
由生成器的目標(biāo)式得
minGEz~p(z)(1-D(G(z)))。
(15)
后面一部分是原作者Ian Goodfellow提出的,效果等同于優(yōu)化前面那個而且梯度性質(zhì)更好。
LossD=-log(D(x))-log(1-D(G(z))),
(16)
maxDEx~pr(x)(D(x))+Ez~p(z)(1-D(G(z)))。
(17)
2.4.1 判別器越好,生成器梯度消失越嚴(yán)重
在最優(yōu)判別器的條件下,最小化生成器的損失函數(shù)和最小化P1與P2之間的JS散度是等價的[7],
(18)
對于P1與P2來說是完全對稱的,JS是兩個KL散度的疊加(KL散度又稱相對熵),一定是大于等于0的,所以JS散度一定大于等于0。在這里可能會出現(xiàn)嚴(yán)重的問題:如果兩個分布沒有重疊的話,JS散度就為0,而在訓(xùn)練初期,兩個分布必然是基本不會重疊,所以假如在這里判別器被訓(xùn)練得過于好,損失函數(shù)就會經(jīng)常收斂到固定的-2 log 2,從而產(chǎn)生沒有梯度的情況。然后網(wǎng)絡(luò)就沒法繼續(xù)訓(xùn)練下去了,對抗網(wǎng)絡(luò)中的生成器和判別器是要一起進(jìn)化變強(qiáng)的,一個過于強(qiáng)將會導(dǎo)致另一個無法繼續(xù)訓(xùn)練[8]。
2.4.2 可能出現(xiàn)梯度不穩(wěn)定和模式崩潰
GAN采用的是對抗訓(xùn)練的方式,判別器的梯度更新來自判別器,生成一個樣本,交給判別器去評判,判別器會輸出生成的假樣本是真樣本的概率。生成器會根據(jù)這個反饋不斷改善。但假如有一次生成器生成的并不真實(shí),判別器卻出了問題,給了正確評價,或者在一次生成器生成的結(jié)果中存在某一些特征被判別器所認(rèn)可了,這時候生成器就會認(rèn)為這里的輸出反而是正確的,接下來繼續(xù)輸出相同的數(shù)據(jù)判別器就還會給出高的評分,最終就會導(dǎo)致生成結(jié)果中的一些重要信息或特征殘缺[9]。
首先需要生成器(G)生成圖片模型,判別器(D)判斷圖片是否為真,如圖5所示。
圖5 GAN網(wǎng)絡(luò)架構(gòu)
首先需要向生成器輸入一個噪聲,生成隨機(jī)數(shù)組,繼續(xù)輸出一個數(shù)據(jù)轉(zhuǎn)換為一張圖片,輸入圖片之后,經(jīng)過判別器來輸出是一個數(shù)1或者0,代表圖片是否是狗。
然后通過訓(xùn)練網(wǎng)絡(luò),把真圖與假圖拼接,打上不同的標(biāo)簽,真圖為1,假圖為0,送到網(wǎng)絡(luò)中訓(xùn)練。
3.2.1 數(shù)據(jù)輸入
聲明集合dataloader,將訓(xùn)練和測試數(shù)據(jù)都放入其中。
3.2.2 訓(xùn)練網(wǎng)絡(luò)
先重寫構(gòu)造函數(shù),構(gòu)造一個父類的函數(shù) “super”,然后定義網(wǎng)絡(luò)結(jié)構(gòu)block,運(yùn)用nn.sequential將多個函數(shù),如卷積函數(shù)Conv2d和激活函數(shù)PReLU,并列放置,經(jīng)過多個ResidualBlock殘差網(wǎng)絡(luò)模塊處理。采樣之后,進(jìn)入前向傳播forward函數(shù),最后經(jīng)過tanh函數(shù)映射到-1到1,最后得到一個0到1的數(shù)據(jù)輸出[10]。
判別器是一個二分類的模型,先重寫構(gòu)造函數(shù)構(gòu)造父類函數(shù),然后進(jìn)入多層的網(wǎng)絡(luò),在進(jìn)入一層池化層之后,取平均值下采樣,得到1×1的數(shù)據(jù),最后只得到batchsize的數(shù)據(jù),然后通過sigmoid函數(shù)將實(shí)數(shù)域映射到0~1,即batchsize的概率,符合判別器二分類概率的原理[11]。
通過優(yōu)化器進(jìn)行判別器的訓(xùn)練。首先為了優(yōu)化判別器,將其梯度歸零,然后規(guī)定判斷真實(shí)圖片和虛假圖片的概率,接著規(guī)定判別器的損失函數(shù),計(jì)算出d_loss,然后執(zhí)行上面的步驟。
訓(xùn)練生成器時,將生成器的梯度置零后,生成一個假的圖片,輸入判別器,得出判別器判斷為假的概率,輸入給生成器的損失函數(shù),計(jì)算得出g_loss,再反向傳播backward,最終運(yùn)行開始訓(xùn)練。
完整的網(wǎng)絡(luò)架構(gòu)中日志記錄以及數(shù)據(jù)輸入輸出可視化不再贅述,可將生成模型記錄保存在字典文件pth之中,以供之后的測試或者訓(xùn)練使用。
完成了GAN構(gòu)造并經(jīng)過訓(xùn)練之后,進(jìn)行網(wǎng)絡(luò)性能測試。筆者下載了超分辨率重構(gòu)的數(shù)據(jù)集,包含×4和×8的每個大約3 000張圖片的測試用數(shù)據(jù)集,數(shù)據(jù)集文件列表如圖6所示。
圖6 超分辨率重構(gòu)數(shù)據(jù)集
因?yàn)樯窠?jīng)網(wǎng)絡(luò)訓(xùn)練運(yùn)算量巨大,且需要占用大量內(nèi)存,所以這里將其放到訓(xùn)練試驗(yàn)機(jī)上,運(yùn)用4塊RTX 3090顯卡進(jìn)行訓(xùn)練。訓(xùn)練整體大概1 000個迭代epoch,最終得到兩個記錄模型權(quán)重的pth文件,這兩個權(quán)重文件可以直接輸入測試網(wǎng)絡(luò),以下通過幾個測試圖片檢測訓(xùn)練的結(jié)果。
測試所用的一組原圖Ground truth,如圖7所示?!?超分的測試結(jié)果如圖8所示。×8超分的測試結(jié)果如圖9所示??梢钥闯?在×8的超分上,如果細(xì)節(jié)比較小的話,得出的超分圖會比較邊緣性的模糊,×4的超分結(jié)果已經(jīng)比較理想。
圖7 原圖
圖8 ×4測試結(jié)果
圖9 ×8測試結(jié)果
整體來說,網(wǎng)絡(luò)訓(xùn)練結(jié)果比較理想,成功收斂且沒有出現(xiàn)梯度消失以及模式崩潰的情況。說明利用深度學(xué)習(xí)的神經(jīng)網(wǎng)絡(luò)中的GAN生成對抗網(wǎng)絡(luò),能夠?qū)崿F(xiàn)圖像超分辨率的目標(biāo)。