Tensorflow.js執行Python下訓練的模型
一、引言
這兩天的專案需要用到Tensorflow.js來實現一個AI,儘管說Tensorflow.js本身是有訓練模型的功能的,不過考慮到javascript這個東西載入資源要考慮跨域問題等種種因素。。最終還是決定使用python的tensorflow來訓練模型,然後利用js端來使用模型進行運算,那麼關鍵問題就是:js如何載入python下訓練的模型
【webAI】Tensorflow.js載入預訓練的model
這位博主的部落格給了我很大的幫助,前兩步按照他的教程來做都是沒有什麼問題的,不過其實還是有一些潛在的坑或者對於我這種前端小白不太友好的地方,這裡我把我的整個過程都來敘述一遍吧。
注:首先在命令列中執行
pip install tensorflowjs
安裝模型轉換的部分,否則轉換可能會報錯
二、python部分
這裡我用了一個更加簡單的例子,and.py,讓神經網路來學習異或運算(忽略這個"and"emmm),這裡我直接把python程式碼貼出來:
#coding=utf-8# import tensorflow as tf import numpy as np x_data=[[0.0,0.0],[0.0,1.0],[1.0,0.0],[1.0,1.0]] #訓練資料 y_data=[[0.0],[1.0],[1.0],[0.0]] #標籤 x_test=[[0.0,1.0],[1.0,1.0]] #測試資料 xs=tf.placeholder(tf.float32,[None,2]) ys=tf.placeholder(tf.float32,[None,1]) #定義x和y的佔位符作為將要輸入神經網路的變數 #構建隱藏層,假設隱藏層有20個神經元 W1=tf.Variable(tf.random_normal([2,10])) B1=tf.Variable(tf.zeros([1,10])+0.1) out1=tf.nn.relu(tf.matmul(xs,W1)+B1) #構建輸出層,假設輸出層有一個神經元 W2=tf.Variable(tf.random_normal([10,1])) B2=tf.Variable(tf.zeros([1,1])+0.1) prediction=tf.add(tf.matmul(out1,W2),B2,name="model") #計算預測值和真實值之間的誤差 loss=tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction),reduction_indices=[1])) train_step=tf.train.GradientDescentOptimizer(0.1).minimize(loss) init=tf.global_variables_initializer() #初始化所有變數 sess=tf.Session() sess.run(init) for i in range(40): #訓練10次 sess.run(train_step,feed_dict={xs:x_data,ys:y_data}) print(sess.run(loss,feed_dict={xs:x_data,ys:y_data})) #列印損失值 re=sess.run(prediction,feed_dict={xs:x_test}) print(re) for x in re: if x[0]>0.5: print(1) else: print(0) # 儲存模型為saved_model tf.saved_model.simple_save(sess, "./saved_model",inputs={"x": xs, }, outputs={"model": prediction, })
這個程式碼非常簡單,由於異或運算就四種可能性,所以資料也很小,這裡應該也很好理解,儲存的部分也是照著那個博主的部分來寫的。
三、模型轉換
首先執行這個python程式,會得到如圖所示的資料夾:
然後,在控制檯中執行如下的命令:
tensorflowjs_converter --input_format=tf_saved_model --output_node_names="model" --saved_model_tags=serve ./saved_model ./web_model
就目前看來,這裡的 --output_node_names應該就和python檔案中的outputs中的字典鍵名一致。
總之,這一步應該也沒有什麼問題,執行成功後,會生成如圖所示的web_model資料夾。
好了,接下來是我踩的第一個坑,在web_model下有三個檔案,那位博主只說了其中兩個的作用,於是我傻乎乎的以為就需要這兩個,然後在最後一步瀏覽器執行的時候,一直輸出無窮大。。總之,只要記住,這個資料夾下的東西待會都是要用的就好了。
四、在web中執行
好了,這一步對我來說就是一個巨坑了,作為一個前端小白,我分別見識到了"瀏覽器跨域問題"和"ES6的import語句需要編譯才能被瀏覽器識別"兩大問題,第二個問題昨天花了兩個多小時學會了簡單的ES6標準編譯結果今天查資料發現不用import語句這程式也能在瀏覽器中跑emmm,這樣,我分別來講述一下這兩個問題吧:
1.避免使用Import語句
這個很好做到,只需要把它變成script的引用就行,這裡我把我測試的檔案的原始碼貼出來:
<!doctype html>
<html lang="en">
<head>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"> </script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-converter"></script>
</head>
<body>
<img style="display: none" id="cat" src="high-detail.jpg" width="224" height="224">
<script>
const MODEL_URL = './tensorflowjs_model.pb'
const WEIGHTS_URL = './weights_manifest.json'
async function fun(){
const model = await tf.loadFrozenModel(MODEL_URL, WEIGHTS_URL)
const cs = tf.tensor([[1.0,1.0],[0.0,0.0]])
cs.print()
model.predict(cs).print()
}
fun()
</script>
</body>
</html>
基本上也是按那個博主的部落格改的吧,有兩個地方需要注意
第一個是loadFrozenModel這個函式是被tf這個物件呼叫的,這一點也是我今天查資料時發現的,原文是這樣的:
The difference here is there is no more "tf_converter", only "tf". You call "tf.loadFrozenModel" like you would any other tf op.
大概意思就是說以後都會變成tf.loadxxx這樣子的了
第二個是,這個函式寫完以後,記得最後要呼叫,否則開啟網頁什麼都看不到,畢竟我也是花了五分鐘看著空白的控制檯懵逼的人。
2.瀏覽器跨域問題
這個世界就是這麼神奇,我寫部落格這會居然可以直接訪問同目錄下的pb檔案和json檔案了,不過之前一直都會報CORS問題,之前我是通過安裝配置一個web伺服器來訪問頁面的,我用了tomcat,將HTML檔案和pb、json等三個檔案置於其webapps的目錄下然後啟動伺服器就能正常訪問了:
當然,如果你根本就沒遇到跨域問題,只要忽略這一步即可。