1. 程式人生 > >Android+TensorFlow+CNN+MNIST 手寫數字識別實現

Android+TensorFlow+CNN+MNIST 手寫數字識別實現

SkySeraph 2018

Overview

本文系“SkySeraph AI 實踐到理論系列”第一篇,咱以AI界的HelloWord 經典MNIST資料集為基礎,在Android平臺,基於TensorFlow,實現CNN的手寫數字識別。
Code here~

Practice

Environment

  • TensorFlow: 1.2.0
  • Python: 3.6
  • Python IDE: PyCharm 2017.2
  • Android IDE: Android Studio 3.0

Train & Evaluate(Python+TensorFlow)

訓練和評估部分主要目的是生成用於測試用的pb檔案,其儲存了利用TensorFlow python API構建訓練後的網路拓撲結構和引數資訊,實現方式有很多種,除了cnn外還可以使用rnn,fcnn等。
其中基於cnn的函式也有兩套,分別為tf.layers.conv2d和tf.nn.conv2d, tf.layers.conv2d使用tf.nn.conv2d作為後端處理,引數上filters是整數,filter是4維張量。原型如下:
convolutional.py檔案
def conv2d(inputs, filters, kernel_size, strides=(1, 1), padding=’valid’, data_format=’channels_last’,
dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=None,
bias_initializer=init_ops.zeros_initializer(), kernel_regularizer=None, bias_regularizer=None,
activity_regularizer=None, kernel_constraint=None, bias_constraint=None, trainable=True, name=None,
reuse=None)

gen_nn_ops.py 檔案

def conv2d(input, filter, strides, padding, use_cudnn_on_gpu=True, data_format="NHWC", name=None)

官方Demo例項中使用的是layers module,結構如下:

  • Convolutional Layer #1:32個5×5的filter,使用ReLU啟用函式
  • Pooling Layer #1:2×2的filter做max pooling,步長為2
  • Convolutional Layer #2:64個5×5的filter,使用ReLU啟用函式
  • Pooling Layer #2:2×2的filter做max pooling,步長為2
  • Dense Layer #1:1024個神經元,使用ReLU啟用函式,dropout率0.4 (為了避免過擬合,在訓練的時候,40%的神經元會被隨機去掉)
  • Dense Layer #2 (Logits Layer):10個神經元,每個神經元對應一個類別(0-9)

核心程式碼在cnn_model_fn(features, labels, mode)函式中,完成卷積結構的完整定義,核心程式碼如下.

也可以採用傳統的tf.nn.conv2d函式, 核心程式碼如下。

