pytorch筆記-batch
阿新 • • 發佈:2018-12-11
在學習莫煩大神的pytorch視訊的batch部分,由於pytorch版本更新,產生了一些不相容的情況。原始碼如下:
import torch import torch.utils.data as Data torch.manual_seed(1) # 設定隨機數種子 BATCH_SIZE=5 x=torch.linspace(1,10,10) y=torch.linspace(10,1,10) torch_dataset=Data.TensorDataset(data_tensor=x,target_tensor=y) loader=Data.DataLoader(#變成小批資料 dataset=torch_dataset, batch_size=BATCH_SIZE,#每一組batch裡面原資料個數 shuffle=True, #是否將原資料打亂分組 num_workers=2 ) for epoch in range(3): for step,(batch_x,batch_y) in enumerate(loader): print('Epoch:',epoch)
直接執行會報錯,是由於Data.TensorDataset()函式版本更新後接受引數為*tensor,不再設預設值,故只需將對應行改為:
torch_dataset=Data.TensorDataset(x,y)
但是會繼續報錯: The “freeze_support()” line can be omitted if the program is not going to be frozen to produce an executable. 只需把訓練過程放在if name == ‘main’:以下即可。更正後程式碼:
import torch import torch.utils.data as Data torch.manual_seed(1) # 設定隨機數種子 BATCH_SIZE=5 x=torch.linspace(1,10,10) y=torch.linspace(10,1,10) torch_dataset=Data.TensorDataset(x,y) loader=Data.DataLoader(#變成小批資料 dataset=torch_dataset, batch_size=BATCH_SIZE,#每一組batch裡面原資料個數 shuffle=True, #是否將原資料打亂分組 num_workers=2 ) if __name__ == '__main__': for epoch in range(3): # 訓練所有!整套!資料 3 次 for step, (batch_x, batch_y) in enumerate(loader): # 每一步 loader 釋放一小批資料用來學習 # 假設這裡就是你訓練的地方... # 打出來一些資料 print('Epoch: ', epoch, '| Step: ', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())