使用@tf.function加快訓練速度
TensorFlow 2 預設的即時執行模式(Eager Execution)為我們帶來了靈活及易除錯的特性,但為了追求更快的速度與更高的效能,我們依然希望使用 TensorFlow 1.X 中預設的圖執行模式(Graph Execution)。此時,TensorFlow 2 為我們提供了 tf.function
模組,結合 AutoGraph 機制,使得我們僅需加入一個簡單的 @tf.function
修飾符,就能輕鬆將模型以圖執行模式執行。
實現方式
只需要將我們希望以圖執行模式執行的程式碼封裝在一個函式內,並在函式前加上 @tf.function
即可。
import tensorflow as tf from tensorflow import keras import numpy as np from matplotlib import pyplot as plt import time np.random.seed(42) # 設定numpy隨機數種子 tf.random.set_seed(42) # 設定tensorflow隨機數種子 # 生成訓練資料 x = np.linspace(-1, 1, 100) x = x.astype('float32') y = x * x + 1 + np.random.rand(100)*0.1 # y=x^2+1 + 隨機噪聲 x_train = np.expand_dims(x, 1) # 將一維資料擴充套件為二維 y_train = np.expand_dims(y, 1) # 將一維資料擴充套件為二維 plt.plot(x, y, '.') # 畫出訓練資料 def create_model(): inputs = keras.Input((1,)) x = keras.layers.Dense(10, activation='relu')(inputs) outputs = keras.layers.Dense(1)(x) model = keras.Model(inputs=inputs, outputs=outputs) return model model = create_model() # 建立一個模型 loss_fn = keras.losses.MeanSquaredError() # 定義損失函式 optimizer = keras.optimizers.SGD() # 定義優化器 @tf.function # 將訓練過程轉化為圖執行模式 def train(): with tf.GradientTape() as tape: y_pred = model(x_train, training=True) # 前向傳播,注意不要忘了training=True loss = loss_fn(y_train, y_pred) # 計算損失 tf.summary.scalar("loss", loss, epoch+1) # 將損失寫入tensorboard grads = tape.gradient(loss, model.trainable_variables) # 計算梯度 optimizer.apply_gradients(zip(grads, model.trainable_variables)) # 使用優化器進行反向傳播 return loss epochs = 1000 begin_time = time.time() # 訓練開始時間 for epoch in range(epochs): loss = train() print('epoch:', epoch+1, '\t', 'loss:', loss.numpy()) # 列印訓練資訊 end_time = time.time() # 訓練結束時間 print("訓練時長:", end_time-begin_time) # 預測 y_pre = model.predict(x_train) # 畫出預測值 plt.plot(x, y_pre.squeeze()) plt.show()
通過實驗得出結論:如果不使用@tf.function
,那麼訓練時間大約為3秒。如果使用@tf.function
,訓練時間僅需要0.5秒。快了很多倍。
內在原理
使用@tf.function
的函式在執行時會生成一個計算圖,裡面的操作就是計算圖的每個節點。下次呼叫相同的函式,且引數型別相同時,則會直接使用這個計算圖計算。若函式名不同或引數型別不同時,則會另外生成一個新的計算圖。
注意點
建議在函式內只使用 TensorFlow 的原生操作,不要使用過於複雜的 Python 語句,函式引數最好只包括 TensorFlow 張量或 NumPy 陣列。
-
因為只有tf的原生操作才會在計算圖中生產節點。(如python的原生
print()
tf.print()
會) -
對於Tensorflow張量或Numpy陣列作為引數的函式,只要型別相同便可重用之前的計算圖。而對於python原聲資料(如原生的整數、浮點數 1,1.5等)必須引數的值一模一樣才會重用之前的計算圖,否則的話會建立新的計算圖。
另外,一般而言,當模型由較多小的操作組成的時候, @tf.function
帶來的提升效果較大。而當模型的運算元量較少,但單一操作均很耗時的時候,則 @tf.function
帶來的效能提升不會太大。
參考
https://tf.wiki/zh_hans/basic/tools.html#tf-function