1. 程式人生 > >keras自定義網路層

keras自定義網路層

在深度學習領域,Keras是一個高度封裝的庫並被廣泛應用,可以通過呼叫其內建網路模組(各種網路層)實現針對性的模型結構;當所需要的網路層功能不被包含時,則需要通過自定義網路層或模型實現。 如何在keras框架下自定義層,基本“套路”如下。 ####一般地,keras中的網路層是一個類,所以自定義層即編寫一個類,更為重要的是這個類(即自定義層)需要繼承Layer父類,而且需要實現以下四種方法: 1. __init __ (self, output_dim, **kwargs) 這個方法是用來初始化並自定義自定義層所需的屬性,比如output_dim; 此外,該方法需要執行super().__init __(**kwargs),這行程式碼是執行Layer類中的初始化函式; 當執行上述程式碼就沒有必要去管input_shape,weights,trainable等關鍵字引數,因為父類(Layer)的初始化函式實現了它們與layer例項的繫結。 2. build(self, input_shape) 這個方法是用來建立層的權重; 在該方法中,根據之前的繼承,通過Layer類的add_weight方法來自定義並新增一個權重矩陣,這個方法需要input_shape引數; 該方法必須設self.built = True,目的是為了保證這個層的權重定義函式build被執行過了; 在built函式中,需要說明這個權重各方面的屬性,比如shape、初始化方式以及可訓練性等資訊。 3. call(self, x) 這個方法是用來編寫層的功能邏輯; 在該方法中,需要關注傳入call的第一個引數:輸入張量x;x只能是一種形式變數,不能是具體的變數,即它不能被定義; 這個call函式就是該層的計算邏輯,當建立好這個層例項後,該例項可以執行call函式; 可見,這個層的核心應該是一段符號式的輸入張量到輸出張量的計算過程。 4. compute_output_shape(self, input_shape) 這個方法是用來保證輸出shape是正確的; 這裡重寫compute_output_shape方法去覆蓋父類中的同名方法,來保證輸出的shape符合實際; 父類Layer中的compute_output_shape方法直接返回的是input_shape這明顯是不對的,所以需要重寫該方法。 ###示例 結合官方文件的例子,給出如下一個自定義層的程式碼: ![](https://img2020.cnblogs.com/blog/1875829/202102/1875829-20210216154544398-200567612.png) 使用自定義層,就如同使用keras內建網路層一樣,如下圖所示:(另外,本例使用kears內建的啟用函式層ReLU承接自定義層的輸出,從而避免將啟用函式的功能加入到自定義層中) ![](https://img2020.cnblogs.com/blog/1875829/202102/1875829-20210216154657047-18184049