TensorFlow中對訓練後的神經網路引數(權重、偏置)提取
基於TensorFlow可以輕而易舉搭建一個神經網路,而且很好地支援GPU加速訓練。但基於TensorFlow的預測過程,往往需要在嵌入式裝置上才能得以應用。對於我目前做的工作而言,用TF搭建神經網路以及用GPU加速訓練過程的主要用處就是:獲取訓練後的引數(權重和偏置),將這些引數直接放到嵌入式板卡如FPGA中,以其低功耗、高效能、低延時等特點完成嵌入式AI工程。那麼,提取出TF訓練後的引數變成很重要的過程。
不少IDE可以提供視覺化的引數顯示,本文介紹的方法是不依賴IDE的神經網路引數提取。我們知道,在training之後,模型將會被儲存在一個特定的路徑下。model裡面包括包括加算圖結構、節點資訊、引數
在TensorFlow裡,提供了tf.train.NewCheckpointReader來檢視model.ckpt檔案中儲存的變數資訊。
引數就是trainable集合的變數,所以也可以通過這個tf.train.NewCheckpointReader來檢視,具體程式碼如下:
import tensorflow as tf import numpy as np reader = tf.train.NewCheckpointReader('llw/MNIST_model/mnist_model-29001') all_variables = reader.get_variable_to_shape_map() w1 = reader.get_tensor("layer1/weights") print(type(w1)) print(w1.shape) print(w1[0])
輸出為:
<class 'numpy.ndarray'>
(784, 500)
[ 2.24018339e-02 -2.00362392e-02 -1.12209506e-02 6.77579222e-03
-9.59016196e-03 1.21959345e-02 -9.51156951e-03 -1.60046462e-02
-1.37826744e-02 -1.76466629e-02 -2.11188430e-03 3.54206143e-03
-2.03107391e-02 2.13961536e-03 -4.41462384e-04 -1.93272587e-02
-3.71702737e-03 2.22449750e-03 2.98950635e-02 -2.47442089e-02
-7.97873642e-03 2.99713714e-03 -1.77890640e-02 2.59044971e-02
9.38970014e-04 1.46359997e-02 -2.18281448e-02 1.55605981e-02
-2.44196616e-02 -2.03805566e-02 -7.10553257e-03 -8.46040528e-03
-1.21834688e-02 -1.71028115e-02 -1.73374973e-02 1.58206956e-03
7.28264870e-03 -2.08463762e-02 -7.46442471e-03 7.55013386e-03
4.64899749e-05 3.26069025e-03 -1.22860866e-02 -2.33450923e-02
8.73958052e-04 -2.50798613e-02 -2.91012623e-03 2.18578596e-02
....
上述的檔案路徑llw/MNIST_model/mnist_model-29001,為checkpoint指定的路徑:
model_checkpoint_path: "mnist_model-29001"
上述程式只是輸出第一個節點的引數(500個),總引數光第一層引數就有784X500個,不太適合全部列印在螢幕上。
所以可以通過python 的file write()函式將引數寫到txt文字中。在這裡不做詳述。