C++呼叫tensorflow 訓練好的模型
阿新 • • 發佈:2019-02-09
這個東西我弄了好久!!!!
中間都想放棄了,但是我不服啊,還好弄出來了!!!分享給大家,希望可以幫助到大家哦~嘻嘻嘻
我看到有些說只能安裝32位的py,我開始也是這樣的,但是安裝TensorFlow做測試的時候,就一直有問題,所以呀,我就換成了ananconda
安裝這個網上一大堆,自己可以好好看看哦!
然後就是新建一個C++的工程
1.把ananconda的減壓後,將裡面的inlcude和libs兩個資料夾拷貝到sln的同一級目錄下
2.然後開啟libs,複製一份python35.lib,並命名為python35_d.lib
3.C++->常規->附加包含目錄,輸入..\include;
4.連結器->常規->附加目錄項,輸入..\libs;
5.連結器->輸入->附加依賴項,新增python35_d.lib;
6. python35.dll拷貝到Debug目錄下(與Test.exe同目錄)
7.將py拷貝到Debug目錄下(與Test.exe同目錄)
8.將你訓練好的模型新建一個資料夾拷貝到C++專案資料夾裡來
1-----測試圖片 2.3就是py裡面的東西 4就是你的模型
好了 我開始貼程式碼了~~~
C++:
void testImage(char * path) { try{ Py_Initialize(); PyEval_InitThreads(); PyObject*pFunc = NULL; PyObject*pArg = NULL; PyObject* module = NULL; module = PyImport_ImportModule("myModel");//myModel:Python檔名 if (!module) { printf("cannot open module!"); //Py_Finalize(); } pFunc = PyObject_GetAttrString(module, "test_one_image");//test_one_image:Python檔案中的函式名 if (!pFunc) { printf("cannot open FUNC!"); //Py_Finalize(); } //開始呼叫model pArg = Py_BuildValue("(s)", path); if (module != NULL) { PyGILState_STATE gstate; gstate = PyGILState_Ensure(); PyEval_CallObject(pFunc, pArg); PyGILState_Release(gstate); } } catch (exception& e) { cout << "Standard exception: " << e.what() << endl; } }
python:
def test_one_image(test_dir): image = Image.open(test_dir) plt.imshow(image) plt.axis('off') plt.show() image = image.resize([32, 32]) image_array = np.array(image) with tf.Graph().as_default(): image = tf.cast(image_array, tf.float32) image = tf.reshape(image, [1, 32, 32, 3])#調整image的形狀 p = mmodel(image, 1) logits = tf.nn.softmax(p) x = tf.placeholder(tf.float32, shape=[32, 32, 3]) saver = tf.train.Saver() model_path='E:/MyProject/MachineLearning/call64PY/test/model/' with tf.Session() as sess: sess.run(tf.global_variables_initializer()) ckpt = tf.train.get_checkpoint_state(model_path) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, tf.train.latest_checkpoint('E:/MyProject/MachineLearning/call64PY/test/model/')) saver.restore(sess, ckpt.model_checkpoint_path) print('載入ckpt成功!') else: print('error') prediction = sess.run(logits, feed_dict={x: image_array}) max_index = np.argmax(prediction) if max_index == 0: print('case0: %.6f' % prediction[:, 0]) return result else: print('-case1: %.6f' % prediction[:, 1]) return result2
這裡面好多坑啊~~~