蔣光峰,胡鵬程,葉 樺,仰燕蘭
東南大學(xué)自動(dòng)化學(xué)院,南京210096
近年來(lái),由于結(jié)構(gòu)化和半結(jié)構(gòu)化數(shù)據(jù)的爆炸式增長(zhǎng),人們?cè)絹?lái)越重視對(duì)圖神經(jīng)網(wǎng)絡(luò)(graph neural network,GNN)的研究。來(lái)自生物信息學(xué)、化學(xué)信息學(xué)、社交網(wǎng)絡(luò)、城市計(jì)算和網(wǎng)絡(luò)空間安全等數(shù)據(jù)都可以自然地表示成標(biāo)簽圖數(shù)據(jù)。此外,企業(yè)常常希望以知識(shí)圖譜的形式存儲(chǔ)信息并應(yīng)用各種機(jī)器學(xué)習(xí)技術(shù)來(lái)充分挖掘數(shù)據(jù)價(jià)值。
深度學(xué)習(xí)的發(fā)展促進(jìn)了數(shù)據(jù)挖掘技術(shù)的更新,確切地說(shuō)是通過(guò)對(duì)卷積神經(jīng)網(wǎng)絡(luò)(convolutional neural network,CNN)和循環(huán)神經(jīng)網(wǎng)絡(luò)(recurrent neural network,RNN)結(jié)構(gòu)的復(fù)用,分別在兩種歐幾里德空間數(shù)據(jù)(如圖片和序列)上取得了較好的成績(jī),因此人們嘗試在非歐幾里德空間應(yīng)用CNN。以往的研究大多是重新定義卷積層和池化層來(lái)處理圖數(shù)據(jù)。
現(xiàn)有的圖神經(jīng)網(wǎng)絡(luò)(GNN)方法,通??煞譃榛谧V圖卷積和基于空間域卷積。典型的基于譜圖卷積有Spectral CNN、ChebNet和GCN等;基于空間域卷積有GraphSAGE、GAT等。
圖池化比圖卷積的方法要少很多,以往的方法只考慮網(wǎng)絡(luò)的拓?fù)浣Y(jié)構(gòu)。隨著人們對(duì)圖神經(jīng)網(wǎng)絡(luò)領(lǐng)域的研究加深,更多的圖池化方法被提出,如SortPool利用最后一維節(jié)點(diǎn)特征排序節(jié)點(diǎn);DiffPool引入基于分配學(xué)習(xí)的可微池化;SAGPool利用注意力機(jī)制學(xué)習(xí)節(jié)點(diǎn)的得分,并保留Top節(jié)點(diǎn)。
全局圖池化,又稱(chēng)圖讀出層(readout layer),目的是將圖映射到有限維的歐幾里德空間?;诮y(tǒng)計(jì)的方 法 有MeanPooling、MaxPooling 和SumPooling 等。這些方法簡(jiǎn)單有效,不會(huì)增加額外的參數(shù),但損失了大量的結(jié)構(gòu)信息。
上述圖卷積和池化方法設(shè)計(jì)的初衷是解決節(jié)點(diǎn)分類(lèi)、連接預(yù)測(cè)等圖節(jié)點(diǎn)級(jí)別問(wèn)題,對(duì)圖分類(lèi)等圖級(jí)別任務(wù)所需要的拓?fù)浣Y(jié)構(gòu)信息提取不充分。此外現(xiàn)有的圖讀出層設(shè)計(jì)過(guò)于簡(jiǎn)單,同樣損失了大量的拓?fù)浣Y(jié)構(gòu)信息。
本文提出一種基于重構(gòu)誤差的多重注意力機(jī)制同構(gòu)圖分類(lèi)模型(multi-heads attention wave graph isomorphic convolution based on reconstruction error,RMAWaveGIC),主要貢獻(xiàn)如下:
(1)提出同構(gòu)圖卷積WaveGIC。對(duì)比現(xiàn)有的算法(GCN、GraphSAGE),WaveGIC 能更有效地提取圖拓?fù)浣Y(jié)構(gòu)信息,更適合圖分類(lèi)任務(wù)。
(2)提出基于多重注意力機(jī)制的圖全局池化(圖讀出)方法,能更全面地表征整張圖,減少全局池化帶來(lái)的信息損失。
(3)提出基于重構(gòu)誤差的圖分類(lèi)訓(xùn)練算法,訓(xùn)練不僅考慮分類(lèi)器分類(lèi)性能,同時(shí)考慮圖卷積過(guò)程中對(duì)拓?fù)浣Y(jié)構(gòu)信息的損失。
實(shí)驗(yàn)表明,本文方法在現(xiàn)有的基準(zhǔn)數(shù)據(jù)集上取得了最好的識(shí)別準(zhǔn)確率。
近年來(lái)大量的圖神經(jīng)網(wǎng)絡(luò)模型被提出,包括圖卷積神經(jīng)網(wǎng)絡(luò)、圖殘差網(wǎng)絡(luò)以及圖循環(huán)神經(jīng)網(wǎng)絡(luò)。大部分方法都遵循Gilmer 等提出的神經(jīng)信息傳遞(neural message passing)框架,即節(jié)點(diǎn)被鄰近節(jié)點(diǎn)使用可微的聚合函數(shù)表達(dá)。文獻(xiàn)[18-19]詳細(xì)敘述了GNN 領(lǐng)域近年來(lái)的重要研究成果。文獻(xiàn)[20]分析了主流GNN對(duì)捕獲圖結(jié)構(gòu)的表達(dá)能力,并提出圖同構(gòu)網(wǎng)絡(luò)(graph isomorphism network,GIN)。
在CNN 中池化層也稱(chēng)下采樣卷積核,其操作與卷積操作基本相同,不同的是下采樣的卷積核只取對(duì)應(yīng)位置的最大值或平均值,并且在反向傳播時(shí)參數(shù)不需要更新。圖池化方法可以劃分成三種類(lèi)別:基于拓?fù)涑鼗?、全局池化和分層池化?/p>
基于拓?fù)涞某鼗缙谕ǔJ褂脠D粗化算法(graph coarsening algorithms)?;谧V聚類(lèi)算法使用特征分解來(lái)得到粗化的圖,但是特征分解的時(shí)間復(fù)雜度高,需要更快更簡(jiǎn)潔的算法。
全局池化方法考慮圖節(jié)點(diǎn)特征,使用求和、最大值、均值或神經(jīng)網(wǎng)絡(luò)等方法表達(dá)整張圖。Vinyals 等利用Set2Set方法獲取整張圖的表征,提出一種通用的圖分類(lèi)框架。SortPool 利用GCN 網(wǎng)絡(luò)提取的特征值排序節(jié)點(diǎn)并傳遞給下一層網(wǎng)絡(luò),首次實(shí)現(xiàn)了端到端訓(xùn)練的圖分類(lèi)模型。
全局池化方法無(wú)法學(xué)習(xí)到圖結(jié)構(gòu)信息及其關(guān)鍵的分層表征。現(xiàn)實(shí)應(yīng)用中,很多圖信息都是層級(jí)表征的,例如地圖、概念圖和流程圖等。分層池化方法的設(shè)計(jì)初衷是使網(wǎng)絡(luò)既能學(xué)習(xí)節(jié)點(diǎn)特征,也能學(xué)習(xí)圖的拓?fù)浣Y(jié)構(gòu)信息。可微圖池化方法DiffPool學(xué)習(xí)分類(lèi)矩陣,將節(jié)點(diǎn)映射到一組簇中,輸入到GNN 下一層。DiffPool 的空間復(fù)雜度高,當(dāng)處理大圖時(shí)對(duì)硬件要求較高。gPool 使用可學(xué)習(xí)的向量計(jì)算節(jié)點(diǎn)得分,并選擇得分排名靠前的節(jié)點(diǎn),數(shù)學(xué)表達(dá)式如式(1)~(3)。
其中,表征第層節(jié)點(diǎn)特征矩陣;表征第層的鄰接矩陣。gPool 雖解決了DiffPool 的空間復(fù)雜度問(wèn)題,但忽略了圖的拓?fù)浣Y(jié)構(gòu)信息。
SAGPool 在合理的空間和時(shí)間復(fù)雜度下同時(shí)考慮拓?fù)浣Y(jié)構(gòu)和特征信息生成分層表達(dá),式(4)~(9)是其數(shù)學(xué)表達(dá)。
其中,表征節(jié)點(diǎn)特征矩陣;表征第層的節(jié)點(diǎn)特征矩陣;表征第層的鄰接矩陣。
文獻(xiàn)[23]提出一種正則化對(duì)抗圖自編碼模型,該模型同時(shí)學(xué)習(xí)圖的拓?fù)浣Y(jié)構(gòu)和節(jié)點(diǎn)特征隱向量表示(embedding),并在此基礎(chǔ)上訓(xùn)練解碼器來(lái)重構(gòu)圖的結(jié)構(gòu)。若能從高層特征矩陣中恢復(fù)原始鄰接矩陣,說(shuō)明網(wǎng)絡(luò)能夠?qū)W習(xí)圖的拓?fù)浣Y(jié)構(gòu)信息。
使用(,)表示圖,其中∈{0,1}代表鄰接矩陣,∈R代表節(jié)點(diǎn)特征矩陣,每個(gè)節(jié)點(diǎn)有維特征。給定圖數(shù)據(jù)集D={(,),(,),…,(G,y)},其中y∈У 表示圖G∈G 的標(biāo)簽,圖分類(lèi)的目標(biāo)是學(xué)習(xí)映射:G →У 將圖映射到標(biāo)簽。和標(biāo)準(zhǔn)的有監(jiān)督學(xué)習(xí)相比,圖分類(lèi)的困難是在使用常規(guī)的機(jī)器學(xué)習(xí)方法(如SVM、DNN 等)分類(lèi)時(shí),需要從輸入圖中提取高效的有限維特征向量R。
同構(gòu)圖指的是圖中的節(jié)點(diǎn)類(lèi)型和關(guān)系類(lèi)型都僅有一種的圖。GNN 利用圖鄰接矩陣和節(jié)點(diǎn)特征矩陣X來(lái)學(xué)習(xí)節(jié)點(diǎn)的表征h或圖的表征h。GNN 遵循鄰居聚合策略,通過(guò)式(12)、式(13)聚合鄰節(jié)點(diǎn)的表示迭代更新當(dāng)前節(jié)點(diǎn)的表示。
其中,可設(shè)置成固定值或可學(xué)習(xí)的變量,表征中心節(jié)點(diǎn)的衰減程度。
本文提出一種能更好學(xué)習(xí)圖拓?fù)浣Y(jié)構(gòu)特征的卷積單元WaveGIC,其表達(dá)形式如式(20)、式(21)。
圖1 不同聚合策略的表達(dá)能力Fig.1 Expressive ability of different aggregation strategies
在COMBINE 階段,使用圖2 所示結(jié)構(gòu),門(mén)控激活單元使用tanh 和sigmoid 控制信息選通來(lái)調(diào)整單元狀態(tài)。ReLU 激活所引起的稀疏性適用卷積神經(jīng)網(wǎng)絡(luò)而不適合圖數(shù)據(jù),因?yàn)閳D數(shù)據(jù)需要以更平滑的梯度在多層圖卷積體系結(jié)構(gòu)上流動(dòng)。在實(shí)驗(yàn)中發(fā)現(xiàn),采用WaveNet中的門(mén)控激活函數(shù),即用一個(gè)非線(xiàn)性的tanh 函數(shù)選通sigmoid 函數(shù)的激活函數(shù)比使用單一的激活函數(shù),如ReLU、LeakyReLU 等效果好。
圖2 WaveGIC 層Fig.2 WaveGIC layer
2.2節(jié)的WaveGIC 相比GCN 等具有更強(qiáng)的拓?fù)浣Y(jié)構(gòu)表征能力,但隨著模型深度的加深,局部拓?fù)浣Y(jié)構(gòu)的特征表達(dá)越來(lái)越不明顯。因此本文提出一種基于重構(gòu)誤差的圖分類(lèi)學(xué)習(xí),即利用WaveGIC 高層表征嘗試恢復(fù)圖的鄰接矩陣,如式(22)、式(23)所示:
模型訓(xùn)練的損失函數(shù)包含二分類(lèi)交叉熵?fù)p失和圖重構(gòu)誤差損失,如式(24)所示:
式中,是超參數(shù),實(shí)驗(yàn)中設(shè)置為0.5。
圖讀出操作用來(lái)生成圖的表示,它要求操作本身對(duì)節(jié)點(diǎn)的順序不敏感。在歐式空間,旋轉(zhuǎn)圖像形成新的圖像;但在非歐式空間,圖的旋轉(zhuǎn)(如對(duì)節(jié)點(diǎn)重新編號(hào))不形成新的圖,這是典型的圖重構(gòu)。要使重構(gòu)圖與原圖表示一致,圖讀出操作就需要對(duì)節(jié)點(diǎn)順序不敏感。在數(shù)學(xué)上,能夠表達(dá)這種操作的函數(shù)稱(chēng)為對(duì)稱(chēng)函數(shù)。
本文使用注意力機(jī)制以不同權(quán)重組合節(jié)點(diǎn)特征生成圖的表征,使用Multi-Heads 從多角度以不同的注意力權(quán)重生成多種圖表征,合并作為最終的表征結(jié)果。
式(25)到式(27)是多重注意力機(jī)制讀出層的數(shù)學(xué)表達(dá)式。
其中,F∈R是可學(xué)習(xí)參數(shù),att代表每個(gè)節(jié)點(diǎn)的得分,readout代表第重讀出層的圖向量表示,表示重圖讀出表示。通過(guò)多重注意力機(jī)制讀出層,將圖節(jié)點(diǎn)以不同的權(quán)重組合,最終形成整張圖的表示。
在圖3中,輸入D=(G,y),首先經(jīng)過(guò)層WaveGIC提取節(jié)點(diǎn)的高層特征,并且每層的輸出連接多重注意力機(jī)制讀出層,用于表征不同層次的圖嵌入;隨后通過(guò)Concat 層拼接每層的圖嵌入表示,送入后續(xù)的1-D 卷積神經(jīng)網(wǎng)絡(luò);最后用DNN 網(wǎng)絡(luò)學(xué)習(xí)分類(lèi)提取的圖表示向量。同時(shí),利用WaveGIC 提取的節(jié)點(diǎn)高階表征矩陣,重構(gòu)原始圖的鄰接矩陣。
圖3 RMAWaveGIC 模型框架Fig.3 RMAWaveGIC model framework
表1 是5 個(gè)來(lái)自醫(yī)學(xué)、化工等領(lǐng)域的基準(zhǔn)數(shù)據(jù)集。AIDS是一組化合物數(shù)據(jù)集,標(biāo)簽表示化合物是否可以抗艾滋病毒活性。FRANKENSTEIN是一組包含節(jié)點(diǎn)特征的分子圖,標(biāo)簽表示一個(gè)分子是誘變劑還是非誘變劑。NCI數(shù)據(jù)集中每個(gè)圖表示一個(gè)化合物,節(jié)點(diǎn)和邊分別表示原子和化學(xué)鍵,標(biāo)簽表示抗癌活性分類(lèi)。NCI1 和NCI109 通常被用來(lái)當(dāng)作圖分類(lèi)的基準(zhǔn)數(shù)據(jù)集。PROTEINS是一組蛋白質(zhì)圖,具有邊的節(jié)點(diǎn)處于氨基酸序列中或在封閉的三維空間中。
表1 基準(zhǔn)數(shù)據(jù)集基本信息Table 1 Basic information of benchmark data set
常規(guī)的圖分類(lèi)方法分為三部分,即卷積層提取節(jié)點(diǎn)高層表征、讀出層表征整張圖和頂層分類(lèi)網(wǎng)絡(luò)。把現(xiàn)有的解決圖分類(lèi)的基礎(chǔ)算法如Set2Set、SortPool和SAGPool 等作為基準(zhǔn)模型,與本文所提WaveGIC、重構(gòu)誤差訓(xùn)練與多重注意力機(jī)制讀出層構(gòu)成模型的結(jié)構(gòu)對(duì)比。對(duì)比模型的基本信息如表2 所示。
表2 基準(zhǔn)模型與RMAWaveGIC 模型對(duì)比Table 2 Comparison of Base and RMAWaveGIC model
實(shí)驗(yàn)中采用相同的迭代次數(shù)、早停輪數(shù)、學(xué)習(xí)率以及GNN 層數(shù)等訓(xùn)練超參。模型評(píng)估采取準(zhǔn)確率和ROC_AUC 評(píng)價(jià)指標(biāo)。準(zhǔn)確率的定義如式(28)所示,反映分類(lèi)器正確分類(lèi)占總樣本的百分比;ROC 全稱(chēng)“受試者工作特征”曲線(xiàn),AUC 則是該曲線(xiàn)與橫坐標(biāo)的面積,ROC_AUC 反映了分類(lèi)器將某個(gè)隨機(jī)正類(lèi)別樣本排列在某個(gè)隨機(jī)負(fù)類(lèi)別樣本之上的概率。
其中,是正例,是反例,是真正例,是真反例。
文獻(xiàn)[30]論證了不同的數(shù)據(jù)劃分影響GNN 模型的表現(xiàn)。實(shí)驗(yàn)中,使用十折交叉驗(yàn)證評(píng)估訓(xùn)練結(jié)果,減少因訓(xùn)練集劃分造成的結(jié)果評(píng)估不準(zhǔn)確的問(wèn)題。在圖4 中,全局模型結(jié)構(gòu)僅使用最后一層表征進(jìn)行分類(lèi)學(xué)習(xí);分層模型拼接每一層表征進(jìn)行分類(lèi)學(xué)習(xí)。針對(duì)每種方法,同時(shí)訓(xùn)練全局結(jié)構(gòu)和分層結(jié)構(gòu)模型。
圖4 全局模型結(jié)構(gòu)(左)和分層模型結(jié)構(gòu)(右)Fig.4 Global model structure(left)and hierarchical model structure(right)
實(shí)驗(yàn)在NVIDIA GTX1080Ti GPU 上進(jìn)行,使用PyTorch和DGL圖深度學(xué)習(xí)庫(kù)實(shí)現(xiàn)所有的基準(zhǔn)模型和所提模型。
從圖5 上可以看出,在使用相同的圖讀出層,WaveGIC 分類(lèi)損失遠(yuǎn)遠(yuǎn)低于GCN;對(duì)比表3、表4 中WaveGIC與GCN模型的實(shí)驗(yàn)結(jié)果可以發(fā)現(xiàn),WaveGIC在5 個(gè)基準(zhǔn)數(shù)據(jù)上ROC 平均提升5.13%,并且在節(jié)點(diǎn)特征豐富的數(shù)據(jù)集(PROTEINS)上的結(jié)果提升更顯著(13.55%)。表明門(mén)控激活函數(shù)使得梯度更平滑地在多層圖卷積體系結(jié)構(gòu)上流動(dòng),模型能夠得到更充分的學(xué)習(xí)。此外,門(mén)控激活函數(shù)增強(qiáng)了模型的復(fù)雜度,WaveGIC 可以更好地從拓?fù)浣Y(jié)構(gòu)與節(jié)點(diǎn)特征中提取更加高效的節(jié)點(diǎn)表征。
表3 全局模型實(shí)驗(yàn)結(jié)果Table 3 Experimental results of global model
表4 分層模型實(shí)驗(yàn)結(jié)果Table 4 Experimental results of hierarchical model
對(duì)比WaveGIC 與MHAWaveGIC 模型實(shí)驗(yàn)結(jié)果,基于多重注意力機(jī)制讀出層的分類(lèi)模型比Sum&MaxPooling 讀出層模型的分類(lèi)損失更?。▓D5),基準(zhǔn)數(shù)據(jù)集上的準(zhǔn)確率和AUC 值得到提升,分類(lèi)性能進(jìn)一步提高。
圖5 全局結(jié)構(gòu)(左)和分層結(jié)構(gòu)(右)在PROTEINS 數(shù)據(jù)集上訓(xùn)練損失Fig.5 Train loss of global structure(left)and hierarchical structure(right)on PROTEINS dataset
注意力機(jī)制是一種能讓模型對(duì)重要信息重點(diǎn)關(guān)注并充分學(xué)習(xí)吸收的技術(shù),本文注意力機(jī)制主要是讓模型能夠更加關(guān)注關(guān)鍵節(jié)點(diǎn)。只從一個(gè)角度去學(xué)習(xí)關(guān)鍵節(jié)點(diǎn),存在偏差,因此設(shè)計(jì)種不同角度的注意力權(quán)重。本文提出的多重注意力機(jī)制讀出層可以為圖的分類(lèi)任務(wù)提供新的圖讀出方法。
圖6 展示了ReWaveGIC 模型在PROTEINS 上重構(gòu)誤差損失的收斂曲線(xiàn)。隨著迭代輪數(shù)增加,損失逐漸減小,說(shuō)明高層表征學(xué)習(xí)到圖的拓?fù)浣Y(jié)構(gòu)信息。對(duì)比MHAWaveGIC 與RMAWaveGIC 模型,使用結(jié)構(gòu)誤差和分類(lèi)誤差共同指導(dǎo)訓(xùn)練(RMAWaveGIC)比僅使用分類(lèi)誤差訓(xùn)練(MHAWaveGIC)的模型結(jié)果更加準(zhǔn)確。表3 全局結(jié)構(gòu)準(zhǔn)確率平均提升1.58 個(gè)百分點(diǎn),表4 分層結(jié)構(gòu)準(zhǔn)確率平均提升1.876 個(gè)百分點(diǎn)。在分類(lèi)誤差損失保證分類(lèi)精度的基礎(chǔ)上,重構(gòu)誤差損失指導(dǎo)使得高層表征能夠恢復(fù)圖的鄰接矩陣,賦予模型學(xué)習(xí)圖的拓?fù)浣Y(jié)構(gòu)的能力。圖由拓?fù)浣Y(jié)構(gòu)和節(jié)點(diǎn)特征組成,綜合考慮拓?fù)浣Y(jié)構(gòu)和節(jié)點(diǎn)特征對(duì)圖的分類(lèi)是極有裨益的。
圖6 ReWaveGIC 全局模型的重構(gòu)損失Fig.6 Reconstruction loss of ReWaveGIC global model
對(duì)比RMAWaveGIC 的全局結(jié)構(gòu)與分層結(jié)構(gòu)模型,分層結(jié)構(gòu)模型的分類(lèi)準(zhǔn)確率在FRANKENSTEIN數(shù)據(jù)集上比全局結(jié)構(gòu)高0.73 個(gè)百分點(diǎn),在PROTEINS上高1.63 個(gè)百分點(diǎn)。節(jié)點(diǎn)特征豐富的圖數(shù)據(jù),各層WaveGIC 都能提取到有效的節(jié)點(diǎn)特征與局部結(jié)構(gòu)信息。而節(jié)點(diǎn)特征不足的圖數(shù)據(jù),網(wǎng)絡(luò)提取的更多是結(jié)構(gòu)信息,高層結(jié)構(gòu)表征可以覆蓋低層結(jié)構(gòu)表征。故分層結(jié)構(gòu)適合節(jié)點(diǎn)特征豐富的圖數(shù)據(jù),全局結(jié)構(gòu)適合節(jié)點(diǎn)特征不足的圖數(shù)據(jù)。
本文提出WaveGIC 卷積層、基于重構(gòu)誤差訓(xùn)練和多重注意力機(jī)制讀出層,可以同時(shí)進(jìn)行節(jié)點(diǎn)特征和拓?fù)浣Y(jié)構(gòu)信息的端到端學(xué)習(xí)。RMAWaveGIC 的泛化性能較好,在不同的數(shù)據(jù)集上都能取得最優(yōu)的分類(lèi)準(zhǔn)確度和ROC-AUC 得分,尤其適用于小規(guī)模并且節(jié)點(diǎn)信息豐富的同構(gòu)圖場(chǎng)景,如化合物、蛋白質(zhì)等分類(lèi)任務(wù)。
目前存在的局限是:(1)重構(gòu)誤差時(shí)需要利用特征矩陣重構(gòu)鄰接矩陣,當(dāng)圖結(jié)構(gòu)十分龐大時(shí)極其消耗內(nèi)存空間;(2)對(duì)于大規(guī)模圖,邊的連接是稀疏的,這導(dǎo)致計(jì)算重構(gòu)誤差時(shí)正負(fù)樣本不均衡,訓(xùn)練偏向無(wú)連接邊,不利于模型學(xué)習(xí)。