1. 程式人生 > >CNN框架的搭建及各個引數的調節

CNN框架的搭建及各個引數的調節



本文主要講解將CNN應用於人臉識別的流程,程式基於Python+numpy+theano+PIL開發,採用類似LeNet5的CNN模型,應用於olivettifaces人臉資料庫,實現人臉識別的功能,模型的誤差降到了5%以下。本程式只是個人學習過程的一個toy implement,樣本很小,模型隨時都會過擬合。

但是,本文意在理清程式開發CNN模型的具體步驟,特別是針對影象識別,從拿到影象資料庫,到實現一個針對這個影象資料庫的CNN模型,我覺得本文對這些流程的實現具有參考意義。

《本文目錄》

一、olivettifaces人臉資料庫介紹

二、CNN的基本“構件”(LogisticRegression、HiddenLayer、LeNetConvPoolLayer)

三、組建CNN模型,設定優化演算法,應用於Olivetti Faces進行人臉識別

四、訓練結果以及引數設定的討論

五、利用訓練好的引數初始化模型

六、一些需要說明的

一、olivettifaces人臉資料庫介紹

由40個人的400張圖片構成,即每個人的人臉圖片為10張。每張圖片的灰度級為8位,每個畫素的灰度大小位於0-255之間,每張圖片大小為64×64。如下圖,這個圖片大小是1190*942,一共有20*20張人臉,故每張人臉大小是(1190/20)*(942/20)即57*47=2679:


本文所用的訓練資料就是這張圖片,400個樣本,40個類別,乍一看樣本好像比較小,用CNN效果會好嗎?先別下結論,請往下看。

要執行CNN演算法,這張圖片必須先轉化為陣列(或者說矩陣),這個用到python的影象庫PIL,幾行程式碼就可以搞定,具體的方法我之前剛好寫過一篇文章,也是用這張圖,考慮到文章冗長,就不復制過來了,連結在此:《利用Python PIL、cPickle讀取和儲存影象資料庫》

訓練機器學習演算法,我們一般將原始資料分成訓練資料(training_set)、驗證資料(validation_set)、測試資料(testing_set)。本程式將training_set、validation_set、testing_set分別設定為320、40、40個樣本。它們的label為0~39,對應40個不同的人。這部分的程式碼如下:

  1. """ 
  2. 載入影象資料的函式,dataset_path即影象olivettifaces的路徑 
  3. 載入olivettifaces後,劃分為train_data,valid_data,test_data三個資料集 
  4. 函式返回train_data,valid_data,test_data以及對應的label 
  5. """
  6. def load_data(dataset_path):  
  7.     img = Image.open(dataset_path)  
  8.     img_ndarray = numpy.asarray(img, dtype='float64')/256
  9.     faces=numpy.empty((400,2679))  
  10.     for row in range(20):  
  11.        for column in range(20):  
  12.         faces[row*20+column]=numpy.ndarray.flatten(img_ndarray [row*57:(row+1)*57,column*47:(column+1)*47])  
  13.     label=numpy.empty(400)  
  14.     for i in range(40):  
  15.     label[i*10:i*10+10]=i  
  16.     label=label.astype(numpy.int)  
  17.     #分成訓練集、驗證集、測試集,大小如下
  18.     train_data=numpy.empty((320,2679))  
  19.     train_label=numpy.empty(320)  
  20.     valid_data=numpy.empty((40,2679))  
  21.     valid_label=numpy.empty(40)  
  22.     test_data=numpy.empty((40,2679))  
  23.     test_label=numpy.empty(40)  
  24.     for i in range(40):  
  25.     train_data[i*8:i*8+8]=faces[i*10:i*10+8]  
  26.     train_label[i*8:i*8+8]=label[i*10:i*10+8]  
  27.     valid_data[i]=faces[i*10+8]  
  28.     valid_label[i]=label[i*10+8]  
  29.     test_data[i]=faces[i*10+9]  
  30.     test_label[i]=label[i*10+9]  
  31.     #將資料集定義成shared型別,才能將資料複製進GPU,利用GPU加速程式。
  32.     def shared_dataset(data_x, data_y, borrow=True):  
  33.         shared_x = theano.shared(numpy.asarray(data_x,  
  34.                                                dtype=theano.config.floatX),  
  35.                                  borrow=borrow)  
  36.         shared_y = theano.shared(numpy.asarray(data_y,  
  37.                                                dtype=theano.config.floatX),  
  38.                                  borrow=borrow)  
  39.         return shared_x, T.cast(shared_y, 'int32')  
  40.     train_set_x, train_set_y = shared_dataset(train_data,train_label)  
  41.     valid_set_x, valid_set_y = shared_dataset(valid_data,valid_label)  
  42.     test_set_x, test_set_y = shared_dataset(test_data,test_label)  
  43.     rval = [(train_set_x, train_set_y), (valid_set_x, valid_set_y),  
  44.             (test_set_x, test_set_y)]  
  45.     return rval  

