pytorch讀取自己的csv資料集
阿新 • • 發佈:2022-05-24
自己需要完成一個基於時序訊號資料的自編碼器,但是在讀取資料時遇到了一些困難。資料集(無標籤)由csv檔案構成,每個資料樣本為一個csv檔案,csv檔案中的資料為4列512行。
CSDN同步:pytorch讀取自己的csv資料集
python程式碼
// 重寫dataset類 import os import torch import pandas as pd from torch.utils.data import Dataset from torch.utils.data import DataLoader class myDataset(Dataset): def __init__(self, data_dir): # data_dir:輸入路徑 # 讀資料夾下每個資料檔案的名稱 self.file_name = os.listdir(data_dir) self.data_path = [] # 把每一個檔案的路徑拼接起來 for index in range(len(self.file_name)): self.data_path.append(os.path.join(data_dir, self.file_name[index])) def __len__(self): return len(self.file_name) def __getitem__(self, index): # 讀取每一個數據 data = pd.read_csv(self.data_path[index], header=None) # 轉成張量 data = torch.tensor(data.values) return data // 例項化,讀取資料 in_dir = r"E:\202205\train" # 讀取資料集 train_dataset = myDataset(data_dir=in_dir) # 載入資料集 train_iter = DataLoader(train_dataset)
看看讀取的資料的大小和格式
for x in enumerate(train_iter):
print(x)
print(type(x))
當index = 0 時的結果為
(0, tensor([[[ 5.4378, 5.4378, 5.4378, 5.4378], [ 5.6137, 5.6137, 5.6137, 5.6137], [ 5.7855, 5.7855, 5.7855, 5.7855], ..., [219.6983, 219.6983, 219.6983, 219.6983], [214.3309, 214.3309, 214.3309, 214.3309], [209.1339, 209.1339, 209.1339, 209.1339]]], dtype=torch.float64)) <class 'tuple'>
對於我來說,tuple不能直接用,需要取出裡面的tensor
# 取出tuple的第二列
mydata = x[1]
print(mydata)
print(mydata.shape)
print(type(mydata))
當index = 0時的結果為
tensor([[[ 5.4378, 5.4378, 5.4378, 5.4378], [ 5.6137, 5.6137, 5.6137, 5.6137], [ 5.7855, 5.7855, 5.7855, 5.7855], ..., [219.6983, 219.6983, 219.6983, 219.6983], [214.3309, 214.3309, 214.3309, 214.3309], [209.1339, 209.1339, 209.1339, 209.1339]]], dtype=torch.float64) torch.Size([1, 512, 4]) <class 'torch.Tensor'>
OK,能得到我想要的東西了。
參考文章:https://blog.csdn.net/qq_42653159/article/details/124511467