Test(Android+TensorFlow)

  • 核心是使用API介面: TensorFlowInferenceInterface.java
  • 配置gradle 或者 自編譯TensorFlow原始碼匯入jar和so
    compile ‘org.tensorflow:tensorflow-android:1.2.0’
  • 匯入pb檔案.pb檔案放assets目錄,然後讀取

    String actualFilename = labelFilename.split(“file:///android_asset/“)[1];
    Log.i(TAG, “Reading labels from: “ + actualFilename);
    BufferedReader br = null;
    br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
    String line;
    while ((line = br.readLine()) != null) {
    c.labels.add(line);
    }
    br.close();

  • TensorFlow介面使用

  • 最終效果:

Theory

MNIST

MNIST,最經典的機器學習模型之一,包含0~9的數字,28*28大小的單色灰度手寫數字圖片資料庫,其中共60,000 training examples和10,000 test examples。
檔案目錄如下,主要包括4個二進位制檔案,分別為訓練和測試圖片及Label。

如下為訓練圖片的二進位制結構,在真實資料前(pixel),有部分描述欄位(魔數,圖片個數,圖片行數和列數),真實資料的儲存採用大端規則。
(大端規則,就是資料的高位元組儲存在低記憶體地址中,低位元組儲存在高記憶體地址中)

在具體實驗使用,需要提取真實資料,可採用專門用於處理位元組的庫struct中的unpack_from方法,核心方法如下:
struct.unpack_from(self._fourBytes2, buf, index)

MNIST作為AI的Hello World入門例項資料,TensorFlow封裝對其封裝好了函式,可直接使用
mnist = input_data.read_data_sets(‘MNIST’, one_hot=True)

CNN(Convolutional Neural Network)

CNN Keys

  • CNN,Convolutional Neural Network,中文全稱卷積神經網路,即所謂的卷積網(ConvNets)。
  • 卷積(Convolution)可謂是現代深度學習中最最重要的概念了,它是一種數學運算,讀者可以從下面連結[23]中卷積相關數學機理,包括分別從傅立葉變換和狄拉克δ函式中推到卷積定義,我們可以從字面上巨集觀粗魯的理解成將因子翻轉相乘捲起來。
  • 卷積動畫。演示如下圖[26],更多動畫演示可參考[27]
  • 神經網路。一個由大量神經元(neurons)組成的系統,如下圖所示[21]

    其中x表示輸入向量,w為權重,b為偏值bias,f為啟用函式。

  • Activation Function 啟用函式: 常用的非線性啟用函式有Sigmoid、tanh、ReLU等等,公式如下如所示。

    • Sigmoid缺點
      • 函式飽和使梯度消失(神經元在值為 0 或 1 的時候接近飽和,這些區域,梯度幾乎為 0)
      • sigmoid 函式不是關於原點中心對稱的(無0中心化)
    • tanh: 存在飽和問題,但它的輸出是零中心的,因此實際中 tanh 比 sigmoid 更受歡迎。
    • ReLU
      • 優點1:ReLU 對於 SGD 的收斂有巨大的加速作用
      • 優點2:只需要一個閾值就可以得到啟用值,而不用去算一大堆複雜的(指數)運算
      • 缺點:需要合理設定學習率(learning rate),防止訓練時dead,還可以使用Leaky ReLU/PReLU/Maxout等代替
  • Pooling池化。一般分為平均池化mean pooling和最大池化max pooling,如下圖所示[21]為max pooling,除此之外,還有重疊池化(OverlappingPooling)[24],空金字塔池化(Spatial Pyramid Pooling)[25]
    • 平均池化:計算影象區域的平均值作為該區域池化後的值。
    • 最大池化:選影象區域的最大值作為該區域池化後的值。

CNN Architecture

  • 三層神經網路。分別為輸入層(Input layer),輸出層(Output layer),隱藏層(Hidden layer),如下圖所示[21]
  • CNN層級結構。 斯坦福cs231n中闡述了一種[INPUT-CONV-RELU-POOL-FC],如下圖所示[21],分別為輸入層,卷積層,激勵層,池化層,全連線層。
  • CNN通用架構分為如下三層結構:
    • Convolutional layers 卷積層
    • Pooling layers 匯聚層
    • Dense (fully connected) layers 全連線層
  • 動畫演示。參考[22]。

Regression + Softmax

機器學習有監督學習(supervised learning)中兩大演算法分別是分類演算法和迴歸演算法,分類演算法用於離散型分佈預測,迴歸演算法用於連續型分佈預測。
迴歸的目的就是建立一個迴歸方程用來預測目標值,迴歸的求解就是求這個迴歸方程的迴歸係數。
其中迴歸(Regression)演算法包括Linear Regression,Logistic Regression等, Softmax Regression是其中一種用於解決多分類(multi-class classification)問題的Logistic迴歸演算法的推廣,經典例項就是在MNIST手寫數字分類上的應用。

Linear Regression

Linear Regression是機器學習中最基礎的模型,其目標是用預測結果儘可能地擬合目標label

  • 多元線性迴歸模型定義
  • 多元線性迴歸求解
  • Mean Square Error (MSE)
    • Gradient Descent(梯度下降法)
    • Normal Equation(普通最小二乘法)
    • 區域性加權線性迴歸(LocallyWeightedLinearRegression, LWLR ):針對線性迴歸中模型欠擬合現象,在估計中引入一些偏差以便降低預測的均方誤差。
    • 嶺迴歸(ridge regression)和縮減方法
  • 選擇: Normal Equation相比Gradient Descent,計算量大(需計算X的轉置與逆矩陣),只適用於特徵個數小於100000時使用;當特徵數量大於100000時使用梯度法。當X不可逆時可替代方法為嶺迴歸演算法。LWLR方法增加了計算量,因為它對每個點做預測時都必須使用整個資料集,而不是計算出迴歸係數得到迴歸方程後代入計算即可,一般不選擇。
  • 調優: 平衡預測偏差和模型方差(高偏差就是欠擬合,高方差就是過擬合)
    • 獲取更多的訓練樣本 - 解決高方差
    • 嘗試使用更少的特徵的集合 - 解決高方差
    • 嘗試獲得其他特徵 - 解決高偏差
    • 嘗試新增多項組合特徵 - 解決高偏差
    • 嘗試減小 λ - 解決高偏差
    • 嘗試增加 λ -解決高方差

Softmax Regression

  • Softmax Regression估值函式(hypothesis)
  • Softmax Regression代價函式(cost function)
  • 理解:
  • Softmax Regression & Logistic Regression:
    • 多分類 & 二分類。Logistic Regression為K=2時的Softmax Regression
    • 針對K類問題,當類別之間互斥時可採用Softmax Regression,當非斥時,可採用K個獨立的Logistic Regression
  • 總結: Softmax Regression適用於類別數量大於2的分類,本例中用於判斷每張圖屬於每個數字的概率。

References & Recommends

MNIST

Softmax

CNN

TensorFlow+CNN / TensorFlow+Android

相關推薦

Android+TensorFlow+CNN+MNIST 數字識別實現

SkySeraph 2018 Overview 本文系“SkySeraph AI 實踐到理論系列”第一篇,咱以AI界的HelloWord 經典MNIST資料集為基礎,在Android平臺,基於TensorFlow,實現CNN的手寫數字識別。Code here~ Practice Env

Tensorflow實踐 mnist數字識別

model 損失函數 兩層 最簡 sin test http gif bat minst數據集      tensorflow的文檔中就自帶了mnist手寫數字識別的例子,是一個很經典也比較簡單

TensorflowMNIST數字識別:分類問題(1)

一、MNIST資料集讀取 one hot 獨熱編碼獨熱編碼是一種稀疏向量,其中:一個向量設為1,其他元素均設為0.獨熱編碼常用於表示擁有有限個可能值的字串或識別符號優點:   1、將離散特徵的取值擴充套件到了歐式空間,離散特徵的某個取值就對應歐式空間的某個點    2、機器學習演算法中,

TensorflowMNIST數字識別:分類問題(2)

整體程式碼: #資料讀取 import tensorflow as tf import matplotlib.pyplot as plt import numpy as np from tensorflow.examples.tutorials.mnist import input_data mnis

基於tensorflowMNIST數字識別(二)--入門篇

一、本文的意義       因為谷歌官方其實已經寫了MNIST入門和深入兩篇教程了,那我寫這些文章又是為什麼呢,只是抄襲?那倒並不是,更準確的說應該是筆記吧,然後用更通俗的語言來解釋,並且補充

基於tensorflowMNIST數字識別(三)--神經網路篇

想想還是要說點什麼     抱歉啊,第三篇姍姍來遲,確實是因為我懶,而不是忙什麼的,所以這次再加點料,以表示我的歉意。廢話不多說,我就直接開始講了。 加入神經網路的意義     前面也講到了,使用普通的訓練方法,也可以進行識別,但是識別的精度不夠高,

Tensorflow案例5:CNN演算法-Mnist數字識別

學習目標 目標 應用tf.nn.conv2d實現卷積計算 應用tf.nn.relu實現啟用函式計算 應用tf.nn.max_pool實現池化層的計算 應用卷積神經網路實現影象分類識別 應用

tensorflow 基礎學習五:MNIST數字識別

truncate averages val flow one die correct 表示 data MNIST數據集介紹: from tensorflow.examples.tutorials.mnist import input_data # 載入MNIST數據集,

MNIST數字識別——CNN

  參考:http://www.tensorfly.cn/tfdoc/tutorials/mnist_pros.html 網上已經有很多相關內容的部落格、資料,有很多也寫得挺好的,我也是參考別人的,這裡就不再寫原理上的東西了。附一下我做實驗的程式碼,簡單記錄一下遇到的問題。 實

Tensorflow專案 - 使用CNN進行數字識別

首先書寫如下的程式碼 程式碼有詳細的註釋,這裡就不一一解釋了 #coding=utf-8 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #引入資料集 mnist=inp

TensorFlow筆記(1)非線性迴歸、MNIST數字識別

程式 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt # numpy生成200個隨機點,下面這麼寫可以得到200行1列的矩陣 x_data = np.linspace(-0.5,

tensorflow實戰:MNIST數字識別的優化2-代價函式優化,準確率98%

最簡單的tensorflow的手寫識別模型,這一節我們將會介紹其簡單的優化模型。我們會從代價函式,多層感知器,防止過擬合,以及優化器的等幾個方面來介紹優化過程。    1.代價函式的優化:             我們可以這樣將代價函式理解為真實值與預測值的差距,我們神經

Tensorflow案例4:Mnist數字識別(線性神經網路)及其侷限性

學習目標 目標 應用matmul實現全連線層的計算 說明準確率的計算 應用softmax_cross_entropy_with_logits實現softamx以及交叉熵損失計算 說明全連線層在神經網路的作用 應用全連

CNN實現MNIST數字識別

關鍵詞:CNN、TensorFlow、卷積、池化、特徵圖 一. 前言 本文用TensorFlow實現了CNN(卷積神經網路)的經典結構LeNet-5, 具體CNN的LeNet-5模型原理見《深度學習(四)卷積神經網路入門學習(1)》,講得還是比較清楚的。

TensorFlow——MNIST數字識別

import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #載入資料集 mnist=input_data.read_data_sets('MNIST_data',one_hot=True) #

TensorFlow——Mnist數字識別並可視化 實戰教程(一)

要點: 該教程為深度學習tensorflow框架mnist手寫數字識別。 實戰教程分為(一)(二)(三)分別從tensorflow和MATLAB雙角度來實現。 筆者資訊:Next_Legend  Q

TensorFlow程式碼實現(一)[MNIST數字識別]

最簡單的神經網路結構: 資料來源準備:資料在之前的文章中分析過了 在這裡我們就構造一層神經網路: 前提準備: 引數: train images:因為圖片是28*28的個數,換算成一維陣列就是784,因此我們定義x = tf.placeholder(tf

TensorFlow實現機器學習的“Hello World”--Mnist數字識別

TensorFlow實現機器學習的“Hello World” 上一篇部落格我們已經說了TensorFlow大概怎麼使用,這次來說說機器學習中特別經典的案例,也相當於是機器學習的“Hello World”,他就是Mnist手寫數字識別,也就是通過訓練機器讓他能看

TensorFlow筆記之一:MNIST數字識別

     本人剛剛開始接觸深度學習不久,對於tensorflow的瞭解也有限,想通過tensorflow這個框架來學習深度學習及其優化與識別。現在直接進入主題。     1.手寫識別的介紹:            MNIST手寫識別在機器學習中就像c語言中Hello Wor

Mnist數字識別CNN實現

Mnist手寫數字識別之CNN實現 最近有點閒,想整一下機器學習,本以為自己程式設計還不錯,想想機器學習也不難,結果被自己啪啪啪的打臉,還疼的不行。 廢話不多說,開始搞事情。 本部落格的主要內容是:通過TF一步一步用卷積神經網路(CNN)實現手寫Mnist數字識別 如果你