二、CNN的基本“構件”(LogisticRegression、HiddenLayer、LeNetConvPoolLayer)

卷積神經網路(CNN)的基本結構就是輸入層、卷積層(conv)、子取樣層(pooling)、全連線層、輸出層(分類器)。  卷積層+子取樣層一般都會有若干個,本程式實現的CNN模型參考LeNet5,有兩個“卷積+子取樣層”LeNetConvPoolLayer。全連線層相當於MLP(多層感知機)中的隱含層HiddenLayer。輸出層即分類器,一般採用softmax迴歸(也有人直接叫邏輯迴歸,其實就是多類別的logistics regression),本程式也直接用LogisticRegression表示。 總結起來,要組建CNN模型,必須先定義LeNetConvPoolLayer、HiddenLayer、LogisticRegression這三種layer,這一點在我上一篇文章介紹CNN演算法時講得很詳細,包括程式碼註解,因為太冗長,這裡給出連結:《DeepLearning tutorial(4)CNN卷積神經網路原理簡介+程式碼詳解》

程式碼太長,就不貼具體的了,只給出框架,具體可以下載我的程式碼看看:

  1. #分類器,即CNN最後一層,採用邏輯迴歸(softmax)
  2. class LogisticRegression(object):  
  3.     def __init__(self, input, n_in, n_out):  
  4.         self.W = ....  
  5.         self.b = ....  
  6.         self.p_y_given_x = ...  
  7.         self.y_pred = ...  
  8.         self.params = ...  
  9.     def negative_log_likelihood(self, y):  
  10.     def errors(self, y):  
  11. #全連線層,分類器前一層
  12. class HiddenLayer(object):  
  13.     def __init__(self, rng, input, n_in, n_out, W=None, b=None,activation=T.tanh):  
  14.         self.input = input  
  15.         self.W = ...  
  16.         self.b = ...  
  17.         lin_output = ...  
  18.         self.params = [self.W, self.b]  
  19. #卷積+取樣層(conv+maxpooling)
  20. class LeNetConvPoolLayer(object):  
  21.     def __init__(self, rng, input, filter_shape, image_shape, poolsize=(22)):  
  22.         self.input = input  
  23.         self.W = ...  
  24.         self.b = ...  
  25.         # 卷積
  26.         conv_out = ...  
  27.         # 子取樣
  28.         pooled_out =...  
  29.         self.output = ...  
  30.         self.params = [self.W, self.b]  

三、組建CNN模型,設定優化演算法,應用於Olivetti Faces進行人臉識別

上面定義好了CNN的幾個基本“構件”,現在我們使用這些構件來組建CNN模型,本程式的CNN模型參考LeNet5,具體為:input+layer0(LeNetConvPoolLayer)+layer1(LeNetConvPoolLayer)+layer2(HiddenLayer)+layer3(LogisticRegression)

這是一個串聯結構,程式碼也很好寫,直接用第二部分定義好的各種layer去組建就行了,上一layer的輸出接下一layer的輸入,具體可以看看程式碼evaluate_olivettifaces函式中的“建立CNN模型”部分。

