1. 程式人生 > >pytorch中資料集的劃分方法及eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray錯誤原因

pytorch中資料集的劃分方法及eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray錯誤原因

在使用pytorch框架時,難免需要對資料集進行訓練集和驗證集的劃分,一般使用sklearn.model_selection中的train_test_split方法

該方法使用如下:

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader

traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]

x_tra, x_val, y_tra, y_val = train_test_split(train_data, train_label_data, test_size=0.1, random_state=0)  # 訓練集和驗證集使用9:1

x_tra = Variable(torch.from_numpy(x_tra))
x_tra = x_tra.float()
y_tra = Variable(torch.from_numpy(y_tra))
y_tra = y_tra.float()

x_val = Variable(torch.from_numpy(x_val))
x_val = x_val.float()
y_val = Variable(torch.from_numpy(y_val))
y_val = y_val.float()

# 訓練集的DataLoader
traindataset = torch.utils.data.TensorDataset(x_tra, y_tra)
trainloader = DataLoader(dataset=traindataset, num_workers=opt.threads, batch_size=8, shuffle=True)  

# 驗證集的DataLoader
validataset = torch.utils.data.TensorDataset(x_val, y_val)
valiloader = DataLoader(dataset=validataset, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)

注意:如果按照如下方式使用,就會報eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray錯誤

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader

traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)

train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]

x_train = Variable(torch.from_numpy(train_data))
x_train = x_train.float()
y_train = Variable(torch.from_numpy(train_label_data))
y_train = y_train.float()
# 將原始的訓練資料集分為訓練集和驗證集,後面就可以使用早停機制
x_tra, x_val, y_tra, y_val = train_test_split(x_train, y_train, test_size=0.1)  # 訓練集和驗證集使用9:1

報錯原因:train_test_split方法接受的x_train,y_train格式應該為numpy.ndarray 而不應該是Tensor,這點需要注意。