1. 程式人生 > >caffe 簡單實現alexnet 識別貓

caffe 簡單實現alexnet 識別貓

#coding=utf-8
import caffe
import numpy as np
caffe.set_mode_cpu()
model_def = './deploy.prototxt'
model_weights = './bvlc_alexnet.caffemodel'
labels_file = './labels.txt'
with open(labels_file,'r') as f:
    lines = f.readlines()
labels = []
for i in lines:
    labels.append(i.strip('\n').split(',')[1:])

net = caffe.Net(model_def, model_weights, caffe.TEST)
mu = np.load('./ilsvrc_2012_mean.npy')#3*256*256
mu = mu.mean(1)#3*256
mu = mu.mean(1)#3*1 BGR

transforms = caffe.io.Transformer({'data': net.blobs['data'].data.shape})#設定圖片的shape格式(1,3,227,227),大小由deploy 檔案指定
transforms.set_transpose('data', (2, 0, 1))#改變維度的順序,由原始圖片(227,227,3)變為(3,227,227)
transforms.set_mean('data', mu)#減去均值
transforms.set_raw_scale('data', 255)#放縮到255
transforms.set_channel_swap('data', (2, 1, 0))  #用於將輸入圖片通道進行重排,將RGB格式轉換為BGR
net.blobs['data'].reshape(1, 3, 227, 227)

image = caffe.io.load_image('../cat.jpg')
image_pre = transforms.preprocess('data', image)
net.blobs['data'].data[...] = image_pre
output = net.forward()
output_pro = output['prob'][0]
print(labels[output_pro.argmax()-1])

注意:caffe 裡面的影象格式都是BGR而不是RGB,影象維度是(batch,C,H,W) 不是平時的(H,W,C)

caffe.io.load_image讀取的圖片都是RGB格式,且畫素範圍都是0-1之間

所有用到的檔案https://download.csdn.net/download/daixiangzi/10703475

注意:transforms.preprocess中,最後才做的減均值操作,以免誤導,在RGB和BGR格式不對應的情況下進行了減均值操作