遷移學(xué)習(xí)目錄 · 定義 · 為什么需要遷移學(xué)習(xí) · 遷移學(xué)習(xí)的研究領(lǐng)域 · 遷移學(xué)習(xí)的應(yīng)用 · 基礎(chǔ)知識 · 遷移學(xué)習(xí)的基本方法 · 遷移學(xué)習(xí)算法-TCA · 遷移學(xué)習(xí)算法-Deep Adaptation Networks 定義· 遷移學(xué)習(xí)是一種學(xué)習(xí)的思想和模式。 · 遷移學(xué)習(xí)作為機(jī)器學(xué)習(xí)的一個重要分支,側(cè)重于將已經(jīng)學(xué)習(xí)過的知識遷移應(yīng)用于新的問題中。 · 遷移學(xué)習(xí)的核心問題是,找到新問題和原問題之間的相似性,才可以順利地實現(xiàn)知識的遷移。 · 遷移學(xué)習(xí),是指利用數(shù)據(jù)、任務(wù)、或模型之間的相似性,將在舊領(lǐng)域?qū)W習(xí)過的模型,應(yīng)用于新領(lǐng)域的一種學(xué)習(xí)過程。 遷移學(xué)習(xí)例子 為什么需要遷移學(xué)習(xí)原因概括為以下四個方面:
1、大數(shù)據(jù)與少標(biāo)注之間的矛盾 我們正處在一個大數(shù)據(jù)時代,每天每時,社交網(wǎng)絡(luò)、智能交通、視頻監(jiān)控、行業(yè)物流等,都產(chǎn)生著海量的圖像、文本、語音等各類數(shù)據(jù)。數(shù)據(jù)的增多,使得機(jī)器學(xué)習(xí)和深度學(xué)習(xí)模型可以依賴于如此海量的數(shù)據(jù),持續(xù)不斷地訓(xùn)練和更新相應(yīng)的模型,使得模型的性能越來越好,越來越適合特定場景的應(yīng)用。然而,這些大數(shù)據(jù)帶來了嚴(yán)重的問題:總是缺乏完善的數(shù)據(jù)標(biāo)注。 眾所周知,機(jī)器學(xué)習(xí)模型的訓(xùn)練和更新,均依賴于數(shù)據(jù)的標(biāo)注。然而,盡管我們可以獲取到海量的數(shù)據(jù),這些數(shù)據(jù)往往是很初級的原始形態(tài),很少有數(shù)據(jù)被加以正確的人工標(biāo)注。數(shù)據(jù)的標(biāo)注是一個耗時且昂貴的操作,目前為止,尚未有行之有效的方式來解決這一問題。這給機(jī)器學(xué)習(xí)和深度學(xué)習(xí)的模型訓(xùn)練和更新帶來了挑戰(zhàn)。反過來說,特定的領(lǐng)域,因為沒有足夠的標(biāo)定數(shù)據(jù)用來學(xué)習(xí),使得這些領(lǐng)域一直不能很好的發(fā)展。 2、大數(shù)據(jù)與弱計算之間的矛盾 大數(shù)據(jù),就需要大設(shè)備、強(qiáng)計算能力的設(shè)備來進(jìn)行存儲和計算。然而,大數(shù)據(jù)的大計算能力,是' 有錢人' 才能玩得起的游戲。比如 Google,F(xiàn)acebook,Microsoft,這些巨無霸公司有著雄厚的計算能力去利用這些數(shù)據(jù)訓(xùn)練模型。例如,ResNet 需要很長的時間進(jìn)行訓(xùn)練。Google TPU 也都是有錢人的才可以用得起的。 絕大多數(shù)普通用戶是不可能具有這些強(qiáng)計算能力的。這就引發(fā)了大數(shù)據(jù)和弱計算之間的矛盾。在這種情況下,普通人想要利用這些海量的大數(shù)據(jù)去訓(xùn)練模型完成自己的任務(wù),基本上不太可能。那么如何讓普通人也能利用這些數(shù)據(jù)和模型? 3、普適化模型與個性化需求之間的矛盾 機(jī)器學(xué)習(xí)的目標(biāo)是構(gòu)建一個盡可能通用的模型,使得這個模型對于不同用戶、不同設(shè)備、不同環(huán)境、不同需求,都可以很好地進(jìn)行滿足。這是我們的美好愿景。這就是要盡可能地提高機(jī)器學(xué)習(xí)模型的泛化能力,使之適應(yīng)不同的數(shù)據(jù)情形?;谶@樣的愿望,我們構(gòu)建了多種多樣的普適化模型,來服務(wù)于現(xiàn)實應(yīng)用。然而,這只能是我們竭盡全力想要做的,目前卻始終無法徹底解決的問題。人們的個性化需求五花八門,短期內(nèi)根本無法用一個通用的模型去滿足。比如導(dǎo)航模型,可以定位及導(dǎo)航所有的路線。但是不同的人有不同的需求。比如有的人喜歡走高速,有的人喜歡走偏僻小路,這就是個性化需求。并且,不同的用戶,通常都有不同的隱私需求。這也是構(gòu)建應(yīng)用需要著重考慮的。 所以目前的情況是,我們對于每一個通用的任務(wù)都構(gòu)建了一個通用的模型。這個模型可以解決絕大多數(shù)的公共問題。但是具體到每個個體、每個需求,都存在其唯一性和特異性,一個普適化的通用模型根本無法滿足。那么,能否將這個通用的模型加以改造和適配,使其更好地服務(wù)于人們的個性化需求? 4、特定應(yīng)用的需求 機(jī)器學(xué)習(xí)已經(jīng)被廣泛應(yīng)用于現(xiàn)實生活中。在這些應(yīng)用中,也存在著一些特定的應(yīng)用,它們面臨著一些現(xiàn)實存在的問題。比如推薦系統(tǒng)的冷啟動問題。一個新的推薦系統(tǒng),沒有足夠的用戶數(shù)據(jù),如何進(jìn)行精準(zhǔn)的推薦? 一個嶄新的圖片標(biāo)注系統(tǒng),沒有足夠的標(biāo)簽,如何進(jìn)行精準(zhǔn)的服務(wù)?現(xiàn)實世界中的應(yīng)用驅(qū)動著我們?nèi)ラ_發(fā)更加便捷更加高效的機(jī)器學(xué)習(xí)方法來加以解決。 上述存在的幾個重要問題,使得傳統(tǒng)的機(jī)器學(xué)習(xí)方法疲于應(yīng)對。遷移學(xué)習(xí)則可以很好地進(jìn)行解決。 遷移學(xué)習(xí)是如何進(jìn)行解決的呢? 大數(shù)據(jù)與少標(biāo)注:遷移數(shù)據(jù)標(biāo)注
大數(shù)據(jù)與弱計算:模型遷移
普適化模型與個性化需求:自適應(yīng)學(xué)習(xí)
特定應(yīng)用的需求:相似領(lǐng)域知識遷移 為了滿足特定領(lǐng)域應(yīng)用的需求,我們可以利用上述介紹過的手段,從數(shù)據(jù)和模型方法上進(jìn)行遷移學(xué)習(xí)。 總結(jié) 遷移學(xué)習(xí) VS 傳統(tǒng)機(jī)器學(xué)習(xí) 遷移學(xué)習(xí)的研究領(lǐng)域 依據(jù)目前較流行的機(jī)器學(xué)習(xí)分類方法,機(jī)器學(xué)習(xí)主要可以分為有監(jiān)督、半監(jiān)督和無監(jiān)督機(jī)器學(xué)習(xí)三大類。同理,遷移學(xué)習(xí)也可以進(jìn)行這樣的分類。需要注意的是,依據(jù)的分類準(zhǔn)則不同,分類結(jié)果也不同。在這一點上,并沒有一個統(tǒng)一的說法。我們在這里僅根據(jù)目前較流行的方法,對遷移學(xué)習(xí)的研究領(lǐng)域進(jìn)行一個大致的劃分。 大體上講,遷移學(xué)習(xí)的分類可以按照四個準(zhǔn)則進(jìn)行:按目標(biāo)域有無標(biāo)簽分、按學(xué)習(xí)方法分、按特征分、按離線與在線形式分。不同的分類方式對應(yīng)著不同的專業(yè)名詞。當(dāng)然,即使是一個分類下的研究領(lǐng)域,也可能同時處于另一個分類下。下面我們對這些分類方法及相應(yīng)的領(lǐng)域作簡單描述。 按目標(biāo)域標(biāo)簽分 這種分類方式最為直觀。類比機(jī)器學(xué)習(xí),按照目標(biāo)領(lǐng)域有無標(biāo)簽,遷移學(xué)習(xí)可以分為以下三個大類:
顯然,少標(biāo)簽或無標(biāo)簽的問題 (半監(jiān)督和無監(jiān)督遷移學(xué)習(xí)),是研究的熱點和難點。 按學(xué)習(xí)方法分類 按學(xué)習(xí)方法的分類形式,最早在遷移學(xué)習(xí)領(lǐng)域的權(quán)威綜述文章 [Pan and Yang, 2010] 給出定義。它將遷移學(xué)習(xí)方法分為以下四個大類:
這是一個很直觀的分類方式,按照數(shù)據(jù)、特征、模型的機(jī)器學(xué)習(xí)邏輯進(jìn)行區(qū)分,再加上不屬于這三者中的關(guān)系模式。 · 基于實例的遷移,簡單來說就是通過權(quán)重重用,對源域和目標(biāo)域的樣例進(jìn)行遷移。就是說直接對不同的樣本賦予不同權(quán)重,比如說相似的樣本,我就給它高權(quán)重,這樣我就完成了遷移,非常簡單非常非常直接。 · 基于特征的遷移,就是更進(jìn)一步對特征進(jìn)行變換。意思是說,假設(shè)源域和目標(biāo)域的特征原來不在一個空間,或者說它們在原來那個空間上不相似,那我們就想辦法把它們變換到一個空間里面,那這些特征不就相似了?這個思路也非常直接。這個方法是用得非常多的,一直在研究,目前是感覺是研究最熱的。 · 基于模型的遷移,就是說構(gòu)建參數(shù)共享的模型。這個主要就是在神經(jīng)網(wǎng)絡(luò)里面用的特別多,因為神經(jīng)網(wǎng)絡(luò)的結(jié)構(gòu)可以直接進(jìn)行遷移。比如說神經(jīng)網(wǎng)絡(luò)最經(jīng)典的 finetune 就是模型參數(shù)遷移的很好的體現(xiàn)。 · 基于關(guān)系的遷移,這個方法用的比較少,這個主要就是說挖掘和利用關(guān)系進(jìn)行類比遷移。比如老師上課、學(xué)生聽課就可以類比為公司開會的場景。這個就是一種關(guān)系的遷移。 · 目前最熱的就是基于特征還有模型的遷移,然后基于實例的遷移方法和他們結(jié)合起來使用。 按特征分類 按照特征的屬性進(jìn)行分類,也是一種常用的分類方法。這在最近的遷移學(xué)習(xí)綜述 [Weiss et al., 2016]中給出。按照特征屬性,遷移學(xué)習(xí)可以分為兩個大類:
這也是一種很直觀的方式:如果特征語義和維度都相同,那么就是同構(gòu);反之,如果特征完全不相同,那么就是異構(gòu)。舉個例子來說,不同圖片的遷移,就可以認(rèn)為是同構(gòu);而圖片到文本的遷移,則是異構(gòu)的。 按離線與在線形式分 按照離線學(xué)習(xí)與在線學(xué)習(xí)的方式,遷移學(xué)習(xí)還可以被分為:
目前,絕大多數(shù)的遷移學(xué)習(xí)方法,都采用了離線方式。即,源域和目標(biāo)域均是給定的,遷移一次即可。這種方式的缺點是顯而易見的:算法無法對新加入的數(shù)據(jù)進(jìn)行學(xué)習(xí),模型也無法得到更新。與之相對的,是在線的方式。即隨著數(shù)據(jù)的動態(tài)加入,遷移學(xué)習(xí)算法也可以不斷地更新。 遷移學(xué)習(xí)的應(yīng)用 遷移學(xué)習(xí)的應(yīng)用· 遷移學(xué)習(xí)是機(jī)器學(xué)習(xí)領(lǐng)域的一個重要分支。因此,其應(yīng)用并不局限于特定的領(lǐng)域。凡是滿足遷移學(xué)習(xí)問題情景的應(yīng)用,遷移學(xué)習(xí)都可以發(fā)揮作用。這些領(lǐng)域包括但不限于計算機(jī)視覺、文本分類、行為識別、自然語言處理、室內(nèi)定位、視頻監(jiān)控、輿情分析、人機(jī)交互等。 計算機(jī)視覺 遷移學(xué)習(xí)已被廣泛地應(yīng)用于計算機(jī)視覺的研究中。特別地,在計算機(jī)視覺中,遷移學(xué)習(xí)方法被稱為 Domain Adaptation。Domain adaptation 的應(yīng)用場景有很多,比如圖片分類、圖片哈希等。 同一類圖片,不同的拍攝角度、不同光照、不同背景,都會造成特征分布發(fā)生改變。因此,使用遷移學(xué)習(xí)構(gòu)建跨領(lǐng)域的魯棒分類器是十分重要的。 計算機(jī)視覺三大頂會 (CVPR、ICCV、ECCV) 每年都會發(fā)表大量的文章對遷移學(xué)習(xí)在視覺領(lǐng)域的應(yīng)用進(jìn)行介紹。 文本分類 由于文本數(shù)據(jù)有其領(lǐng)域特殊性,因此,在一個領(lǐng)域上訓(xùn)練的分類器,不能直接拿來作用到另一個領(lǐng)域上。這就需要用到遷移學(xué)習(xí)。例如,在電影評論文本數(shù)據(jù)集上訓(xùn)練好的分類器,不能直接用于圖書評論的預(yù)測。這就需要進(jìn)行遷移學(xué)習(xí)。下圖一個由電子產(chǎn)品評論遷移到 DVD 評論的遷移學(xué)習(xí)任務(wù)。 文本和網(wǎng)絡(luò)領(lǐng)域頂級會議 WWW 和 CIKM 每年有大量的文章對遷移學(xué)習(xí)在文本領(lǐng)域的應(yīng)用作介紹。 時間序列-行為識別 行為識別 (Activity Recognition) 主要通過佩戴在用戶身體上的傳感器,研究用戶的行為。行為數(shù)據(jù)是一種時間序列數(shù)據(jù)。不同用戶、不同環(huán)境、不同位置、不同設(shè)備,都會導(dǎo)致時間序列數(shù)據(jù)的分布發(fā)生變化。此時,也需要進(jìn)行遷移學(xué)習(xí)。下圖展示了同一用戶不同位置的信號差異性。在這個領(lǐng)域,華盛頓州立大學(xué)的 Diane Cook 等人在 2013 年發(fā)表的關(guān)于遷移學(xué)習(xí)在行為識別領(lǐng)域的綜述文章 [Cook et al., 2013] 是很好的參考資料。 時間序列-室內(nèi)定位 室內(nèi)定位 (Indoor Location) 與傳統(tǒng)的室外用 GPS 定位不同,它通過 WiFi、藍(lán)牙等設(shè)備研究人在室內(nèi)的位置。不同用戶、不同環(huán)境、不同時刻也會使得采集的信號分布發(fā)生變化。下圖展示了不同時間、不同設(shè)備的 WiFi 信號變化。 醫(yī)療健康 醫(yī)療健康領(lǐng)域的研究正變得越來越重要。不同于其他領(lǐng)域,醫(yī)療領(lǐng)域研究的難點問題是,無法獲取足夠有效的醫(yī)療數(shù)據(jù)。在這一領(lǐng)域,遷移學(xué)習(xí)同樣也變得越來越重要。 基礎(chǔ)知識
是進(jìn)行一切研究的前提。在遷移學(xué)習(xí)中,有兩個基本的概念:領(lǐng)域 (Domain) 和任務(wù) (Task)。它們是最基礎(chǔ)的概念。
是進(jìn)行學(xué)習(xí)的主體。領(lǐng)域主要由兩部分構(gòu)成:數(shù)據(jù)和生成這些數(shù)據(jù)的概率分布。通常我們用 D 來表示一個 domain,用大寫 P 來表示一個概率分布。 特別地,因為涉及到遷移,所以對應(yīng)于兩個基本的領(lǐng)域:源領(lǐng)域 (Source Domain) 和目標(biāo)領(lǐng)域 (Target Domain)。這兩個概念很好理解。源領(lǐng)域就是有知識、有大量數(shù)據(jù)標(biāo)注的領(lǐng)域,是我們要遷移的對象;目標(biāo)領(lǐng)域就是我們最終要賦予知識、賦予標(biāo)注的對象。知識從源領(lǐng)域傳遞到目標(biāo)領(lǐng)域,就完成了遷移。 領(lǐng)域上的數(shù)據(jù),我們通常用小寫粗體 x 來表示,它也是向量的表示形式。例如,xi 就表示第 i 個樣本或特征。用大寫的黑體 X 表示一個領(lǐng)域的數(shù)據(jù),這是一種矩陣形式。我們用大寫花體 X 來表示數(shù)據(jù)的特征空間。 通常我們用小寫下標(biāo) s 和 t 來分別指代兩個領(lǐng)域。結(jié)合領(lǐng)域的表示方式,則:Ds 表示源領(lǐng)域,Dt 表示目標(biāo)領(lǐng)域。 值得注意的是,概率分布 P 通常只是一個邏輯上的概念,即我們認(rèn)為不同領(lǐng)域有不同的概率分布,卻一般不給出(也難以給出)P 的具體形式。
任務(wù) (Task): 是學(xué)習(xí)的目標(biāo)。任務(wù)主要由兩部分組成:標(biāo)簽和標(biāo)簽對應(yīng)的函數(shù)。通常我們用花體 Y 來表示一個標(biāo)簽空間,用 f(·) 來表示一個學(xué)習(xí)函數(shù)。 相應(yīng)地,源領(lǐng)域和目標(biāo)領(lǐng)域的類別空間就可以分別表示為 Ys 和 Yt 。我們用小寫 ys 和yt 分別表示源領(lǐng)域和目標(biāo)領(lǐng)域的實際類別。
常用符號總結(jié) 遷移學(xué)習(xí)的核心是,找到源領(lǐng)域和目標(biāo)領(lǐng)域之間的相似性,并加以合理利用。這種相似性非常普遍。比如,不同人的身體構(gòu)造是相似的;自行車和摩托車的騎行方式是相似的;國際象棋和中國象棋是相似的;羽毛球和網(wǎng)球的打球方式是相似的。這種相似性也可以理解為不變量。以不變應(yīng)萬變,才能立于不敗之地。 找到相似性 (不變量),是進(jìn)行遷移學(xué)習(xí)的核心。 有了這種相似性后,下一步工作就是,如何度量和利用這種相似性。度量工作的目標(biāo)有兩點:一是很好地度量兩個領(lǐng)域的相似性,不僅定性地告訴我們它們是否相似,更定量地給出相似程度。二是以度量為準(zhǔn)則,通過我們所要采用的學(xué)習(xí)手段,增大兩個領(lǐng)域之間的相似性,從而完成遷移學(xué)習(xí)。
定義在兩個向量 (兩個點) 上,這兩個數(shù)據(jù)在同一個分布里。點 x 和點 y 的馬氏距離為:
最大均值差異是遷移學(xué)習(xí)中使用頻率最高的度量。Maximum mean discrepancy,它度量在再生希爾伯特空間中兩個分布的距離,是一種核學(xué)習(xí)方法。兩個隨機(jī)變量的 MMD 平方距離為 遷移學(xué)習(xí)的基本方法· 基于樣本的遷移 · 基于模型的遷移 · 基于特征的遷移 · 基于關(guān)系的遷移
基于樣本的遷移學(xué)習(xí)方法 (Instance based Transfer Learning) 根據(jù)一定的權(quán)重生成規(guī)則,對數(shù)據(jù)樣本進(jìn)行重用,來進(jìn)行遷移學(xué)習(xí)。圖片形象地表示了基于樣本遷移方法的思想。 源域中存在不同種類的動物,如狗、鳥、貓等,目標(biāo)域只有狗這一種類別。在遷移時,為了最大限度地和目標(biāo)域相似,我們可以人為地提高源域中屬于狗這個類別的樣本權(quán)重。 雖然實例權(quán)重法具有較好的理論支撐、容易推導(dǎo)泛化誤差上界,但這類方法通常只在領(lǐng)域間分布差異較小時有效,因此對自然語言處理、計算機(jī)視覺等任務(wù)效果并不理想。
基于特征的遷移方法 (Feature based Transfer Learning) 是指將通過特征變換的方式互相遷移 [Liu et al., 2011, Zheng et al., 2008, Hu and Yang, 2011],來減少源域和目標(biāo)域之間的差距;或者將源域和目標(biāo)域的數(shù)據(jù)特征變換到統(tǒng)一特征空間中 [Pan et al., 2011,Long et al., 2014b, Duan et al., 2012],然后利用傳統(tǒng)的機(jī)器學(xué)習(xí)方法進(jìn)行分類識別。根據(jù)特征的同構(gòu)和異構(gòu)性,又可以分為同構(gòu)和異構(gòu)遷移學(xué)習(xí)。圖片很形象地表示了兩種基于特征的遷移學(xué)習(xí)方法。 基于特征的遷移學(xué)習(xí)方法是遷移學(xué)習(xí)領(lǐng)域中最熱門的研究方法,這類方法通常假設(shè)源域和目標(biāo)域間有一些交叉的特征。
基于模型的遷移方法 (Parameter/Model based Transfer Learning) 是指從源域和目標(biāo)域中找到他們之間共享的參數(shù)信息,以實現(xiàn)遷移的方法。這種遷移方式要求的假設(shè)條件是:源域中的數(shù)據(jù)與目標(biāo)域中的數(shù)據(jù)可以共享一些模型的參數(shù)。
基于關(guān)系的遷移學(xué)習(xí)方法 (Relation Based Transfer Learning) 與上述三種方法具有截然不同的思路。這種方法比較關(guān)注源域和目標(biāo)域的樣本之間的關(guān)系。圖片形象地表示了不同領(lǐng)域之間相似的關(guān)系。 就目前來說,基于關(guān)系的遷移學(xué)習(xí)方法的相關(guān)研究工作非常少,大部分都借助于馬爾科夫邏輯網(wǎng)絡(luò) (Markov Logic Net) 來挖掘不同領(lǐng)域之間的關(guān)系相似性。 基于關(guān)系的遷移學(xué)習(xí)方法示意圖
遷移學(xué)習(xí)算法-TCA
數(shù)據(jù)分布自適應(yīng) (Distribution Adaptation) 是一類最常用的遷移學(xué)習(xí)方法。這種方法的基本思想是,由于源域和目標(biāo)域的數(shù)據(jù)概率分布不同,那么最直接的方式就是通過一些變換,將不同的數(shù)據(jù)分布的距離拉近。 根據(jù)數(shù)據(jù)分布的性質(zhì),這類方法又可以分為邊緣分布自適應(yīng)、條件分布自適應(yīng)、以及聯(lián)合分布自適應(yīng)。 圖片形象地表示了幾種數(shù)據(jù)分布的情況。簡單來說,數(shù)據(jù)的邊緣分布不同,就是數(shù)據(jù)整體不相似。數(shù)據(jù)的條件分布不同,就是數(shù)據(jù)整體相似,但是具體到每個類里,都不太相似。
遷移成分分析 (Transfer Component Analysis)是一種邊緣分布自適應(yīng)方法 (Marginal Distribution Adaptation) 其目標(biāo)是減小源域和目標(biāo)域的邊緣概率分布的距離,從而完成遷移學(xué)習(xí) 從形式上來說,邊緣分布自適應(yīng)方法是用 P(xs )和 P(xt ) 之間的距離來近似兩個領(lǐng)域之間的差異。即: 邊緣分布自適應(yīng)的方法最早由香港科技大學(xué)楊強(qiáng)教授團(tuán)隊提出 [Pan et al., 2011] 問題:但是世界上有無窮個這樣的 ?,我們肯定不能通過窮舉的方法來找 ? 的。那么怎么辦呢? 遷移學(xué)習(xí)的本質(zhì):最小化源域和目標(biāo)域的距離。 能否先假設(shè)這個? 是已知的,然后去求距離,看看能推出什么? 以上式子下面的條件是什么意思呢?那個 min 的目標(biāo)就是要最小化源域和目標(biāo)域的距離,加上 W 的約束讓它不能太復(fù)雜。下面的條件是是要實現(xiàn)第二個目標(biāo):維持各自的數(shù)據(jù)特征。 TCA 要維持的特征是scatter matrix,就是數(shù)據(jù)的散度。就是說,一個矩陣散度怎么計算?對于一個矩陣 A,它的 scatter matrix 就是AHA? 。這個 H 就是上面的中心矩陣。 TCA 和 PCA 的效果對比 可以很明顯地看出,對于概率分布不同的兩部分?jǐn)?shù)據(jù),在經(jīng)過 TCA處理后,概率分布更加接近。這說明了 TCA 在拉近數(shù)據(jù)分布距離上的優(yōu)勢。 遷移學(xué)習(xí)算法-Deep Adaptation Networks from __future__ import print_function import argparse import torch import torch.nn.functional as F import torch.optim as optim from torch.autograd import Variable import os import math import data_loader import ResNet as models from torch.utils import model_zoo os.environ['CUDA_VISIBLE_DEVICES'] = '1' # Training settings batch_size = 32 epochs = 200 lr = 0.01 momentum = 0.9 no_cuda =False seed = 8 log_interval = 10 l2_decay = 5e-4 root_path = './dataset/' source_name = 'amazon' target_name = 'webcam' cuda = not no_cuda and torch.cuda.is_available() torch.manual_seed(seed) if cuda: torch.cuda.manual_seed(seed) kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {} source_loader = data_loader.load_training(root_path, source_name, batch_size, kwargs) target_train_loader = data_loader.load_training(root_path, target_name, batch_size, kwargs) target_test_loader = data_loader.load_testing(root_path, target_name, batch_size, kwargs) len_source_dataset = len(source_loader.dataset) len_target_dataset = len(target_test_loader.dataset) len_source_loader = len(source_loader) len_target_loader = len(target_train_loader) def load_pretrain(model): url = 'https://download.pytorch.org/models/resnet50-19c8e357.pth' pretrained_dict = model_zoo.load_url(url) model_dict = model.state_dict() for k, v in model_dict.items(): if not 'cls_fc' in k: model_dict[k] = pretrained_dict[k[k.find('.') + 1:]] model.load_state_dict(model_dict) return model def train(epoch, model): LEARNING_RATE = lr / math.pow((1 + 10 * (epoch - 1) / epochs), 0.75) print('learning rate{: .4f}'.format(LEARNING_RATE) ) optimizer = torch.optim.SGD([ {'params': model.sharedNet.parameters()}, {'params': model.cls_fc.parameters(), 'lr': LEARNING_RATE}, ], lr=LEARNING_RATE / 10, momentum=momentum, weight_decay=l2_decay) model.train() iter_source = iter(source_loader) iter_target = iter(target_train_loader) num_iter = len_source_loader for i in range(1, num_iter): data_source, label_source = iter_source.next() data_target, _ = iter_target.next() if i % len_target_loader == 0: iter_target = iter(target_train_loader) if cuda: data_source, label_source = data_source.cuda(), label_source.cuda() data_target = data_target.cuda() data_source, label_source = Variable(data_source), Variable(label_source) data_target = Variable(data_target) optimizer.zero_grad() label_source_pred, loss_mmd = model(data_source, data_target) loss_cls = F.nll_loss(F.log_softmax(label_source_pred, dim=1), label_source) gamma = 2 / (1 + math.exp(-10 * (epoch) / epochs)) - 1 loss = loss_cls + gamma * loss_mmd loss.backward() optimizer.step() if i % log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tsoft_Loss: {:.6f}\tmmd_Loss: {:.6f}'.format( epoch, i * len(data_source), len_source_dataset, 100. * i / len_source_loader, loss.data[0], loss_cls.data[0], loss_mmd.data[0])) def test(model): model.eval() test_loss = 0 correct = 0 for data, target in target_test_loader: if cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data, volatile=True), Variable(target) s_output, t_output = model(data, data) test_loss += F.nll_loss(F.log_softmax(s_output, dim = 1), target, size_average=False).data[0] # sum up batch loss pred = s_output.data.max(1)[1] # get the index of the max log-probability correct += pred.eq(target.data.view_as(pred)).cpu().sum() test_loss /= len_target_dataset print('\n{} set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( target_name, test_loss, correct, len_target_dataset, 100. * correct / len_target_dataset)) return correct if __name__ == '__main__': model = models.DANNet(num_classes=31) correct = 0 print(model) if cuda: model.cuda() model = load_pretrain(model) for epoch in range(1, epochs + 1): train(epoch, model) t_correct = test(model) if t_correct > correct: correct = t_correct print('source: {} to target: {} max correct: {} max accuracy{: .2f}%\n'.format( source_name, target_name, correct, 100. * correct / len_target_dataset )) |
|
來自: taotao_2016 > 《AI》