1. 程式人生 > >基於現有 TensorFlow 模型構建 Android 應用

基於現有 TensorFlow 模型構建 Android 應用

轉自
在之前寫的一篇文章 TensorFlow,從一個 Android Demo 開始 中通過編譯官方的 Demo 接觸到了 TensorFlow 實際使用場景。這篇文章打算從一個Android 開發者的角度切入,看看構建一個基於 TensorFlow 的 Android 應用的完整流程。
相關程式碼可檢視:GitHub 專案地址

通過 TensorFlow 用已有模型構建 Android 應用

在 Google 的 TensorFlow examples project 中,有一個 Sample 叫作 TF Classify,它通過使用 Google Inception 模型對實時的相機影象幀進行分類,並顯示展示當前影象的分類推斷結果。
在這裡插入圖片描述

下面我們就基於這個現有模型,在 Android 平臺上實現一個可以對物品進行分類的影象識別應用。

獲取資料模型

這裡可以直接下載 Google 提供的一個數據模型 inception5h.zip ,其中 .pb 字尾的檔案是已經訓練好的模型,而 .txt 對應的是訓練資料包含的所有標籤。
在這裡插入圖片描述
這個模型可對 1008 種物品識別分類,具體有哪些類可以檢視標籤資訊,至於每個類別到底訓練了多少張圖片就不得而知了。
在這裡插入圖片描述

在 Android 專案中引入 TensorFlow

跟在專案中整合其他第三庫一樣,先在 build.gradle 中新增對 TensorFlow 的依賴。

compile 'org.tensorflow:tensorflow-android:1.6.0'

這裡我們直接使用了 Google 為我們編譯好的 TensorFlow 現成庫了,如果你想自行對 TensorFlow 進行 NDK 交叉編譯得到庫檔案也可以。

影象識別功能的實現

  • 複製模型檔案到專案 assets 資料夾

如下圖所示,我們在專案 assets 資料夾下建立一個 model 資料夾,並把之前下載的 inception5h.zip 解壓後的全部檔案複製到該資料夾下。
在這裡插入圖片描述

  • 新增模型呼叫的相關類

因為我們要實現的功能和官方 demo 相似,只是訓練的有所模型不同。既然對模型的使用方式是一樣的,那這裡就直接使用 Google demo 專案中提供的

Classifier.javaTensorFlowImageClassifier.java 這兩個類來實現。

我們可以先跳過這部分內容的具體實現,等到對整體流程有個大致認識後再回過頭來消化掉,這樣可以更好地去理解。

這裡我們重點關注下面兩個方法,一個是 TensorFlowImageClassifier 的靜態方法 create 方法:

/**
     * Initializes a native TensorFlow session for classifying images.
     *
     * @param assetManager The asset manager to be used to load assets.
     * @param modelFilename The filepath of the model GraphDef protocol buffer.
     * @param labelFilename The filepath of label file for classes.
     * @param inputSize The input size. A square image of inputSize x inputSize is assumed.
     * @param imageMean The assumed mean of the image values.
     * @param imageStd The assumed std of the image values.
     * @param inputName The label of the image input node.
     * @param outputName The label of the output node.
     * @throws IOException
     */
    public static Classifier create(AssetManager assetManager, String modelFilename, String labelFilename,
            int inputSize, int imageMean, float imageStd, String inputName, String outputName)

該方法需要傳入模型相關的引數進行初始化,完成後返回一個 Classifier 例項。

通過 Classifier 物件,我們可以呼叫其 recognizeImage 方法來識別我們傳入的 bitmap 影象資料,該方法會返回影象類別後對物品類別進行推斷的標籤結果:

/**
 * 進行圖片識別
 */
 public List<Recognition> recognizeImage(final Bitmap bitmap) 
  • 相關主要功能程式碼的實現:

相關程式碼可檢視:GitHub 專案地址

public class MainActivity extends AppCompatActivity implements View.OnClickListener {
    ...
    
    // 模型相關配置
    private static final int INPUT_SIZE = 224;
    private static final int IMAGE_MEAN = 117;
    private static final float IMAGE_STD = 1;
    private static final String INPUT_NAME = "input";
    private static final String OUTPUT_NAME = "output";
    private static final String MODEL_FILE = "file:///android_asset/model/tensorflow_inception_graph.pb";
    private static final String LABEL_FILE = "file:///android_asset/model/imagenet_comp_graph_label_strings.txt";

    private Executor executor;
    private Uri currentTakePhotoUri;

