1. 程式人生 > >連載二:PyCon2018|用slim微調PNASNet模型(附原始碼)

連載二:PyCon2018|用slim微調PNASNet模型(附原始碼)

第八屆中國Python開發者大會PyConChina2018,由PyChina.org發起,由來自CPyUG/TopGeek等社群的30位組織者,近150位志願者在北京、上海、深圳、杭州、成都等城市舉辦。致力於推動各類Python相關的技術在網際網路、企業應用等領域的研發和應用。

程式碼醫生工作室有幸接受邀請,參加了這次會議的北京站專場。在會上主要分享了《人工智慧實戰案例分享-影象處理與數值分析》。

會上分享的一些案例主要是來源於《python帶我起飛——入門、進階、商業實戰》一書與《深度學習之TensorFlow:入門、原理與進階實戰》一書。另外,還擴充了若干其它案例。在本文作為補充,將會上分享的其它案例以詳細的圖文方式補充進來,並提供原始碼。共分為4期連載。
  1. 用slim呼叫PNASNet模型

  2. 用slim微調PNASNet模型

  3. 用對抗樣本攻擊PNASNet模型

  4. 惡意域名檢測例項

通過微調模型實現分辨男女

案例描述

有一組照片,分為男人和女人。

本案就是讓深度學習模型來學習這些樣本,並能夠找到其中的規律,完成模型的訓練。接著可以使用該模型對圖片中的人物進行識別,區分其性別是男還是女。

本案例中,使用了一個NASNet_A_Mobile的模型來做二次訓練。具體過程分為4步:

(1)準備樣本;

(2)準備NASNet_A_Mobile網路模型;

(3)編寫程式碼進行二次訓練;

(4)使用已經訓練好的模型進行測試。

準備樣本

通過如下連結下載CelebA資料集:

mmlab.ie.cuhk.edu.hk/projects/Ce…

下載完之後,解壓,並手動分出一部分男人與女人的照片。

在本例中,一共用了20000張圖片用來訓練模型,其中訓練樣本由8421張男性頭像和11599張女性頭像構成(在train資料夾下),測試樣本由10張男性頭像和10張女性頭像構成(在val資料夾下)。部分樣本資料如圖5-1。

圖5-1 男女資料集樣本示例

資料樣本整理好後,統一放到data資料夾下。該資料樣本同樣也可以在隨書的配套資源中找到。

程式碼環境及模型準備

為了使讀者能夠快速完成該例項,直觀上感受到模型的識別能力,可以直接使用本書配套的資源。並將其放到程式碼的同級目錄下即可。

如果想體驗下從零開始手動搭建,也可以按照下面的方法準備程式碼環境及預編譯模型。

1. 下載models與部署TensorFlow slim模組

該部分的內容與3.1節完全一樣,這裡不再詳述。

2. 下載NASNet_A_Mobile模型

該部分的內容與3.1節類似。在如圖3-2中的倒數第3個模型,找到 “nasnet-a_mobile_04_10_2017.tar.gz”的下載連結。將其下載並解壓。

3. 整體程式碼檔案部署結構

本案例是通過4個程式碼檔案來實現的,具體檔案及描述如下:

l 5-1 mydataset.py:處理男女圖片資料集的程式碼;

l 5-2 model.py:載入預編譯模型NASNet_A_Mobile,並進行微調的程式碼;

l 5-3 train.py:訓練模型的程式碼;

l 5-4 test.py:測試模型的程式碼。

部署時,將這4個程式碼檔案與slim庫、NASNet_A_Mobile模型、樣本一起放到一個資料夾下即可。完整的檔案結構如圖5-2。

圖5-2 分辨男女案例的檔案結構

程式碼實現:處理樣本資料並生成Dataset物件

本案例中,直接將資料集的相關操作封裝到了“5-1 mydataset.py”程式碼檔案裡。在該檔案中,實現了符合訓練與測試使用場景的資料集。在訓練模式下,會對資料進行亂序處理;在測試模式下,直接使用順序資料。兩種資料集都是按批次讀取。

這部分的知識在第4章已經有全面的介紹,這裡不再詳述。完整程式碼如下:

