1. 程式人生 > 實用技巧 >【516】keras 原始碼分析之 Dense

【516】keras 原始碼分析之 Dense

參考:keras原始碼分析之Layer

參考:keras原始碼分析之Dense


  本文主要講解一下 Dense 層的原始碼,Dense 層即最常用的全連線層,程式碼很簡單,主要是重寫了buildcall方法,在我們自定義 Layer 時,也可以參考該層的實現。但是不需要這麼複雜,只要寫出必要的部分就可以了,參見下一篇部落格。

1.Layer類的相關說明

參考:TensorFlow函式:tf.layers.Layer——W3CschoolTensorFlow官方文件

參考:關於 Keras 網路層——keras中文文件

  基礎層類。這是所有層都繼承的類,實現了通用的基礎結構功能。層是實現常見神經網路操作的類,例如卷積、批量規範等。這些操作需要管理變數、損失和更新,以及將 TensorFlow 操作應用於輸入張量。使用者只需例項化它,然後將其視為可呼叫的。

  我們建議 Layer 的子代實現以下方法:

  • __init__():在成員變數中儲存配置
  • build():當我們知道輸入和 dtype 的形狀時,從 __call__ 呼叫一次。應該有對 add_variable() 的呼叫,然後呼叫高階的 build() (設定為 self.built = True,這在使用者想要在第一個 __call__ 之前手動呼叫 build() 時很好)。
  • * call():確認 build() 已被呼叫一次後呼叫 __call__。實際上應該執行將層應用於輸入張量的邏輯(應該作為第一個引數傳入)。

2.Dense 原始碼解讀

2.1 __init__函式重寫

  構造方法沒什麼好說的,就是一些簡單的賦值。

from keras.layers import Layer

class Dense(Layer):
    def __init__(self, units,
                 activation=None,
                 use_bias=True,
                 kernel_initializer='glorot_uniform',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 **kwargs):
        if 'input_shape' not in kwargs and 'input_dim' in kwargs:
            kwargs['input_shape'] = (kwargs.pop('input_dim'),)
        super(Dense, self).__init__(**kwargs)
        self.units = units
        self.activation = activations.get(activation)
        self.use_bias = use_bias
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        self.activity_regularizer = regularizers.get(activity_regularizer)
        self.kernel_constraint = constraints.get(kernel_constraint)
        self.bias_constraint = constraints.get(bias_constraint)
        self.input_spec = InputSpec(min_ndim=2)
        self.supports_masking = True

  

2.2build函式重寫

  build 方法中定義了兩個 Variable 即權重,最後把 built 引數置為 True。

    def build(self, input_shape):
        assert len(input_shape) >= 2
        # 維度取 input_shape 的最後一維
        # 正好進行後面的叉乘
        input_dim = input_shape[-1]

        # 設定權重矩陣,維度為 (input_dim, self.units),用於叉乘   
        self.kernel = self.add_weight(shape=(input_dim, self.units),
                                      initializer=self.kernel_initializer,
                                      name='kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)
        if self.use_bias:
            # 設定偏置
            self.bias = self.add_weight(shape=(self.units,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
        self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
        self.built = True

  

2.3call函式重寫

  call 方法把輸入值與 build 方法中定義的權重進行了點積的操作,然後與 build 中的偏移量進行相加,最後經過啟用函式返回最終的輸出結果。

    def call(self, inputs):
        # 具體的 矩陣操作
        output = K.dot(inputs, self.kernel)
        if self.use_bias:
            output = K.bias_add(output, self.bias, data_format='channels_last')
        if self.activation is not None:
            output = self.activation(output)
        return output

  

2.4compute_output_shape函式重寫

  計算出輸出tensor的維度並返回。

    def compute_output_shape(self, input_shape):
        assert input_shape and len(input_shape) >= 2
        assert input_shape[-1]
        output_shape = list(input_shape)
        output_shape[-1] = self.units
        return tuple(output_shape)

  

2.5get_config函式重寫

  保留一些中間值並以字典的形式返回。

    def get_config(self):
        config = {
            'units': self.units,
            'activation': activations.serialize(self.activation),
            'use_bias': self.use_bias,
            'kernel_initializer': initializers.serialize(self.kernel_initializer),
            'bias_initializer': initializers.serialize(self.bias_initializer),
            'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
            'bias_regularizer': regularizers.serialize(self.bias_regularizer),
            'activity_regularizer':
                regularizers.serialize(self.activity_regularizer),
            'kernel_constraint': constraints.serialize(self.kernel_constraint),
            'bias_constraint': constraints.serialize(self.bias_constraint)
        }
        base_config = super(Dense, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))