1. 程式人生 > >pytorch批訓練數據構造

pytorch批訓練數據構造

work port span dataset 線程 載器 訓練數據 () 輸出

這是對莫凡python的學習筆記。

1.創建數據

import torch
import torch.utils.data as Data

BATCH_SIZE = 8
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)

可以看到創建了兩個一維數據,x:1~10,y:10~1

2.構造數據集對象,及數據加載器對象

torch_dataset = Data.TensorDataset(x,y)
loader = Data.DataLoader(
            dataset = torch_dataset,
            batch_size 
= BATCH_SIZE, shuffle = False, num_workers = 2)

num_workers應該指的是多線程

3.輸出數據集,這一步主要是看一下batch長什麽樣子

for epoch in range(3):
    for step, (batch_x, batch_y) in  enumerate(loader):
        print(Epoch:,epoch,| Step:, step, | batch x:,
                 batch_x.numpy(), | batch y:
, batch_y.numpy())

輸出如下

(‘Epoch:‘, 0, ‘| Step:‘, 0, ‘| batch x:‘, array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), ‘| batch y:‘, array([10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.], dtype=float32))
(‘Epoch:‘, 0, ‘| Step:‘, 1, ‘| batch x:‘, array([ 9., 10.], dtype=float32), ‘| batch y:‘, array([2., 1.], dtype=float32))
(‘Epoch:‘, 1, ‘| Step:‘, 0, ‘| batch x:‘, array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), ‘| batch y:‘, array([10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.], dtype=float32))
(‘Epoch:‘, 1, ‘| Step:‘, 1, ‘| batch x:‘, array([ 9., 10.], dtype=float32), ‘| batch y:‘, array([2., 1.], dtype=float32))
(‘Epoch:‘, 2, ‘| Step:‘, 0, ‘| batch x:‘, array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), ‘| batch y:‘, array([10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.], dtype=float32))
(‘Epoch:‘, 2, ‘| Step:‘, 1, ‘| batch x:‘, array([ 9., 10.], dtype=float32), ‘| batch y:‘, array([2., 1.], dtype=float32))

可以看到,batch_size等於8,則第二個bacth的數據只有兩個。

將batch_size改為5,輸出如下

(‘Epoch:‘, 0, ‘| Step:‘, 0, ‘| batch x:‘, array([1., 2., 3., 4., 5.], dtype=float32), ‘| batch y:‘, array([10.,  9.,  8.,  7.,  6.], dtype=float32))
(‘Epoch:‘, 0, ‘| Step:‘, 1, ‘| batch x:‘, array([ 6.,  7.,  8.,  9., 10.], dtype=float32), ‘| batch y:‘, array([5., 4., 3., 2., 1.], dtype=float32))
(‘Epoch:‘, 1, ‘| Step:‘, 0, ‘| batch x:‘, array([1., 2., 3., 4., 5.], dtype=float32), ‘| batch y:‘, array([10.,  9.,  8.,  7.,  6.], dtype=float32))
(‘Epoch:‘, 1, ‘| Step:‘, 1, ‘| batch x:‘, array([ 6.,  7.,  8.,  9., 10.], dtype=float32), ‘| batch y:‘, array([5., 4., 3., 2., 1.], dtype=float32))
(‘Epoch:‘, 2, ‘| Step:‘, 0, ‘| batch x:‘, array([1., 2., 3., 4., 5.], dtype=float32), ‘| batch y:‘, array([10.,  9.,  8.,  7.,  6.], dtype=float32))
(‘Epoch:‘, 2, ‘| Step:‘, 1, ‘| batch x:‘, array([ 6.,  7.,  8.,  9., 10.], dtype=float32), ‘| batch y:‘, array([5., 4., 3., 2., 1.], dtype=float32))

pytorch批訓練數據構造