程式碼5-1 mydataset

 1 import tensorflow as tf
 2 import sys                                      
 3 nets_path = r'slim'                                             #載入環境變數
 4 if nets_path not in sys.path:
 5    sys.path.insert(0,nets_path)
 6 else:
 7     print('already add slim')
 8 from nets.nasnet import nasnet                               #匯出nasnet
 9 slim = tf.contrib.slim                                         #slim
10 image_size = nasnet.build_nasnet_mobile.default_image_size     #獲得圖片輸入尺寸 224
11 from preprocessing import preprocessing_factory            #影象處理
12 
13 import os
14 def list_images(directory):
15    """
16    獲取所有directory中的所有圖片和標籤
17    """
18
19    #返回path指定的資料夾包含的檔案或資料夾的名字的列表
20    labels = os.listdir(directory)
21    #對標籤進行排序,以便訓練和驗證按照相同的順序進行
22    labels.sort()
23    #建立檔案標籤列表
24    files_and_labels = []
25    for label in labels:
26        for f in os.listdir(os.path.join(directory, label)):
27            #轉換字串中所有大寫字元為小寫再判斷
28            if 'jpg' in f.lower() or 'png' in f.lower():
29                #加入列表
30                files_and_labels.append((os.path.join(directory, label, f), label))
31    #理解為解壓 把資料路徑和標籤解壓出來
32    filenames, labels = zip(*files_and_labels)
33    #轉換為列表 分別儲存資料路徑和對應標籤
34    filenames = list(filenames)
35    labels = list(labels)
36    #列出分類總數 比如兩類:['man', 'woman']
37    unique_labels = list(set(labels))
38
39    label_to_int = {}
40    #迴圈列出資料和資料下標,給每個分類打上標籤{'woman': 2, 'man': 1,none:0}
41    for i, label in enumerate(sorted(unique_labels)):
42        label_to_int[label] = i+1
43    print(label,label_to_int[label])
44    #把每個標籤化為0 1 這種形式
45    labels = [label_to_int[l] for l in labels]
46    print(labels[:6],labels[-6:])
47    return filenames, labels              #返回儲存資料路徑和對應轉換後的標籤
48
49 num_workers = 2                          #定義並行處理資料的執行緒數量
50
51 #影象批量預處理
52 image_preprocessing_fn = preprocessing_factory.get_preprocessing('nasnet_mobile', is_training=True)
53 image_eval_preprocessing_fn = preprocessing_factory.get_preprocessing('nasnet_mobile', is_training=False)
54
55 def _parse_function(filename, label):      #定義影象解碼函式
56    image_string = tf.read_file(filename)
57    image = tf.image.decode_jpeg(image_string, channels=3)          
58    return image, label
59
60 def training_preprocess(image, label):    #定義調整影象大小函式
61    image = image_preprocessing_fn(image, image_size, image_size)
62    return image, label
63
64 def val_preprocess(image, label):       #定義評估影象預處理函式
65    image = image_eval_preprocessing_fn(image, image_size, image_size)
66    return image, label
67
68 #建立帶批次的資料集
69 def creat_batched_dataset(filenames, labels,batch_size,isTrain = True):
70
71    dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
72
73    dataset = dataset.map(_parse_function, num_parallel_calls=num_workers)    #對影象解碼
74
75    if isTrain == True:
76        dataset = dataset.shuffle(buffer_size=len(filenames))                #打亂資料順序
77        dataset = dataset.map(training_preprocess, num_parallel_calls=num_workers)#調整影象大小
78    else:
79        dataset = dataset.map(val_preprocess,num_parallel_calls=num_workers)    #調整影象大小
80
81    return dataset.batch(batch_size)                                           #返回批次資料
82
83 #根據目錄返回資料集
84 def creat_dataset_fromdir(directory,batch_size,isTrain = True):
85    filenames, labels = list_images(directory)
86    num_classes = len(set(labels))
87    dataset = creat_batched_dataset(filenames, labels,batch_size,isTrain)
88    return dataset,num_classes 複製程式碼

