基於CNN的人臉識別
阿新 • • 發佈:2019-02-17
前些日子讀了一篇關於人臉識別的文章,age and gender classificiation using convolutional neural networks,這是一篇發表在cvpr2015的一篇文章,文章寫得很好,條理清晰,邏輯性強,非常適合深度學習者學習,在這裡特別做一個相關筆記。
當然,我也嘗試著去實現這篇論文,文章中說到要用adience benchmark人臉資料庫,我上網上找了找並沒有找到這個資料庫,於是乎我只能找其他的資料庫來代替。開始的時候自己蒐集資料庫(汗!),蒐集了大概1000張人臉,去訓練神經網路時損失函式直接爆炸,完全不收斂,折騰了好久才想到可能是資料庫樣本太小的原因,後來在知乎大牛的指導下擴大自己的資料庫,然後在caffe上執行試驗,結果不但收斂了,準確率竟能達到約98%!!!
在實現的過程中需要注意的是資料庫的大小一定不能太小,否則會不收斂,再者,要適當選擇網路中的引數。我用cuda-convnet(也就是cifar10分類問題中比較流行的一個網路結構)來訓練速度很快,最高準確率能達到96%,用論文中所述的網路結構時,速度要慢很多,但是最後準確率最高達到98.5%。
我只實現了人臉性別檢測,人臉年齡檢測問題可以說是換湯不換藥,因為資料庫準備需要大量的人力物力,所以就沒有繼續做下去!
卷積神經網路的結構如下圖:
多的不說,先看看程式碼:
age_net:
name: "facenet" layers { name: "data" type: DATA top: "data" top: "label" data_param { source: "data/mywork/face_train_lmdb" backend: LMDB batch_size: 50 } transform_param { crop_size: 227 mean_file: "data/mywork/face_mean.binaryproto" mirror: true } include: { phase: TRAIN } } layers { name: "data" type: DATA top: "data" top: "label" data_param { source: "data/mywork/face_val_lmdb" backend: LMDB batch_size: 50 } transform_param { crop_size: 227 mean_file: "data/mywork/face_mean.binaryproto" mirror: false } include: { phase: TEST } } layers { name: "conv1" type: CONVOLUTION bottom: "data" top: "conv1" blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 convolution_param { num_output: 96 kernel_size: 7 stride: 4 weight_filler { type: "gaussian" std: 0.01 } bias_filler { type: "constant" value: 0 } } } layers { name: "relu1" type: RELU bottom: "conv1" top: "conv1" } layers { name: "pool1" type: POOLING bottom: "conv1" top: "pool1" pooling_param { pool: MAX kernel_size: 3 stride: 2 } } layers { name: "norm1" type: LRN bottom: "pool1" top: "norm1" lrn_param { local_size: 5 alpha: 0.0001 beta: 0.75 } } layers { name: "conv2" type: CONVOLUTION bottom: "norm1" top: "conv2" blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 convolution_param { num_output: 256 pad: 2 kernel_size: 5 weight_filler { type: "gaussian" std: 0.01 } bias_filler { type: "constant" value: 1 } } } layers { name: "relu2" type: RELU bottom: "conv2" top: "conv2" } layers { name: "pool2" type: POOLING bottom: "conv2" top: "pool2" pooling_param { pool: MAX kernel_size: 3 stride: 2 } } layers { name: "norm2" type: LRN bottom: "pool2" top: "norm2" lrn_param { local_size: 5 alpha: 0.0001 beta: 0.75 } } layers { name: "conv3" type: CONVOLUTION bottom: "norm2" top: "conv3" blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 convolution_param { num_output: 384 pad: 1 kernel_size: 3 weight_filler { type: "gaussian" std: 0.01 } bias_filler { type: "constant" value: 0 } } } layers{ name: "relu3" type: RELU bottom: "conv3" top: "conv3" } layers { name: "pool5" type: POOLING bottom: "conv3" top: "pool5" pooling_param { pool: MAX kernel_size: 3 stride: 2 } } layers { name: "fc6" type: INNER_PRODUCT bottom: "pool5" top: "fc6" blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 inner_product_param { num_output: 512 weight_filler { type: "gaussian" std: 0.005 } bias_filler { type: "constant" value: 1 } } } layers { name: "relu6" type: RELU bottom: "fc6" top: "fc6" } layers { name: "drop6" type: DROPOUT bottom: "fc6" top: "fc6" dropout_param { dropout_ratio: 0.5 } } layers { name: "fc7" type: INNER_PRODUCT bottom: "fc6" top: "fc7" blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 inner_product_param { num_output: 512 weight_filler { type: "gaussian" std: 0.005 } bias_filler { type: "constant" value: 1 } } } layers { name: "relu7" type: RELU bottom: "fc7" top: "fc7" } layers { name: "drop7" type: DROPOUT bottom: "fc7" top: "fc7" dropout_param { dropout_ratio: 0.5 } } layers { name: "fc8" type: INNER_PRODUCT bottom: "fc7" top: "fc8" blobs_lr: 10 blobs_lr: 20 weight_decay: 1 weight_decay: 0 inner_product_param { num_output: 2 weight_filler { type: "gaussian" std: 0.01 } bias_filler { type: "constant" value: 0 } } } layers { name: "accuracy" type: ACCURACY bottom: "fc8" bottom: "label" top: "accuracy" include: { phase: TEST } } layers { name: "loss" type: SOFTMAX_LOSS bottom: "fc8" bottom: "label" top: "loss" }
結構非常簡單,經過大概4個小時的訓練達到預期準確率。