1. 程式人生 > >pytorch-05

pytorch-05

1.儲存和載入模型

1.儲存引數

torch.save(model_object.state_dict(), ‘params.pkl’) model_object.load_state_dict(torch.load(‘params.pkl’)) model_object:為模型的例項化 載入時,要定義該類和例項

2.儲存模型

torch.save(model_object, ‘model.pkl’) model = torch.load(‘model.pkl’) model_object:為模型的例項化 載入時,model即為可用模型

2.torch.max(test_output, 1)

torch.max(test_output, 1)

輸出格式[tensor([最大值]),tensor([最大值的位置])] 引數1/0,輸出行/列最大

3.直接呼叫資料集送入模型

mages=Variable(test_dataset.test_data[:100].reshape(-1, 28*28).float()) .test_data[:100]:test_dataset中標籤為test_data的資料 reshape:改變資料型別 .float():改變資料byte為float型

4.torch.max(test_output, 1)[1].data.numpy().squeeze()

torch.max(test_output, 1)[1].data.numpy().squeeze() data.numpy():資料從variable轉為tensor(data) 再轉為numpy(numpy)