程式碼11行,匯入了preprocessing_factory函式,該函式是slim模組中封裝好的工廠函式,用於生成模型的預處理函式。利用統一封裝好的預處理函式,對樣本進行操作(程式碼60、61行),可以提升開發效率,並能夠減小出錯的可能性。

工廠函式的知識點,屬於Python基礎知識,這裡不再詳述。有興趣的讀者可以參考《python帶我起飛——入門、進階、商業實戰》一書的6.10節。

注意:

這裡用了一個技巧。仿照原NASNet_A_Mobile模型的分類方法,在對分類標籤排號時,將標籤為0的分類空出來,男人與女人分別為1和2。

另外在程式碼42行,用到的變數unique_labels是從集合物件轉化過來的。在使用時需要對齊固定順序,所以使用了sorted函式進行變換。如果沒有這句,在下次啟動的時候,有可能出現標籤序號與名稱對應不上的現象。在多次中斷,多次訓練的場景下,會造成訓練結果的混亂。這部分知識在《python帶我起飛——入門、進階、商業實戰》的第四章集合部分的內容中,也做了重點的強調。

程式碼實現:定義微調模型類MyNASNetModel

在微調模型的實現中,統一通過定義類MyNASNetModel來實現。在類MyNASNetModel中,大致可分為2大動作:初始化設定、構建模型。

l 初始化設定:定義好構建模型時所需要的必要引數;

l 構建模型:針對訓練、測試、應用的三種情況分別構建不同的網路模型。在訓練過程中,還要支援載入預編譯模型及微調模型。

實現定義類MyNASNetModel並進行初始化模型設定的程式碼如下:

程式碼5-2 model

  1 import sys                                      
  2 nets_path = r'slim'                                    #載入環境變數
  3 if nets_path not in sys.path:
  4    sys.path.insert(0,nets_path)
  5 else:
  6    print('already add slim')
  7
  8 import tensorflow as tf
  9 from nets.nasnet import nasnet                       #匯出nasnet
 10 slim = tf.contrib.slim 
 11
 12 import os  
 13 mydataset = __import__("5-1  mydataset")
 14 creat_dataset_fromdir = mydataset.creat_dataset_fromdir
 15
 16 class MyNASNetModel(object):
 17    """微調模型類MyNASNetModel
 18    """
 19    def __init__(self, model_path=''):
 20        self.model_path = model_path              #原始模型的路徑           複製程式碼

程式碼20行為初始化MyNASNetModel類的操作。model_path指的是所要載入的原始預編譯模型。該操作只有在訓練模式下是有意義的。在測試和應用模式下,可以為空。

構建MyNASNetModel類中的基本模型

在構建模型中,無論是訓練、測試還是應用,都需要將最基本的NASNet_A_Mobile模型載入。這裡通過定義MyNASNetModel類的MyNASNet方法來實現。具體的實現方式與3.3節的實現基本一致,不同的是3.3節構建的是PNASNet網路結構,這裡構建的NASNet_A_Mobile結構。

程式碼5-2 model(續)

 21  def MyNASNet(self,images,is_training):
 22        arg_scope = nasnet.nasnet_mobile_arg_scope()          #獲得模型名稱空間
 23        with slim.arg_scope(arg_scope):
 24            #構建NASNet Mobile模型
 25            logits, end_points = nasnet.build_nasnet_mobile(images,num_classes = self.num_classes+1, is_training=is_training)
 26
 27        global_step = tf.train.get_or_create_global_step()      #定義記錄步數的張量
 28
 29        return logits,end_points,global_step                   #返回有用的張量
複製程式碼


程式碼25行中,往num_classes引數裡傳的值代表分類的個數,在本案例中分為男人和女人,一共兩類(即,self.num_classes=2,該值是在後文5.2.8節中,build_model方法被賦值的)。再加上一個None類。於是傳入的值為self.num_classes+1。

實現MyNASNetModel類中的微調操作

微調操作是針對訓練場景下使用的。通過定義MyNASNetModel類中的FineTuneNASNet方法來實現。微調操作主要是對預編譯模型的超參進行選擇性恢復。

