Python 應用caffe模型進行分類(caffe介面)
阿新 • • 發佈:2018-12-05
遍歷一個檔案下的所有圖片,進行單張預測,並複製到相應的資料夾
import caffe #import lmdb import numpy as np import cv2 from caffe.proto import caffe_pb2 import os import sys caffe.set_mode_gpu() def dirlist(path, allfile): filelist = os.listdir(path) for filename in filelist: filepath = os.path.join(path, filename) if os.path.isdir(filepath): dirlist(filepath, allfile) else: allfile.append(filepath) return allfile # sys.setrecursionlimit(1000000) def is_bgr_img(img): bools = True try: a, b, c = img.shape except AttributeError: bools = False return bools # load caffe root = 'D:/stomach_raw_data/deepid/' # 根目錄 deploy = root + 'deploy_all.prototxt' # deploy檔案 caffe_model = root + 'id_128_net_iter_1695000.caffemodel' # 訓練好的 caffemodel labels_filename = root + 'labels.txt' # 類別名稱檔案,將數字標籤轉換回類別名稱 # 載入model和network net = caffe.Net(deploy, caffe_model, caffe.TEST) # 設定圖片的shape格式(1,3,28,28)依次為數量,通道,高,寬 transformer = caffe.io.Transformer({'data': net.blobs['data_1'].data.shape}) # 改變顏色通道,由RGB轉成BGR transformer.set_transpose('data', (2, 0, 1)) #減去均值,前面訓練模型時沒有減均值,這兒就不用 # transformer.set_mean('data', np.load(mean_file).mean(1).mean(1)) # transformer.set_raw_scale('data', 255) # 縮放到【0,255】之間 # transformer.set_channel_swap('data', (2,1,0)) #交換通道,將圖片由RGB變為BGR labels = np.loadtxt(labels_filename, str, delimiter='\t') dirs = ['0_CA', '1_FV', '2_GB', '3_GA', '4_SV', '5_PY', '6_OT','7_IV'] imgnames = dirlist('D:\\2D', []) path ='D:/sto_img_1695000/' temp = imgnames[0] print(temp.split('\\')[-2].split('_')[0]) print(temp) t = 0 all = 0 acc = 0 a_pro = 1 for imgname in imgnames: image = cv2.imread(imgname) temp = imgname try: image.shape except AttributeError: print(imgname) os.remove(imgname) continue # imgx = image/255 net.blobs['data_1'].data[...] = transformer.preprocess('data', image) t1 = cv2.getTickCount() for i in range(1): out = net.forward() t += (cv2.getTickCount() - t1) * 1000 / cv2.getTickFrequency() prob = net.blobs['softmax'].data[0].flatten() #print(prob) order = prob.argsort()[-1] prob_max = prob[order] print('max = %f,class = %d,all = %d\n'%(prob_max,order,all)) if prob_max > 0.70: imgname = temp.split('\\')[-1] imgpath = path + dirs[order] if not os.path.exists(imgpath): os.mkdir(imgpath) cv2.imwrite(imgpath+'/'+imgname, image) else: imgname = temp.split('\\')[-1] imgpath = path + 'unkown' if not os.path.exists(imgpath): os.mkdir(imgpath) cv2.imwrite(imgpath+'/'+imgname, image) cv2.imshow('cv2', image) k = cv2.waitKey(1) if k == 27: break if k == 32: cv2.waitKey() cv2.destroyAllWindows()