1. 程式人生 > >測試一個訓練好的caffe模型

測試一個訓練好的caffe模型

在學習caffe的過程中,訓練了出了模型出來,出了當時的準確率和loss值,並沒有看到給定輸入看到真正的輸出,這個時候需要測試一下訓練出來的模型,實際檢視一下效果,其中用到的配置檔案和網路模型在caffe的目錄下都有,自己測試自己的模型時需要修改為自己的*.prototxt和*.caffemodel

#!/usr/bin/env python
#coding=utf-8
#因為需要sudo許可權,只能作如下處理,自己新增caffe的位置,編譯之後,然後用sudo許可權執行該.py程式
import sys
caffe_root='/home/dyh/caffe/caffe/'
sys.path.insert(0, caffe_root+'python')
import caffe
import numpy as np
'''
這個測試模型可以幫助自己測試自己訓練出來的模型效果如何
'''
caffe.set_mode_gpu()
# caffe.set_device(0)

#deploy檔案就是用來測試訓練好的網路的,給其輸入,自己寫測試來輸出類別
model_def = '/home/dyh/caffe/caffe/models/bvlc_reference_caffenet/deploy.prototxt'
# model_def = '/home/dyh/caffe-workspace/face_detect/deploy_full_conv.prototxt'
model_weights = '/home/dyh/caffe-workspace/caffe_case/caffe_case模板/bvlc_reference_caffenet.caffemodel'
# model_weights = '/home/dyh/caffe-workspace/face_detect/model/solver_iter_25000.caffemodel'
 
net = caffe.Net(model_def, #測試的模型,caffe已經給出了,照著用
                model_weights,#訓練好的引數
                caffe.TEST) #使用的模式
#載入均值檔案
mu = np.load('/home/dyh/caffe/caffe/python/caffe/imagenet/ilsvrc_2012_mean.npy')
mu = mu.mean(1).mean(1)
print 'mean-substracted values',zip('BGR',mu)

transformer = caffe.io.Transformer({'data':net.blobs['data'].data.shape})
#[h,w,c]->[c,h,w]
transformer.set_transpose('data', [2,0,1])
# transformer.set_mean('data', mu) #減均值
transformer.set_raw_scale('data', 255)#變換到[0-1]
transformer.set_channel_swap('data', [2,1,0])#RGB->BGR

#按照caffe的輸入格式reshape輸入
net.blobs['data'].reshape(1, #batch,想一張的測試
                          3, #channel
                          227, #height
                          227) #weight

img = caffe.io.load_image('/home/dyh/caffe/caffe/examples/images/cat.jpg')
# img = caffe.io.load_image('/home/dyh/caffe-workspace/face_detect/train/1/4_nonface_0image54477.jpg')
#將輸入進行預處理達到ceffe的輸入格式要求
transformer_img = transformer.preprocess('data',img)
#讓deploy裡面的資料層接收到輸入的圖片
net.blobs['data'].data[...] = transformer_img
#前向傳播一次就行
output = net.forward()
#在網路的最後一個層是輸出的每個類別的概率
output_pro = output['prob'][0]
#概率最大的就是
print 'predict class is:',output_pro.argmax()

lables_path = '/home/dyh/caffe/caffe/data/ilsvrc12/synset_words.txt'
lables = np.loadtxt(lables_path, str,delimiter= '\t')#一行一行的讀取並轉換為ndarray
print lables[output_pro.argmax()]#第xx行是類別

輸入影象是一隻貓,最終的結果如下

I0422 07:28:21.366673 14366 net.cpp:744] Ignoring source layer loss
mean-substracted values [('B', 104.0069879317889), ('G', 116.66876761696767), ('R', 122.6789143406786)]
predict class is: 282
n02123159 tiger cat