    private TextView result;
    private ImageView ivPicture;
    private Classifier classifier;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);

        if (!isTaskRoot()) {
            finish();
        }

        setContentView(R.layout.activity_main);

        findViewById(R.id.iv_choose_picture).setOnClickListener(this);
        findViewById(R.id.iv_take_photo).setOnClickListener(this);

        ivPicture = findViewById(R.id.iv_picture);
        result = findViewById(R.id.tv_classifier_info);

        // 避免耗時任務佔用 CPU 時間片造成UI繪製卡頓,提升啟動頁面載入速度
        Looper.myQueue().addIdleHandler(idleHandler);

    }

    /**
     *  主執行緒訊息佇列空閒時(檢視第一幀繪製完成時)處理耗時事件
     */
    MessageQueue.IdleHandler idleHandler = new MessageQueue.IdleHandler() {
        @Override
        public boolean queueIdle() {
            // 初始化 Classifier
            if (classifier == null) {
                // 建立 TensorFlowImageClassifier
               classifier = TensorFlowImageClassifier.create(MainActivity.this.getAssets(),
                       MODEL_FILE, LABEL_FILE, INPUT_SIZE, IMAGE_MEAN, IMAGE_STD, INPUT_NAME, OUTPUT_NAME);
            }

            // 初始化執行緒池
            executor = new ScheduledThreadPoolExecutor(1, new ThreadFactory() {
                @Override
                public Thread newThread(@NonNull Runnable r) {
                    Thread thread = new Thread(r);
                    thread.setDaemon(true);
                    thread.setName("ThreadPool-ImageClassifier");
                    return thread;
                }
            });
            // 請求許可權
            requestMultiplePermissions();
            // 返回 false 時只會回撥一次
            return false;
        }
    };

    @Override
    public void onClick(View view) {
        switch (view.getId()) {
            case R.id.iv_choose_picture :
                choosePicture();
                break;
            case R.id.iv_take_photo :
                takePhoto();
                break;
            default:break;
        }
    }

    /**
     * 選擇一張圖片並裁剪獲得一個小圖
     */
    private void choosePicture() {
        Intent intent = new Intent(Intent.ACTION_GET_CONTENT);
        intent.setType("image/*");
        startActivityForResult(intent, PICTURE_REQUEST_CODE);
    }

    /**
     * 使用系統相機拍照
     */
    private void takePhoto() {
        if (ContextCompat.checkSelfPermission(this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
            ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.CAMERA}, CAMERA_PERMISSIONS_REQUEST_CODE);
        } else {
            openSystemCamera();
        }
    }

    /**
     * 開啟系統相機
     */
    private void openSystemCamera() {
        //呼叫系統相機
        Intent takePhotoIntent = new Intent();
        takePhotoIntent.setAction(MediaStore.ACTION_IMAGE_CAPTURE);

        //這句作用是如果沒有相機則該應用不會閃退,要是不加這句則當系統沒有相機應用的時候該應用會閃退
        if (takePhotoIntent.resolveActivity(getPackageManager()) == null) {
            Toast.makeText(this, "當前系統沒有可用的相機應用", Toast.LENGTH_SHORT).show();
            return;
        }

        String fileName = "TF_" + System.currentTimeMillis() + ".jpg";
        File photoFile = new File(FileUtil.getPhotoCacheFolder(), fileName);

        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
            //通過FileProvider建立一個content型別的Uri
            currentTakePhotoUri = FileProvider.getUriForFile(this, "gdut.bsx.tensorflowtraining.fileprovider", photoFile);
            //對目標應用臨時授權該 Uri 所代表的檔案
            takePhotoIntent.addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION);
        } else {
            currentTakePhotoUri = Uri.fromFile(photoFile);
        }

        //將拍照結果儲存至 outputFile 的Uri中,不保留在相簿中
        takePhotoIntent.putExtra(MediaStore.EXTRA_OUTPUT, currentTakePhotoUri);
        startActivityForResult(takePhotoIntent, TAKE_PHOTO_REQUEST_CODE);
    }

    /**
     * 處理圖片
     * @param imageUri
     */
    private void handleInputPhoto(Uri imageUri) {
        // 載入圖片
        GlideApp.with(MainActivity.this).asBitmap().listener(new RequestListener<Bitmap>() {

            @Override
            public boolean onLoadFailed(@Nullable GlideException e, Object model, Target<Bitmap> target, boolean isFirstResource) {
                Log.d(TAG,"handleInputPhoto onLoadFailed");
                Toast.makeText(MainActivity.this, "圖片載入失敗", Toast.LENGTH_SHORT).show();
                return false;
            }

            @Override
            public boolean onResourceReady(Bitmap resource, Object model, Target<Bitmap> target, DataSource dataSource, boolean isFirstResource) {
                Log.d(TAG,"handleInputPhoto onResourceReady");
                startImageClassifier(resource);
                return false;
            }
        }).load(imageUri).into(ivPicture);

        result.setText("Processing...");
    }

    /**
     * 開始圖片識別匹配
     * @param bitmap
     */
    private void startImageClassifier(final Bitmap bitmap) {
        executor.execute(new Runnable() {
            @Override
            public void run() {
                try {
                    Log.i(TAG, Thread.currentThread().getName() + " startImageClassifier");
                    Bitmap croppedBitmap = getScaleBitmap(bitmap, INPUT_SIZE);

                    final List<Classifier.Recognition> results = classifier.recognizeImage(croppedBitmap);
                    Log.i(TAG, "startImageClassifier results: " + results);
                    runOnUiThread(new Runnable() {
                        @Override
                        public void run() {
                            result.setText(String.format("results: %s", results));
                        }
                    });
                } catch (IOException e) {
                    Log.e(TAG, "startImageClassifier getScaleBitmap " + e.getMessage());
                }
            }
        });
    }

   /**
     * 請求相機和外部儲存許可權
     */
    private void requestMultiplePermissions() {

        String storagePermission = Manifest.permission.WRITE_EXTERNAL_STORAGE;
        String cameraPermission = Manifest.permission.CAMERA;

        int hasStoragePermission = ActivityCompat.checkSelfPermission(this, storagePermission);
        int hasCameraPermission = ActivityCompat.checkSelfPermission(this, cameraPermission);

        List<String> permissions = new ArrayList<>();
        if (hasStoragePermission != PackageManager.PERMISSION_GRANTED) {
            permissions.add(storagePermission);
        }

        if (hasCameraPermission != PackageManager.PERMISSION_GRANTED) {
            permissions.add(cameraPermission);
        }
        
        if (!permissions.isEmpty()) {
            String[] params = permissions.toArray(new String[permissions.size()]);
            ActivityCompat.requestPermissions(this, params, PERMISSIONS_REQUEST);
        }
    }

    @Override
    protected void onActivityResult(int requestCode, int resultCode, Intent data) {
        super.onActivityResult(requestCode, resultCode, data);

        if (resultCode == RESULT_OK) {
            if (requestCode == PICTURE_REQUEST_CODE) {
                // 處理選擇的圖片
                handleInputPhoto(data.getData());
            } else if (requestCode == OPEN_SETTING_REQUEST_COED){
                requestMultiplePermissions();
            } else if (requestCode == TAKE_PHOTO_REQUEST_CODE) {
                // 如果拍照成功,載入圖片並識別
                handleInputPhoto(currentTakePhotoUri);
            }
        }
    }

    /**
     * 對圖片進行縮放
     * @param bitmap
     * @param size
     * @return
     * @throws IOException
     */
    private static Bitmap getScaleBitmap(Bitmap bitmap, int size) throws IOException {
        int width = bitmap.getWidth();
        int height = bitmap.getHeight();
        float scaleWidth = ((float) size) / width;
        float scaleHeight = ((float) size) / height;
        Matrix matrix = new Matrix();
        matrix.postScale(scaleWidth, scaleHeight);
        return Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, true);
    }
}

執行效果

圖片選擇和拍照獲取介面:
在這裡插入圖片描述

物品識別結果展示介面:
在這裡插入圖片描述

相關程式碼可檢視:GitHub 專案地址
是不是覺得通過 TensorFlow 在現有的資料模型基礎下,我們可以很簡單就完成了一個簡單的影象識別應用。
在使用這個模型來推斷物品型別的過程中,發現好像有時候準確率不是那麼高,這時候改怎麼辦。如果說只是想識別一些特定種類的物品,哪有又該怎麼辦?
在之前一篇文章中我有提到過,機器學習是依靠對大量有標籤的樣本資料進行反覆訓練後才逐步得到的最佳模型。對未知無標籤樣本的推斷依賴這個模型的準確程度。所以我們可以通過對現有模型進行遷移訓練(retrain)來定製我們自己的模型。
下面就通過對現有的 Google Inception-V3 模型進行 retrain ,對 5 種花朵樣本資料的進行訓練,來完成一個可以識別五種花朵的模型。
具體實現方式可以參考我的另外一篇文章:通過遷移訓練來定製 TensorFlow 模型