莫煩課程Batch Normalization 批標準化
阿新 • • 發佈:2017-12-13
github cti mas pen get lin pytorch 生成 def
for i in range(N_HIDDEN): # build hidden layers and BN layers input_size = 1 if i == 0 else 10 fc = nn.Linear(input_size, 10) setattr(self, ‘fc%i‘ % i, fc) # IMPORTANT set layer to the Module self._set_init(fc) # parameters initialization self.fcs.append(fc) if self.do_bn: bn = nn.BatchNorm1d(10, momentum=0.5) setattr(self, ‘bn%i‘ % i, bn) # IMPORTANT set layer to the Module self.bns.append(bn)
上面的代碼對每個隱層進行批標準化,setattr(self, ‘fc%i‘ % i, fc)作用相當於self.fci=fc
每次生成的結果append到bns的最後面,結果的size 10×10,取出這些數據是非常方便
def forward(self, x): pre_activation = [x] if self.do_bn: x = self.bn_input(x) # input batch normalization layer_input = [x] for i in range(N_HIDDEN): x = self.fcs[i](x) pre_activation.append(x) if self.do_bn: x = self.bns[i](x) # batch normalization x = ACTIVATION(x) layer_input.append(x) out = self.predict(x) return out, layer_input, pre_activation
全部的源代碼
莫煩課程Batch Normalization 批標準化