1. 程式人生 > >torch學習(二) nn類結構-Module

torch學習(二) nn類結構-Module

Module |-Container    |-Sequential    |-Parallel    |-Concat       |-DepthConcat 安裝使用torch以後,接下來就可以開始構建模型和訓練了,在這之前,一定要先花半個小時的時間瞭解一下torch裡面nn部分使用到的基本型別,否則容易混亂。 一、Module Module是神經網路的基本組成部分,作為一個抽象類,可以通過定義成員函式實現不同的神經網路結構,nn.*都是Module的子類。 成員變數有output和gradInput,成員函式包括: (一)、[output] forward(input) 根據input計算更新output,不建議重寫該函式 1.呼叫updateOutput(input) 對module的輸出進行處理,可以重寫該函式 (二)、[gradInput] backward(input, gradOutput) 在使用backward之前需要首先呼叫forward,不建議重寫,可以重寫它呼叫的兩個子函式 input:該層的輸入 gradOutput:對該層輸出的梯度 gradInput:對該層輸入的梯度 1.updateGradInput(input, gradOutput) 計算相對於輸入input的梯度,更新gradInput 2. accGradParameters(input, gradOutput, scale) 計算相對於該Module的引數的梯度 gradParameters:對該層引數的梯度,比如sigmoid層沒有引數就沒有該變數 scale: gradParameters乘以係數scale,然後累加 3.zeroGradParameters() 重置gradParameters為0 4.updateParameters(learningRate) 根據backward()累加的gradParameters,更新Module的引數 parameters = parameters - learningRate * gradients_wrt_parameters 5.accUpdateCradParameters(input,gradOutput,learningRate) 累積引數梯度,同時更新引數 ...... 二、Container 父類Module,子類包括Sequential、Parallel、Concat等 新增成員函式包括: 1.add(module) 增加Module 2.get(index) 根據下標獲取Module 3.size() 包含的modules 三、Sequential 以feed-forward全連線的形式組織網路 1.remove([index]) 刪除第index層 2.insert(module,[index]) 在中間插入層 四、Parallel module = Parallel(inputDimension,outputDimension) Creates a container module that applies its ith child module to the ith slice of the input Tensor by using 
select
 on dimension inputDimension. It concatenates the results of its contained modules together along dimensionoutputDimension. 五、Concat module = nn.Concat(dim) Concat concatenates the output of one layer of "parallel" modules along the provided dimension dim: they take the same inputs, and their output is concatenated. 六、DepthConcat DepthConcat concatenates the output of one layer of "parallel" modules along the provided dimension dim: they take the same inputs, and their output is concatenated. For dimensions other than dim having different sizes, the smaller tensors are copied in the center of the output tensor, effectively padding the borders with zeros. 參考: