1. 程式人生 > 程式設計 >淺談tensorflow之記憶體暴漲問題

淺談tensorflow之記憶體暴漲問題

在用tensorflow實現一些模型的時候,有時候我們在執行程式的時候,會發現程式佔用的記憶體在不斷增長。最後記憶體溢位,程式被kill掉了。

這個問題,其實有兩個可能性。一個是比較常見,同時也是很難發現的。這個問題的解決,需要我們知道tensorflow在構圖的時候,是沒有所謂的臨時變數的,只要有operator。那麼tensorflow就會在構建的圖中增加這個operator所代表的節點。所以,在執行程式的過程中,記憶體不斷增長的原因就是在模型訓練迭代的過程中,tensorflow一直在幫你增加圖的節點。導致記憶體佔用越來越多。

那麼什麼情況下就會像上面說的那樣呢?我們舉個例子:

import tensorflow as tf

x = tf.Variable(tf.constant(1))
y = tf.constant(2)
sess = tf.Session()
sess.run(tf.global_variables_initializer())

while True:
 print(sess.run(x+y))

如果你執行上面這段程式碼,會發現在執行的過程中,記憶體佔用越來越大。原因就在於sess.run(x+y)這個語句。我們知道在tensorflow中,所有的操作都是graph的節點。而在迭代的過程中,x+y這個operator(操作)是匿名的,所以它會不斷地重複,在graph中建立節點,導致記憶體佔用越來越大。

所以要對上面的程式碼進行修改:

z = x+y
while True:
 print(sess.run(z))

這樣就不會出現問題了。

上面只是一個簡單的例子,我們可以很快發現問題。但是有時候我們的模型比較複雜,很難判斷是否在迭代的過程中一直在增加節點。那怎麼辦呢?

其實在tensorflow裡面有個函式叫做:

sess.graph.finalize()

只要每一次構圖完成後,呼叫這個函式。然後執行程式,如果你的程式在執行的過程中還一直新建節點,這個函式就會檢測到,然後就會報錯。這樣你就知道你的程式中一定有不合理的地方。

另一個導致記憶體暴漲的原因是,資料的載入問題。tensorflow現在有一個API介面,tf.data.Dataset 。這個接口裡面有個函式叫做cache(filename)。cache函式的作用是將載入進來的資料存放到filename指定的地方。但是如果我們沒有指定filename,資料就是一直儲存在記憶體中。所以,隨著迭代次數的增加,儲存在記憶體中的資料越來越多,就會導致記憶體暴漲。所以要麼不要使用這個函式,要麼就要記得新增filename引數。

以上這篇淺談tensorflow之記憶體暴漲問題就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。