因為預編譯模型NASNet_A_Mobile是在ImgNet上訓練的,有1000個分類,而本案例中識別男女的任務只有兩個分類。所以最後兩個輸出層的超參不應該被恢復(由於分類不同,導致超參的個數不同)。在實際使用時,最後兩層的引數需要對其初始化,並單獨訓練即可。

程式碼5-2 model(續)

 30 def FineTuneNASNet(self,is_training):      #實現微調模型的網路操作 
 31        model_path = self.model_path
 32
 33        exclude = ['final_layer','aux_7']      #恢復超參, 除了exclude以外的全部恢復
 34        variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
 35        if is_training == True:
 36            init_fn = slim.assign_from_checkpoint_fn(model_path, variables_to_restore)
 37        else:
 38            init_fn = None
 39
 40        tuning_variables = []             #將沒有恢復的超參收集起來,用於微調訓練
 41        for v in exclude:
 42            tuning_variables += slim.get_variables(v)
 43
 44        return init_fn, tuning_variables複製程式碼

程式碼中,使用了exclude列表,將不需要恢復的網路節點收集起來(程式碼33行),接著將預訓練模型中的超參值賦值給剩下的節點,完成了預訓練模型的載入(程式碼36行)。最後使用了tuning_variables列表,將不需要恢復的網路節點權重收集起來(程式碼40行),用於微調訓練。

注意:

這裡介紹個技巧,如何獲得exclude中的元素(程式碼33行):通過額外執行程式碼tf.global_variables(),將張量圖中的節點打印出來。從裡面找到最後兩層的節點,並將其填入程式碼中即可。在找到節點後,還可以通過slim.get_variables函式,來檢查該名稱的節點是否正確。例如,可以通過將slim.get_variables('final_layer')的返回值打印出來,來觀察張量圖中是否有final_layer節點。這部分的原理可以參考《深度學習之TensorFlow:入門、原理與進階實戰》書中第4章的內容(在第11章也有類似的案例)。

程式碼實現:實現與訓練相關的其他方法

在MyNASNetModel類中,還需要定義與訓練操作相關的其他方法,具體如下:

l build_acc_base方法:用於構建評估模型的相關節點;

l load_cpk方法:用於載入及儲存模型檢查點

l build_model_train方法:用於構建訓練模型中的損失函式及優化器等操作節點。

具體程式碼如下:

程式碼5-2 model(續)

 45 def build_acc_base(self,labels):#定義評估函式
 46        #返回張量中最大值的索引
 47        self.prediction = tf.to_int32(tf.argmax(self.logits, 1))
 48        #計算prediction、labels是否相同 
 49        self.correct_prediction = tf.equal(self.prediction, labels)
 50        #計算平均值
 51        self.accuracy = tf.reduce_mean(tf.to_float(self.correct_prediction))
 52        #將前5個最高正確率的值取出來,計算平均值
 53        self.accuracy_top_5 = tf.reduce_mean(tf.to_float(tf.nn.in_top_k(predictions=self.logits, targets=labels, k=5)))
 54
 55    def load_cpk(self,global_step,sess,begin = 0,saver= None,save_path = None):                                                    #儲存和匯出模型
 56       if begin == 0:
 57            save_path=r'./train_nasnet'                      #定義檢查點路徑
 58            if not os.path.exists(save_path):
 59                print("there is not a model path:",save_path)
 60            saver = tf.train.Saver(max_to_keep=1)            #生成saver
 61            return saver,save_path
 62        else:
 63            kpt = tf.train.latest_checkpoint(save_path)    #查詢最新的檢查點
 64            print("load model:",kpt)
 65            startepo= 0                                    #計步
 66            if kpt!=None:
 67                saver.restore(sess, kpt)                     #還原模型
 68                ind = kpt.find("-")
 69                startepo = int(kpt[ind+1:])
 70                print("global_step=",global_step.eval(),startepo)    
 71            return startepo  
 72
 73    def build_model_train(self,images,
 74           labels,learning_rate1,learning_rate2,is_training):
 75           self.logits,self.end_points, 
 76           self.global_step= self.MyNASNet(images,is_training=is_training)
 77        self.step_init = self.global_step.initializer
 78
 79        self.init_fn,self.tuning_variables = self.FineTuneNASNet(
 80            is_training=is_training)
 81        #定義損失函式
 82       tf.losses.sparse_softmax_cross_entropy(labels=labels, 
 83            logits=self.logits)
 84        loss = tf.losses.get_total_loss()
 85        #定義微調的率退化學習速率
 86        learning_rate1=tf.train.exponential_decay(
 87                 learning_rate=learning_rate1, global_step=self.global_step,
 88                 decay_steps=100, decay_rate=0.5)
 89        #定義聯調的率退化學習速率
 90        learning_rate2=tf.train.exponential_decay(
 91             learning_rate=learning_rate2, global_step=self.global_step,
 92             decay_steps=100, decay_rate=0.2)                
 93        last_optimizer = tf.train.AdamOptimizer(learning_rate1) #優化器
 94        full_optimizer = tf.train.AdamOptimizer(learning_rate2)   
 95        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)  
 96        with tf.control_dependencies(update_ops):      #更新批量歸一化中的引數
 97            #使loss減小方向做優化
 98            self.last_train_op = last_optimizer.minimize(loss, self.global_step,var_list=self.tuning_variables)
 99            self.full_train_op = full_optimizer.minimize(loss, self.global_step)
