利用tensorflow進行德國交通標誌識別》
一,前言
利用暑假學習了一段時間的DL,德國交通標誌馬上實踐一下,下面主要提到我用的模型以及資料處理,如果有哪些地方大佬們覺得哪些地方做得不好或者可以改進得,非常歡迎直接在評論裡提出來,初學需要一起努力。
二,資料
我只需要上面四類檔案中“影象與註釋”的內容(其他feature我也不會用。。。),下載下來之後開啟看看檔案構成。
像官網介紹的一樣,一共43個種類的資料,43個資料夾,每個資料夾都有小到幾百張,大到上千張圖片,在檔案的最後都以csv格式給出來了label,影象用ppm格式給了出來,接下來我們嘗試對資料進行處理。
1, 首先官網給出了從資料夾中讀取資料的py程式,可以直接拿來用,可以省掉很多時間,我們先來看看他的程式以及讀取結果。
上面中文註釋是我看程式的時候加上去的方便理解,range()函式中的數字是你要讀取的檔案種類,現在是(1,2),表示讀取了第二類檔案!如果全資料集的話應該是(0,43),中間有一段resize()函式,這一段是方便讀取的時候就統一照片尺寸,有利於後面的引數設定。
接下來用matplotlib繪製圖像,將每個種類的第一張照片以及他的數量輸出,程式碼以及效果如下圖所示:
三,模型分析
對資料有了初步的認識,接下來我們就要準備模型,這裡選擇了難度程度較低的LeNet-5模型,LeNet-5模型由兩個卷積層與兩個池化層,以及兩個全連線層構成。在前向傳播traffic_inference中卷積層,池化層以及全連線層的引數設定如下:第一層卷積層深度為32,卷積核大小為5,第一層池化層中strides=[1,2,2,1],padding='SAME'。第一層全連線節點個數為512
後面的卷積層以及池化層引數設定都類似,這裡不再一一寫出 , 下面給出一組卷積層以及池化層的程式碼作為參考。
接下來應該是全連線神經網路,現在我們的資料是一個矩陣形式,而全連結神經網路需要我們把資料進行“拍扁”,變成一維的資料才能繼續進行,我們對矩陣進行拉伸成一個一維陣列,具體程式碼如下:
上面經過第二層池化層之後,我們將資料通過pool2的shape()函式來獲取裡面一共有多少個引數,將其轉化為list,注意我們是一個一個batch訓練的,所以shape中的第一個數字表示batch_size的大小,後面才是我們的引數,求得nodes之後我們就求得了pool2中一共的引數個數,通過reshape()函式我們將pool2重現拉伸為一維陣列,之後便可以輸入到全連結神經網路中,具體程式碼最後給出。
目前為止我們實現了前向傳播的整個過程,接下來我們要實現反向傳播過程,也就是我們的迴圈訓練過程,所有程式碼在最後的github連線中給了出來
其中我們定義了滑動平均類為: variable_average = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECARY, global_step) 其中我們設定了MOVING_AVERAGE_DECARY中是滑動平均率,滑動平均可以讓我們採取一種這種的方式更新我們的引數,可以防止一定的噪聲。
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_,1))
cross_entropy採用了tf.nn.sparse_softmax_cross_entropy_with_logits(),因為最後的結果只有一個,所以可以採用上面的函式,把損失函式和softmax()函式結合到一起了,加快了運算速度,而最後的tf.argmax()函式是因為你的正確標籤是one-hot型別的,至於如何轉換成one-hot型別我們等會再說
learning_rate = tf.train.exponential_decay( LEARNING_RATE_BASE, global_step, 400, LEARNING_RATE_DECAY, staircase = True )
learning_rate採用的指數下降函式,學習率剛開始設定一個合適的值,然後隨著訓練輪數的疊加而不斷降低,這樣我們就可以在迭代後期採用較小的學習率防止產生loss區線震盪,同時學習率也不宜過大,如果過大會出現loss值是None或者非常大的情況,如果出現說明你應該調低你的初始學習率,調低大概2-3個位數。
我們通過官網給出來的程式讀取的資料不可以直接放進網路中訓練!
不管是images的資料還是標籤的資料都不符合網路要求。 ①首先我們應該把照片的size重新置成一個固定值,這個我們已經在讀取程式中更改過了,就是在readTrafficSign.py中的resize()函式中做的。 ②接下來將圖片矩陣轉換成numpy矩陣才可以輸入到網路中。而且讀取出來的標籤中是str格式,我們也需要將其轉換成int形式,並且轉換成one-hot資料形式,普通的資料list不能被演算法學習,簡單來說就是你輸入進去一堆一維的陣列,每個數字都表示一個例項的標籤,感覺資料和標籤都輸入進去就可以訓練了,但是有一個問題,程式根本不知道你有多少個標籤,以後可以嘗試用sklearn來快速轉化一下(無奈現在還不會),所以現在我們需要構造一個稀疏矩陣,具體程式碼如下:
這樣的話資料就可以使用了,用一個一個batch_size輸入到程式中訓練,需要注意的是從檔案中讀取資料訓練的時候需要打亂順序讀取,不再給出打亂檔案的程式自己感興趣的可以自己完成。
其實中間還有很多細節沒有提到,其中的具體內容就不在這裡寫出來,感興趣可以去看一看。