楊晶東 李熠偉 江彪 姜泉 韓曼 宋夢歌
摘 要:為解決臨床醫(yī)學(xué)量表數(shù)據(jù)類別不均衡容易對模型產(chǎn)生影響,以及在處理量表數(shù)據(jù)任務(wù)時(shí)深度學(xué)習(xí)框架性能難以媲美傳統(tǒng)機(jī)器學(xué)習(xí)方法問題,提出了一種基于級(jí)聯(lián)欠采樣的Transformer網(wǎng)絡(luò)模型(layer by layer Transformer, LLT)。LLT通過級(jí)聯(lián)欠采樣方法對多數(shù)類數(shù)據(jù)逐層刪減,實(shí)現(xiàn)數(shù)據(jù)類別平衡,降低數(shù)據(jù)類別不均衡對分類器的影響,并利用注意力機(jī)制對輸入數(shù)據(jù)的特征進(jìn)行相關(guān)性評(píng)估實(shí)現(xiàn)特征選擇,細(xì)化特征提取能力,改善模型性能。采用類風(fēng)濕關(guān)節(jié)炎(RA)數(shù)據(jù)作為測試樣本,實(shí)驗(yàn)證明,在不改變樣本分布的情況下,提出的級(jí)聯(lián)欠采樣方法對少數(shù)類別的識(shí)別率增加了6.1%,與常用的NEARMISS和ADASYN相比,分別高出1.4%和10.4%;LLT在RA量表數(shù)據(jù)的準(zhǔn)確率和F1-score指標(biāo)上達(dá)到了72.6%和71.5%,AUC值為0.89,mAP值為0.79,性能超過目前RF、XGBoost和GBDT等主流量表數(shù)據(jù)分類模型。最后對模型過程進(jìn)行可視化,分析了影響RA的特征,對RA臨床診斷具有較好的指導(dǎo)意義。
關(guān)鍵詞:量表數(shù)據(jù)分類; 類別不均衡; 級(jí)聯(lián)欠采樣; Transformer
中圖分類號(hào):TP391 文獻(xiàn)標(biāo)志碼:A 文章編號(hào):1001-3695(2023)10-025-3047-06
doi:10.19734/j.issn.1001-3695.2023.01.0056
Application of layer by layer Transformer in class-imbalanced data
Yang Jingdong1, Li Yiwei1, Jiang Biao1, Jiang Quan2, Han Man2, Song Mengge2
(1.School of Optical-Electrical & Computer Engineering, University of Shanghai for Science & Technology, Shanghai 200093, China; 2.Guanganmen Hospital, China Academy of Chinese Medical Science, Beijing 100053, China)
Abstract:In order to solve the problem that class-imbalance data of clinical medical tables tend to have an impact on the model and that the performance of deep learning framework is difficult to match that of traditional machine learning methods when processing scale data tasks, this paper proposed a layer by layer Transformer (LLT) network model based on cascaded under-sampling. LLT deleted the most types of data layer by layer by cascade under-sampling method to achieve the balance of data categories and reduced the impact of class-imbalance data on the classifier. Moreover, LLT used attention mechanism to carry out correlation evaluation on the features of the input data to achieve feature selection, refined the feature extraction abi-lity and improved the model performance. This paper used RA (rheumatoid arthritis) data as test samples. Experimental results show that, on the premise of not changing the sample distribution, the recognition rate of a few categories is increased by 6.1% by the proposed cascade under-sampling method, which is 1.4% and 10.4% higher than that of the commonly used NEARMISS and ADASYN respectively. The accuracy of the RA tabular data and the F1-score index of LLT reach 72.6% and 71.5%, the AUC value is 0.89, the mAP value is 0.79, and the performance exceeds the current mainstream tabular data classification models such as RF, XGBoost and GBDT. This paper also visualized the model process and analyzed the characteristics affecting RA. It has a good guiding significance for the clinical diagnosis of RA.
Key words:tabular data classification; class-imbalance; cascaded under-sampling; Transformer
0 引言
臨床問診是指臨床醫(yī)生采用對話方式,向就醫(yī)患者及其陪同人員了解疾病的發(fā)生、發(fā)展及現(xiàn)狀過程,是醫(yī)生了解患者病情的重要方式。隨著各種醫(yī)學(xué)信息數(shù)據(jù)庫的不斷建立,問診數(shù)據(jù)逐漸積累形成了大量的臨床量表數(shù)據(jù),這些臨床量表數(shù)據(jù)多具有類別不均衡(class-imbalance)的特點(diǎn)。類別不均衡是指在一個(gè)數(shù)據(jù)集中某些類別樣本數(shù)量多于其他類別[1~3]。這種類型的數(shù)據(jù)會(huì)對分類器預(yù)測結(jié)果有較大影響,使預(yù)測偏向多數(shù)類,表現(xiàn)為多數(shù)類的先驗(yàn)概率增加、準(zhǔn)確率降低,少數(shù)類樣本漏檢率增加,部分少數(shù)類樣本被預(yù)測為多數(shù)類,分類結(jié)果偏向多數(shù)類別,嚴(yán)重影響了分類精度和泛化性能[4~6]。
近年來,研究人員提出各種方法來減少樣本不均衡對分類結(jié)果的影響。常用的樣本均衡化方法包括隨機(jī)過采樣(ROS)和隨機(jī)欠采樣(RUS)。為了獲得平衡的數(shù)據(jù),ROS隨機(jī)復(fù)制一些正樣本,而RUS隨機(jī)丟棄一些負(fù)樣本。在更高級(jí)的方法中,如合成少數(shù)過采樣技術(shù)(SMOTE)系列[7,8],通過線性插值生成新的正樣本。類似地,可以通過丟棄信息較少的樣本改進(jìn)RUS。例如Lin等人[9]使用K-means算法將負(fù)樣本分成k個(gè)聚類,負(fù)類由聚類中心表示;Hoyos-Osorio等人[10]引入了信息論學(xué)習(xí),以更少的樣本保持負(fù)類的相關(guān)結(jié)構(gòu);Koziarski[11]采用相互類勢(MCP)的方法改進(jìn)過采樣和過采樣過程。此外還有一種將集成學(xué)習(xí)技術(shù)和數(shù)據(jù)級(jí)方法結(jié)合的混合策略,如SMOTEBoost[12]、Balanced Cascade[13]和AdaC1-AdaC3[14]。這些方法通過將權(quán)重集成到集成學(xué)習(xí)算法中,迭代地增強(qiáng)正樣本的影響,雖然一定程度上緩解了數(shù)據(jù)集類別不均衡帶來的問題,但通過生成或者刪減數(shù)據(jù)的方式改變了原有的數(shù)據(jù)分布,降低了分類性能的可信度。
決策樹(DT)通常用于量表類數(shù)據(jù)分類,主要優(yōu)點(diǎn)是可以有效地選取具有最多統(tǒng)計(jì)信息增益的全局特征,從而提高標(biāo)準(zhǔn)分類性能。隨機(jī)森林(random forest)[15]則根據(jù)隨機(jī)選擇特征生成若干森林樹,統(tǒng)計(jì)若干森林樹的分類結(jié)果投票,生成一個(gè)強(qiáng)分類器。XGBoost[16]和LightGBM[17]是近幾年流行的DT方法,并且在大部分?jǐn)?shù)據(jù)分類比賽中占據(jù)主導(dǎo)地位。
深度學(xué)習(xí)(deep learning, DL)模型憑借良好的全局關(guān)聯(lián)特征提取能力,已經(jīng)在CV/NLP等非結(jié)構(gòu)化數(shù)據(jù)上得到了充分研究,涌現(xiàn)了大批突破性的研究成果[18~20]。但在量表數(shù)據(jù)中并未表現(xiàn)出優(yōu)越的分類性能,主要是因?yàn)榱勘頂?shù)據(jù)具有特征多源性和稀疏性,缺乏先驗(yàn)知識(shí)和可解釋性,使DL無法有效兼容[21]。集成決策樹對于量表類數(shù)據(jù)雖然具有較好的分類精度和泛化性能,但是無法提取量表數(shù)據(jù)中有效的全局關(guān)聯(lián)性的深層特征,如能改進(jìn)DL模型結(jié)構(gòu),將多源數(shù)據(jù)和量表數(shù)據(jù)一起編碼,可以有效減輕DT中煩瑣的特征工程需求[22]。近年來人工智能技術(shù)越來越多地應(yīng)用于醫(yī)學(xué)臨床輔助決策,Khanam等人[23]通過醫(yī)院提供的患者數(shù)據(jù),訓(xùn)練和設(shè)計(jì)準(zhǔn)確的ML/DL分類器,實(shí)現(xiàn)糖尿病早期檢測的自動(dòng)分類;Islam等人[24]使用ML分類模型,包括樸素貝葉斯(NB)[25]、邏輯回歸(LR)[26]和隨機(jī)森林,使用10倍交叉驗(yàn)證和80∶20訓(xùn)練測試分割方法對ESDRPD量表數(shù)據(jù)集進(jìn)行訓(xùn)練。隨著深度學(xué)習(xí)研究的不斷深入,許多研究人員嘗試使用神經(jīng)網(wǎng)絡(luò)實(shí)現(xiàn)量表類數(shù)據(jù)分類。Humbird等人[27]開始采用DNN模擬DT;Cortes等人[28]提出一種裝箱函數(shù),通過枚舉所有可能的決策方法,采用DNN來模擬DT;Tanno等人[29]提出一種DNN架構(gòu),通過從原始?jí)K自適應(yīng)生長,從邊緣路由函數(shù)和葉節(jié)點(diǎn)進(jìn)行表征學(xué)習(xí)。但這些通過窮舉的方法會(huì)產(chǎn)生冗余的表征導(dǎo)致學(xué)習(xí)效率低下。
為了解決上述問題,本文提出一種級(jí)聯(lián)欠采樣方法,并結(jié)合改進(jìn)后的基于Transformer的分類模型(LLT)。LLT順序多步地提取數(shù)據(jù)中有用的信息,提升模型對少數(shù)類樣本的特征學(xué)習(xí)能力。本文的主要貢獻(xiàn)如下:a)提出了一種級(jí)聯(lián)采樣的類別不均衡的數(shù)據(jù)處理方法,相較于原始的類別不均衡數(shù)據(jù),可以讓分類器模型更加有效地學(xué)習(xí)少數(shù)類數(shù)據(jù)的特征,減少類別不均衡對分類器的性能影響;b)通過在分類模型中添加注意力模塊,更加有效地學(xué)習(xí)特征對類別標(biāo)簽的影響,評(píng)估特征之間的相關(guān)性,提取特征之間的高維信息,從而進(jìn)行過濾,減少輸入特征的維度,一定程度上提升了模型推理速度,提高了模型的分類精度。本文方法在廣安門醫(yī)院提供的RA數(shù)據(jù)集進(jìn)行實(shí)驗(yàn),與現(xiàn)有的量表分類模型和采樣方法進(jìn)行對比,性能指標(biāo)優(yōu)于現(xiàn)有方法,并通過可視化mask掩碼機(jī)制分析各個(gè)特征對于分類性能的影響,挖掘影響RA的特征因素,對RA臨床診斷具有較好的指導(dǎo)意義。
1 LLT方法
LLT方法通過分批定量刪除分類器預(yù)測結(jié)果中正確分類的多數(shù)類樣本,保留少數(shù)類樣本和難分樣本,然后將級(jí)聯(lián)欠采樣后的數(shù)據(jù)打包輸入下一輪分類器訓(xùn)練,分層學(xué)習(xí)少數(shù)類以及難區(qū)分的數(shù)據(jù)樣本規(guī)律。具體方法整體結(jié)構(gòu)如圖1所示。通過組合分類器,在每個(gè)分類器后連接一個(gè)欠采樣模塊,從多數(shù)類樣本中隨機(jī)剔除一定數(shù)量的樣本,再傳遞給下一個(gè)分類器,并在分類模型中融入注意力機(jī)制,提取特征的多維信息,對數(shù)據(jù)進(jìn)行降維。以此往復(fù),直到進(jìn)行完最后一輪學(xué)習(xí),將之前的分類器權(quán)重進(jìn)行整合加權(quán),求得最終模型權(quán)重參數(shù)。
1.1 級(jí)聯(lián)采樣
欠采樣是解決樣本類別不均衡的常用方法之一,其核心思想是按照某種抽樣法則,每次從多數(shù)類樣本中抽取和少數(shù)類樣本相同數(shù)量的樣本,與少數(shù)類樣本構(gòu)成訓(xùn)練數(shù)據(jù)。由于常規(guī)欠采樣方法拋棄了大量的多數(shù)類樣本,導(dǎo)致丟失了原始數(shù)據(jù)信息,弱化了一些重要信息對模型的影響,最終結(jié)果使得預(yù)測模型有較大的偏差。過采樣通過對少數(shù)類樣本進(jìn)行復(fù)制來均衡數(shù)據(jù)集,但會(huì)造成模型訓(xùn)練復(fù)雜度加大,另一方面也容易造成模型過擬合的問題,不利于分類器的泛化性能。本文提出的級(jí)聯(lián)欠采樣方法,通過從分類器預(yù)測結(jié)果中刪除一定數(shù)量預(yù)測正確的多數(shù)類樣本,處理后的數(shù)據(jù)重新傳遞給下一個(gè)分類器訓(xùn)練,最終組合分類器。級(jí)聯(lián)欠采樣方法結(jié)構(gòu)如圖2所示。
在級(jí)聯(lián)欠采樣方法中有兩個(gè)主要過程:a)在隨機(jī)采樣過程中,對分類模型輸出進(jìn)行判斷,統(tǒng)計(jì)被正確分類以及錯(cuò)誤分類樣本,藍(lán)色和黃色分別表示分類正確的多數(shù)類(記為P)和少數(shù)類(記為Q),紅色表示錯(cuò)誤分類的樣本,灰色線條表示分類器的分類邊界H(x)(參見電子版),該過程抽取被正確分類的多數(shù)類樣本(記為Pselected);b)在剔除多數(shù)類過程中,將所選擇的樣本從多數(shù)類別中刪除,然后將剩余樣本傳入下一次分類器作為其輸入數(shù)據(jù)。每次刪除樣本的數(shù)量為一個(gè)常數(shù)S,由多數(shù)類和少數(shù)類之間的樣本數(shù)量差決定,其公式表示為
分類器為H(x),分類器參數(shù)以及權(quán)重為θk、wk,預(yù)測函數(shù)為h(xi,θk),最終的分類器公式表示為
由于被正確分類的樣本被逐步剔除,一方面降低了數(shù)據(jù)類別不平衡的程度,另一方面使得難分樣本被保留,可以進(jìn)行更多輪的學(xué)習(xí)??壳暗挠?xùn)練次數(shù)可以利用更多的全局信息,靠后的訓(xùn)練次數(shù)可以學(xué)習(xí)更多的少數(shù)類樣本的特征表達(dá)。本文方法通過不斷地進(jìn)行級(jí)聯(lián)欠采樣,刪除被正確分類的多數(shù)類樣本,最后使得輸入樣本中的少數(shù)類和多數(shù)類樣本數(shù)量相等,達(dá)到平衡數(shù)據(jù)的目的。
1.2 分類器
TabNet[30]是基于Transformer的模型,原模型通過對linear layers(線性層)和batch normalization的不斷堆疊,在取得不遜于傳統(tǒng)算法性能的同時(shí)可以更加有效地編碼多種數(shù)據(jù)類型,但隨著特征維度和層數(shù)的增多,模型的參數(shù)會(huì)有明顯上升。在此基礎(chǔ)上,本文在原模型結(jié)構(gòu)中加入注意力機(jī)制,并結(jié)合門控線性單元(gate linear unit,GLU)對每個(gè)step輸入的特征進(jìn)行降維,在降低每個(gè)step輸入特征維度的同時(shí)提升了對數(shù)據(jù)的特征提取能力。分類器網(wǎng)絡(luò)整體結(jié)構(gòu)如圖3所示。按照從左往右、自上向下的順序,將整個(gè)預(yù)測過程分為多步。分類器網(wǎng)絡(luò)整體結(jié)構(gòu)由FT模塊、AF模塊與mask機(jī)制堆疊而成。數(shù)據(jù)首先經(jīng)過網(wǎng)絡(luò)模型中的GBN[31]層,GBN層的輸出作為模型中后續(xù)階段的輸入features;通過重復(fù)的結(jié)構(gòu)(step1、step2…),在每一輪step根據(jù)輸入特征以及學(xué)習(xí)到的特征表變換和權(quán)重系數(shù),得到當(dāng)前step的輸出向量;最終將多個(gè)step輸出向量累加之后,利用全連接層做一次變換,得到最終的輸出,即
如圖4(a)所示,F(xiàn)T由四個(gè)相同的FT block模塊組成,F(xiàn)T block由FC、MHA、GLU、GBN串聯(lián)組成,負(fù)責(zé)將輸入的特征進(jìn)行線性層變換,提取數(shù)據(jù)信息。其中GBN為ghost batch normalization,與普通的batch normalization相比更具有魯棒性;MHA為multi-head attention機(jī)制,在結(jié)構(gòu)中和GLU串聯(lián),用來計(jì)算和過濾特征;GLU的變換公式為hl(X)=(XW+b)σ(XV+c)。融入多頭注意力機(jī)制后有助于模型捕捉到更豐富的特征表示,計(jì)算數(shù)據(jù)中各個(gè)特征的依賴關(guān)系,最后將得到的多維信息進(jìn)行集成,通過GLU進(jìn)行特征的過濾、降維,從而縮短模型推理時(shí)間。以RA患者數(shù)據(jù)為例,多數(shù)患者存在關(guān)節(jié)疼痛與疼痛關(guān)節(jié)發(fā)熱的問題,模型會(huì)分析特征之間的依賴性,若患者不存在關(guān)節(jié)疼痛,自然就不會(huì)出現(xiàn)疼痛關(guān)節(jié)發(fā)熱的病癥,模型會(huì)在當(dāng)前step中將疼痛關(guān)節(jié)發(fā)熱的特征過濾,減少輸入特征的維度。各個(gè)block之間通過跳躍鏈接(skip-connection)相連,并乘以0.5,用來防止網(wǎng)絡(luò)中輸出的方差變換波動(dòng)太大,造成學(xué)習(xí)過程的不穩(wěn)定。圖4(b)為AF結(jié)構(gòu),用來計(jì)算特征重要性。為了方便對模型內(nèi)部進(jìn)行可視化,在每個(gè)AF與FT模塊之間并行mask,用于記錄當(dāng)前step對特征的關(guān)注點(diǎn)。對于樣本的特征重
2 實(shí)驗(yàn)與結(jié)果分析
2.1 實(shí)驗(yàn)數(shù)據(jù)集及預(yù)處理
實(shí)驗(yàn)數(shù)據(jù)集由中國中醫(yī)科學(xué)院廣安門醫(yī)院提供。數(shù)據(jù)總計(jì)條目27條,包括病人就診時(shí)間、發(fā)病時(shí)間等基本信息,關(guān)節(jié)疼痛、關(guān)節(jié)腫脹等關(guān)節(jié)信息和相應(yīng)化驗(yàn)指標(biāo)共計(jì)26條,RA關(guān)節(jié)活動(dòng)等級(jí)分類標(biāo)簽1條,分類級(jí)別由輕到重依次為緩解疾病活動(dòng)、輕微疾病活動(dòng)、中度疾病活動(dòng)、重度疾病活動(dòng)階段,分別編碼為0、1、2、3。數(shù)據(jù)樣本共計(jì)10 514例,緩解疾病活動(dòng)等級(jí)911例,輕微疾病活動(dòng)等級(jí)749例,中度疾病活動(dòng)等級(jí)3 713例,重度疾病活動(dòng)等級(jí)5 141例,存在數(shù)據(jù)類別不平衡問題。對于原始RA數(shù)據(jù)樣本,數(shù)據(jù)質(zhì)量存在較大差別,如圖5(a)(b)所示,其中,(a)所示特征具有良好區(qū)分性,特征應(yīng)予以保留;(b)所示特征不具有區(qū)分性,應(yīng)予以刪除。再剔除與RA疾病等級(jí)分類明顯無關(guān)的特征,如患者ID、發(fā)病時(shí)間等。對于缺失值采用前值填充法進(jìn)行填充,刪除了異常樣本,最后采用one-hot編碼對類別變量進(jìn)行轉(zhuǎn)換。處理后的數(shù)據(jù)保留了10 110例樣本,每例樣本包含14個(gè)特征,處理后的數(shù)據(jù)各類別占比如圖5(c)所示。實(shí)驗(yàn)采用5折交叉驗(yàn)證,將預(yù)處理之后的數(shù)據(jù)等分成5份,在5輪實(shí)驗(yàn)中,取其中的4份用于訓(xùn)練,1份用于測試,并將最終的結(jié)果取均值作為最終的分類結(jié)果。實(shí)驗(yàn)流程如圖6所示。
2.2 評(píng)價(jià)指標(biāo)
對于分類任務(wù),通常選擇整體分類準(zhǔn)確度(overall accuracy)作為對模型的評(píng)價(jià)指標(biāo),從模型預(yù)測的混淆矩陣中獲取四個(gè)基本指標(biāo),分別是真陽性(TP)、假陽性(FP)、真陰性(TN)以及假陰性(FN),進(jìn)而得到針對類別不均衡樣本分類時(shí)常用的三個(gè)評(píng)估指標(biāo),即真陽性率(TPR)、真陰性率(TNR)以及G-mean。其中TPR即靈敏度(recall),衡量了分類器對正樣本的識(shí)別程度,如果TPR值過小,表明分類器將大量的正樣本預(yù)測為負(fù)樣本,導(dǎo)致漏診;TNR即特異性(specificity),衡量了分類器對負(fù)樣本的識(shí)別程度,如果TNR值過小,表明分類器將大量負(fù)樣本預(yù)測為正樣本,導(dǎo)致誤診。因此,泛化性能較好的分類器需要同時(shí)具有較高的TPR和TNR值。G-mean常用來表示算法的平衡程度,該值越大說明模型的表現(xiàn)性能越好。本文實(shí)驗(yàn)中還用到了評(píng)估指標(biāo)F1-score,該指標(biāo)是精確度(precision)和recall的調(diào)和平均值。各模型評(píng)估指標(biāo)定義如下:
2.3 實(shí)驗(yàn)結(jié)果及分析
1)采樣方法對比 為了驗(yàn)證級(jí)聯(lián)欠采樣方法的有效性,本文采用改進(jìn)后的分類器結(jié)合目前主流的采樣方法進(jìn)行對比,包括兩種過采樣方法SMOTE、ADASYN和一種欠采樣方法NearMiss。衡量指標(biāo)采用真陽性率(TPR)、真陰性率(TNR)和G-mean三種分類評(píng)估指標(biāo)。實(shí)驗(yàn)結(jié)果如圖7所示。相較于不進(jìn)行采樣處理的原始分類器,本文提出的級(jí)聯(lián)采樣方法能在保持原有的TNR不減退的前提下,對少數(shù)類樣本的識(shí)別率由0.558提升到了0.619。與欠采樣NearMiss方法相比,后者通過直接丟棄部分多數(shù)類樣本而丟失過多的原有數(shù)據(jù)信息,最終
導(dǎo)致衡量指標(biāo)的下降。與過采樣方法SMOTE和ADASYN比較,過采樣方法通過增加少數(shù)類樣本使數(shù)據(jù)類別達(dá)到均衡,但新增樣本會(huì)引入冗余信息增加噪聲,同時(shí)也可能造成過擬合。雖然SMOTE過采樣方法整體TPR和TNR指標(biāo)略高,但在一定程度上都會(huì)改變樣本原始分布,降低了原始模型對于輕微疾病類患者和中度疾病類患者的G-mean值,導(dǎo)致整體模型預(yù)測偏差,減弱了模型泛化能力。而本文提出的級(jí)聯(lián)欠采樣方法既能有效利用所有原始樣本信息,保持了原始模型對于多數(shù)類類別的TNR指標(biāo),又能減少類別不均衡對模型分類的影響,因此三項(xiàng)評(píng)估指標(biāo)均具有較高精度。
2)與典型分類模型性能對比 本文選取了九種主流的機(jī)器學(xué)習(xí)方法和神經(jīng)網(wǎng)絡(luò)展開對比實(shí)驗(yàn),包括幾種基于ML的分類方法,如人工神經(jīng)網(wǎng)絡(luò)中的多層感知機(jī)(MLP)、隨機(jī)森林(RF)、XGBoost、GBDT、邏輯回歸(LR)、支持向量機(jī)(SVM)等。表1為5折交叉驗(yàn)證的準(zhǔn)確度結(jié)果。
圖8依次為各模型的recall、precison以及F1-score的交叉驗(yàn)證結(jié)果,由圖可見,本文方法在三個(gè)指標(biāo)上均高于其他分類器。本實(shí)驗(yàn)還采用ROC和P-R曲線來可視化模型分類性能。圖9給出了各分類器的ROC和P-R曲線。ROC反映的是真陽性率(TPR)隨著假陽性率(FPR)的變化情況,同時(shí)兼顧了對正樣例和負(fù)樣例的預(yù)測情況;P-R曲線則反映的是precision隨著recall的變化,只關(guān)注對正樣例的預(yù)測情況。從圖9可以看出,當(dāng)前主流的機(jī)器學(xué)習(xí)分類模型雖然有著整體不錯(cuò)的最終分類準(zhǔn)確度,但多數(shù)屬于盲目將少數(shù)類歸為多數(shù)類獲得的虛高,反映在ROC與P-R上卻并不理想。本文方法在ROC曲線的AUC值達(dá)到0.89,比次高的LR方法高出3%,對比當(dāng)下流行的XGBoost和GBDT的AUC值分別高出7%和17%;與剩余方法分類器的AUC值相比也有不同程度的提升。本文方法的P-R曲線的mAP值達(dá)到0.79,比LR方法的mAP高出16%,并且大幅度高于其他分類方法的mAP值。這是由于改進(jìn)后的分類模型融合了注意力機(jī)制,細(xì)化了模型的信息,提取了特征與特征之間的相關(guān)性,過濾了患者不重要的信息,提升了模型的性能;此外,通過級(jí)聯(lián)欠采樣方法,多次學(xué)習(xí)難分樣本以及少數(shù)類樣本規(guī)律,最后加權(quán)后的分類器對于少數(shù)類具有更出色的判別能力。
圖10為本文方法最終分類效果展示,其中(a)為RA數(shù)據(jù)集中少數(shù)類樣本病例1和難區(qū)分的病例2在主流分類器的分類結(jié)果與本文方法分類結(jié)果的對比??梢钥闯觯瑢τ诓±?,主流的RF、XGB和改進(jìn)前的TabNet由于同類數(shù)據(jù)樣本少,無法學(xué)到合理的特征表達(dá),最終沒有預(yù)測到正確的結(jié)果,而LR方法分類正確;在本文方法中,靠前的分類器雖然沒有預(yù)測對少數(shù)類樣本,但隨著級(jí)聯(lián)欠采樣次數(shù)的增加,數(shù)據(jù)集類別不平衡的程度得到了改善,靠后的分類器可以學(xué)習(xí)到難分樣本和少數(shù)類樣本規(guī)律做到正確分類,最后的分類器通過加權(quán)獲得正確的分類結(jié)果。在病例2的測試中,RF、LR由于樣本的區(qū)分度
不高,沒能正確分類,XGB、TabNet和本文的LLT都預(yù)測正確,雖然本文方法中間的分類器2分類錯(cuò)誤,但在最后通過對所有分類器權(quán)重加權(quán)后,最后的模型分類結(jié)果正確。圖10(b)為本文方法的分類結(jié)果的混淆矩陣,雖然該方法對每類樣本的分類結(jié)果沒能達(dá)到完全正確,但卻通過級(jí)聯(lián)欠采樣模塊后加權(quán)融合改進(jìn)后的分類器,做到有效地減弱了數(shù)據(jù)樣本不均衡對分類器造成的影響。
3)級(jí)聯(lián)采樣方法下不同模型對比 在基于本文提出的級(jí)聯(lián)采樣的方法中,結(jié)合上節(jié)對比的分類模型對RA數(shù)據(jù)進(jìn)行實(shí)驗(yàn),計(jì)算不同方法下的模型分類的TPR、TNR以及G-mean指標(biāo)。實(shí)驗(yàn)結(jié)果如表2所示,粗體數(shù)據(jù)為最優(yōu)值,其可視化如圖11所示。
由表2可以看到在級(jí)聯(lián)采樣方法下的GBDT在TNR達(dá)到了0.882,雖然高于本文的LLT,但在TPR和G-mean指標(biāo)上本文方法高于前者。這是因?yàn)镚BDT是一種樹模型,雖然每一次采樣后降低了類別不均衡的程度,但GBDT也會(huì)生成更多的樹,與之前的樹進(jìn)行合并,決策投票時(shí)會(huì)有更多的樹投票給多數(shù)類,造成多數(shù)類檢測率虛高;本文方法通過采樣平衡數(shù)據(jù)不均衡程度,分類器再對樣本進(jìn)行特征轉(zhuǎn)換,更好地學(xué)習(xí)少數(shù)類樣本,因此有著更好的少數(shù)類檢測率和G-mean指標(biāo)。綜上所述,對于不均衡數(shù)據(jù),LLT方法具有更好的分類性能和綜合指標(biāo),提出的級(jí)聯(lián)欠采樣方法也具有一定的通用性。
4)模型mask可視化 對本文模型的掩碼模塊進(jìn)行可視化,探討分析模型在每個(gè)step的關(guān)注點(diǎn),結(jié)果如圖12所示。圖中共展示了50個(gè)樣本,從上往下依次是樣本順序0~49,從左往右依次是14個(gè)特征條目(0:關(guān)節(jié)疼痛;1:關(guān)節(jié)腫脹;2:關(guān)節(jié)晨僵;3:疼痛關(guān)節(jié)發(fā)熱;4:能否自己洗頭;5:能自己使用筷子;6:能否端起盛滿的杯子送到嘴邊;7:能否伸手摘下衣架上的衣帽;8:能否上5級(jí)臺(tái)階么;9:神疲乏力;10:胃口如何;11:心煩不安;12:能否加做家務(wù);13:失眠多夢),mask 0~2分別對應(yīng)模型設(shè)置的3個(gè)step,顏色越亮,代表在該step下模型更關(guān)注此特征。
從掩碼的可視化熱力圖可以看到,模型中的每一個(gè)step的關(guān)注特征都有所不同,第一個(gè)step更關(guān)注RA患者的第0、1、12個(gè)特征;第二個(gè)step更關(guān)注RA患者的第4~6個(gè)特征;第三個(gè)step更關(guān)注患者的第2~6、8、9個(gè)特征。臨床研究表明,絕大多數(shù)RA患者都會(huì)出現(xiàn)關(guān)節(jié)疼痛發(fā)熱、關(guān)節(jié)腫脹以及肢體活動(dòng)受限,進(jìn)而常常感覺身體疲勞的病癥。本文的LLT模型分步分層地將更多的關(guān)注點(diǎn)放在了這些癥狀上,提高了模型的預(yù)測效果。模型的關(guān)注點(diǎn)與臨床結(jié)論一致,充分說明改進(jìn)后的方法具有良好的可信度。
3 結(jié)束語
量表數(shù)據(jù)因?yàn)槠浯嬖诘臄?shù)據(jù)類別不均衡以及常見的混合數(shù)據(jù)屬性的問題,導(dǎo)致在深度學(xué)習(xí)領(lǐng)域沒有得到充分的探索發(fā)展。本文提出了一種基于Transformer網(wǎng)絡(luò)和級(jí)聯(lián)采樣方法的網(wǎng)絡(luò)模型LLT并且應(yīng)用于中醫(yī)RA數(shù)據(jù)集。將注意力機(jī)制融合進(jìn)原始模型細(xì)化特征之間的信息,提升了模型的精度,通過級(jí)聯(lián)采樣方法層層刪減多數(shù)類樣本以減緩樣本不均衡對模型的影響,提升模型對少數(shù)類的識(shí)別率。與當(dāng)前主流的針對量表型數(shù)據(jù)的RF、LR、XGBoost、GBDT等模型相比較,改進(jìn)后的模型具有更好的精度和泛化性能。本文提出的級(jí)聯(lián)采樣方法能更好地學(xué)習(xí)少數(shù)類樣本的特征,對于其他類別不均衡的采樣方法,指標(biāo)也有所提升,對臨床診斷有著一定的指導(dǎo)作用。
在類別不均衡的量表數(shù)據(jù)研究領(lǐng)域,仍存在一些需要進(jìn)一步研究的方法,如通過增量式學(xué)習(xí)方法,利用新數(shù)據(jù)更新模型或利用基于規(guī)則的處理策略,對特定的問題和場景設(shè)計(jì)出有效的類別不平衡策略等。未來研究會(huì)綜合考慮以上因素,并更好地結(jié)合實(shí)際應(yīng)用場景來設(shè)計(jì)和改良模型。
參考文獻(xiàn):
[1]李勇,劉戰(zhàn)東,張海軍.不平衡數(shù)據(jù)的集成分類算法綜述[J].計(jì)算機(jī)應(yīng)用研究,2014,31(5):1287-1291.(Li Yong, Liu Zhandong, Zhang Haijun. Review on ensemble algorithms for imbalanced data classification[J].Application Research of Computers,2014,31(5):1287-1291.)
[2]Herland M, Khoshgoftaar T M, Bauder R A. Big data fraud detection using multiple medicare data sources[J].Journal of Big Data,2018,5(1):article No.29.
[3]Mohd H A H, Marina Y, Azlinah M. Survey on highly imbalanced multi-class data[J/OL].International Journal of Advanced Computer Science and Applications,2022,13(6).http://dx.doi.org/10.14569/IJACSA.2022.0130627.
[4]Peng Lizhi, Zhang Haibo, Chen Yuehui, et al. Imbalanced traffic identification using an imbalanced data gravitation-based classification model[J].Computer Communications,2017,102(4):177-189.
[5]Chawla N V. Data mining for imbalanced datasets: an overview [M]// Data Mining and Knowledge Discovery Handbook. Boston, MA: Springer, 2010: 853-867.
[6]Mahdiyah U, Irawan M I, Imah E M. Integrating data selection and extreme learning machine for imbalanced data[J].Procedia Computer Science,2015,59:221-229.
[7]Tesfahun A, Bhaskari D L. Intrusion detection using random forests classifier with SMOTE and feature reduction [C]// Proc of International Conference on Cloud & Ubiquitous Computing & Emerging Technologies. Piscataway, NJ: IEEE Press, 2013: 127-132.
[8]Li Junnan, Zhu Qingsheng, Wu Quanwang, et al. A novel oversampling technique for class-imbalanced learning based on SMOTE and natural neighbors[J].Information Sciences,2021,565(7):438-455.
[9]Lin Weichao, Tsai C F, Hu Yahan, et al. Clustering-based undersampling in class-imbalanced data-ScienceDirect[J].Information Sciences,2017,409-410(10): 17-26.
[10]Hoyos-Osorio J, lvarez-Meza A, Daza-Santacoloma G, et al. Relevant information undersampling to support imbalanced data classification[J].Neurocomputing,2021,436(5):136-146.
[11]Koziarski M. Radial-based undersampling for imbalanced data classification[J].Pattern Recognition,2020,102(6):107262.
[12]Chawla N, Lazarevic A, Hall L O, et al. SMOTEBoost: improving prediction of the minority class in boosting [C]// Proc of the 7th European Conference on Principles and Practice of Knowledge Discovery in Databases. Berlin: Springer, 2003: 107-119.
[13]Liu Xuying, Wu Jianxin, Zhou Zhihua. Exploratory under-sampling for class-imbalance learning[J].IEEE Trans on Systems, Man,and Cybernetics,2009,39(2):539-550.
[14]Sun Yanmin, Kamel M S, Wong A K C, et al. Cost-sensitive boosting for classification of imbalanced data[J].Pattern Recognition,2007,40(12):3358-3378.
[15]Liu Yanli, Wang Yourong, Zhang Jian. New machine learning algorithm: random forest [C]// Proc of the 3rd International Conference Information Computing and Applications. Berlin: Springer, 2012: 246-252.
[16]Chen Tianqi, Guestrin C. XGBoost: a scalable tree boosting system [C]// Proc of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining. New York: ACM Press, 2016: 785-794.
[17]Ke Guolin, Meng Qi, Finley T, et al. LightGBM: a highly efficient gradient boosting decision tree [C]// Proc of the 31st International Conference on Neural Information Processing Systems. Red Hook, NY: Curran Associates Inc., 2017: 3149-3157.
[18]He Kaiming, Zhang Xiangyu, Ren Shaoqing, et al. Deep residual learning for image recognition [C]// Proc of IEEE Conference on Computer Vision and Pattern Recognition. Washington DC: IEEE Computer Society, 2016: 770-778.
[19]Dosovitskiy A, Beyer L, Kolesnikov A, et al. An image is worth 16x16 words: transformers for image recognition at scale [EB/OL].(2021-06-03). https://arxiv.org/abs/2010.11929.
[20]Howard A, Sandler M, Chu G, et al. Searching for MobileNetV3 [EB/OL]. (2019-11-20). https://arxiv.org/abs/1905.02244.
[21]Shwartz-Ziv R, Armon A. Tabular data: deep learning is not all you need[J].Information Fusion,2022,81(5): 84-90.
[22]Hestness J, Narang S, Ardalani N, et al. Deep learning scaling is predictable, empirically [EB/OL]. (2017-12-01). https://arxiv.org/abs/1712.00409.
[23]Khanam J J, Foo S Y. A comparison of machine learning algorithms for diabetes prediction[J].ICT Express,2021,7(4):432-439.
[24]Islam M M F, Ferdousi R, Rahman S, et al. Likelihood prediction of diabetes at early stage using data mining techniques [C]// Proc of International Symposium on Computer Vision and Machine Intelligence in Medical Image Analysis. Singapore: Springer, 2019: 113-125.
[25]Best L, Foo E, Tian Hui. A hybrid approach: utilizing K-means clustering and naive Bayes for IoT anomaly detection [M]// Secure and Trusted Cyber Physical Systems. Cham: Springer, 2022: 177-214.
[26]Bredt L C, Peres L A B, Risso M, et al. Risk factors and prediction of acute kidney injury after liver transplantation:logistic regression and artificial neural network approaches[J].World Journal of Hepatology: English Edition,2022,14(3):570-582.
[27]Humbird K D, Peterson J L, McClarren R G. Deep neural network initialization with decision trees[J].IEEE Trans on Neural Networks and Learning Systems,2019,30(5):1286-1295.
[28]Cortes C, Gonzalvo X, Kuznetsov V, et al. AdaNet: adaptive structural learning of artificial neural networks [EB/OL]. (2017-02-28). https://arxiv.org/abs/1607.01097.
[29]Tanno R, Arulkumaran, K, Alexander D C, et al. Adaptive neural trees [EB/OL]. (2019-06-09). https://arxiv.org/abs/1807.06699.
[30]Arik S , Pfister T. TabNet: attentive interpretable tabular learning[J].Proceedings of the AAAI Conference on Artificial Intelligence,2021,35(8):6679-6687.
[31]Dimitriou N, Arandjelovic O. A new look at ghost normalization [EB/OL]. (2020-07-16). https://arxiv.org/abs/2007.08554.
[32]Zhao Puning, Lai Lifeng. Minimax rate optimal adaptive nearest neighbor classification and regression[J].IEEE Trans on Information Theory,2021,67(5):3155-3182.
[33]De Nogueira T O, Palacio G B A, Braga F D, et al. Imbalance classification in a scaled-down wind turbine using radial basis function kernel and support vector machines[J].Energy,2022,238(1):122064.
[34]Liu Wanan, Fan Hong, Xia Meng. Credit scoring based on tree-enhanced gradient boosting decision trees[J].Expert Systems with Application,2022,189(3):116034.
[35]Li Wenshuo, Chen Hanting, Guo Jianyuan, et al. Brain-inspired multilayer perceptron with spiking neurons [C]// Proc of IEEE/CVF Conference on Computer Vision and Pattern Recognition. Piscataway, NJ: IEEE Press, 2022: 773-783.
收稿日期:2023-01-31;修回日期:2023-03-24
基金項(xiàng)目:國家自然科學(xué)基金資助項(xiàng)目(81973749);中國中醫(yī)科學(xué)院科技創(chuàng)新工程項(xiàng)目(CI2021A01503)
作者簡介:楊晶東(1973-),男,黑龍江齊齊哈爾人,副教授,碩導(dǎo),博士,主要研究方向?yàn)槿斯ぶ悄堋C(jī)器學(xué)習(xí)與大數(shù)據(jù)分析、機(jī)器視覺等;李熠偉(1997-),男(通信作者),江蘇徐州人,碩士研究生,主要研究方向?yàn)槿斯ぶ悄?、機(jī)器學(xué)習(xí)等(eerfriend@yeah.net);江彪(1998-),男,安徽宣城人,碩士研究生,主要研究方向?yàn)槿斯ぶ悄?、機(jī)器學(xué)習(xí)等;姜泉(1961-),女,主任醫(yī)師,博士,主要研究方向?yàn)轱L(fēng)濕免疫病的中醫(yī)、中西醫(yī)結(jié)合臨床及基礎(chǔ)研究;韓曼(1984-),女,副主任醫(yī)師,博士,主要研究方向?yàn)轱L(fēng)濕免疫病的中醫(yī)、中西醫(yī)結(jié)合臨床及基礎(chǔ)研究;宋夢歌(1993-),女,博士研究生,主要研究方向?yàn)轱L(fēng)濕免疫疾病的臨床與基礎(chǔ)研究.