1. 程式人生 > >Android端實現深度學習

Android端實現深度學習

這裡截取了本人畢業設計關於移動端實現深度學習的章節。本章節將詳細介紹如何實現移動端呼叫深度學習模型進行使用,簡單來說就是兩個步驟,生成可供呼叫的模型和呼叫模型。這裡我們用到的人臉檢測模型為第三章節訓練出來的mAP最高的模型。
4.1 固定模型
為了使android能夠呼叫,檢測模型一定要轉成pb檔案。
4.1.1 讀取檢查點生成新的pbtxt檔案
SSD模型訓練過程中生成的checkpoints檔案和pbtxt檔案雖然包含了整個模型的引數,但是我們需要的只是生成預測值的那一部分網路結點。對於其他loss,gradient這些引數,我們並不需要。明確了我們需要的網路後,我們還要給它的輸入層和輸出層取名字,不然在android中無法呼叫這個模型。具體來說,我們需要寫一段程式碼來重新得到需要的pbtxt檔案,程式碼中關鍵的步驟如下:

1.首先定義一個有名字的佔位符,用於表示輸入輸出資料的格式。告訴系統:這裡有一個值/向量/矩陣,現在沒法給你具體的資料,不過正式執行的會補上的。這裡的正式執行指的就是後續在android中呼叫這個模型的時候,我們就要給它輸入這個格式的資料。這裡要注意的是,這個佔位符的名字一定要取。因為呼叫的時候是要明確接收資料的變數的,否則呼叫模型將出錯。本文定義如下:
這裡寫圖片描述

2.在佔位符後面緊跟著接收資料後,要處理的函式或者網路。一般情況下我們要對網路輸出的預測值進行處理,便於後續在android程式設計中的呼叫。因為Tensorflow[8,9]和android裡面的資料格式不同,最好是能都化成float一維陣列形式。SSD模型有6層的輸出,所以我們不能直接使用網路輸出的預測值,而要將其合併為一層,再對最後的輸出值取個名字,用於後續模型的讀取。

3.恢復checkpoint檔案,重新生成pbtxt檔案。在經過網路得到最終輸出值後,我們要再其後面新增saver,表明以上的網路模型引數將恢復。之後再開啟一個Tensorflow會話,從checkpoint中恢復引數,並將上述的圖(如果沒有明確取名的話,圖將為預設圖),寫入新的pbtxt檔案。為了能夠更清楚地被理解,這裡將上一步的最後輸出和這一步的程式碼寫下來:

這裡寫圖片描述
4.1.2 執行官方工具固定模型引數
首先明確執行檔案需要的引數:輸入的圖檔案(上一步驟生成的pbtxt檔案),輸入的檢查點檔案(上一步中中使用的checkpoint檔案),輸出的pb檔案存放路徑和名字,最後還要輸出結點的名字(上一步中輸出值的名字)。確定後,執行Tensorflow[8,9]官方工具freeze_graph.py檔案,便會生成可供呼叫的pb檔案。
4.2 Android專案呼叫


關於Android Studio環境的配置,大多可以自行搜尋找到教程,本文就不做細說。關鍵的點有1.修改配置檔案,2.根據tensorflow github上的教程生成libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar,3.將上一節生成的pb檔案放入專案的app/src/main/assets下,assets不存在則自行建立,4.將jar包放在app/libs下,並add as library,5.將so檔案放置在app/src/main/jniLibs下,jniLibs不存在則自行建立。

4.2.1 具體呼叫方法
環境都配置好後,要能順利呼叫深度學習模型還需要以下幾個步驟:
1.匯入jar包:

這裡寫圖片描述

2.在要呼叫的java類的類定義首行,匯入so檔案:

這裡寫圖片描述

3.定義變數和物件

這裡寫圖片描述

4.Tensorflow介面初始化

這裡寫圖片描述

5.人臉檢測模型的呼叫

這裡寫圖片描述

在得到推理輸出的資料後,將其轉變為最終需要的形式。本文中對資料的處理流程為,先進行資料格式的轉換(資料轉為xmin,ymin,xmax,ymax的形式)和挑選,具體為選擇一定數量預測分數大於0.3的資料,再將它們進行非最大抑制處理得到最終的預測位置和類別。