1. 程式人生 > >Theano中scan函式的使用

Theano中scan函式的使用

  Theano作為深度學習領域一個經典的python庫,可以先定義數學表示式再進行運算,開創了符號計算的先河。最近在科研過程中用到了其中的scan()函式,雖然以前也學過這個框架,但是tensorflow永久了,感覺就把一些知識點忘記了,於是想到了寫筆記的重要性,故在這裡記錄下來,以備隨時回頭翻看。
  在學習任何一種程式語言的時候,我們都會學習它的變數定義、資料型別、程式結構控制(順序結構,選擇結構,迴圈結構等),而theano作為一種符號計算框架,是用於解決迴圈結構的,只不過與一般的程式語言中的迴圈結構不同的是,它針對的是符號變數(symbol variable)。scan()函式中的引數有很多,包括sequences, outputs_info, non_sequences, n_steps等,剛開始學習時我們可能會想,為什麼會有這麼多引數呢?我所能想到的就是,迴圈結構本身其實包含了不同的情況,有的只需要一步平行計算即可,有的則需要在反覆迭代中進行變數自身的更新。另外,theano中的變數包括一般的tensor,shared variable等,需要不同的考慮。以下對scan函式進行詳細介紹。
theano.scan

(fn, sequences=None, outputs_info=None, non_sequences=None, n_steps=None, truncate_gradient=-1, go_backwards=False, mode=None, name=None, profile=False, allow_gc=None, strict=False)
  輸入引數:
  ‘fn’:’fn’是一個lambda或者def函式,描述了一步scan操作的運算式,運算式的輸入引數按照sequences, outputs_info, non_sequences的順序,運算式的輸出作為theano.scan的返回值。
  ‘sequences’
:sequences是一個theano variables或者dictionaries的列表。字典物件的結構為{‘variable’:taps},其中taps是一個整數列表。’sequences’列表中的所有Theano variable會被自動封裝成一個字典,此時taps被設定成[0]。比如sequences = [ dict(input= Sequence1, taps = [-3,2,-1]), Sequence2, dict(input = Sequence3, taps = 3) ], 對映到scan輸入引數為Sequence1[t-3],Sequence1[t+2],Sequence1[t-1],Sequence2[t],Sequence3[t+3]。還有一點就是,如果序列的長度不一致,scan會裁剪成它們中最短的,這個性質方便我們傳遞一個很長的arange,比如sequences=[coefficients, theano.tensor.arange(max_coefficients_supported)]
  ‘outputs_info’
:outputs_info是一個theano variables或者dictionaries的列表,它描述了輸出的初始狀態,顯然應該和輸出有相同的shape,而且,每進行一步scan操作,outputs_info中的數值會被上一次迭代的輸出值更新掉。當然,如果當前迴圈結構不需要recursive,而僅僅是一個map操作的話,這個引數便可以省略;
  ‘non_sequences’non_sequences 是一個‘常量’引數列表,這裡所謂的‘常量’是相對於‘outputs_info’中的引數更新而言的,代表了一步scan操作中不會被更新的變數。計算圖中的有些變數在這裡也可以不顯式的指明,但顯式指明變數引數會得到一個簡化的計算圖,加速編譯器對圖的優化和執行。 常見的應用是,把shared variables作為non_sequences引數中的值.
  ‘n_steps’:‘n_steps’引數是一個int或者theano scalar,代表了scan操作的迭代次數。如果存在輸入序列,其中的元素個數小於n_steps,scan函式會報錯。如果n_steps引數未指定,scan會根據他的輸入引數自動計算出迭代步數;
  ‘truncate_gradient’:‘truncate_gradient’引數代表了使用BPTT(back propagation through time)演算法時,“梯度截斷”後的步數。“梯度截斷”的目的是在可接受的誤差範圍內,降低梯度的計算複雜度。常見的應用場景是RNN(recurrent neural network)。
  ‘strict’:’strict’是一個shared variable校驗標誌,用於檢驗是否fn函式用到的所有shared variabes都在non_sequences中,若不滿足則會Raise an error。
  其他引數:用的比較少,這裡就略了。
 返回引數:
  形如(outputs, updates)格式的元組型別。’outputs‘是一個theano變數,或者多個theano變數構成的list。並且,每一個theano變數包含了所有迭代步驟的輸出結果。’updates‘是形如(var, expression)的字典結構,指明瞭scan中用到的所有shared variables的更新規則 。

使用共享變數的例子-Gibbs取樣

import theano
from theano import tensor as T

W = theano.shared(W_values) # we assume that ``W_values`` contains the
                            # initial values of your weight matrix

bvis = theano.shared(bvis_values)
bhid = theano.shared(bhid_values)

trng = T.shared_randomstreams.RandomStreams(1234)

def OneStep(vsample) :
    hmean = T.nnet.sigmoid(theano.dot(vsample, W) + bhid)
    hsample = trng.binomial(size=hmean.shape, n=1, p=hmean)
    vmean = T.nnet.sigmoid(theano.dot(hsample, W.T) + bvis)
    return trng.binomial(size=vsample.shape, n=1, p=vmean,
                         dtype=theano.config.floatX)

sample = theano.tensor.vector()

values, updates = theano.scan(OneStep, outputs_info=sample, n_steps=10)

gibbs10 = theano.function([sample], values[-1], updates=updates)

  在這段樣例中,updates字典建立了shared variable到k steps後的更新值的對映,在這段程式碼中也就是random streams的更新情況。如果不傳遞updates引數到theano.function中,每次呼叫gibbs10函式時,會得到相同的10個隨機數sets,導致返回值values[-1]是相同的。

注:更多的關於theano.scan的使用樣例,請參見官網教程http://deeplearning.net/software/theano/library/scan.html