100
101        self.build_acc_base(labels)                    #定義評估模型相關指標
102        #寫入日誌,支援tensorBoard操作
103        tf.summary.scalar('accuracy', self.accuracy)    
104        tf.summary.scalar('accuracy_top_5', self.accuracy_top_5)
105
106        #將收集的所有預設圖表併合並
107        self.merged = tf.summary.merge_all()
108        #寫入日誌檔案
109        self.train_writer = tf.summary.FileWriter('./log_dir/train')
110        self.eval_writer = tf.summary.FileWriter('./log_dir/eval')
111        #定義檢查點相關變數
112        self.saver,self.save_path = self.load_cpk(self.global_step,None)複製程式碼

在上面程式碼中,使用了tf.losses介面來獲得loss值。通過呼叫tf.losses.sparse_softmax_cross_entropy 函式計算具體的loss(見程式碼82行)。該函式會自動將loss值新增到內部集合ops.GraphKeys.LOSSES中。然後呼叫tf.losses.get_total_loss函式,將ops.GraphKeys.LOSSES集合中的所有loss值獲取,並返回回來(見程式碼84行)。

在程式碼96行中,在反向優化時,使用了tf.control_dependencies函式對的批量歸一化操作中的均值與方差進行更新。

程式碼實現:構建模型,用於訓練、測試、使用

在MyNASNetModel類中,定義build_model方法用與構建模型的實現。在build_model方法中,通過引數mode來指定模型的具體使用場景。具體程式碼如下:

程式碼5-2 model(續)

113 def build_model(self,mode='train',testdata_dir='./data/val',traindata_dir='./data/train', batch_size=32,learning_rate1=0.001,learning_rate2=0.001):
114
115        if mode == 'train':        
116            tf.reset_default_graph()
117            #建立訓練資料和測試資料的Dataset資料集
118            dataset,self.num_classes = creat_dataset_fromdir(traindata_dir,batch_size)
119            testdataset,_ = creat_dataset_fromdir(testdata_dir,batch_size,isTrain = False)
120
121            #建立一個可初始化的迭代器
122            iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
123            #讀取資料
124            images, labels = iterator.get_next()
125
126            self.train_init_op = iterator.make_initializer(dataset)
127            self.test_init_op = iterator.make_initializer(testdataset)
128
129            self.build_model_train(images, labels,learning_rate1,learning_rate2,is_training=True)
130            self.global_init = tf.global_variables_initializer()    #定義全域性初始化op
131            tf.get_default_graph().finalize()                #將後續的圖設為只讀
132        elif mode == 'test':
133            tf.reset_default_graph()
134
135            #建立測試資料的Dataset資料集
136            testdataset,self.num_classes = creat_dataset_fromdir(testdata_dir,batch_size,isTrain = False)
137
138            #建立一個可初始化的迭代器
139            iterator = tf.data.Iterator.from_structure(testdataset.output_types, testdataset.output_shapes)
140            #讀取資料
141            self.images, labels = iterator.get_next()
142
143            self.test_init_op = iterator.make_initializer(testdataset)
144            self.logits,self.end_points, self.global_step= self.MyNASNet(self.images,is_training=False)
145            self.saver,self.save_path = self.load_cpk(self.global_step,None)                  #定義檢查點相關變數
146            #評估指標
147            self.build_acc_base(labels)
148            tf.get_default_graph().finalize()            #將後續的圖設為只讀
149        elif mode == 'eval':
150            tf.reset_default_graph()
151            #建立測試資料的Dataset資料集
152            testdataset,self.num_classes = creat_dataset_fromdir(testdata_dir,batch_size,isTrain = False)
153
154            #建立一個可初始化的迭代器
155            iterator = tf.data.Iterator.from_structure(testdataset.output_types, testdataset.output_shapes)
156            #讀取資料
157            self.images, labels = iterator.get_next()
158
159            self.logits,self.end_points, self.global_step= self.MyNASNet(self.images,is_training=False)
160            self.saver,self.save_path = self.load_cpk(self.global_step,None)   #定義檢查點相關變數
161            tf.get_default_graph().finalize()                        #將後續的圖設為只讀複製程式碼

