pytorch中的inference使用例項
阿新 • • 發佈:2020-02-21
這裡inference兩個程式的連線,如目標檢測,可以利用一個程式提取候選框,然後把候選框輸入到分類cnn網路中。
這裡常需要進行一定的連線。
#載入訓練好的分類CNN網路 model=torch.load('model.pkl') #假設proposal_img是我們提取的候選框,是需要輸入到CNN網路的資料 #先定義transforms對輸入cnn的網路資料進行處理,常包括resize、totensor等操作 data_transforms=transforms.Compose([transforms.RandomSizedCrop(224),transforms.ToTensor()]) #由於transforms是對PIL格式資料操作,所以必要時轉化格式 def tensor_to_PIL(tensor): image = tensor.cpu().clone() image = image.squeeze(0) image = unloader(image) return image #unqueeze(0)是加多一維,對應原來batchsiaze data=data_transforms(proposal_img).unqueeze(0) #新版本pytorch已經不用variable,可以省略這句 data=Variable(data) #貌似這句也是多餘的 torch.no_grad() predict=F.softmax(model(data.cuda()).cuda())
以上這篇pytorch中的inference使用例項就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。