Tensorflow lite up up~
阿新 • • 發佈:2018-11-08
簡介
在桌面PC或是伺服器上使用TensorFlow訓練出來的模型檔案,不能直接用在TFLite上執行,需要使用離線工具先轉成.tflite檔案。筆者發現官方文件中很多細節介紹的都不太明確,在使用過程中需要不斷嘗試。我把自己的嘗試過的步驟分享出來,希望能幫助大家節省時間。
具體說來,tflite檔案的生成大致分為3步:
1. 在演算法訓練的指令碼中儲存圖模型檔案(GraphDef)和變數檔案(CheckPoint)。
2. 利用freeze_graph工具生成frozen的graphdef檔案。
3. 利用toco(Tensorflow Optimizing COnverter)工具,生成最終的tflite檔案。
圖1. 生成tflite檔案的整個流程示意圖
第1步:匯出圖模型檔案和變數檔案
在你的演算法的訓練或推理任務的指令碼中,利用tensorflow.train中的write_graph和saver API來匯出GraphDef及Checkpoint檔案。
這樣我們可以拿到模型的pb檔案或ckpt檔案
第2步:freeze graph
可參考Tensorflow C++ API線上預測服務的文件。即我的另一篇部落格《使用TensorFlow C++ API構建線上預測服務》 https://blog.csdn.net/lsj1342/article/details/82752951
第3步:生成最終的tflite檔案
構建toco工具
bazel build tensorflow/contrib/lite/toco:toco
構建成功後,可在tensorflow/bazel-bin/tensorflow/contrib/lite/toco中看到可執行檔案toco,將路徑加入環境變數
toco --input_file=/data/liusijia/tensorflowTest/freeze_graph.pb \ //路徑自己指定,這裡是freeze_graph後pb檔案 --input_format=TENSORFLOW_GRAPHDEF \ --output_format=TFLITE \ --output_file=/data/liusijia/tensorflowTest/tmp.tflite \ //路徑自己指定 --inference_type=FLOAT \ --input_type=FLOAT \ --input_arrays=Placeholder \ //輸入節點的名稱 --output_arrays=Softmax \ //輸出節點的名稱 --input_shapes=1,784 //輸入節點的維度