1. 程式人生 > >vs2015呼叫tensorflow.dll實現mnist識別

vs2015呼叫tensorflow.dll實現mnist識別

1.準備檔案

一共需要三個檔案,dll檔案,lib檔案,和tensorflow原始碼,其中原始碼指的是build目錄底下已經生成過工程檔案併成功編譯的,因為有很多接下來專案中包含的標頭檔案是在編譯之後產生的。準備檔案
建議把三個檔案放到後面新建的vs2015專案的sln檔案所在目錄下。

2. 新建vs2015專案, 並配置專案的屬性,主要四個部分:

(1) 附加包含目錄(新增標頭檔案所在目錄)
(2) 附加庫目錄(新增lib檔案所在目錄)
(3) 附加依賴項(新增lib檔案具體檔名)
(4) 執行時環境(新增dll所在目錄)

新建vs2015專案後,右鍵專案->屬性:
專案屬性
選擇平臺

,由於我之前編譯tensorflow.dll檔案時選擇的是release64位,因此這裡配置平臺選擇release和x64:
選擇平臺
接下來是

(1) 附加包含目錄(新增標頭檔案所在目錄)

點選C/C++ ->附加包含目錄 ->編輯:
輸入下面的包含目錄

C:\eigen
$(SolutionDir)tensorflow-r1.6
$(SolutionDir)tensorflow-r1.6\tensorflow\contrib\cmake\build
$(SolutionDir)tensorflow-r1.6\tensorflow\contrib\cmake\build\protobuf\src\protobuf\src
$(SolutionDir)tensorflow-r1.6\tensorflow\contrib\cmake\build\external\nsync\public
$(SolutionDir)tensorflow-r1.6\tensorflow\contrib\cmake\build\external\eigen_archive

其中C:\eigen表示eigen所在目錄,eigen下載地址,下載完成解壓即可。
$(SolutionDir)表示專案解決方案sln檔案所在目錄。
tensorflow-r1.6是tensorflow編譯後原始碼的根目錄。

在這裡插入圖片描述
附加包含目錄

(2) 附加庫目錄(新增lib檔案所在目錄)

連結器 -> 常規 -> 附加庫目錄 -> 編輯,輸入:

$(SolutionDir)

這裡因為tensorflow.lib檔案在專案根目錄下,所以輸入$(SolutionDir),由於專案用到了opencv,因此還要輸入opencv的lib所在目錄。
附加庫目錄

(3) 附加依賴項(新增lib檔案具體檔名)

連結器 -> 輸入 -> 附加依賴項 -> 編輯,步驟與上面的附加庫目錄相同,輸入:

tensorflow.lib

(4) 執行時環境(新增dll所在目錄)

配置屬性 -> 除錯 -> 環境 -> 編輯,輸入

path = $(SolutionDir)

執行時環境

3.轉換tensorflow模型(ckpt->pb)

一般tensorflow通過saver進行模型儲存的時候,會將模型的引數和圖結構分開儲存。如圖:
ckpt格式
其中meta是模型圖結構,data是模型權值引數。而pb模型則是將圖結構和權值引數繫結到一起,具體來說就是將data裡的變數引數轉成常量引數寫到pb檔案中去。那麼如何將ckpt格式的模型轉換成pb模型呢?
理論和詳細介紹參見部落格。這裡具體講一下如何實現。

(1)確定輸出節點

在轉換ckpt檔案到pb檔案過程中,需要確定輸出節點。一般來說,訓練模型的原始碼裡面會定義節點名字。如果沒辦法獲得原始碼,或者原始碼裡面沒有定義節點名字的話,可以通過以下程式碼載入ckpt格式的模型,打印出所有的節點。

#encoding=utf-8
import tensorflow as tf
input_checkpoint = './model.ckpt'#模型路徑
#恢復圖的結構
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 
with tf.Session() as sess:
    saver.restore(sess, input_checkpoint) #恢復圖並得到資料
    for op in tf.get_default_graph().get_operations():
        print(op.name, op.values()) #列印節點名字和張量維度資訊
        

本專案中mnist模型打印出的節點如圖所示(部分):
打印出的節點

tensorflow 每呼叫一次類似 tf.XXX()函式,就代表了生成一個節點,那麼tf.variable()就會生成一個Variable節點,如果呼叫了很多次,就生成Variable_1,Variable_2…以此類推。
我們看一個簡單例子:

#encoding=utf-8
import tensorflow as tf

#下面程式碼定義了兩個constant常量i1和i3,
#兩個Variable變數i2和i4
#兩個Add操作tf.add_n()

def test_ops():
	#定義名稱空間”input1"
	with tf.name_scope("input1"):
		i1 = tf.constant([1.0, 2.0, 3.0],name="input1")

	i2 = tf.Variable(tf.random_uniform([3]))
	i3 = tf.constant([1.0, 2.0, 4.0])
	i4 = tf.Variable(tf.random_uniform([3], name="input2"))

	output = tf.add_n([i1,i2],name = "add")
	output1 = tf.add_n([i3,i4])

	#將圖結構寫入log,以便tensorboard呼叫視覺化
	writer = tf.summary.FileWriter("./log",tf.get_default_graph())
	writer.close()

test_ops()

視覺化節點如下:
視覺化節點
每一個 tf.XXX()函式都產生了一個節點,包括tf.random_uniform(),所以i2和i4可以看成是兩個節點產生的一個張量。而i1由於定義了名字input,所以沒有顯示預設的Const節點名,第一個tf.add_n()函式定義了名字“add”所以直接顯示名字,第二個add_n()就顯示了預設的AddN。

從例子可以看出,我們需要從打印出的所有節點中,根據shape和變數名,推測出最終輸出的節點,將推理出該節點的所有節點涉及到的變數固化(常量化)。例如mnist的最終輸出肯定是一個維度為(,10)的張量。

文章目錄