孫 紅,黃甌嚴(yán)
(上海理工大學(xué)光電信息與計(jì)算機(jī)工程學(xué)院,上海 200093)
隨著深度學(xué)習(xí)在自然語(yǔ)言處理(Natural Language Processing,NLP)任務(wù)中的不斷發(fā)展與性能指標(biāo)的不斷提高,NLP 任務(wù)的工業(yè)落地成為可能。然而,深度學(xué)習(xí)模型結(jié)構(gòu)復(fù)雜,占用較大的存儲(chǔ)空間且計(jì)算資源消耗大,因此高性能模型很難直接部署在移動(dòng)端。為了解決這些問(wèn)題,需對(duì)模型進(jìn)行壓縮以減小模型在計(jì)算時(shí)間和空間上的消耗。
模型壓縮的目的是在保證模型預(yù)測(cè)效果的前提下,盡可能減小模型體積,提升模型推演速度。常用模型壓縮方法有剪枝(Pruning)、權(quán)重分解(Weight Factorization)、削減精度(Quantization)、權(quán)重共享(Weight Sharing)及知識(shí)蒸餾(Knowledge Distillation)。其中,知識(shí)蒸餾方法能將模型壓縮至最小規(guī)模,使性能效果最佳,近年來(lái)備受關(guān)注。Nakashole 等[1]采用零次學(xué)習(xí)(Zero-shot)的方式在翻譯任務(wù)上使用知識(shí)蒸餾;Wang 等[2]在基于強(qiáng)化學(xué)習(xí)的對(duì)話系統(tǒng)中使用知識(shí)蒸餾,提高系統(tǒng)可維護(hù)性與拓展性;Siddhartha等[3]使用多種知識(shí)蒸餾策略結(jié)合的方式應(yīng)用于問(wèn)答系統(tǒng)(QA)響應(yīng)預(yù)測(cè);Sun 等[4]利用知識(shí)蒸餾壓縮BERT 模型,壓縮得到的模型可以用于移動(dòng)端;Liu 等[5]采用樣本自適應(yīng)機(jī)制,不依賴標(biāo)注數(shù)據(jù),利用自蒸餾的方式訓(xùn)練模型;廖勝蘭等[6]利用卷積神經(jīng)網(wǎng)絡(luò)蒸餾模型(Bidirectional Encoder Representation from Transformers,BERT),用于意圖識(shí)別分類;Subhabrata 等[7]在多語(yǔ)言命名實(shí)體識(shí)別任務(wù)中采用多階段蒸餾框架,使用階段優(yōu)化的方式提高性能。
本文在現(xiàn)有研究基礎(chǔ)上,提出一種基于知識(shí)蒸餾的短文本分類模型,其中教師模型為BERT 模型,學(xué)生模型為雙向長(zhǎng)短時(shí)記憶網(wǎng)絡(luò)模型(Bi-directional Long Short-Term Memory,BiLSTM)。實(shí)驗(yàn)結(jié)果表明,經(jīng)過(guò)知識(shí)蒸餾的學(xué)生模型比單獨(dú)訓(xùn)練的學(xué)生模型分類效果更佳,并且與復(fù)雜的教師模型相比,本文模型可極大降低預(yù)測(cè)所需的響應(yīng)時(shí)間,有利于模型在工業(yè)場(chǎng)景中有效部署與使用。
很多NLP 任務(wù)場(chǎng)景可以歸結(jié)為文本分類任務(wù)。依據(jù)分類對(duì)象,文本分類可分為短文本分類(標(biāo)題、評(píng)論)和長(zhǎng)文本分類(文章、文檔);依據(jù)分類體系,文本分類可分為新聞分類、情感分析和意圖識(shí)別等;依據(jù)分類模式,文本分類可分為二分類問(wèn)題、多分類問(wèn)題以及多標(biāo)簽問(wèn)題(一個(gè)文本屬于多個(gè)類別)。
傳統(tǒng)文本分類方法常使用基于規(guī)則的特征匹配,或依賴專家系統(tǒng),往往能做到快速分類,然而與數(shù)據(jù)集所屬領(lǐng)域高度相關(guān),需要不同領(lǐng)域的專家構(gòu)建特定規(guī)則,耗力費(fèi)時(shí),且準(zhǔn)確率并不高,無(wú)法達(dá)到工業(yè)要求。
隨著統(tǒng)計(jì)學(xué)習(xí)方法的興起及互聯(lián)網(wǎng)文本數(shù)據(jù)集數(shù)量爆炸式增長(zhǎng),利用機(jī)器學(xué)習(xí)處理文本分類問(wèn)題成為主流。機(jī)器學(xué)習(xí)方法一般包含3 個(gè)步驟:文本預(yù)處理、文本特征提取和分類模型分類。文本預(yù)處理首先對(duì)文本進(jìn)行分詞,再建立停用詞詞典,去除副詞、形容詞以及連接詞,有些任務(wù)還需要進(jìn)行詞性標(biāo)注,對(duì)分詞后得到的詞直接判斷詞性;文本特征提取方式可以考慮詞頻,在一段文本中反復(fù)出現(xiàn)越多的詞越重要,權(quán)重越大,也可以考慮詞的重要性,以TF-IDF(Term Frequency-inverse Document Frequency)作為特征,表征詞重要程度;分類模型通常有邏輯回歸模型(Logistic Regression,LR)、支持向量機(jī)(Support Vector Machine,SVM)、隨機(jī)森林(RandomForest,RF)等。
機(jī)器學(xué)習(xí)的方法雖然在文本分類上取得了較好效果,但也存在問(wèn)題,文本特征提取得到的文本表示是高緯度、高稀疏的,表達(dá)特征能力很弱,且往往需要人工進(jìn)行特征工程,成本很高。于是深度學(xué)習(xí)的方法被應(yīng)用于文本分類任務(wù),用端到端的方式解決復(fù)雜耗時(shí)的人工特征工程。深度學(xué)習(xí)文本分類模型包括訓(xùn)練速度快的FastText 模型[8]、利用CNN 提取句子關(guān)鍵信息的TextCNN 模型[9]、利用雙向RNN 得到每個(gè)詞上下文表示的TextRNN 模型[10]及基于層次注意力機(jī)制網(wǎng)絡(luò)的HAN 模型[11]等。
知識(shí)蒸餾短文本分類模型主要由兩個(gè)子模型組成:教師模型與學(xué)生模型。其中教師模型直接學(xué)習(xí)真實(shí)數(shù)據(jù)標(biāo)簽,學(xué)生模型為結(jié)構(gòu)精簡(jiǎn)的小模型,蒸餾模型由學(xué)生模型通過(guò)學(xué)習(xí)教師模型的結(jié)果并結(jié)合真實(shí)標(biāo)簽的分布構(gòu)建而成。
教師模型(Teacher Model)通常為結(jié)構(gòu)相對(duì)復(fù)雜的模型,具有很好的泛化能力。本文選用BERT 模型[12]作為教師模型。
雙向編碼模型BERT 采用多層雙向Transformer 編碼器為主體進(jìn)行訓(xùn)練,舍棄RNN 等循環(huán)神經(jīng)網(wǎng)絡(luò),采用注意力機(jī)制對(duì)文本進(jìn)行建模,可捕捉更長(zhǎng)距離的依賴。BERT 使用深而窄的神經(jīng)網(wǎng)絡(luò),中間層有1 024 個(gè)神經(jīng)元,層數(shù)有12層,并采用無(wú)監(jiān)督學(xué)習(xí)的方式,無(wú)需人工干預(yù)和標(biāo)注,使用大規(guī)模語(yǔ)料進(jìn)行訓(xùn)練,其模型結(jié)構(gòu)如圖1 所示。
Fig.1 The structure of the BERT model圖1 BERT 模型結(jié)構(gòu)
文獻(xiàn)[12]將上文信息和下文信息獨(dú)立編碼再進(jìn)行拼接,但Devlin 等[13]說(shuō)明了同時(shí)編碼上下文信息的重要性。BERT 模型聯(lián)合所有層上下文進(jìn)行訓(xùn)練,使模型能很好地結(jié)合上下文理解語(yǔ)義。預(yù)訓(xùn)練好的模型只需進(jìn)行參數(shù)微調(diào)即可快速適應(yīng)多種類型的下游具體任務(wù)。
本文選用哈工大訊飛聯(lián)合實(shí)驗(yàn)室(HFL)發(fā)布的基于全詞Mask 的中文預(yù)訓(xùn)練模型BERT-wwm-ext1。該預(yù)訓(xùn)練模型收集超大量語(yǔ)料用于預(yù)訓(xùn)練,包括百科、問(wèn)答、新聞等通用語(yǔ)料,總詞數(shù)達(dá)到5.4B。BERT-wwm-ext 采用與BERT 相同的模型結(jié)構(gòu),由12 層Transformer 構(gòu)成,訓(xùn)練第一階段(最大長(zhǎng)度為128)采用的batchsize 為2 560,訓(xùn)練1M 步;訓(xùn)練第二階段(最大長(zhǎng)度為512)采用的batchsize 為384,訓(xùn)練400K步。
為了更直觀測(cè)試模型蒸餾效果,本文實(shí)驗(yàn)僅選用1 層全連接層作為分類器,對(duì)短文本類別進(jìn)行分類,并對(duì)教師模型BERT 最后4 層進(jìn)行微調(diào)。
盡管教師模型性能良好,但其模型規(guī)模往往很大,訓(xùn)練過(guò)程需消耗大量計(jì)算資源,甚至由多個(gè)模型集成而成。由于教師模型推斷速度慢,對(duì)內(nèi)存、顯存等資源要求高,因此需構(gòu)建結(jié)構(gòu)相對(duì)簡(jiǎn)單的學(xué)生模型(Student Model)學(xué)習(xí)教師模型學(xué)到的知識(shí)。
單獨(dú)訓(xùn)練學(xué)生模型往往無(wú)法達(dá)到與教師模型一樣或相當(dāng)?shù)男Ч?,因此本文將學(xué)生模型與教師模型建立聯(lián)系,通過(guò)學(xué)習(xí)教師模型的輸出訓(xùn)練學(xué)生模型。
本文選用單層雙向長(zhǎng)短時(shí)記憶網(wǎng)絡(luò)(BiLSTM)作為學(xué)生模型,采用1 層全連接層作分類器,模型結(jié)構(gòu)如圖2 所示。輸入為短文本句向量x,hl和hr分別為雙向LSTM 隱層輸出,預(yù)測(cè)結(jié)果為輸出y。
Fig.2 Student model structure diagram圖2 學(xué)生模型結(jié)構(gòu)
在學(xué)習(xí)上下文相關(guān)信息時(shí),通常使用循環(huán)神經(jīng)網(wǎng)絡(luò)(Recurrent Neural Network,RNN),然而標(biāo)準(zhǔn)RNN 存儲(chǔ)的上下文信息有限,并在網(wǎng)絡(luò)結(jié)構(gòu)較深時(shí)存在梯度消失的問(wèn)題。為了解決這些問(wèn)題,Hochreiter 等[14]提出了長(zhǎng)短時(shí)記憶網(wǎng)絡(luò)(LSTM),通過(guò)訓(xùn)練可以使LSTM 學(xué)習(xí)記憶有效信息并遺忘無(wú)效信息,更好地捕捉長(zhǎng)距離依賴關(guān)系。而雙向長(zhǎng)短時(shí)記憶網(wǎng)絡(luò)可從后往前地對(duì)信息進(jìn)行編碼,更好地捕捉雙向語(yǔ)義依賴。
對(duì)于單獨(dú)訓(xùn)練學(xué)生模型和結(jié)合教師模型訓(xùn)練的知識(shí)蒸餾模型,本文對(duì)學(xué)生模型采取相同的結(jié)構(gòu),以便進(jìn)行性能對(duì)比。
知識(shí)蒸餾是一種模型壓縮方法,最早由Hinton 等[15]在計(jì)算機(jī)視覺(jué)領(lǐng)域提出。由于計(jì)算資源昂貴,因此本文選用規(guī)模更小的模型,消耗更小的計(jì)算代價(jià)達(dá)到期望的性能。但單獨(dú)訓(xùn)練規(guī)模小的模型很難達(dá)到預(yù)期效果,所以將大規(guī)模教師模型學(xué)習(xí)到的細(xì)粒度知識(shí)遷移至學(xué)生模型訓(xùn)練中。
對(duì)于分類問(wèn)題,本文將真實(shí)的標(biāo)簽數(shù)據(jù)稱為“硬標(biāo)簽”,即每1 個(gè)數(shù)據(jù)屬于某類別的概率為1,屬于其他類別的概率為0。然而硬標(biāo)簽包含的信息量很低,真實(shí)數(shù)據(jù)往往包含一定量其他標(biāo)簽信息。例如在圖像分類識(shí)別的任務(wù)中,由于狗和貓具有相似特征,狗被預(yù)測(cè)為貓的概率遠(yuǎn)大于預(yù)測(cè)為手機(jī)的概率。具體來(lái)說(shuō),1 張長(zhǎng)得像貓的狗圖片則蘊(yùn)含更多信息量,而硬標(biāo)簽僅給出了這張照片屬于狗這一類別的分類信息。Hinton 將教師模型輸出的softmax 結(jié)果作為“軟標(biāo)簽”,軟標(biāo)簽有較高的信息熵,學(xué)生模型可通過(guò)學(xué)習(xí)軟標(biāo)簽提高自身泛化能力。
知識(shí)蒸餾模型采用教師-學(xué)生結(jié)構(gòu),如圖3 所示,教師模型輸出知識(shí),學(xué)生模型接受知識(shí)。預(yù)訓(xùn)練教師模型使用的數(shù)據(jù)集與知識(shí)蒸餾模型使用的數(shù)據(jù)集相同,模型具體實(shí)現(xiàn)步驟如下。
Step 1.使用真實(shí)數(shù)據(jù)集D中的硬標(biāo)簽訓(xùn)練教師模型T,超參調(diào)優(yōu)得到性能較好的模型。
Step 2.利用訓(xùn)練好的教師模型T計(jì)算軟標(biāo)簽。
Step 3.結(jié)合真實(shí)數(shù)據(jù)集D中的硬標(biāo)簽以及上一步驟計(jì)算得到的軟標(biāo)簽,訓(xùn)練學(xué)生模型S,損失函數(shù)如公式(1)所示。
Step 4.學(xué)生模型S的預(yù)測(cè)與常規(guī)方式相同。
Fig.3 Structure diagram of knowledge distillation model圖3 知識(shí)蒸餾模型結(jié)構(gòu)
本文選用的知識(shí)蒸餾模型損失函數(shù)Loss 由兩部分構(gòu)成。第一部分為硬標(biāo)簽與學(xué)生模型輸出的交叉熵LossCE,第二部分為軟標(biāo)簽與學(xué)生模型輸出logits 的均方差Lossdistill。
其中α為兩部分損失的平衡參數(shù),si為學(xué)生模型輸出,yi為真實(shí)標(biāo)簽數(shù)據(jù),zt為教師模型輸出的logits,zs為學(xué)生模型輸出的logits。
為驗(yàn)證該模型,本文使用CLUE 上的短文本分類公開(kāi)數(shù)據(jù)集TNEWS 作為實(shí)驗(yàn)數(shù)據(jù)。該數(shù)據(jù)集由今日頭條中文新聞標(biāo)題采集得到,包含380 000 條新聞標(biāo)題,共有15 個(gè)新聞?lì)悇e。實(shí)驗(yàn)環(huán)境如表1 所示。
Table 1 Experimental environment表1 實(shí)驗(yàn)環(huán)境
本文使用macroF1值作為模型評(píng)價(jià)指標(biāo)。
首先分別計(jì)算每個(gè)類別精度。
macro精度為所有精度平均值。
同理分別計(jì)算每個(gè)類別的召回率。
macro召回為所有召回平均值。
最后macroF1計(jì)算公式為:
其中,n為類別總數(shù),TPi、FPi和FNi分別表示第i類對(duì)應(yīng)的真正例、假正例和假反例。
3.3.1 教師模型參數(shù)
在本文實(shí)驗(yàn)中,教師模型選用在超大量語(yǔ)料上訓(xùn)練的預(yù)訓(xùn)練模型BERT-wwm-ext,并后接一層全連接層作分類。BERT-wwm-ext 模型共有12 層,隱層含有768 個(gè)神經(jīng)元,使用12 頭自注意力模式。模型采用Adam 優(yōu)化器進(jìn)行優(yōu)化,學(xué)習(xí)率為5e-4,每句話處理的長(zhǎng)度(短填長(zhǎng)切)為32。訓(xùn)練時(shí)采用批量處理的方法,批處理大小為64。教師模型共有參數(shù)102 424 805 個(gè)。
3.3.2 學(xué)生模型參數(shù)
在本文實(shí)驗(yàn)中,學(xué)生模型選用雙向長(zhǎng)短時(shí)記憶(BiLSTM)模型,也后接一層全連接層作分類。BiLSTM 為單層雙向模型,隱層含有256 個(gè)神經(jīng)元。模型輸入的句向量由其組成的詞的詞向量求和取平均得到,組成句子的詞由結(jié)巴分詞工具分詞后得到,詞向量選取用人民日?qǐng)?bào)預(yù)訓(xùn)練好的300 維詞向量[16]。模型使用SGD 作為優(yōu)化器,學(xué)習(xí)率為0.05。訓(xùn)練時(shí)采用批量處理的方法,批處理大小為64。學(xué)生模型共有參數(shù)1 209 093 個(gè)。
3.3.3 蒸餾模型參數(shù)
在本文實(shí)驗(yàn)中,蒸餾模型聯(lián)合教師-學(xué)生模型進(jìn)行訓(xùn)練,硬標(biāo)簽采用交叉熵作為損失函數(shù),軟標(biāo)簽采用均方差(MSE)作為損失函數(shù),平衡參數(shù)α選取為0.2。
本文使用TNEWS 數(shù)據(jù)集進(jìn)行分類實(shí)驗(yàn),該數(shù)據(jù)集在各類別中存在非常嚴(yán)重的不平衡問(wèn)題,且本文關(guān)注的重點(diǎn)為知識(shí)蒸餾模型效果,過(guò)多的類別數(shù)量也會(huì)產(chǎn)生影響。為了防止上述問(wèn)題對(duì)實(shí)驗(yàn)產(chǎn)生影響,本文選擇汽車、文化、教育、游戲、體育5 類數(shù)據(jù)進(jìn)行實(shí)驗(yàn),如表2 所示。
Table 2 Statistics of balanced datasets表2 實(shí)驗(yàn)平衡數(shù)據(jù)
首先,分別將教師模型與學(xué)生模型進(jìn)行單獨(dú)訓(xùn)練,得到原始模型macroF1結(jié)果,再用蒸餾模型對(duì)知識(shí)進(jìn)行蒸餾,結(jié)果如表3 所示。由表3 可知教師模型BERT-wwm-ext 在微調(diào)后準(zhǔn)確率可達(dá)81.00%,而學(xué)生模型BiLSTM 只有75.67%,即教師模型具有更深的網(wǎng)絡(luò)層數(shù)和更多參數(shù),從原始數(shù)據(jù)中學(xué)習(xí)到了更多知識(shí),具有更好的模型泛化能力;而學(xué)生模型結(jié)構(gòu)簡(jiǎn)單,僅通過(guò)超參優(yōu)化無(wú)法達(dá)到較好的效果。在教師模型的指導(dǎo)學(xué)習(xí)下,蒸餾模型macroF1值可達(dá)78.83%,比單獨(dú)訓(xùn)練學(xué)生模型提高3.16%??梢?jiàn)學(xué)生模型不僅可自主地從硬標(biāo)簽學(xué)習(xí)知識(shí),還從教師模型獲取了一部分知識(shí)。由于模型結(jié)構(gòu)簡(jiǎn)單等原因,蒸餾模型分類結(jié)果無(wú)法超越教師模型,但與單獨(dú)訓(xùn)練學(xué)生模型相比,性能明顯提升。
Table 3 Classification results of each model表3 各模型分類效果
各模型在不同類別中的分類結(jié)果如圖4 所示。其中,橫坐標(biāo)為測(cè)試數(shù)據(jù)的不同類別,縱坐標(biāo)為測(cè)試性能指標(biāo)macroF1值。由圖4 可知,蒸餾模型在教師模型的指導(dǎo)下,在所有5 個(gè)類別上分類效果均優(yōu)于單獨(dú)訓(xùn)練的學(xué)生模型,并且在汽車類別上十分接近教師模型分類結(jié)果,這表明在一些區(qū)分度高的類別上,簡(jiǎn)單模型通過(guò)一定方式學(xué)習(xí)后可達(dá)到復(fù)雜模型效果。所有模型在文化類別中分類效果均相對(duì)較差,這是因?yàn)槲幕愓Z(yǔ)句相對(duì)其他類別的語(yǔ)句更加抽象復(fù)雜,往往包含其他類別的含義,模型無(wú)法學(xué)習(xí)到深層次特征從而抽象表達(dá)這些語(yǔ)句,導(dǎo)致分類結(jié)果難以提升。
Fig.4 Test results in various categories圖4 模型在各個(gè)類別測(cè)試結(jié)果
本文使用不同模型在相同測(cè)試集中進(jìn)行實(shí)驗(yàn),分析模型時(shí)間性能。教師模型與學(xué)生模型在3 000 條測(cè)試數(shù)據(jù)中進(jìn)行實(shí)驗(yàn)的推理時(shí)間對(duì)比結(jié)果如表4 所示。
Table 4 Runtime to complete one itevation表4 完成1 次迭代的推理時(shí)間
從表4 可以看出,學(xué)生模型在完成1 次推理的時(shí)間遠(yuǎn)少于教師模型,所需時(shí)間僅有教師模型的1/725,這主要是因?yàn)榕c教師模型相比,學(xué)生模型結(jié)構(gòu)相對(duì)簡(jiǎn)單,模型參數(shù)只有教師模型的1/85 倍。知識(shí)蒸餾模型在準(zhǔn)確率接近教師模型的情況下,推理時(shí)間更短,有利于在真實(shí)場(chǎng)景(如移動(dòng)端)的部署。
本文針對(duì)結(jié)構(gòu)復(fù)雜模型難以落地應(yīng)用的現(xiàn)狀,提出一種基于教師—學(xué)生框架的知識(shí)蒸餾模型,應(yīng)用于短文本分類任務(wù)。該模型首先預(yù)訓(xùn)練1 個(gè)分類性能好、結(jié)構(gòu)復(fù)雜的大模型,再將大模型所學(xué)知識(shí)遷移至結(jié)構(gòu)簡(jiǎn)單的小模型中,以此彌補(bǔ)小模型單獨(dú)訓(xùn)練時(shí)泛化能力不足的問(wèn)題。實(shí)驗(yàn)結(jié)果表明,知識(shí)蒸餾小模型性能顯著改善,同時(shí),模型迭代推理時(shí)間大幅縮短,使模型在工業(yè)場(chǎng)景中進(jìn)行應(yīng)用成為可能。
目前開(kāi)源中文數(shù)據(jù)集較少,多為爬蟲(chóng)獲得,數(shù)據(jù)集質(zhì)量較差,人工標(biāo)記數(shù)據(jù)又有耗力費(fèi)時(shí)等問(wèn)題。本文使用的分類器僅使用了1 層全連接層,學(xué)生模型也選擇了較為簡(jiǎn)單的單層雙向長(zhǎng)短時(shí)記憶網(wǎng)絡(luò)。針對(duì)上述問(wèn)題,下一步工作是尋找合適的生成方式以產(chǎn)生高質(zhì)量數(shù)據(jù),以及如何選擇較為復(fù)雜的適用于工業(yè)場(chǎng)景的學(xué)生模型。