pytorch批訓練數據構造
阿新 • • 發佈:2018-08-11
work port span dataset 線程 載器 訓練數據 () 輸出
這是對莫凡python的學習筆記。
1.創建數據
import torch import torch.utils.data as Data BATCH_SIZE = 8 x = torch.linspace(1,10,10) y = torch.linspace(10,1,10)
可以看到創建了兩個一維數據,x:1~10,y:10~1
2.構造數據集對象,及數據加載器對象
torch_dataset = Data.TensorDataset(x,y) loader = Data.DataLoader( dataset = torch_dataset, batch_size= BATCH_SIZE, shuffle = False, num_workers = 2)
num_workers應該指的是多線程
3.輸出數據集,這一步主要是看一下batch長什麽樣子
for epoch in range(3): for step, (batch_x, batch_y) in enumerate(loader): print(‘Epoch:‘,epoch,‘| Step:‘, step, ‘| batch x:‘, batch_x.numpy(), ‘| batch y:‘, batch_y.numpy())
輸出如下
(‘Epoch:‘, 0, ‘| Step:‘, 0, ‘| batch x:‘, array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), ‘| batch y:‘, array([10., 9., 8., 7., 6., 5., 4., 3.], dtype=float32)) (‘Epoch:‘, 0, ‘| Step:‘, 1, ‘| batch x:‘, array([ 9., 10.], dtype=float32), ‘| batch y:‘, array([2., 1.], dtype=float32)) (‘Epoch:‘, 1, ‘| Step:‘, 0, ‘| batch x:‘, array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), ‘| batch y:‘, array([10., 9., 8., 7., 6., 5., 4., 3.], dtype=float32)) (‘Epoch:‘, 1, ‘| Step:‘, 1, ‘| batch x:‘, array([ 9., 10.], dtype=float32), ‘| batch y:‘, array([2., 1.], dtype=float32)) (‘Epoch:‘, 2, ‘| Step:‘, 0, ‘| batch x:‘, array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), ‘| batch y:‘, array([10., 9., 8., 7., 6., 5., 4., 3.], dtype=float32)) (‘Epoch:‘, 2, ‘| Step:‘, 1, ‘| batch x:‘, array([ 9., 10.], dtype=float32), ‘| batch y:‘, array([2., 1.], dtype=float32))
可以看到,batch_size等於8,則第二個bacth的數據只有兩個。
將batch_size改為5,輸出如下
(‘Epoch:‘, 0, ‘| Step:‘, 0, ‘| batch x:‘, array([1., 2., 3., 4., 5.], dtype=float32), ‘| batch y:‘, array([10., 9., 8., 7., 6.], dtype=float32)) (‘Epoch:‘, 0, ‘| Step:‘, 1, ‘| batch x:‘, array([ 6., 7., 8., 9., 10.], dtype=float32), ‘| batch y:‘, array([5., 4., 3., 2., 1.], dtype=float32)) (‘Epoch:‘, 1, ‘| Step:‘, 0, ‘| batch x:‘, array([1., 2., 3., 4., 5.], dtype=float32), ‘| batch y:‘, array([10., 9., 8., 7., 6.], dtype=float32)) (‘Epoch:‘, 1, ‘| Step:‘, 1, ‘| batch x:‘, array([ 6., 7., 8., 9., 10.], dtype=float32), ‘| batch y:‘, array([5., 4., 3., 2., 1.], dtype=float32)) (‘Epoch:‘, 2, ‘| Step:‘, 0, ‘| batch x:‘, array([1., 2., 3., 4., 5.], dtype=float32), ‘| batch y:‘, array([10., 9., 8., 7., 6.], dtype=float32)) (‘Epoch:‘, 2, ‘| Step:‘, 1, ‘| batch x:‘, array([ 6., 7., 8., 9., 10.], dtype=float32), ‘| batch y:‘, array([5., 4., 3., 2., 1.], dtype=float32))
pytorch批訓練數據構造