解決python在import caffe時出現的no module name _caffe問題
阿新 • • 發佈:2019-02-01
在python檔案的最上面放上下面內容就可以了
import sys
sys.path.append("/home/zhangqi/Desktop/caffe-master/python")
sys.path.append("/home/zhangqi/Desktop/caffe-master/python/caffe")
#############################################################################################
記得修改caffe_forward.py檔案的路徑
import sys sys.path.append("/home/zhangqi/Desktop/caffe-master/python") sys.path.append("/home/zhangqi/Desktop/caffe-master/python/caffe")
import caffe import cv2 as cv import matplotlib.pyplot as plt import numpy as np model_defination = '/usr/xhh/model/general_prediction/cnn/forward_network.prototxt' weights = '/usr/xhh/model/general_prediction/cnn/train_iter_146000.caffemodel' data_path='/usr/xhh/model/general_prediction/cnn/current_position.txt' def load_net(list): caffe.set_mode_cpu() # net=caffe.Net(weights) net=caffe.Net(model_defination,weights,caffe.TEST) transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape}) # transformer.set_transpose('data', (2, 0, 1)) # transformer.set_mean('data', ()) # Load the mean file # transformer.set_raw_scale('data', 1) # transformer.set_channel_swap('data', (2, 1, 0)) # Convert RGB to GBR net.blobs['data'].reshape(1, 3, 19, 19) data=get_img_datum(data_path) # net.blobs['data'].data[...] = transformer.preprocess('data', data) net.blobs['data'].data[...] = data.reshape(1, 3, 19, 19) res = net.forward() res = np.asarray(res['loss']) res=res[0] if list[0]!='': for i in list: res[int(i)]=0 res=res.tolist() a=sorted(res,reverse=True) index=res.index(a[0]) #res = res.reshape((19, 19)) #plt.imshow(res, cmap= plt.cm.jet) print index for i in range(50): print res.index(a[i]) def get_img_datum(data_path): file=open(data_path) for line in file: str=line.split() img=np.zeros((3,19,19)) img1=np.zeros((19,19)) img2=np.zeros((19,19)) img3=np.zeros((19,19)) i=0 j=0 for s1 in str[0]: img1[i,j]=int(s1) i=i+1 if i>18: i=0 j=j+1 i=0 j=0 for s2 in str[1]: img2[i,j]=int(s2) i=i+1 if i>18: i=0 j=j+1 i=0 j=0 for s3 in str[2]: img3[i,j]=int(s3) i=i+1 if i>18: i=0 j=j+1 img[0,:,:]=img1 img[1,:,:]=img2 img[2,:,:]=img3 return img if __name__ == '__main__': index="" list=[] jList=sys.argv[1:] jList_str="" for i in jList: jList_str+=i jList_str=jList_str[1:len(jList_str)-1] list=jList_str.split(',') load_net(list)