1. 程式人生 > 實用技巧 >使用@tf.function加快訓練速度

使用@tf.function加快訓練速度

TensorFlow 2 預設的即時執行模式(Eager Execution)為我們帶來了靈活及易除錯的特性,但為了追求更快的速度與更高的效能,我們依然希望使用 TensorFlow 1.X 中預設的圖執行模式(Graph Execution)。此時,TensorFlow 2 為我們提供了 tf.function 模組,結合 AutoGraph 機制,使得我們僅需加入一個簡單的 @tf.function 修飾符,就能輕鬆將模型以圖執行模式執行。

實現方式

只需要將我們希望以圖執行模式執行的程式碼封裝在一個函式內,並在函式前加上 @tf.function 即可。

import tensorflow as tf
from tensorflow import keras
import numpy as np
from matplotlib import pyplot as plt
import time

np.random.seed(42)  # 設定numpy隨機數種子
tf.random.set_seed(42)  # 設定tensorflow隨機數種子

# 生成訓練資料
x = np.linspace(-1, 1, 100)
x = x.astype('float32')
y = x * x + 1 + np.random.rand(100)*0.1  # y=x^2+1 + 隨機噪聲
x_train = np.expand_dims(x, 1)  # 將一維資料擴充套件為二維
y_train = np.expand_dims(y, 1)  # 將一維資料擴充套件為二維
plt.plot(x, y, '.')  # 畫出訓練資料


def create_model():
    inputs = keras.Input((1,))
    x = keras.layers.Dense(10, activation='relu')(inputs)
    outputs = keras.layers.Dense(1)(x)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model


model = create_model()  # 建立一個模型
loss_fn = keras.losses.MeanSquaredError()  # 定義損失函式
optimizer = keras.optimizers.SGD()  # 定義優化器


@tf.function  # 將訓練過程轉化為圖執行模式
def train():
    with tf.GradientTape() as tape:
        y_pred = model(x_train, training=True)  # 前向傳播,注意不要忘了training=True
        loss = loss_fn(y_train, y_pred)  # 計算損失
        tf.summary.scalar("loss", loss, epoch+1)  # 將損失寫入tensorboard
    grads = tape.gradient(loss, model.trainable_variables)  # 計算梯度
    optimizer.apply_gradients(zip(grads, model.trainable_variables))  # 使用優化器進行反向傳播
    return loss


epochs = 1000
begin_time = time.time()  # 訓練開始時間
for epoch in range(epochs):
    loss = train()

    print('epoch:', epoch+1, '\t', 'loss:', loss.numpy())  # 列印訓練資訊
end_time = time.time()  # 訓練結束時間

print("訓練時長:", end_time-begin_time)

# 預測
y_pre = model.predict(x_train)

# 畫出預測值
plt.plot(x, y_pre.squeeze())
plt.show()

通過實驗得出結論:如果不使用@tf.function,那麼訓練時間大約為3秒。如果使用@tf.function,訓練時間僅需要0.5秒。快了很多倍。

內在原理

使用@tf.function的函式在執行時會生成一個計算圖,裡面的操作就是計算圖的每個節點。下次呼叫相同的函式,且引數型別相同時,則會直接使用這個計算圖計算。若函式名不同或引數型別不同時,則會另外生成一個新的計算圖。

注意點

建議在函式內只使用 TensorFlow 的原生操作,不要使用過於複雜的 Python 語句,函式引數最好只包括 TensorFlow 張量或 NumPy 陣列。

  • 因為只有tf的原生操作才會在計算圖中生產節點。(如python的原生print()

    函式不會生成節點,而tensorflow的tf.print()會)

  • 對於Tensorflow張量或Numpy陣列作為引數的函式,只要型別相同便可重用之前的計算圖。而對於python原聲資料(如原生的整數、浮點數 1,1.5等)必須引數的值一模一樣才會重用之前的計算圖,否則的話會建立新的計算圖。

另外,一般而言,當模型由較多小的操作組成的時候, @tf.function 帶來的提升效果較大。而當模型的運算元量較少,但單一操作均很耗時的時候,則 @tf.function 帶來的效能提升不會太大。

參考

https://tf.wiki/zh_hans/basic/tools.html#tf-function