Caffe:如何fine tune一個現有的網路(VGG16)——將資料預處理並儲存為h5格式
阿新 • • 發佈:2019-01-04
在訓練神經網路的過程中,常常需要fine tune一個現有的網路,首先是需要對輸入資料進行預處理,包括有:
- 對尺寸大小進行處理
- 將正負例和測試的data&label儲存為h5檔案
- 將h5檔案中data&label對應的書序打亂
實現程式碼如下:
1. 導包以及VGG網路初始化
import numpy as np
import matplotlib.pyplot as plt
import skimage
import skimage.io
import skimage.transform
import os
import h5py
%matplotlib inline
plt.rcParams['figure.figsize' ]=(10,10)
plt.rcParams['image.interpolation']='nearest'
plt.rcParams['image.cmap']='gray'
VGG_MEAN = [103.939, 116.779, 123.68]
2.處理圖片RGB三通道
def preprocess(img):
out = np.copy(img) * 255
out = out[:, :, [2,1,0]] # swap channel from RGB to BGR
# sub mean
out[:,:,0] -= VGG_MEAN[0]
out[:,:,1] -= VGG_MEAN[1 ]
out[:,:,2] -= VGG_MEAN[2]
out = out.transpose((2,0,1)) # h, w, c -> c, h, w
return out
3.畫素歸一化
def deprocess(img):
out = np.copy(img)
out = out.transpose((1,2,0)) # c, h, w -> h, w, c
out[:,:,0] += VGG_MEAN[0]
out[:,:,1] += VGG_MEAN[1]
out[:,:,2] += VGG_MEAN[2]
out = out[:, :, [2 ,1,0]]
out /= 255
return out
4.尺寸處理
# returns image of shape [224, 224, 3]
# [height, width, depth]
def load_image(path):
# load image
img = skimage.io.imread(path)
img = img / 255.0
assert (0 <= img).all() and (img <= 1.0).all()
#print "Original Image Shape: ", img.shape
# we crop image from center
short_edge = min(img.shape[:2])
yy = int((img.shape[0] - short_edge) / 2)
xx = int((img.shape[1] - short_edge) / 2)
crop_img = img[yy : yy + short_edge, xx : xx + short_edge]
# resize to 224, 224
resized_img = skimage.transform.resize(crop_img, (224, 224))
return resized_img
5.迴圈遍歷檔案儲存資料以及label
關鍵程式碼:
儲存count值方便後續使用以及檢查:
imgData_count = 0
imgTest_count = 0
FilePrefixlist = [] #存取檔案字首名的列表
#分別將list中出現的image名字和label儲存在不同的矩陣
PositiveList = np.loadtxt(r'plane.txt',dtype=np.int)
#獲取檔案的字首名,字首名為string型別
with open(r'plane.txt', 'r') as f:
while True:
line = f.readline() #逐行讀取
if not line:
break
linesplit = line.split(' ')
FilePrefixlist.append(linesplit[0]) #只取得第一列的資料即檔案的字首名
labelPositiveList = PositiveList[:,1]
#統計正例中儲存為訓練集的個數
labelPositiveCount=np.sum(labelPositiveList==1)
labelNegativeCount=np.sum(labelPositiveList==-1)
#初始化訓練集和測試集的data和label
imgData = np.zeros([labelPositiveCount+190,3,224,224],dtype= np.float32)
label = []
imgTest = np.zeros([labelNegativeCount+95,3,224,224],dtype= np.float32)
labelTest =[]
接下里開始正式讀資料和label,以其中某一個檔案資料為例:
#通過讀正類指令碼檔案將正類中train和test的儲存到對應data中
for index in range(len(FilePrefixlist)):
line=FilePrefixlist[index]
#如果label=1,那麼是訓練集
if labelPositiveList[index]==1 :
imgData[imgData_count,:,:,:]=preprocess(load_image(path+'/'+line+'.jpg'))
label.append(1)
imgData_count = imgData_count+1
#否則label就是-1,代表這是一個測試集的資料,放在測試集中
else:
imgTest[imgTest_count,:,:,:]=preprocess(load_image(path+'/'+line+'.jpg'))
labelTest.append(1)
imgTest_count = imgTest_count+1
上述過程將所有data存在numpy數組裡面,label存在list中用append()方式追加,於是需要將list轉變為numpy陣列:
#將label列表變為numpy
label = np.array(label)
labelTest = np.array(labelTest)
使用shuffle打亂順序:
#打亂h5檔案訓練集正負例順序
index = [i for i in range(len(imgData))]
np.random.shuffle(index)
imgData = imgData[index]
label = label[index]
建立h5檔案,放入data和label:
f = h5py.File('aeroplane_train.h5','w')#相對路徑,絕對路徑會報錯
f['data']=imgData
f['label']=label
f.close()
#HDF5的讀取:
f = h5py.File('aeroplane_train.h5','r') #開啟h5檔案
f.keys() #可以檢視所有的主鍵
a = f['data'][:] #取出主鍵為data的所有的鍵值
f.close()
資料預處理以及儲存過程關鍵程式碼如上所示。
在編碼中遇到一些小坑:
1、win與linux在寫路徑是正反斜槓”/”“\”的問題,win下複製的路徑與自己新增的完整路徑的斜槓方向不同。。。
2、在loadtxt的時候,由於\t或者\n會識別為轉義字元,於是需要在路徑前加上r,否則會報錯,例如:
PositiveList = np.loadtxt(r'C:\Users\Administrator\plane.txt',dtype=np.int)