Triplet資料集匯入遇到的BUG
最開始借鑑的版本使用的是github上找的一個比較靠譜的版本,並且以前在跑過關於mnist的triplet程式碼編寫,但是在呼叫cifar10資料集是發現並不能用,報出以下錯誤
錯誤出現大概就是dataloader的迭代器問題,對於太底層的東西不是很清楚,大概覺得是格式問題,當時覺得github上給的程式碼應該是沒問題的,因為畢竟已經用mnist的資料集測試過了,主要錯誤在__getitem__中,我當時寫的_getitem__是這樣的,
def __getitem__(self, index): path1, path2, path3 = self.triplets[index] img1 = self.loader(os.path.join(self.base_path,self.filenamelist[int(path1)])) img2 = self.loader(os.path.join(self.base_path,self.filenamelist[int(path2)])) img3 = self.loader(os.path.join(self.base_path,self.filenamelist[int(path3)])) if self.transform is not None: img1 = self.transform(img1) img2 = self.transform(img2) img3 = self.transform(img3) return img1, img2, img3
後來總是在迭代器上報錯:
Traceback (most recent call last):
File "E:\Professional_Software\Pycharm\PyCharm 2018.3\helpers\pydev\pydevd.py", line 1689, in <module>
main()
File "E:\Professional_Software\Pycharm\PyCharm 2018.3\helpers\pydev\pydevd.py", line 1683, in main
globals = debugger.run(setup['file'], None, None, is_module)
File "E:\Professional_Software\Pycharm\PyCharm 2018.3\helpers\pydev\pydevd.py", line 1083, in run
pydev_imports.execfile(file, globals, locals) # execute the script
File "E:\Professional_Software\Pycharm\PyCharm 2018.3\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "E:/Retrieval/First-reproduction-mutlilabel/utils/DataProcessing.py", line 157, in <module>
for batch_idx, triplet_train_data in enumerate(dset_triplet_train_loader, 0):#train_input, train_label, batch_ind
File "E:\Professional_Software\Anaconda\envs\Python36Pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 322, in __next__
return self._process_next_batch(batch)
File "E:\Professional_Software\Anaconda\envs\Python36Pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 357, in _process_next_batch
raise batch.exc_type(batch.exc_msg)
IndexError: Traceback (most recent call last):
File "E:\Professional_Software\Anaconda\envs\Python36Pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 106, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "E:\Professional_Software\Anaconda\envs\Python36Pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 106, in <listcomp>
samples = collate_fn([dataset[i] for i in batch_indices])
File "E:\Retrieval\First-reproduction-mutlilabel\utils\DataProcessing.py", line 93, in __getitem__
img3 = Image.open(os.path.join(self.data_path, self.img_filename_lists[path3 + 1]))
IndexError: list index out of range
根據raise batch.exc_type(batch.exc_msg) google搜尋出這麼一篇部落格
class RamMakTrilet(Dataset): def __init__(self, root, img_name_file, num_cls = 10, num_triplets = 1000, limit = 500, transforms = None, ): self.transforms = transforms self.triplets = [Sel.select_triplet(root,img_name_file, num_cls, limit) for i in range(num_triplets)] def __getitem__(self, index): ''' 每次返回一個triplet ''' triplet = self.triplets[index] # print(triplet) data = [self.transforms(Image.open(img_path)) for img_path in triplet[:3]] label = triplet[3:] # print(label) return data, label
def select_triplet(dir, img_name_file, num_classes, limit=500, is_car=False): ''' 從triplet_dir中隨機選擇一個三元組 ''' # np.random.seed(100) # 設定固定的隨機數種子,便於驗證 # 獲取anchor, positive, negative 圖片ID anchor_cls = np.random.choice(num_classes)#0~9 anchor_id = anchor_cls*limit + np.random.choice(limit) positive_id = anchor_cls*limit + \ get_negative_id(limit, anchor_id - anchor_cls*limit) negative_cls = get_negative_id(num_classes, anchor_cls) # 隨機選擇一個反例型別 negative_id = negative_cls*limit + np.random.choice(limit) # 隨機選擇一個反例ID img_name = os.path.join(dir, img_name_file) # 獲取anchor, positive, negative 圖片地址 fp = open(img_name, 'r') # 這裡是open檔案 img_filename = [x.strip() for x in fp] # 返回照片檔名 這個是list格式 fp.close() anchor_path = os.path.join(dir, img_filename[anchor_id]) positive_path = os.path.join(dir, img_filename[positive_id]) negative_path = os.path.join(dir, img_filename[negative_id]) # 返回6元組 return anchor_path, positive_path, negative_path, anchor_cls, anchor_cls, negative_cls 最終在主函式中呼叫如下:
for batch_idx, triplet_train_data in enumerate(train_loader, 0):#train_input, train_label, batch_ind data, _ = triplet_train_data data1, data2, data3 = data
完美解決問題