使用caffe的python layer自定義資料增強層DataAugmentationLayer
阿新 • • 發佈:2019-01-03
專案地址:https://github.com/zhongqianli/caffe_python_layer
caffe自定義網路層的一種方式是使用python layer,這種方式需要使用pycaffe執行,命令列的方式執行會報錯。
編寫DataAugmentationLayer
這個類的基類是caffe.Layer,需要編寫setup,reshape,forward,backward四個方法,每個方法都有top和bottom引數,可以通過top[0].data和bottom[0].data獲取一個4維的資料,分別是batch_size、通道數、高、寬。
import caffe import json import cv2 import numpy as np import random # 4 pixel pad, random crop # img: 64x3x32x32 def zeropadding_and_crop(data): # # cifar10 # # padding_img = np.pad(img, ((4, 4), (4, 4), (4, 4)), "constant", padder=0) padding_img = np.zeros((np.shape(data)[0], 3, 40, 40), dtype=np.uint8) padding_img[..., 4:36, 4:36] = data[...] # # # cv2.imshow("pad", data[0][0]) row_rand_num = random.randrange(9) col_rand_num = random.randrange(9) croped_img = padding_img[..., row_rand_num : row_rand_num + 32, col_rand_num : col_rand_num + 32] return croped_img class DataAugmentationLayer(caffe.Layer): def setup(self, bottom, top): pass def reshape(self, bottom, top): top[0].reshape(*bottom[0].data.shape) pass def forward(self, bottom, top): top[0].data[...] = zeropadding_and_crop(bottom[0].data) pass def backward(self, top, propagate_down, bottom): pass
使用自定義的網路層DataAugmentationLayer
pycaffe最好使用net.xxx的方式建立網路,因為第二種方式會自動命名,可能會出現一下意想不到的問題。
# 第一種方式,推薦使用 net.data = L.Python(net.data_temp, python_param=dict(module="custom_data_augmentation", layer="DataAugmentationLayer"), include=dict(phase=caffe_pb2.Phase.Value('TRAIN'))) # 第二種方式,不推薦這種方式 data = L.Python(data_temp, python_param=dict(module="custom_data_augmentation", layer="DataAugmentationLayer"), include=dict(phase=caffe_pb2.Phase.Value('TRAIN')))