測試一個訓練好的caffe模型
阿新 • • 發佈:2019-01-22
在學習caffe的過程中,訓練了出了模型出來,出了當時的準確率和loss值,並沒有看到給定輸入看到真正的輸出,這個時候需要測試一下訓練出來的模型,實際檢視一下效果,其中用到的配置檔案和網路模型在caffe的目錄下都有,自己測試自己的模型時需要修改為自己的*.prototxt和*.caffemodel
#!/usr/bin/env python #coding=utf-8 #因為需要sudo許可權,只能作如下處理,自己新增caffe的位置,編譯之後,然後用sudo許可權執行該.py程式 import sys caffe_root='/home/dyh/caffe/caffe/' sys.path.insert(0, caffe_root+'python') import caffe import numpy as np ''' 這個測試模型可以幫助自己測試自己訓練出來的模型效果如何 ''' caffe.set_mode_gpu() # caffe.set_device(0) #deploy檔案就是用來測試訓練好的網路的,給其輸入,自己寫測試來輸出類別 model_def = '/home/dyh/caffe/caffe/models/bvlc_reference_caffenet/deploy.prototxt' # model_def = '/home/dyh/caffe-workspace/face_detect/deploy_full_conv.prototxt' model_weights = '/home/dyh/caffe-workspace/caffe_case/caffe_case模板/bvlc_reference_caffenet.caffemodel' # model_weights = '/home/dyh/caffe-workspace/face_detect/model/solver_iter_25000.caffemodel' net = caffe.Net(model_def, #測試的模型,caffe已經給出了,照著用 model_weights,#訓練好的引數 caffe.TEST) #使用的模式 #載入均值檔案 mu = np.load('/home/dyh/caffe/caffe/python/caffe/imagenet/ilsvrc_2012_mean.npy') mu = mu.mean(1).mean(1) print 'mean-substracted values',zip('BGR',mu) transformer = caffe.io.Transformer({'data':net.blobs['data'].data.shape}) #[h,w,c]->[c,h,w] transformer.set_transpose('data', [2,0,1]) # transformer.set_mean('data', mu) #減均值 transformer.set_raw_scale('data', 255)#變換到[0-1] transformer.set_channel_swap('data', [2,1,0])#RGB->BGR #按照caffe的輸入格式reshape輸入 net.blobs['data'].reshape(1, #batch,想一張的測試 3, #channel 227, #height 227) #weight img = caffe.io.load_image('/home/dyh/caffe/caffe/examples/images/cat.jpg') # img = caffe.io.load_image('/home/dyh/caffe-workspace/face_detect/train/1/4_nonface_0image54477.jpg') #將輸入進行預處理達到ceffe的輸入格式要求 transformer_img = transformer.preprocess('data',img) #讓deploy裡面的資料層接收到輸入的圖片 net.blobs['data'].data[...] = transformer_img #前向傳播一次就行 output = net.forward() #在網路的最後一個層是輸出的每個類別的概率 output_pro = output['prob'][0] #概率最大的就是 print 'predict class is:',output_pro.argmax() lables_path = '/home/dyh/caffe/caffe/data/ilsvrc12/synset_words.txt' lables = np.loadtxt(lables_path, str,delimiter= '\t')#一行一行的讀取並轉換為ndarray print lables[output_pro.argmax()]#第xx行是類別
輸入影象是一隻貓,最終的結果如下
I0422 07:28:21.366673 14366 net.cpp:744] Ignoring source layer loss
mean-substracted values [('B', 104.0069879317889), ('G', 116.66876761696767), ('R', 122.6789143406786)]
predict class is: 282
n02123159 tiger cat