程式碼115行,對mode進行了判斷,並按照具體的場景進行構建模型。針對訓練、測試、使用的三個場景,構建的步驟幾乎一樣,具體如下:

(1)清空張量圖(程式碼116、133、150);

(2)生成資料集(程式碼118、136、152);

(3)定義網路結構(程式碼129、144、159)。

測試與使用的場景是最相似的。在程式碼中測試比使用的操作對了個評估節點的生成(程式碼147)。

注意:

在每個操作分支的最後程式碼部分都加了程式碼tf.get_default_graph().finalize()(見程式碼131、148、161行),這是一個很好的習慣。該程式碼的功能是把圖鎖定,之後想要新增任何新的操作都會產生錯誤。這麼做的意圖是防止在後面訓練或是測試過程中,由於開發人員疏忽,在圖中新增額外的圖操作。一旦在迴圈內部加了某個張量的操作,將會使整體效能大大下降。然而這種錯誤又很難發現。利用鎖定圖的方法,可以避免這種情況的發生。

程式碼實現:通過二次迭代來訓練微調模型

訓練微調模型的操作是在程式碼檔案“5-3 train.py”中單獨實現的。與正常的訓練方式不同,這裡使用了二次迭代的方式:

l 第一次迭代:微調模型,固定預編譯模型載入的權重,只訓練最後兩層;

l 第二次迭代:聯調模型,使用更小的學習率,訓練全部節點。

先將類MyNASNetModel進行例項化,在呼叫其build_model方法構建模型,然後使用session開始訓練。具體程式碼如下:

程式碼5-3 train

