1. 程式人生 > >在 Android 中 整合 tensorflow 並使用訓練後的模型

在 Android 中 整合 tensorflow 並使用訓練後的模型

第一次整合使用 tensorflow ,內心還是有些激動的。開始時候,並不知道怎麼進行,其實是一臉茫然的,然後就看了不少文章,關於整合的,大致有了個思路,然後就開始整合測試。這次就總結下具體整合思路和步驟。

方式: tensorflow android 而不是 tensorflow lite

基本思路

首先,在 android 上整合 tensorflow ,我們可以確定使用 TensorFlowInferenceInterface 類,不知道的自己去查資料。通過閱讀官方給的原始碼,我們大概就可以知道所需要的內容和怎麼使用。

TensorFlowInferenceInterface

建構函式

首先,在例項化此類的時候,我們需要提供 assetsmodelName ,所以可以確定要將模型放在 assets 資料夾下,同時傳入模型名字即可。其次,在初始化的時候首先執行的 prepareNativeRuntime() 函式,可以確定要載入相關的 so 庫,並且在整合後初始化後,不需要在業務程式碼中去重複載入,因為這裡已經載入過了。

  • 模型
  • 模型放在 assets 資料夾下
  • so
  • so 庫 不需要去載入
 public TensorFlowInferenceInterface(AssetManager assetManager, String model) {
    prepareNativeRuntime();

    this
.modelName = model; this.g = new Graph(); this.sess = new Session(g); this.runner = sess.runner(); .... } private void prepareNativeRuntime() { ···· try { System.loadLibrary("tensorflow_inference"); Log.i(TAG, "Successfully loaded TensorFlow native methods (RunStats error may be ignored)"
); } catch (UnsatisfiedLinkError e2) { throw new RuntimeException( "Native TF methods not found; check that the correct native" + " libraries are present in the APK."); } } }

TensorFlowInferenceInterfacefeedrunfetch 函式

feed 過載函式有很多,根據需要傳入引數即可,必須傳入的為 intputName 和資料 src 。必須要的是傳入資料的型別是什麼,不然是不成功的。

作用:注入資料

這裡寫圖片描述

run 過載函式也有好幾個,是執行執行的,需要傳入 outputName 陣列,這裡的outputName 需要和 fetch 相關函式中的一致。

作用:執行

這裡寫圖片描述

fetch 過載的函式也有很多,也是需要傳出的即可,必須傳入的是 outputName 和 要儲存結果的陣列 dst。必須要確定傳出結果的資料型別。

作用:取出結果

這裡寫圖片描述

總結下操作函式

  • inputName
  • 驗證的資料
  • outputName : runfetch 均需要
  • 結果放置的陣列

實現步驟

1.準備模型 ( .pb 檔案)

這個要演算法工程師給訓練好的模型並打包成 .pb 檔案,當然自己可以,自己來。並確定使用模型的時候所需要的輸入、輸出引數 ,即上述 feedfetch 的引數。

注意將 .pb 檔案放入 assets 資料夾下。

2.在專案中引入 tensorflow 提供的 jar 包和 so

jar 包整合方式有兩種,一是放置 libs 資料夾下,進行整合,另一種為下面辦法。我這邊使用的版本是 1.6.0

dependencies {
    implementation 'org.tensorflow:tensorflow-android:+' // 1.6.0
}

so 庫整合方式,將下載的 so 庫相關資料夾放置到 libs 下,在 appbuild.gradle 檔案指定 jniLibs.srcDirs 目錄即可, 當然其它方式也可以。

 sourceSets {
        main {
            jniLibs.srcDirs = ['libs']
        }
    }

這時目錄結構為:

這裡寫圖片描述

3.呼叫

這邊是使用 kotlin 實現的

class MainActivity : AppCompatActivity() {

    // modelName
    private val MODEL_FILE: String = "lstm_150_2_50.pb"
    // inputName
    private val INPUT_NODE: String = "input"
    // outputName
    private val OUTPUT_NODE: String = "output"

    // TensorFlowInferenceInterface
    private lateinit var mTensorFlowInferenceInterface: TensorFlowInferenceInterface

    private lateinit var btnResult: AppCompatButton

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)
        btnResult = findViewById(R.id.btn_result)

        // 初始化模型
        mTensorFlowInferenceInterface = TensorFlowInferenceInterface(assets, MODEL_FILE)

        btnResult.setOnClickListener {
            // 整理資料
            val valuesList = resources.getString(R.string.tf_values).split(",")
            val valuesDatas: ArrayList<Int> = ArrayList()
            valuesList.forEach {
                valuesDatas.add(it.toInt())
            }
            val datas = valuesDatas.toIntArray()
            // 注入: 必須確定 inputName 和資料型別
            mTensorFlowInferenceInterface.feed(INPUT_NODE, datas,這裡根據模型定義)
            // 執行:outputName
            mTensorFlowInferenceInterface.run(arrayOf(OUTPUT_NODE))
            // 輸出: 必須確定 outputName 和資料型別
            val result = IntArray(1)
            mTensorFlowInferenceInterface.fetch(OUTPUT_NODE, result)
            Log.v("Main", "輸出結果: ${Arrays.toString(result)}")
            tv_result.text = Arrays.toString(result)
        }

    }
}

4.確認 so 庫 和 模型載入成功

so 庫載入成功,日誌中會出現

Successfully loaded TensorFlow native methods (....)

模型載入成功,日誌中出現下面內容

Successfully loaded model from $modelName

這裡寫圖片描述