CNN模型組建好了,就剩下用優化演算法求解了,優化演算法採用批量隨機梯度下降演算法(MSGD),所以要先定義MSGD的一些要素,主要包括:代價函式,訓練、驗證、測試model、引數更新規則(即梯度下降)。這部分的程式碼在evaluate_olivettifaces函式中的“定義優化演算法的一些基本要素”部分。

優化演算法的基本要素也定義好了,接下來就要運用人臉影象資料集來訓練這個模型了,訓練過程有訓練步數(n_epoch)的設定,每個epoch會遍歷所有的訓練資料(training_set),本程式中也就是320個人臉圖。還有迭代次數iter,一次迭代遍歷一個batch裡的所有樣本,具體為多少要看所設定的batch_size。關於引數的設定我在下面會討論。這一部分的程式碼在evaluate_olivettifaces函式中的“訓練CNN階段”部分。

程式碼很長,只貼框架,具體可以下載我的程式碼看看:

  1. def evaluate_olivettifaces(learning_rate=0.05, n_epochs=

    相關推薦

    CNN框架搭建各個引數調節

    本文主要講解將CNN應用於人臉識別的流程,程式基於Python+numpy+theano+PIL開發,採用類似LeNet5的CNN模型,應用於olivettifaces人臉資料庫,實現人臉識別的功能,模型的誤差降到了5%以下。本程式只是個人

    實戰 :Spring MVC + 註解 +SqlServer 框架搭建詳解

    原始碼下載:http://download.csdn.NET/detail/u010469432/6786687 https://blog.csdn.net/u010469432/article/details/17587699 先說一下Spring3 MVC的優點: spring&nb

    第八章 專案框架搭建使用者管理初建

           連線資料庫,使用EA直接生成資料庫指令碼,建立表。 /* ---------------------------------------------------- */ /* Generated by Enterprise Arc

    實戰 :Spring MVC + 註解 +SqlServer 框架搭建詳解

    先說一下Spring3 MVC的優點: spring MVC 屬於輕量級框架 1、Spring3 MVC的學習難度小於Struts2,Struts2用不上的多餘功能太多。呵呵,當然這不是決定因素。 2、Spring3 MVC很容易就可以寫出效能優秀的程式,Str

    eleme專案框架搭建eslint常見報錯(含圖示字型庫說明)

    全域性安裝vue-cli 全域性安裝vue-cli 開啟終端,npm install vue-cli -g 安裝webpack模板 1.專案資料夾下,npm init webpack eleme (eleme是專案名) 2.按照提示選擇,

    SSM框架搭建專案實戰

    (可通過圖片水印檢視部落格地址) 1、基本概念 1.1、Spring         Spring是一個開源框架,Spring是於2003 年興起的一個輕量級的Java 開發框架,由Rod Johnson 在其著作Expert One

    學習Hibernate框架--框架搭建簡單實用

    在學習Hibernate框架之前首先來了解一下什麼是Hibernate框架以及為什麼要使用它 Hibernate是 一個開源的物件關係對映框架 對JDBC進行了非常輕量級的物件封裝 將JacaBean物件和資料庫的表建立對應關係 它是一個優秀的Java持久化層解決方案 是

    MyBatis框架搭建增刪改查操作

    一.搭建好myBatis的環境   1.引入jar包     (1).mybatis-3.2.7jar核心jar     (2).mysql核心jar     (3).junit-4.9.jar   2.建立mybatis全域性配置檔案SqlMapConfig.xm(名字自

    SpringBoot框架簡介搭建

    pro org 除了 運行 maven sha 方法 variable tid Spring Boot 簡介 1. 什麽是SpringBoot 1.1 Spring Boot是由Pivotal團隊提供的全新框架,其設計目的是用來簡化新Spring應用的初始搭建以及開發過程。

    Mybatis框架簡介、搭建核心元素講解

    大家都知道當前比較流行的框架,如Spring、SpringMVC、Hibernate、Mybatis等框架,那今天咱們來講講當前比較流行的企業框架Mybatis。 首先先來介紹一下Mybatis這個框架:Mybatis是一個開源的資料持久化層框架,它內部封裝了通過JDBC訪問資料庫的操作,支

    dubbo框架搭建各種功能的實現

    之前專案中用過dubbo做分散式,最近想把這個框架的搭建及一些基本功能的使用記錄下來。 註冊中心用zookeeper 架構 Provider 暴露服務的服務提供⽅ Consumer 調⽤遠端服務的服務消費⽅ Registry 服務註冊與發現的註冊中⼼ Monitor 統計服務的調

    Ribbon 框架簡介搭建(沒有與SpringCloud整合,獨立使用)

    Ribbon簡介 1.  負載均衡框架,支援可插拔式的負載均衡規則 2.  支援多種協議,如HTTP、UDP等 3.  提供負載均衡客戶端 Ribbon子模組 1.  ribbon-core(ribbon的核心,主要包含負載均衡器、負載均衡介面、客戶端介面

    jhipster框架簡介搭建

    jhipster框架 Jhipster是為您生成一個完整的現代Web應用程式或微服務架構: 它使用Spring Boot在伺服器端實現高效能和強大的Java堆疊,這是一款時尚,現代,移動優先的前端,配有Angular,React和Bootstrap,一個強大的微服務架構,

    CNN網路模型發展程序各個網路優缺點

    此篇部落格主要是講一些目前較好的網路發展,提出問題,解決了什麼問題,有什麼優缺點。 1、Alexnet網路,本網路是2012年的分類冠軍,掀起來深度學習的熱潮,Alexnet引數量巨大,由於受限於硬體計算,作者採用了group 卷積,將特徵圖分別放入不同gpu進行運算,

    分散式系統詳解--框架(Hadoop-HDFS的HA搭建測試)

                  分散式系統詳解 - 框架(Hadoop的HDFS的HA搭建及測試) 一,背景概述 解決問題1:通過上面的圖我們可以明確的看出來,如果的Namenode壞掉了,那我們的整個叢集可以說就是要癱瘓了也就是單節點故障問題。 於是現在就出現了另外一

    SSM框架之Mybatis同時傳入多個物件普通引數

    當傳入多個檔案時,mapper介面檔案的方法引數要使用@param(“xx”)註釋。 例子: mapper: //Student是物件,age是String型別。 int getPojo(@param("student") Student student, @param("age") S

    Vue2.0框架搭建基礎操作目錄說明

    一、概述 vue.js是一套構建使用者介面的漸進式框架。vue採用自底向上增量開發的設計。vue的核心庫只關心檢視層,非常容易學習,非常容易與其它庫和已有專案整合。vue完全有能力驅動採用單檔案元件和vue生態系統支援的庫開發的複雜單頁應用。 vue.js的目標是通過儘可能簡單的API實現響應的資料繫結和

    【SSH】--框架搭建錯誤專案中問題

    2015.11.04日 404是瀏覽器無法連線到伺服器,500是IIS無法解析ASP,總結了一下自己的錯誤, 404: 1.hibernate配置檔案配錯,多寫了主鍵 2.spring裡面JSP跳轉配錯 3.DAO層沒有繼承HibernateDaoSupp

    利用Tensorflow構建CNN影象多分類模型影象引數、資料維度變化情況例項分析

    本文以CIFAR-10為資料集,基於Tensorflow介紹了CNN(卷積神經網路)影象分類模型的構建過程,著重分析了在建模過程中卷積層、池化層、扁平化層、全連線層、輸出層的運算機理,以及經過運算後圖像尺寸、資料維度等引數的變化情況。 CIFAR-10資料

    SSM框架搭建專案開發的步驟

    第一階段: 1、用PowerDesign建資料模型,並匯出SQL檔案; 2、將SQL檔案匯入到MySQL客戶端,建立表格;   MySQL資料遠端訪問:GRANT ALL PRIVILEGES ON . TO ‘root’@’%’IDENTIFIED BY