第三課:把tensorflow,模型和測試資料匯入Android工程
關於Android專案的建立這裡就不做贅述了,我們直接進入主題,看下如何把機器學習庫和訓練的模型匯入一個安卓應用中。
匯入 Inference Interface
在上一課時中我們下載了 Inference Interface 的 nightly build 的 AAR 檔案,這個AAR其實就是庫檔案,需要把這個檔案匯入到專案中,通常會把這個 AAR 檔案放在 app/libs 下面:
為了匯入這個 AAR,首先需要在 app/build.gradle 中宣告一個本地的 flatDir 倉庫:
repositories { flatDir { dirs 'libs' } }
然後指定依賴:
compile name: 'tensorflow', ext: 'aar'
最後再做一個 Project Sync 就完成了 Inference Interface 的匯入,完整的 app/build.gradle 應該是這樣的:
.... repositories { flatDir { dirs 'libs' } } android{ ..... } dependencies{ ..... compile name: 'tensorflow', ext: 'aar' ..... }
匯入 Pre-trained Model
在上一課時中已經下載了 Pre-trained model 的二進位制包,解壓縮這個包,會發現裡面有這些檔案:
其中model.ckpt.* 是我們在訓練自己的模型時會用到的檔案(下一系列課程我們會專門講解如何訓練自己的模型),這裡暫時忽略;frozen_inference_graph.pb 檔案正是我們需要的,開箱即用的模型檔案,把這個檔案作為一個 asset 匯入專案中。在 Android Studio 中,單擊 New | Folder | Assets Folder 命令建立一個 assets 目錄,將 frozen_inference_graph.pb 複製到 assets 目錄中,重新命名為 model.pb。
匯入測試資料
在機器學習的世界裡面,絕大部分的輸入和輸出資料都是數字,換句話來說,當訓練這個識別模型的時候,你不會告訴它這張圖片上的是人,而是告訴它這個圖片上面的物體代號是 1;模型在輸出識別結果的時候,也不會輸出人,汽車這樣的字元,而是輸出 1、2、3 這樣的數字,那麼這些數字代表的是什麼,去哪裡找這樣的對應關係呢?
首先我們要清楚一點,模型和訓練資料要存在一一對應關係。要麼是使用的資料提前訓練了模型,要麼資料當下訓練模型。因為TensorFlow Object Detection API 中的模型訓練時使用的是 MS COCO 的物體資料集合,所以我們可以在這裡(關注公眾號,後臺留言提供下載連結)下載到相應的標籤檔案, 我們開啟這個檔案:
0: unlabeled
1: person
2: bicycle
3: car
4: motorcycle
5: airplane
6: bus
7: train
8: truck
9: boat
10: traffic light
11: fire hydrant
12: street sign
13: stop sign
14: parking meter
15: bench
16: bird
17: cat
18: dog
19: horse
20: sheep
21: cow
22: elephant
23: bear
把這個檔案也存到 assets 目錄中,重新命名為 labels.txt,現在 assets 目錄應該是這樣的:
關鍵的程式碼
現在相關的資源都匯入到專案裡面了,接下來我們寫一點程式碼把模型和資料載入起來!
想一想我們需要做哪些工作:
- 載入模型 model.pb,獲取一個 TensorFlowInferenceInterface 來進行後續操作。
TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(getAssets(), "model.pb");
獲取到 TensorFlowInferenceInterface 的物件之後就可以在這個物件上面輸入圖片資料並獲取識別結果了。誇張的講人工智慧科學家大牛們現在研究的就是如何實現這個介面,這裡我們先不探究原理,只會用就好,以後的系列課程我們慢慢的去剖析機器學習的原理。
- 把資料集 labels.txt 的內容讀到陣列中,供查詢識別結果中的物體名稱;
List<String> labels = new ArrayList<>();
InputStream labelsInput = getAssets().open("labels.text");
BufferedReader br = new BufferedReader(new InputStreamReader(labelsInput));
String line;
while ((line = br.readLine()) != null) {
labels.add(line);
}
br.close();
至此,相信你應該瞭解了tensorflow框架,訓練模型,測試資料是如何在一個應用程式中實現的。是不是很簡單!下一課程讓我們一起完成一個完整的apk程式,也會把完成的原始碼後臺發給大家!