DeepLearning tutorial(2)機器學習演算法在訓練過程中儲存引數
阿新 • • 發佈:2018-11-15
分享一下我老師大神的人工智慧教程!零基礎,通俗易懂!http://blog.csdn.net/jiangjunshow
也歡迎大家轉載本篇文章。分享知識,造福人民,實現我們中華民族偉大復興!
DeepLearning tutorial(2)機器學習演算法在訓練過程中儲存引數
@author:wepon
@blog:http://blog.csdn.net/u012162613/article/details/43169019
參考:pickle — Python object serialization、DeepLearning Getting started
一、python讀取"***.pkl.gz"檔案
用到python裡的gzip以及cPickle模組,簡單的使用程式碼如下,如果想詳細瞭解可以參考上面給出的連結。
#以讀取mnist.pkl.gz為例import cPickle, gzipf = gzip.open('mnist.pkl.gz', 'rb')train_set, valid_set, test_set = cPickle.load(f)f.close()
其實就是分兩步,先讀取gz檔案,再讀取pkl檔案。pkl檔案的應用正是下文要講的,我們用它來儲存機器學習演算法訓練過程中的引數。
二、機器學習演算法在訓練過程中如何儲存引數?
我們知道,機器學習演算法的計算量特別大,跑起程式來少則幾十分鐘,多則幾小時甚至幾天,中間如果有什麼狀況(比如電腦過熱重啟、程式出現一些小bug...)程式就會中斷,如果你沒把引數定時儲存下來,前面的訓練就當白費了,所以很有必要在程式中加入定時儲存引數的功能,這樣下次訓練就可以將引數初始化為上次儲存下來的結果,而不是從頭開始隨機初始化。
那麼如何儲存模型引數?可以將引數深複製,或者呼叫python的資料永久儲存cPickle模組,原理不多說,直接使用就行。(注:python裡有cPickle和pickle,cPickle基於c實現,比pickle快。)
a=[1,2,3]b={4:5,6:7}#儲存,cPickle.dump函式。/home/wepon/ab是路徑,ab是儲存的檔案的名字,如果/home/wepon/下本來就有ab這個檔案,將被覆寫#,如果沒有,則建立。'wb'表示以二進位制可寫的方式開啟。dump中的-1表示使用highest protocol。import cPicklewrite_file=open('/home/wepon/ab','wb')cPickle.dump(a,write_file,-1)cPickle.dump(b,write_file,-1)write_file.close()#讀取,cPickle.load函式。read_file=open('/home/wepon/ab','rb')a_1=cPickle.load(read_file)b_1=cPickle.load(read_file)print a,bread_file.close()
在deeplearning演算法中,因為用到GPU,經常是將引數宣告為shared變數,因此必須用上get_value()、set_value,例如有w、v、u三個shared變數,使用程式碼如下:
import cPickle#儲存write_file = open('path', 'wb') cPickle.dump(w.get_value(borrow=True), write_file, -1) cPickle.dump(v.get_value(borrow=True), write_file, -1) cPickle.dump(u.get_value(borrow=True), write_file, -1) write_file.close()#讀取read_file = open('path')w.set_value(cPickle.load(read_file), borrow=True)v.set_value(cPickle.load(read_file), borrow=True)u.set_value(cPickle.load(read_file), borrow=True)read_file.close()
一個例項
下面我以一個實際的例子來說明如何在程式中加入儲存引數的功能。以deeplearnig.net上的邏輯迴歸為例,它的程式碼地址:logistic_sgd.py。這個程式是將邏輯迴歸用於MNIST分類,程式執行過程並不會儲存引數,甚至執行結束時也不儲存引數。怎麼做可以儲存引數?
在logistic_sgd.py程式碼裡最後面的sgd_optimization_mnist()函式裡,有個while迴圈,裡面有一句程式碼:
if this_validation_loss < best_validation_loss:
這句程式碼的意思就是判斷當前的驗證損失是否小於最佳的驗證損失,是的話,下面會更新best_validation_loss,也就是說當前引數下,模型比之前的有了優化,因此我們可以在這個if語句後面加入儲存引數的程式碼:
save_params(classifier.W,classifier.b)
save_params函式定義如下:
def save_params(param1,param2): import cPickle write_file = open('params', 'wb') cPickle.dump(param1.get_value(borrow=True), write_file, -1) cPickle.dump(param2.get_value(borrow=True), write_file, -1) write_file.close()
當然引數的個數根據需要去定義。在logistic_sgd.py中引數只有classifier.W,classifier.b,因此這裡定義為save_params(param1,param2)。
在logistic_sgd.py裡我加入了save_params(classifier.W,classifier.b),運行了3次epoch,中斷掉程式,在程式碼所在的資料夾下,多出了一個params檔案,我們來看看這個檔案裡是什麼東西:
import cPicklef=open('params')w=cPickle.load(f)b=cPickle.load(f)#w大小是(n_in,n_out),b大小時(n_out,),b的值如下,因為MINST有10個類別,n_out=10,下面正是10個數array([-0.0888151 , 0.16875755, -0.03238435, -0.06493175, 0.05245609, 0.1754718 , -0.0155049 , 0.11216578, -0.26740651, -0.03980861])
也就是說,params檔案確實儲存了我們訓練過程中的引數。
那麼如何用儲存下來的引數來初始化我們的模型的引數呢?
在logistic_sgd.py中的class LogisticRegression(object)下,self.W和self.b本來是初始化為0的,我們可以在下面加上幾行程式碼,這樣就可以用我們儲存下來的params檔案來初始化引數了:
class LogisticRegression(object): def __init__(self, input, n_in, n_out): self.W = theano.shared( value=numpy.zeros( (n_in, n_out), dtype=theano.config.floatX ), name='W', borrow=True ) self.b = theano.shared( value=numpy.zeros( (n_out,), dtype=theano.config.floatX ), name='b', borrow=True )#!!!#加入的程式碼在這裡,程式執行到這裡將會判斷當前路徑下有沒有params檔案,有的話就拿來初始化W和b if os.path.exists('params'): f=open('params') self.W.set_value(cPickle.load(f), borrow=True) self.b.set_value(cPickle.load(f), borrow=True)