1. 程式人生 > >TensorFlow Lite簡單分析

TensorFlow Lite簡單分析

簡介

在桌面PC或是伺服器上使用TensorFlow訓練出來的模型檔案,不能直接用在TFLite上執行,需要使用離線工具先轉成.tflite檔案。筆者發現官方文件中很多細節介紹的都不太明確,在使用過程中需要不斷嘗試。我把自己的嘗試過的步驟分享出來,希望能幫助大家節省時間。

具體說來,tflite檔案的生成大致分為3步:

1. 在演算法訓練的指令碼中儲存圖模型檔案(GraphDef)和變數檔案(CheckPoint)。

2. 利用freeze_graph工具生成frozen的graphdef檔案。

3. 利用toco工具,生成最終的tflite檔案。

圖1. 生成tflite檔案的整個流程示意圖

第1步:匯出圖模型檔案和變數檔案

在你的演算法的訓練或推理任務的指令碼中,利用tensorflow.train中的write_graph和saver API來匯出GraphDef及Checkpoint檔案。

圖2. TensorFlow中匯出GraphDef檔案和Checkpoint檔案

其中,tf.train.write_graph一行將匯出模型的GraphDef檔案,實際上儲存了訓練的神經網路的結構圖資訊。儲存格式為protobuffer,所以檔名字尾為pb。

圖3. 匯出的GraphDef檔案

tf.train.saver.save一行匯出的是模型的變數檔案,實際上儲存了整個圖中所有變數目前的取值。

圖4. 匯出的checkpoint檔案

如圖4所示,實際上產生了4個檔案。在後續步驟中需要用到的是nsfw_model.ckpt.data-00000-of-00001這個檔案,儲存了當前神經網路各引數的取值。

第2步:生成frozen的graphdef檔案

在此步驟中,使用Tensorflow原始碼中自帶的freeze_graph工具,生成一個frozen的GraphDef檔案。

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/data/deep_learning/nsfw/model/nsfw-graph.pb --input_checkpoint=/data/deep_learning/nsfw/model/nsfw_model.ckpt --input_binary=true --output_graph=/data/deep_learning/nsfw/model/frozen_nsfw.pb --output_node_names=predictions

這裡有兩個地方容易搞錯。第一個地方,input_checkpoint引數實際上用到的檔案應該是nsfw_model.ckpt.data-00000-of-00001,但是在指定檔名的時候只需要指定nsfw_model.ckpt即可。第二個地方,是output_node_names引數,此處指定的是神經網路圖中的輸出節點的名字,是在訓練階段的Python指令碼中定義的。如下圖所示,在定義網路結構時,輸出節點的名稱為"predictions"。則最終output_node_names需要指定為“predictions”。

圖5. output_node_names引數取值與網路模型定義時的名字要對應

當然,也可以利用summarize_graph打印出模型的輸入和輸出節點,如:

bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=/data/deep_learning/nsfw/model/frozen_nsfw.pb

圖6. 輸入節點名稱為input圖7. 輸出節點名稱為predictions

第3步:生成最終的tflite檔案

在此步驟中,使用Tensorflow原始碼中自帶的toco工具,生成一個可供TensorFlow Lite框架使用tflite檔案。其中input_arrays和output_arrays的名稱需要與定義網路型別時取的名稱保持一致。

bazel run --config=opt tensorflow/contrib/lite/toco:toco --input_file=/data/deep_learning/nsfw/model/frozen_nsfw.pb --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE --output_file=/data/deep_learning/nsfw/model/nsfw.lite --inference_type=FLOAT --input_type=FLOAT --input_arrays=input --output_arrays=predictions --input_shapes=1,224,224,3

生成的nsfw.lite檔案即可用於TensorFlow Lite應用。