162 import tensorflow as tf
163 model = __import__("5-2  model")
164 MyNASNetModel = model.MyNASNetModel
165
166 batch_size = 32
167 train_dir  = 'data/train'
168 val_dir  = 'data/val'
169
170 learning_rate1 = 1e-1                                       #定義兩次迭代的學習率
171 learning_rate2 = 1e-3
172
173 mymode = MyNASNetModel(r'nasnet-a_mobile_04_10_2017\model.ckpt')#初始化模型
174 mymode.build_model('train',val_dir,train_dir,batch_size,learning_rate1 ,learning_rate2 )                                                                    #將模型定義載入圖中
175
176 num_epochs1 = 20                                           #微調的迭代次數
177 num_epochs2 = 200                                        #聯調的迭代次數
178
179 with tf.Session() as sess:
180    sess.run(mymode.global_init)                         #初始全域性節點
181
182    step = 0
183    step = mymode.load_cpk(mymode.global_step,sess,1,mymode.saver,mymode.save_path )#載入模型
184    print(step)
185    if step == 0:                                        #微調
186        mymode.init_fn(sess)                                 #載入預編譯模型權重
187
188        for epoch in range(num_epochs1):
189
190            print('Starting1 epoch %d / %d' % (epoch + 1, num_epochs1))    #輸出進度 
191            #用訓練集初始化迭代器
192            sess.run(mymode.train_init_op)                                #資料集從頭開始
193            while True:
194                try:
195                    step += 1
196                    #預測,合併圖,訓練
197                    acc,accuracy_top_5, summary, _ = sess.run([mymode.accuracy, mymode.accuracy_top_5,mymode.merged,mymode.last_train_op])
198
199                    #mymode.train_writer.add_summary(summary, step)#寫入日誌檔案
200                    if step % 100 == 0:
201                        print(f'step: {step} train1 accuracy: {acc},{accuracy_top_5}')
202                except tf.errors.OutOfRangeError:#資料集指標在最後
203                    print("train1:",epoch," ok")
204                    mymode.saver.save(sess, mymode.save_path+"/mynasnet.cpkt",   global_step=mymode.global_step.eval())
205                    break
206
207        sess.run(mymode.step_init)                    #微調結束,計數器從0開始
208
209    #整體訓練
210    for epoch in range(num_epochs2):
211        print('Starting2 epoch %d / %d' % (epoch + 1, num_epochs2))
212        sess.run(mymode.train_init_op)
213        while True:
214            try:
215                step += 1
216                #預測,合併圖,訓練
217                acc, summary, _ = sess.run([mymode.accuracy, mymode.merged, mymode.full_train_op])
218
219                mymode.train_writer.add_summary(summary, step)#寫入日誌檔案
220
221                if step % 100 == 0:
222                    print(f'step: {step} train2 accuracy: {acc}')
223            except tf.errors.OutOfRangeError:
224                print("train2:",epoch," ok")
225                mymode.saver.save(sess, mymode.save_path+"/mynasnet.cpkt",   global_step=mymode.global_step.eval())
226                break複製程式碼

將以上程式碼執行後,經過一段時間的訓練,可以在本地找到“train_nasnet”資料夾,裡面放著的就是訓練生成的模型檔案。

程式碼實現:測試模型

測試模型的操作是在程式碼檔案“5-4 test.py”中單獨實現的。這裡實現了使用測試資料集對現有模型的評估,並且使用單張圖片放到模型裡進行預測。

1. 定義測試模型所需要的功能函式

首先定義函式check_accuracy實現準確率的計算,接著定義函式check_sex實現男女性別的識別。具體程式碼如下:

程式碼5-4 test

227 import tensorflow as tf
228 model = __import__("5-2  model")
229 MyNASNetModel = model.MyNASNetModel
230
231 import sys                                      
232 nets_path = r'slim'                                     #載入環境變數
233 if nets_path not in sys.path:
234    sys.path.insert(0,nets_path)
235 else:
236    print('already add slim')
237
238 from nets.nasnet import nasnet                     #匯出nasnet
239 slim = tf.contrib.slim                                 #slim
240 image_size = nasnet.build_nasnet_mobile.default_image_size  #獲得圖片輸入尺寸 224
241
242 import numpy as np
243 from PIL import Image
244
245 batch_size = 32
246 test_dir  = 'data/val'
247
248 def check_accuracy(sess):
249    """
250    測試模型準確率
251    """
252    sess.run(mymode.test_init_op)                  #初始化測試資料集
253    num_correct, num_samples = 0, 0                 #定義正確個數 和 總個數
254    i = 0
255    while True:
256        i+=1
257        print('i',i)
258        try:
259            #計算correct_prediction 獲取prediction、labels是否相同 
260            correct_pred,accuracy,logits = sess.run([mymode.correct_prediction,mymode.accuracy,mymode.logits])
261            #累加correct_pred
262            num_correct += correct_pred.sum()
263            num_samples += correct_pred.shape[0]
264            print("accuracy",accuracy,logits)
265
266
267        except tf.errors.OutOfRangeError:          #捕獲異常,資料用完自動跳出
268            print('over')
269            break
270
271    acc = float(num_correct) / num_samples         #計算並返回準確率
272    return acc 
273
274
275 def check_sex(imgdir,sess):                        #定義函式識別男女
276    img = Image.open(image_dir)                      #讀入圖片
277    if "RGB"!=img.mode :                             #檢查圖片格式
278        img = img.convert("RGB") 
279
280    img = np.asarray(img.resize((image_size,image_size)),     #影象預處理  
281                          dtype=np.float32).reshape(1,image_size,image_size,3)
282    img = 2 *( img / 255.0)-1.0 
283
284    prediction = sess.run(mymode.logits, {mymode.images: img})#傳入nasnet輸入端中
285    print(prediction)
286
287    pre = prediction.argmax()                    #返回張量中最大值的索引
288    print(pre)
289
290    if pre == 1: img_id = 'man'
291    elif pre == 2: img_id = 'woman'
292    else: img_id = 'None'
293    plt.imshow( np.asarray((img[0]+1)*255/2,np.uint8 )  )
294    plt.show()
295    print(img_id,"--",image_dir)                    #返回類別
296    return pre複製程式碼

