在 Android 中 整合 tensorflow 並使用訓練後的模型
第一次整合使用 tensorflow
,內心還是有些激動的。開始時候,並不知道怎麼進行,其實是一臉茫然的,然後就看了不少文章,關於整合的,大致有了個思路,然後就開始整合測試。這次就總結下具體整合思路和步驟。
方式:
tensorflow android
而不是tensorflow lite
基本思路
首先,在 android
上整合 tensorflow
,我們可以確定使用 TensorFlowInferenceInterface 類,不知道的自己去查資料。通過閱讀官方給的原始碼,我們大概就可以知道所需要的內容和怎麼使用。
TensorFlowInferenceInterface
首先,在例項化此類的時候,我們需要提供 assets
和 modelName
,所以可以確定要將模型放在 assets
資料夾下,同時傳入模型名字即可。其次,在初始化的時候首先執行的 prepareNativeRuntime()
函式,可以確定要載入相關的 so
庫,並且在整合後初始化後,不需要在業務程式碼中去重複載入,因為這裡已經載入過了。
- 模型
- 模型放在
assets
資料夾下 so
庫so
庫 不需要去載入
public TensorFlowInferenceInterface(AssetManager assetManager, String model) {
prepareNativeRuntime();
this .modelName = model;
this.g = new Graph();
this.sess = new Session(g);
this.runner = sess.runner();
....
}
private void prepareNativeRuntime() {
····
try {
System.loadLibrary("tensorflow_inference");
Log.i(TAG, "Successfully loaded TensorFlow native methods (RunStats error may be ignored)" );
} catch (UnsatisfiedLinkError e2) {
throw new RuntimeException(
"Native TF methods not found; check that the correct native"
+ " libraries are present in the APK.");
}
}
}
TensorFlowInferenceInterface
中 feed
、run
和 fetch
函式
feed
過載函式有很多,根據需要傳入引數即可,必須傳入的為 intputName
和資料 src
。必須要的是傳入資料的型別是什麼,不然是不成功的。
作用:注入資料
run
過載函式也有好幾個,是執行執行的,需要傳入 outputName
陣列,這裡的outputName
需要和 fetch
相關函式中的一致。
作用:執行
fetch
過載的函式也有很多,也是需要傳出的即可,必須傳入的是 outputName
和 要儲存結果的陣列 dst
。必須要確定傳出結果的資料型別。
作用:取出結果
總結下操作函式
inputName
- 驗證的資料
outputName
:run
和fetch
均需要- 結果放置的陣列
實現步驟
1.準備模型 ( .pb
檔案)
這個要演算法工程師給訓練好的模型並打包成 .pb
檔案,當然自己可以,自己來。並確定使用模型的時候所需要的輸入、輸出引數 ,即上述 feed
和 fetch
的引數。
注意將 .pb
檔案放入 assets
資料夾下。
2.在專案中引入 tensorflow
提供的 jar
包和 so
庫
jar
包整合方式有兩種,一是放置 libs
資料夾下,進行整合,另一種為下面辦法。我這邊使用的版本是 1.6.0
。
dependencies {
implementation 'org.tensorflow:tensorflow-android:+' // 1.6.0
}
so
庫整合方式,將下載的 so
庫相關資料夾放置到 libs
下,在 app
下 build.gradle
檔案指定 jniLibs.srcDirs
目錄即可, 當然其它方式也可以。
sourceSets {
main {
jniLibs.srcDirs = ['libs']
}
}
這時目錄結構為:
3.呼叫
這邊是使用 kotlin 實現的
class MainActivity : AppCompatActivity() {
// modelName
private val MODEL_FILE: String = "lstm_150_2_50.pb"
// inputName
private val INPUT_NODE: String = "input"
// outputName
private val OUTPUT_NODE: String = "output"
// TensorFlowInferenceInterface
private lateinit var mTensorFlowInferenceInterface: TensorFlowInferenceInterface
private lateinit var btnResult: AppCompatButton
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
btnResult = findViewById(R.id.btn_result)
// 初始化模型
mTensorFlowInferenceInterface = TensorFlowInferenceInterface(assets, MODEL_FILE)
btnResult.setOnClickListener {
// 整理資料
val valuesList = resources.getString(R.string.tf_values).split(",")
val valuesDatas: ArrayList<Int> = ArrayList()
valuesList.forEach {
valuesDatas.add(it.toInt())
}
val datas = valuesDatas.toIntArray()
// 注入: 必須確定 inputName 和資料型別
mTensorFlowInferenceInterface.feed(INPUT_NODE, datas,這裡根據模型定義)
// 執行:outputName
mTensorFlowInferenceInterface.run(arrayOf(OUTPUT_NODE))
// 輸出: 必須確定 outputName 和資料型別
val result = IntArray(1)
mTensorFlowInferenceInterface.fetch(OUTPUT_NODE, result)
Log.v("Main", "輸出結果: ${Arrays.toString(result)}")
tv_result.text = Arrays.toString(result)
}
}
}
4.確認 so
庫 和 模型載入成功
so
庫載入成功,日誌中會出現
Successfully loaded TensorFlow native methods (....)
模型載入成功,日誌中出現下面內容
Successfully loaded model from $modelName