1. 程式人生 > 其它 >CLASStorch.utils.data.TensorDataset(*tensors)

CLASStorch.utils.data.TensorDataset(*tensors)

CLASStorch.utils.data.TensorDataset(*tensors)

Dataset wrapping tensors.

Each sample will be retrieved by indexing tensors along the first dimension.

Parameters

*tensors(Tensor) – tensors that have the same size of the first dimension.

train_features=torch.Tensor([[1.1,2.1,3.1],[4.1,5.1,6.1],[7.1,8.1,9.1],[10.1,11.1,12.1]])

train_labels=torch.Tensor([[1.1],[2.1],[3.1],[4.1]]) dataset=torch.utils.data.TensorDataset(train_features, train_labels) print(dataset) foriindataset: print(i) 輸出結果:

<torch.utils.data.dataset.TensorDataset object at 0x0000023D5A814B38>
(tensor([1.1000, 2.1000, 3.1000]), tensor([1.1000]))
(tensor([4.1000, 5.1000, 6.1000]), tensor([2.1000]))
(tensor([7.1000, 8.1000, 9.1000]), tensor([3.1000]))
(tensor([10.1000, 11.1000, 12.1000]), tensor([4.1000]))

該函式將行數或列數相同的倆個維度陣列進行拼接,在這個程式碼中倆個數據按行包裝。

batch_size=2 train_iter=torch.utils.data.DataLoader(dataset,batch_size,shuffle=True) print(train_iter) forX,yintrain_iter: print(X) print(y) <torch.utils.data.dataloader.DataLoader object at 0x0000024E1888B898>

tensor([[1.1000, 2.1000, 3.1000],
[4.1000, 5.1000, 6.1000]])
tensor([[1.1000],
[2.1000]])
tensor([[ 7.1000, 8.1000, 9.1000],
[10.1000, 11.1000, 12.1000]])
tensor([[3.1000],
[4.1000]])

CLASStorch.utils.data.DataLoader,如圖可知DataLoader是將包裝好的n*(特徵向量+標籤)分成n/x批,每批(x*特徵向量,x*標籤)。該函式返回的是一個迭代物件。batch_size是批大小。