1. 程式人生 > >tensorflow模型優化技巧

tensorflow模型優化技巧

當把模型跑起來後,開始考慮如何優化model,提升效能,從網上找了找資料,並結合實際,整理一下分享給大家。

預處理資料

說道預處理資料,我覺得我自己做的還是不少,學習tensorflow時候,把mnist_soft.py跑起來以後,就開始思考mnist資料是什麼資料?帶著這個疑問我開始嘗試製作自己的資料集,期間使用了很多的方法,如二進位制檔案,直接讀取圖片進記憶體等等。其實仔細想想可以知道,如果預處理資料沒有製作好,會直接影響後續tensorflow讀取資料的速度。但如果你覺得提高機器效能可以的話,那我只能說,就像是蘭博基尼跑在泥濘的道路上。所有一定得使預處理資料乾淨, tensorflow官方提供的資料格式TFRecord 是一個很不錯的選擇,可以試著製作一下
最近釋出了預處理元件:tf.Transfrom() 有興趣點選如下地址瞭解

http://www.leiphone.com/news/201702/Yi4oU1mSwKLc8Rad.html

使用佇列

佇列的優勢就不說了,把預處理資料放進佇列,怎麼出自己控制。有一種發現昂貴的預處理管道的方法是檢視 Tensorboard 的佇列圖。如果你使用框架 QueueRunners並將摘要儲存在檔案中,這些圖都是自動生成的。這些圖會顯示你的計算機是否能夠保持佇列處在排滿的狀態。如果你發現圖當中出現了負峰值,則系統無法在計算機要處理一個批次的時間內生成新的資料。其中的一個原因上面已經說過了。根據我的經驗,最常見的原因是 min_after_dequeue 值很大。如果佇列試圖在記憶體中保留大量記錄,你的容量很容易就飽和了,這會導致交換(swapping),並且顯著降低佇列的速度。其他的原因還包括硬碟問題(例如磁碟速度慢),以及單純的是資料大,大過了你係統可以處理的程度。無論原因為何,修復這個問題都會加快你的訓練過程。

注意記憶體

確定整個模型的記憶體消耗沒有超出機器記憶體,如果超出了,必然使用swapping,而 swapping 肯定會讓輸入流程放慢,會讓你的 GPU 開始坐等新資料。如何偵探這個行為呢?一個簡單地 top,就像下文講到的 TensorBoard 佇列圖就應當足夠偵測到這樣的行為。

tensorboard

說道tensorboard,不得不說就是它對於tensorflow的視覺化分析太有用了,不僅可以對當前執行的graph
進行流式圖分析,還能進行效能監控。

# Collect tracing information during the fifth step.
if
global_step == 5: # Create an object to hold the tracing data run_metadata = tf.RunMetadata() # Run one step and collect the tracing data _, loss = sess.run([train_op, loss_op], options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata) # Add summary to the summary writer summary_writer.add_run_metadata(run_metadata, 'step%d', global_step)

之後,一個 timeline.json 檔案會被儲存到當前資料夾,跟蹤資料可以在 Tensorboard 找到。現在,你可以很容易地看到一個操作花了多長時間來計算,以及這個操作消耗了多少記憶體。開啟Tensorboard的圖檢視,選擇左側的最新執行,你就能在右邊看到效能的詳細資訊。一方面,這方便你調整模型,儘可能多地使用機器;另一方面,這方便你在訓練管道中發現瓶頸。如果你更喜歡時間軸檢視,在 Google Chromes 跟蹤事件分析工具(Trace Event Profiling Tool)中載入timeline.json 檔案就行了。
  另一個不錯的工具是 tfprof,tfprof 使用相同的功能做記憶體和執行時間分析,不過提供了更多的便利功能(feature)。額外的統計資訊需要更改程式碼。

Debug

作為開發人員,這個我就不說了。
提示:
TensorFlow 1.0 推出了新的 TFDebugger,應該很有用的,這是一篇關於它的介紹

設定運算超時時間

當我們點選執行的時候,session 也啟動了,但沒有事情都沒有什麼發生?這通常是由空佇列引起的。但是,如果你不知道是哪一個佇列導致的,那麼有一個簡單的修復方法:只需在建立會話時啟用一個操作執行超時,這樣當操作超過限制時,指令碼就會崩潰:

config = tf.ConfigProto()
config.operation_timeout_in_ms=5000
sess = tf.Session(config=config)