pytorch中資料集的劃分方法及eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray錯誤原因
阿新 • • 發佈:2019-01-10
在使用pytorch框架時,難免需要對資料集進行訓練集和驗證集的劃分,一般使用sklearn.model_selection中的train_test_split方法
該方法使用如下:
from sklearn.model_selection import train_test_split import numpy as np import torch import torch.autograd import Variable from torch.utils.data import DataLoader traindata = np.load(train_path) # image_num * W * H trainlabel = np.load(train_label_path) train_data = traindata[:, np.newaxis, ...] train_label_data = trainlabel[:, np.newaxis, ...] x_tra, x_val, y_tra, y_val = train_test_split(train_data, train_label_data, test_size=0.1, random_state=0) # 訓練集和驗證集使用9:1 x_tra = Variable(torch.from_numpy(x_tra)) x_tra = x_tra.float() y_tra = Variable(torch.from_numpy(y_tra)) y_tra = y_tra.float() x_val = Variable(torch.from_numpy(x_val)) x_val = x_val.float() y_val = Variable(torch.from_numpy(y_val)) y_val = y_val.float() # 訓練集的DataLoader traindataset = torch.utils.data.TensorDataset(x_tra, y_tra) trainloader = DataLoader(dataset=traindataset, num_workers=opt.threads, batch_size=8, shuffle=True) # 驗證集的DataLoader validataset = torch.utils.data.TensorDataset(x_val, y_val) valiloader = DataLoader(dataset=validataset, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
注意:如果按照如下方式使用,就會報eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray錯誤
from sklearn.model_selection import train_test_split import numpy as np import torch import torch.autograd import Variable from torch.utils.data import DataLoader traindata = np.load(train_path) # image_num * W * H trainlabel = np.load(train_label_path) train_data = traindata[:, np.newaxis, ...] train_label_data = trainlabel[:, np.newaxis, ...] x_train = Variable(torch.from_numpy(train_data)) x_train = x_train.float() y_train = Variable(torch.from_numpy(train_label_data)) y_train = y_train.float() # 將原始的訓練資料集分為訓練集和驗證集,後面就可以使用早停機制 x_tra, x_val, y_tra, y_val = train_test_split(x_train, y_train, test_size=0.1) # 訓練集和驗證集使用9:1
報錯原因:train_test_split方法接受的x_train,y_train格式應該為numpy.ndarray 而不應該是Tensor,這點需要注意。