【516】keras 原始碼分析之 Dense
阿新 • • 發佈:2021-01-01
本文主要講解一下 Dense 層的原始碼,Dense 層即最常用的全連線層,程式碼很簡單,主要是重寫了build
與call
方法,在我們自定義 Layer 時,也可以參考該層的實現。但是不需要這麼複雜,只要寫出必要的部分就可以了,參見下一篇部落格。
1.Layer類的相關說明
參考:TensorFlow函式:tf.layers.Layer——W3CschoolTensorFlow官方文件
參考:關於 Keras 網路層——keras中文文件
基礎層類。這是所有層都繼承的類,實現了通用的基礎結構功能。層是實現常見神經網路操作的類,例如卷積、批量規範等。這些操作需要管理變數、損失和更新,以及將 TensorFlow 操作應用於輸入張量。使用者只需例項化它,然後將其視為可呼叫的。
我們建議 Layer 的子代實現以下方法:
- __init__():在成員變數中儲存配置
- build():當我們知道輸入和 dtype 的形狀時,從 __call__ 呼叫一次。應該有對 add_variable() 的呼叫,然後呼叫高階的 build() (設定為 self.built = True,這在使用者想要在第一個 __call__ 之前手動呼叫 build() 時很好)。
- * call():確認 build() 已被呼叫一次後呼叫 __call__。實際上應該執行將層應用於輸入張量的邏輯(應該作為第一個引數傳入)。
2.Dense 原始碼解讀
2.1 __init__函式重寫
構造方法沒什麼好說的,就是一些簡單的賦值。
from keras.layers import Layer class Dense(Layer): def __init__(self, units, activation=None, use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, **kwargs): if 'input_shape' not in kwargs and 'input_dim' in kwargs: kwargs['input_shape'] = (kwargs.pop('input_dim'),) super(Dense, self).__init__(**kwargs) self.units = units self.activation = activations.get(activation) self.use_bias = use_bias self.kernel_initializer = initializers.get(kernel_initializer) self.bias_initializer = initializers.get(bias_initializer) self.kernel_regularizer = regularizers.get(kernel_regularizer) self.bias_regularizer = regularizers.get(bias_regularizer) self.activity_regularizer = regularizers.get(activity_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) self.input_spec = InputSpec(min_ndim=2) self.supports_masking = True
2.2build函式重寫
build 方法中定義了兩個 Variable 即權重,最後把 built 引數置為 True。
def build(self, input_shape): assert len(input_shape) >= 2 # 維度取 input_shape 的最後一維 # 正好進行後面的叉乘 input_dim = input_shape[-1] # 設定權重矩陣,維度為 (input_dim, self.units),用於叉乘 self.kernel = self.add_weight(shape=(input_dim, self.units), initializer=self.kernel_initializer, name='kernel', regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) if self.use_bias: # 設定偏置 self.bias = self.add_weight(shape=(self.units,), initializer=self.bias_initializer, name='bias', regularizer=self.bias_regularizer, constraint=self.bias_constraint) else: self.bias = None self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim}) self.built = True
2.3call函式重寫
call 方法把輸入值與 build 方法中定義的權重進行了點積的操作,然後與 build 中的偏移量進行相加,最後經過啟用函式返回最終的輸出結果。
def call(self, inputs): # 具體的 矩陣操作 output = K.dot(inputs, self.kernel) if self.use_bias: output = K.bias_add(output, self.bias, data_format='channels_last') if self.activation is not None: output = self.activation(output) return output
2.4compute_output_shape函式重寫
計算出輸出tensor的維度並返回。
def compute_output_shape(self, input_shape): assert input_shape and len(input_shape) >= 2 assert input_shape[-1] output_shape = list(input_shape) output_shape[-1] = self.units return tuple(output_shape)
2.5get_config函式重寫
保留一些中間值並以字典的形式返回。
def get_config(self): config = { 'units': self.units, 'activation': activations.serialize(self.activation), 'use_bias': self.use_bias, 'kernel_initializer': initializers.serialize(self.kernel_initializer), 'bias_initializer': initializers.serialize(self.bias_initializer), 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 'activity_regularizer': regularizers.serialize(self.activity_regularizer), 'kernel_constraint': constraints.serialize(self.kernel_constraint), 'bias_constraint': constraints.serialize(self.bias_constraint) } base_config = super(Dense, self).get_config() return dict(list(base_config.items()) + list(config.items()))