pytorch-05
阿新 • • 發佈:2018-12-20
1.儲存和載入模型
1.儲存引數
torch.save(model_object.state_dict(), ‘params.pkl’) model_object.load_state_dict(torch.load(‘params.pkl’)) model_object:為模型的例項化 載入時,要定義該類和例項
2.儲存模型
torch.save(model_object, ‘model.pkl’) model = torch.load(‘model.pkl’) model_object:為模型的例項化 載入時,model即為可用模型
2.torch.max(test_output, 1)
torch.max(test_output, 1)
輸出格式[tensor([最大值]),tensor([最大值的位置])] 引數1/0,輸出行/列最大
3.直接呼叫資料集送入模型
mages=Variable(test_dataset.test_data[:100].reshape(-1, 28*28).float()) .test_data[:100]:test_dataset中標籤為test_data的資料 reshape:改變資料型別 .float():改變資料byte為float型
4.torch.max(test_output, 1)[1].data.numpy().squeeze()
torch.max(test_output, 1)[1].data.numpy().squeeze() data.numpy():資料從variable轉為tensor(data) 再轉為numpy(numpy)