1. 程式人生 > >使用caffe的python layer自定義資料增強層DataAugmentationLayer

使用caffe的python layer自定義資料增強層DataAugmentationLayer

專案地址:https://github.com/zhongqianli/caffe_python_layer
caffe自定義網路層的一種方式是使用python layer,這種方式需要使用pycaffe執行,命令列的方式執行會報錯。

編寫DataAugmentationLayer

這個類的基類是caffe.Layer,需要編寫setup,reshape,forward,backward四個方法,每個方法都有top和bottom引數,可以通過top[0].data和bottom[0].data獲取一個4維的資料,分別是batch_size、通道數、高、寬。

import caffe
import json
import cv2
import numpy as np
import random

# 4 pixel pad, random crop
# img: 64x3x32x32
def zeropadding_and_crop(data):
    # # cifar10
    # # padding_img = np.pad(img, ((4, 4), (4, 4), (4, 4)), "constant", padder=0)
    padding_img = np.zeros((np.shape(data)[0], 3, 40, 40), dtype=np.uint8)
    padding_img[..., 4:36, 4:36] = data[...]
    # #
    # cv2.imshow("pad", data[0][0])
    row_rand_num = random.randrange(9)
    col_rand_num = random.randrange(9)
    croped_img = padding_img[..., row_rand_num : row_rand_num + 32, col_rand_num : col_rand_num + 32]

    return croped_img

class DataAugmentationLayer(caffe.Layer):
    def setup(self, bottom, top):
        pass
    def reshape(self, bottom, top):
        top[0].reshape(*bottom[0].data.shape)
        pass

    def forward(self, bottom, top):
        top[0].data[...] = zeropadding_and_crop(bottom[0].data)
        pass

    def backward(self, top, propagate_down, bottom):
        pass

使用自定義的網路層DataAugmentationLayer

pycaffe最好使用net.xxx的方式建立網路,因為第二種方式會自動命名,可能會出現一下意想不到的問題。

# 第一種方式,推薦使用
net.data = L.Python(net.data_temp,
                            python_param=dict(module="custom_data_augmentation",
                                              layer="DataAugmentationLayer"),
                            include=dict(phase=caffe_pb2.Phase.Value('TRAIN')))

# 第二種方式,不推薦這種方式
data = L.Python(data_temp,
                            python_param=dict(module="custom_data_augmentation",
                                              layer="DataAugmentationLayer"),
                            include=dict(phase=caffe_pb2.Phase.Value('TRAIN')))