1. 程式人生 > 其它 >tensorflow1.x版本程式碼遷移到2.0

tensorflow1.x版本程式碼遷移到2.0

由於3090顯示卡只支援tf2.0以後的版本,而且隨著顯示卡的更新換代,tf1.x版本也不支援更高階的顯示卡,所以有必要將1.x版本的程式碼轉成2.0後的版本。

Tf2.0版本和tf1.0版本的主要區別

  主要區別在於tf1.x是靜態圖,需要先把模型結構先定好,再進行訓練

Tf2.0版本則是動態圖,訓練前不用先構建完整的結構,而是按流程一步步構建,因此在訓練的時候tf1.x相比於tf2.0佔cpu記憶體大,訓練的速度更快

程式碼轉換主要分幾個方面:輸入、模型網路、訓練、模型儲存

1.輸入

  1.x的程式碼中,對於輸入需要首先加placeholder,作為整個網路的入口。而tf2.0取消了這個部分,因此修改的方法是去掉這部分程式碼,直接在訓練的時候輸入資料,例如: 

 

Tf1.x

self.inputs = tf.placeholder(tf.int32, [None, None], name="inputs")  # 資料輸入

self.labels = tf.placeholder(tf.float32, [None, None], name="labels")  # 標籤

 

修改後直接在訓練的時候賦值就行:

self.inputs = batch["x"]

self.labels = batch["y"]

self.keep_prob = dropout_prob

  

2.模型網路

這部分比較好改,因為很多api可以在tensorflow官方文件上找到相應的替換函式,幾個常用的如下:

tf.get_variable()變成tf.variable()

Initializer的改變

# embedding_w = tf.compat.v1.get_variable("embedding_w", shape=[self.vocab_size, self.config["embedding_size"]],
#                               initializer=tf.compat.v1.contrib.layers.xavier_initializer())

embedding_w = tf.Variable(tf.keras.initializers.glorot_normal()(shape=[self.vocab_size, self.config["embedding_size"]],
                                          dtype=tf.float32), name='embedding')

  

3.訓練

訓練過程包括梯度的操作、優化演算法的選擇,主要的操作如下:

模型訓練要繼承tf.Module這個api,因為訓練的時候要選擇狀態容器以便儲存模型的引數,如果用keras或estimator模組寫模型也可以繼承其他的api,具體的繼承規則可以參考這個樹形結構:

https://zhuanlan.zhihu.com/p/73575776

Trackable

|

|-- tf.Variable

|

|-- MutableHashTable

|

|-- AutoTrackable

|

|-- ListWrapper/DictWrapper

|

|-- tf.train.Checkpoint

|

|-- tf.Module

|

|-- tf.keras.layers.Layer

|

|-- tf.keras.Model

|

|-- tf.keras.Sequential

幾種狀態容器的選擇準則一般為:

僅在學習和深入研究狀態容器(或基於物件的儲存)時使用Trackable和AutoTrackable

tf.Module: 適合自定義訓練迴圈時使用

tf.keras.layers.Layer:適合實現一些中間層,比如Attention之類的,可以配合tf.keras.Sequential使用,極少看見大的模型繼承自這個型別。

tf.keras.Model:適合一些固定套路的模型(使用compile+fit)。雖然也可以自定義訓練迴圈,但是有一種殺雞用牛刀的感覺。

tf.keras.Sequential:適合一條路走到黑的(子)模型。

選擇完狀態容器後則要進行對應的訓練迴圈,也就是梯度下降的操作:

Tf1.x首先定義好train_op,然後session.run

Tf2.0則直接在epoch迴圈內使用

with tf.GradientTape() as t:

    grads = t.gradient(self.model.loss, self.model.trainable_variables)

optimizer.apply_gradients(zip(grads, self.model.trainable_variables))

也就是將sess.run裡面的操作換成一步步執行的函式流程

4.模型儲存

  Tf1.x和2.0的模型儲存變化不大,都可以儲存成checkepoint和pb這兩種格式,根據文件將api換一下就可以了,但是需要注意的是儲存的模型載入的時候版本需要和之前一致,否則在模型載入的時候可能會報錯。Summary的儲存也是一樣,需要把api替換掉。