1. 程式人生 > 其它 >第三課:把tensorflow,模型和測試資料匯入Android工程

第三課:把tensorflow,模型和測試資料匯入Android工程

關於Android專案的建立這裡就不做贅述了,我們直接進入主題,看下如何把機器學習庫和訓練的模型匯入一個安卓應用中。

匯入 Inference Interface

在上一課時中我們下載了 Inference Interface 的 nightly build 的 AAR 檔案,這個AAR其實就是庫檔案,需要把這個檔案匯入到專案中,通常會把這個 AAR 檔案放在 app/libs 下面:

為了匯入這個 AAR,首先需要在 app/build.gradle 中宣告一個本地的 flatDir 倉庫:

   repositories {
        flatDir {
           dirs 'libs'
      }
    }

然後指定依賴:

compile name: 'tensorflow', ext: 'aar'

最後再做一個 Project Sync 就完成了 Inference Interface 的匯入,完整的 app/build.gradle 應該是這樣的:

    ....
    repositories {
       flatDir {
          dirs 'libs'
       }
    }
    android{
      .....
    }
    dependencies{
        .....
        compile name: 'tensorflow', ext: 'aar'
        .....
    }

匯入 Pre-trained Model

在上一課時中已經下載了 Pre-trained model 的二進位制包,解壓縮這個包,會發現裡面有這些檔案:

其中model.ckpt.* 是我們在訓練自己的模型時會用到的檔案(下一系列課程我們會專門講解如何訓練自己的模型),這裡暫時忽略;frozen_inference_graph.pb 檔案正是我們需要的,開箱即用的模型檔案,把這個檔案作為一個 asset 匯入專案中。在 Android Studio 中,單擊 New | Folder | Assets Folder 命令建立一個 assets 目錄,將 frozen_inference_graph.pb 複製到 assets 目錄中,重新命名為 model.pb。

匯入測試資料

在機器學習的世界裡面,絕大部分的輸入和輸出資料都是數字,換句話來說,當訓練這個識別模型的時候,你不會告訴它這張圖片上的是人,而是告訴它這個圖片上面的物體代號是 1;模型在輸出識別結果的時候,也不會輸出人,汽車這樣的字元,而是輸出 1、2、3 這樣的數字,那麼這些數字代表的是什麼,去哪裡找這樣的對應關係呢?

首先我們要清楚一點,模型和訓練資料要存在一一對應關係。要麼是使用的資料提前訓練了模型,要麼資料當下訓練模型。因為TensorFlow Object Detection API 中的模型訓練時使用的是 MS COCO 的物體資料集合,所以我們可以在這裡(關注公眾號,後臺留言提供下載連結)下載到相應的標籤檔案, 我們開啟這個檔案:

0: unlabeled
1: person
2: bicycle
3: car
4: motorcycle
5: airplane
6: bus
7: train
8: truck
9: boat
10: traffic light
11: fire hydrant
12: street sign
13: stop sign
14: parking meter
15: bench
16: bird
17: cat
18: dog
19: horse
20: sheep
21: cow
22: elephant
23: bear

把這個檔案也存到 assets 目錄中,重新命名為 labels.txt,現在 assets 目錄應該是這樣的:

關鍵的程式碼

現在相關的資源都匯入到專案裡面了,接下來我們寫一點程式碼把模型和資料載入起來!

想一想我們需要做哪些工作:

  1. 載入模型 model.pb,獲取一個 TensorFlowInferenceInterface 來進行後續操作。
TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(getAssets(), "model.pb");

獲取到 TensorFlowInferenceInterface 的物件之後就可以在這個物件上面輸入圖片資料並獲取識別結果了。誇張的講人工智慧科學家大牛們現在研究的就是如何實現這個介面,這裡我們先不探究原理,只會用就好,以後的系列課程我們慢慢的去剖析機器學習的原理。

  1. 把資料集 labels.txt 的內容讀到陣列中,供查詢識別結果中的物體名稱;
    List<String> labels = new ArrayList<>();
    InputStream labelsInput = getAssets().open("labels.text");
    BufferedReader br = new BufferedReader(new InputStreamReader(labelsInput));
    String line;
    while ((line = br.readLine()) != null) {
          labels.add(line);
    }
    br.close();

至此,相信你應該瞭解了tensorflow框架,訓練模型,測試資料是如何在一個應用程式中實現的。是不是很簡單!下一課程讓我們一起完成一個完整的apk程式,也會把完成的原始碼後臺發給大家!