1. 程式人生 > >pytorch筆記-batch

pytorch筆記-batch

在學習莫煩大神的pytorch視訊的batch部分,由於pytorch版本更新,產生了一些不相容的情況。原始碼如下:

import torch
import torch.utils.data as Data
torch.manual_seed(1) # 設定隨機數種子


BATCH_SIZE=5
x=torch.linspace(1,10,10)
y=torch.linspace(10,1,10)

torch_dataset=Data.TensorDataset(data_tensor=x,target_tensor=y)
loader=Data.DataLoader(#變成小批資料
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,#每一組batch裡面原資料個數
    shuffle=True,  #是否將原資料打亂分組
    num_workers=2
)

for epoch in range(3):
    for step,(batch_x,batch_y) in enumerate(loader):
        print('Epoch:',epoch)

直接執行會報錯,是由於Data.TensorDataset()函式版本更新後接受引數為*tensor,不再設預設值,故只需將對應行改為:

torch_dataset=Data.TensorDataset(x,y)

但是會繼續報錯: The “freeze_support()” line can be omitted if the program is not going to be frozen to produce an executable. 只需把訓練過程放在if name == ‘main’:以下即可。更正後程式碼:

import torch
import torch.utils.data as Data
torch.manual_seed(1) # 設定隨機數種子


BATCH_SIZE=5
x=torch.linspace(1,10,10)
y=torch.linspace(10,1,10)

torch_dataset=Data.TensorDataset(x,y)
loader=Data.DataLoader(#變成小批資料
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,#每一組batch裡面原資料個數
    shuffle=True,  #是否將原資料打亂分組
    num_workers=2
)
if __name__ == '__main__':
    for epoch in range(3):   # 訓練所有!整套!資料 3 次
        for step, (batch_x, batch_y) in enumerate(loader):  # 每一步 loader 釋放一小批資料用來學習
            #  假設這裡就是你訓練的地方...
            #  打出來一些資料
            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                  batch_x.numpy(), '| batch y: ', batch_y.numpy())