Ordinal Regression with Multiple Output CNN for Age Estimation
阿新 • • 發佈:2018-11-06
文獻提出了結合CNN和迴歸進行年齡預測的端到端的深度學習網路,網路結構圖如下,
輸入為
最後根據所有的100個子問題的輸出結果計算預測年齡,其計算公式為:
損失函式的計算
採用交叉損失熵,對於每個類別,賦予不同的權重
- 我們有K=100個子任務,所以我們最後的損失函式公式為
E - 其中
λt 代表著第t個子任務的資料權重,也就是每個年齡資料量在總資料量中的比重,其計算公式為λt=Nt√∑Ti=1Ni√ 而$ - w_i$表示每一個二分類子問題中每一個類別的權重,可以簡單地設為1。 最後再利用反向梯度傳播來實現引數最優化。
程式碼分析
文章提供了caffe程式碼,
輸入為HDF5資料格式,
layer {
top: "data"
type: "HDF5Data"
top: "label"
name: "data"
hdf5_data_param {
source: "../data/SourceDataForCaffe/normal_Tutu_morph_wiki/hdf5_data_gray/train.txt"
batch_size: 256
}
include {
phase: TRAIN
}
}
layer {
top: "data"
top: "label"
name: "data"
type: "HDF5Data"
hdf5_data_param {
source: "../data/SourceDataForCaffe/normal_Tutu_morph_wiki/hdf5_data_gray/test.txt"
batch_size: 256
}
include {
phase: TEST
}
}
對於hdf5資料的轉換,可以參考程式碼,
import sys
import numpy as np
import matplotlib.pyplot as plt
import h5py
IMAGE_SIZE = (60, 60)
MEAN_VALUE = 128
filename = sys.argv[1]
setname, ext = filename.split('.')
with open(filename, 'r') as f:
lines = f.readlines()
np.random.shuffle(lines)
sample_size = len(lines)
imgs = np.zeros((sample_size, 1,) + IMAGE_SIZE, dtype=np.float32)
freqs = np.zeros((sample_size, 2), dtype=np.float32)
h5_filename = '{}.h5'.format(setname)
with h5py.File(h5_filename, 'w') as h:
for i, line in enumerate(lines):
image_name, fx, fy = line[:-1].split()
img = plt.imread(image_name)[:, :, 0].astype(np.float32)
img = img.reshape((1, )+img.shape)
img -= MEAN_VALUE
imgs[i] = img
freqs[i] = [float(fx), float(fy)]
if (i+1) % 1000 == 0:
print('Processed {} images!'.format(i+1))
h.create_dataset('data', data=imgs)
h.create_dataset('freq', data=freqs)
with open('{}_h5.txt'.format(setname), 'w') as f:
f.write(h5_filename)
之後是卷積和FC層,最後是損失函式層,
layer {
bottom: "fc_output"
bottom: "label"
top: "loss"
name: "loss"
type: "OrdinalRegressionLoss"
include {
phase: TRAIN
}
ordinal_regression_loss_param {
weight_file: "../data/weight_file/weight_file_normal.txt"
}
}
layer {
bottom: "fc_output"
bottom: "label"
top: "loss"
name: "loss"
type: "OrdinalRegressionLoss"
include {
phase: TEST
}
ordinal_regression_loss_param {
weight_file: "../data/weight_file/weight_file_normal.txt"
}
}
由於caffe沒有多工訓練的損失函式,因此作者編寫了對於的層,需要重新編譯caffe,可以參考,
https://github.com/kongsicong/Age_recognition_OR/tree/master/OrdinalRegression
處理後便可以訓練.
tensorflow模型訓練
編寫對應的tensorflow模型,訓練lmdb年齡分類資料,測試精度為,
平均年齡誤差:8.
github參考程式碼:https://github.com/kongsicong/Age_recognition_OR