Deep learning:四十一(Dropout簡單理解)
前言
訓練神經網路模型時,如果訓練樣本較少,為了防止模型過擬合,Dropout可以作為一種trikc供選擇。Dropout是hintion最近2年提出的,源於其文章Improving neural networks by preventing co-adaptation of feature detectors.中文大意為:通過阻止特徵檢測器的共同作用來提高神經網路的效能。本篇博文就是按照這篇論文簡單介紹下Dropout的思想,以及從用一個簡單的例子來說明該如何使用dropout。
基礎知識:
Dropout是指在模型訓練時隨機讓網路某些隱含層節點的權重不工作,不工作的那些節點可以暫時認為不是網路結構的一部分,但是它的權重得保留下來(只是暫時不更新而已),因為下次樣本輸入時它可能又得工作了(有點抽象,具體實現看後面的實驗部分)。
按照hinton的文章,他使用Dropout時訓練階段和測試階段做了如下操作:
在樣本的訓練階段,在沒有采用pre-training的網路時(Dropout當然可以結合pre-training一起使用),hintion並不是像通常那樣對權值採用L2範數懲罰,而是對每個隱含節點的權值L2範數設定一個上限bound,當訓練過程中如果該節點不滿足bound約束,則用該bound值對權值進行一個規範化操作(即同時除以該L2範數值),說是這樣可以讓權值更新初始的時候有個大的學習率供衰減,並且可以搜尋更多的權值空間(沒理解)。
在模型的測試階段,使用”mean network(均值網路)”來得到隱含層的輸出,其實就是在網路前向傳播到輸出層前時隱含層節點的輸出值都要減半(如果dropout的比例為50%),其理由文章說了一些,可以去檢視(沒理解)。
關於Dropout,文章中沒有給出任何數學解釋,Hintion的直觀解釋和理由如下:
1. 由於每次用輸入網路的樣本進行權值更新時,隱含節點都是以一定概率隨機出現,因此不能保證每2個隱含節點每次都同時出現,這樣權值的更新不再依賴於有固定關係隱含節點的共同作用,阻止了某些特徵僅僅在其它特定特徵下才有效果的情況。
2. 可以將dropout看作是模型平均的一種。對於每次輸入到網路中的樣本(可能是一個樣本,也可能是一個batch的樣本),其對應的網路結構都是不同的,但所有的這些不同的網路結構又同時share隱含節點的權值。這樣不同的樣本就對應不同的模型,是bagging的一種極端情況。個人感覺這個解釋稍微靠譜些,和bagging,boosting理論有點像,但又不完全相同。
3. native bayes是dropout的一個特例。Native bayes有個錯誤的前提,即假設各個特徵之間相互獨立,這樣在訓練樣本比較少的情況下,單獨對每個特徵進行學習,測試時將所有的特徵都相乘,且在實際應用時效果還不錯。而Droput每次不是訓練一個特徵,而是一部分隱含層特徵。
4. 還有一個比較有意思的解釋是,Dropout類似於性別在生物進化中的角色,物種為了使適應不斷變化的環境,性別的出現有效的阻止了過擬合,即避免環境改變時物種可能面臨的滅亡。
文章最後當然是show了一大把的實驗來說明dropout可以阻止過擬合。這些實驗都是些常見的benchmark,比如Mnist, Timit, Reuters, CIFAR-10, ImageNet.
實驗過程:
本文實驗時用mnist庫進行手寫數字識別,訓練樣本2000個,測試樣本1000個,用的是matlab的https://github.com/rasmusbergpalm/DeepLearnToolbox,程式碼在test_example_NN.m上修改得到。關於該toolbox的介紹可以參考網友的博文【面向程式碼】學習 Deep Learning(一)Neural Network。這裡我只用了個簡單的單個隱含層神經網路,隱含層節點的個數為100,所以輸入層-隱含層-輸出層節點依次為784-100-10. 為了使本例子簡單話,沒用對權值w進行規則化,採用mini-batch訓練,每個mini-batch樣本大小為100,迭代20次。權值採用隨機初始化。
實驗結果:
沒用Dropout時:
訓練樣本錯誤率(均方誤差):0.032355
測試樣本錯誤率:15.500%
使用Dropout時:
訓練樣本錯誤率(均方誤差):0.075819
測試樣本錯誤率:13.000%
可以看出使用Dropout後,雖然訓練樣本的錯誤率較高,但是訓練樣本的錯誤率降低了,說明Dropout的泛化能力不錯,可以防止過擬合。
實驗主要程式碼及註釋:
test_dropout.m:
%% //匯入minst資料並歸一化 load mnist_uint8; train_x = double(train_x(1:2000,:)) / 255; test_x = double(test_x(1:1000,:)) / 255; train_y = double(train_y(1:2000,:)); test_y = double(test_y(1:1000,:)); % //normalize [train_x, mu, sigma] = zscore(train_x);% //歸一化train_x,其中mu是個行向量,mu是個列向量 test_x = normalize(test_x, mu, sigma);% //線上測試時,歸一化用的是訓練樣本的均值和方差,需要特別注意 %% //without dropout rng(0); nn = nnsetup([784 100 10]);% //初步構造了一個輸入-隱含-輸出層網路,其中包括了 % //權值的初始化,學習率,momentum,激發函式型別, % //懲罰係數,dropout等 opts.numepochs = 20; % //Number of full sweeps through data opts.batchsize = 100; % //Take a mean gradient step over this many samples [nn, L] = nntrain(nn, train_x, train_y, opts); [er, bad] = nntest(nn, test_x, test_y); str = sprintf('testing error rate is: %f',er); disp(str) %% //with dropout rng(0); nn = nnsetup([784 100 10]); nn.dropoutFraction = 0.5; % //Dropout fraction,每一次mini-batch樣本輸入訓練時,隨機扔掉50%的隱含層節點 opts.numepochs = 20; % //Number of full sweeps through data opts.batchsize = 100; % //Take a mean gradient step over this many samples nn = nntrain(nn, train_x, train_y, opts); [er, bad] = nntest(nn, test_x, test_y); str = sprintf('testing error rate is: %f',er); disp(str)
下面來分析與dropout相關的程式碼,集中在上面test.m程式碼的後面with drop部分。首先在訓練過程中需要將神經網路結構nn的dropoutFraction設定為一定比例,這裡設定為50%:nn.dropoutFraction = 0.5;
然後進入test_dropout.m中的nntrain()函式,沒有發現與dropoutFraction相關的程式碼,繼續進入網路前向傳播函式nnff()函式中,在網路的隱含層節點激發函式值被計算出來後,有下面的程式碼:
if(nn.dropoutFraction > 0) if(nn.testing) nn.a{i} = nn.a{i}.*(1 - nn.dropoutFraction); else nn.dropOutMask{i} = (rand(size(nn.a{i}))>nn.dropoutFraction); nn.a{i} = nn.a{i}.*nn.dropOutMask{i}; end end
由上面的程式碼可知,隱含層節點的輸出值以dropoutFraction百分比的機率被隨機清0(注意此時是在訓練階段,所以是else那部分的程式碼),既然前向傳播時有些隱含節點值被清0了,那麼在誤差方向傳播時也應該有相應的處理,果然,在反向傳播函式nnbp()中,有下面的程式碼:
if(nn.dropoutFraction>0) d{i} = d{i} .* [ones(size(d{i},1),1) nn.dropOutMask{i}]; end
也就是說計算節點誤差那一項時,其誤差項也應該清0。從上面可以看出,使用dropout時,其訓練部分的程式碼更改很少。
(有網友發私信說,反向傳播計算誤差項時可以不用乘以dropOutMask{i}矩陣,後面我仔細看了下bp的公式,一開始也感覺不用乘有道理。因為原始碼中有為:
for i = 1 : (n - 1) if i+1==n nn.dW{i} = (d{i + 1}' * nn.a{i}) / size(d{i + 1}, 1); else nn.dW{i} = (d{i + 1}(:,2:end)' * nn.a{i}) / size(d{i + 1}, 1); end end
程式碼進行權重更新時,由於需要乘以nn.a{i},而nn.a{i}在前向過程中如果被mask清掉的話(使用了dropout前提下),則已經為0了。但其實這時錯誤的,因為對誤差
敏感值作用的是與它相連線的前一層權值,並不是本層的權值,而本層的輸出a只對它的下一層權值更新有效。)
再來看看測試部分,測試部分如hintion論文所說的,採用mean network,也就是說前向傳播時隱含層所有節點的輸出同時減小dropoutFraction百分比,即保留(1- dropoutFraction)百分比,程式碼依舊是上面貼出的nnff()函式裡滿足if(nn.testing)的部分:
if(nn.dropoutFraction > 0) if(nn.testing) nn.a{i} = nn.a{i}.*(1 - nn.dropoutFraction); else nn.dropOutMask{i} = (rand(size(nn.a{i}))>nn.dropoutFraction); nn.a{i} = nn.a{i}.*nn.dropOutMask{i}; end end
上面只是個簡單的droput實驗,可以用來幫助大家理解dropout的思想和使用步驟。其中網路的引數都是採用toolbox預設的,並沒有去調整它,如果該實驗將訓練樣本增大,比如6w張,則引數不變的情況下使用了dropout的識別率還有可能會降低(當然這很有可能是其它引數沒調到最優,另一方面也說明在樣本比較少的情況下,droput確實可以防止過擬合),為了體現droput的優勢,這裡我只用了2000張訓練樣本。
參考資料:
Hinton, G. E., et al. (2012). "Improving neural networks by preventing co-adaptation of feature detectors." arXiv preprint arXiv:1207.0580.
相關推薦
Deep learning:四十一(Dropout簡單理解)
前言 訓練神經網路模型時,如果訓練樣本較少,為了防止模型過擬合,Dropout可以作為一種trikc供選擇。Dropout是hintion最近2年提出的,源於其文章Improving neural networks by preventing co-adaptation of feature d
Deep learning:四十五(maxout簡單理解)
maxout出現在ICML2013上,作者Goodfellow將maxout和dropout結合後,號稱在MNIST, CIFAR-10, CIFAR-100, SVHN這4個數據上都取得了start-of-art的識別率。 從論文中可以看出,maxout其實一種激發函式形式。通常情況下,如果激
Deep learning:五十一(CNN的反向求導及練習)
前言: CNN作為DL中最成功的模型之一,有必要對其更進一步研究它。雖然在前面的博文Stacked CNN簡單介紹中有大概介紹過CNN的使用,不過那是有個前提的:CNN中的引數必須已提前學習好。而本文的主要目的是介紹CNN引數在使用bp演算法時該怎麼訓練,畢竟CNN中有卷積層和下采樣層,雖然
Deep learning:四十(龍星計劃2013深度學習課程小總結)
頭腦一熱,坐幾十個小時的硬座北上去天津大學去聽了門4天的深度學習課程,課程預先的計劃內容見:http://cs.tju.edu.cn/web/courseIntro.html。上課老師為微軟研究院的大牛——鄧力,群(qq群介紹見:Deep learning高質量交流群)裡
論文閱讀筆記四十一:Very Deep Convolutional Networks For Large-Scale Image Recongnition(VGG ICLR2015)
結合 等價 選擇 mac 不同的 works info 內存 enc 論文原址:https://arxiv.org/abs/1409.1556 代碼原址:https://github.com/machrisaa/tensorflow-vgg 摘要 本
Deep learning:十九(RBM簡單理解)
這篇部落格主要用來簡單介紹下RBM網路,因為deep learning中的一個重要網路結構DBN就可以由RBM網路疊加而成,所以對RBM的理解有利於我們對DBN演算法以及deep learning演算法的進一步理解。Deep learning是從06年開始火得,得益於大牛Hinton的文章,不過這位大牛的
【leetcode 簡單】第四十一題 階乘後的零
時間 時間復雜度 ron elf 說明 輸入 數量 n) 復雜度 給定一個整數 n,返回 n! 結果尾數中零的數量。 示例 1: 輸入: 3 輸出: 0 解釋: 3! = 6, 尾數中沒有零。 示例 2: 輸入: 5 輸出: 1 解釋: 5! = 120, 尾數中有 1
Java並發編程原理與實戰四十一:重排序 和 happens-before
而已 註意 ron 不知道 load chm title 並行 ola 一、概念理解 首先我們先來了解一下什麽是重排序:重排序是指編譯器和處理器為了優化程序性能而對指令序列進行重新排序的一種手段。 從Java源代碼到最終實際執行的指令序列,會分別經歷下面3種重排序,如下
Java基礎系列(四十一):集合之List
前言 List是繼承自Collection的一個子介面,它提供了一個有序的集合,在這個集合中我們可以使用索引去獲取集合中的值,同時,我們也可以通過迭代器去訪問集合中的元素,第一種方法被稱為隨機訪問,因為我們可以按照任意的順序去訪問元素,而使用迭代器就必須順序的去訪問元素。 結構圖
第四十一章:Maven自動化構建依賴管理工具
1Maven簡介 Maven是一款自動化“構建”和“依賴”管理的工具。 2提出問題 2.1jar包管理問題 jar包從哪來的? 官網下載 CSND…… 51CTO…… …… jar包之間存在依賴關係 A jar包中用到了B jar包中的類,就說A依賴B。jar包之間普遍存在的依
“全棧2019”Java第四十一章:static關鍵字
難度 初級 學習時間 10分鐘 適合人群 零基礎 開發語言 Java 開發環境 JDK v11 IntelliJ IDEA v2018.3 文章原文連結 “全棧2019”Java第四十一章:static關鍵字 下一章 “全棧2019”Java第四十二章:靜態程式碼塊與初
Scrum立會報告+燃盡圖(十二月九日總第四十一次):使用者推廣
此作業要求參見:https://edu.cnblogs.com/campus/nenu/2018fall/homework/2484 專案地址:https://git.coding.net/zhangjy982/QuJianBang.git Scrum立會master:於洋 一、小組介紹
練習四十一:百錢百雞
公雞5文錢一隻,母雞3文錢一隻,小雞3只一文錢,用100文錢買100只雞,其中公雞、母雞、小雞都必須有,那麼可以買公雞、母雞、小雞各多少隻,正好是100文錢了 1 for x in range(1,21): 2 for y in range(1,34): 3 z =
環形連結串列(leetcode簡單篇一百四十一題)
給定一個連結串列,判斷連結串列中是否有環 拿到這道題其實我相信很多同學已經看過官方的解答了 快慢指標法:在返回單鏈表的中間節點其實我們就已經使用過這個辦法了,讓fast指標一次走倆步,讓slow指標一次走一步。你可以假想倆個人同時在操場上跑步,快的那個人假設一直比慢的快,那麼在不久後
演算法題練習系列之(四十一):童年生活二三事
----------------------------------------------------------------------------------------------------
ABP原始碼分析四十一:ZERO的Audit,Setting,Background Job
AuditLog: 繼承自Entity<long>的實體類。封裝AuditLog的資訊。 AuditingStore: 實現了IAuditingStore介面,實現了將AuditLog的資訊儲存到資料庫的功能。其通過IRepository<AuditLog, long>例項完成對資料
Deep Learning讀書筆記(一):Reducing the Dimensionality of Data with Neural Networks
這是發表在Science上的一篇文章,是Deep Learning的開山之作,同樣也是我讀的第一篇文章,我的第一篇讀書筆記也從這開始吧。 文章的主要工作是資料的降維,等於說這裡使用深度學習網路主要提取資料中的特徵,但卻並沒有將這個特徵應用到分類等
史上最簡單的 MySQL 教程(四十一)「觸發器」
溫馨提示:本系列博文已經同步到 GitHub,地址為「mysql-tutorial」,歡迎感興趣的童鞋Star、Fork,糾錯。 案例:網上購物,根據生產訂單的型別,商品的庫存量對應的進行增和減。此案例涉及兩張表,分別為訂單表和商品表,下單時,商
Android問題集錦之四十一:getActionBar為null的解決以及ActionBar的Back鍵
我們用getActionBar獲得ActionBar例項,將其隱藏或做其他操作。 ActionBar actionBar = getActionBar(); 不過我這樣做了,但是actionBar是null。 問題是由Activity基類引起的,當使用Ap
性能測試四十一:sql案例之慢sql配置、執行計劃和索引
xpl 通過 數據庫 data 設計 增加 應該 windows 情況下 MYSQL 慢查詢使用方法MYSQL慢查詢介紹分析MySQL語句查詢性能的問題時候,可以在MySQL記錄中查詢超過指定時間的語句,我們將超過指定時間的SQL語句查詢稱為“慢查詢&rdquo