2. 建立會話,進行測試

首先建立會話session,對模型進行測試,接著取2張圖片輸入模型,進行男女的判斷。具體程式碼如下:

程式碼5-4 test(續)

297 mymode = MyNASNetModel()                                     #初始化模型
298 mymode.build_model('test',test_dir )                     #將模型定義載入圖中
299
300 with tf.Session() as sess:  
301    #載入模型
302    mymode.load_cpk(mymode.global_step,sess,1,mymode.saver,mymode.save_path )
303
304    #測試模型的準確性
305    val_acc = check_accuracy(sess)
306    print('Val accuracy: %f\n' % val_acc)
307
308    #單張圖片測試
309    image_dir = 'tt2t.jpg'                                 #選取測試圖片
310    check_sex(image_dir,sess)
311
312    image_dir = test_dir + '\\woman' + '\\000001.jpg'       #選取測試圖片
313    check_sex(image_dir,sess)
314
315    image_dir = test_dir + '\\man' + '\\000003.jpg'         #選取測試圖片
316    check_sex(image_dir,sess)複製程式碼

該程式使用的是迭代了100次資料集後的模型檔案(如果要效果提高,可以再執行久一點)。程式碼執行後,輸出結果如下。

(1)顯示測試集的輸出結果:

i 1

accuracy 0.90625 [[-3.813714 1.4075054 1.1485975 ]

[-7.3948846 6.220533 -1.4093535 ]

[-1.9391974 3.048838 0.21784738]

[-3.873174 4.530942 0.43135062]

……

[-3.8561587 2.7012844 -0.3634925 ]

[-4.4860134 4.7661724 -0.67080706]

[-2.9615571 2.8164086 0.71033645]]

i 2

accuracy 0.90625 [[ -6.6900268 -2.373093 6.6710057 ]

[ -4.1005263 0.74619263 4.980012 ]

[ -5.6469827 0.39027584 1.2689826 ]

……

[ -5.8080773 0.9121424 3.4134243 ]

[ -4.242001 0.08483959 4.056322 ]]

i 3

over

Val accuracy: 0.906250

上面顯示的是測試集中man和woman資料夾中圖片的計算結果。最終模型的準確率為90%。

(2)顯示單張圖片的執行結果:

[[-4.8022223 1.9008529 1.9379601]]

2

圖5-3 分辨男女測試圖片(a)

woman -- tt2t.jpg

[[-6.181205 -2.9042015 6.1356106]]

2

圖5-3 分辨男女測試圖片(b)

woman -- data/val\woman\000001.jpg

[[-4.896065 1.7791721 1.3118265]]

1

圖5-3 分辨男女測試圖片(c)

man -- data/val\man\000003.jpg

上面顯示了3張圖片,分別為自選圖片、測試資料集中的女人圖片、測試資料集中的男人圖片,每張圖片下面顯示了模型識別的結果。可以看到結果與圖片內容一致。

結尾

文內程式碼可以直接執行使用。如果不想手動搭建,還可以下載本文的配套程式碼。

【程式碼獲取】:關注公眾號:xiangyuejiqiren   公眾號回覆“pycon2

如果覺得本文有用

可以分享給更多小夥伴