1. 程式人生 > >Triplet資料集匯入遇到的BUG

Triplet資料集匯入遇到的BUG

最開始借鑑的版本使用的是github上找的一個比較靠譜的版本,並且以前在跑過關於mnist的triplet程式碼編寫,但是在呼叫cifar10資料集是發現並不能用,報出以下錯誤

 

錯誤出現大概就是dataloader的迭代器問題,對於太底層的東西不是很清楚,大概覺得是格式問題,當時覺得github上給的程式碼應該是沒問題的,因為畢竟已經用mnist的資料集測試過了,主要錯誤在__getitem__中,我當時寫的_getitem__是這樣的,

def __getitem__(self, index):
    path1, path2, path3 = self.triplets[index]
    img1 = self.loader(os.path.join(self.base_path,self.filenamelist[int(path1)]))
    img2 = self.loader(os.path.join(self.base_path,self.filenamelist[int(path2)]))
    img3 = self.loader(os.path.join(self.base_path,self.filenamelist[int(path3)]))
    if self.transform is not None:
        img1 = self.transform(img1)
        img2 = self.transform(img2)
        img3 = self.transform(img3)

    return img1, img2, img3

後來總是在迭代器上報錯:

Traceback (most recent call last):
  File "E:\Professional_Software\Pycharm\PyCharm 2018.3\helpers\pydev\pydevd.py", line 1689, in <module>
    main()
  File "E:\Professional_Software\Pycharm\PyCharm 2018.3\helpers\pydev\pydevd.py", line 1683, in main
    globals = debugger.run(setup['file'], None, None, is_module)
  File "E:\Professional_Software\Pycharm\PyCharm 2018.3\helpers\pydev\pydevd.py", line 1083, in run
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "E:\Professional_Software\Pycharm\PyCharm 2018.3\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "E:/Retrieval/First-reproduction-mutlilabel/utils/DataProcessing.py", line 157, in <module>
    for batch_idx, triplet_train_data in enumerate(dset_triplet_train_loader, 0):#train_input, train_label, batch_ind 
  File "E:\Professional_Software\Anaconda\envs\Python36Pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 322, in __next__
    return self._process_next_batch(batch)
  File "E:\Professional_Software\Anaconda\envs\Python36Pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 357, in _process_next_batch
    raise batch.exc_type(batch.exc_msg)
IndexError: Traceback (most recent call last):
  File "E:\Professional_Software\Anaconda\envs\Python36Pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 106, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "E:\Professional_Software\Anaconda\envs\Python36Pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 106, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "E:\Retrieval\First-reproduction-mutlilabel\utils\DataProcessing.py", line 93, in __getitem__
    img3 = Image.open(os.path.join(self.data_path, self.img_filename_lists[path3 + 1]))
IndexError: list index out of range
根據raise batch.exc_type(batch.exc_msg) google搜尋出這麼一篇部落格

https://discuss.pytorch.org/t/cannot-unsqueeze-empty-tensor/1300,說的和最終發現的問題有點聯絡,我在懷疑應該是在__getitem__上return 三張圖片 而一般不是triplet網路是返回的一張圖片,一個標籤,一個索引號,所以應該會在呼叫時 發生資料型別相關的錯誤,但是不是太懷疑的,因為我一直堅信使用前面mnist跑過 return img1, img2, img3,返回三張圖片不加一個標籤,一個索引號是沒問題的,,最終查了一整天的問題還是沒查出來,妥協了,參考了github上的一個triplet資料載入程式碼。https://github.com/CaptainEven/FaceRecognition
 重改的Triplet資料載入程式碼,變成這樣:

class RamMakTrilet(Dataset):
    def __init__(self,
                 root,
                 img_name_file,
                 num_cls = 10,
                 num_triplets = 1000,
                 limit = 500,
                 transforms = None,
                 ):
        self.transforms = transforms
        self.triplets = [Sel.select_triplet(root,img_name_file, num_cls, limit)
                         for i in range(num_triplets)]

    def __getitem__(self, index):
        '''
        每次返回一個triplet
        '''
        triplet = self.triplets[index]
        # print(triplet)
        data = [self.transforms(Image.open(img_path))
                for img_path in triplet[:3]]
        label = triplet[3:]
        # print(label)
        return data, label
def select_triplet(dir, img_name_file, num_classes, limit=500, is_car=False):
    '''
    從triplet_dir中隨機選擇一個三元組
    '''
    # np.random.seed(100)  # 設定固定的隨機數種子,便於驗證

    # 獲取anchor, positive, negative 圖片ID
    anchor_cls = np.random.choice(num_classes)#0~9
    anchor_id = anchor_cls*limit + np.random.choice(limit)
    positive_id = anchor_cls*limit + \
        get_negative_id(limit, anchor_id - anchor_cls*limit)
    negative_cls = get_negative_id(num_classes, anchor_cls)  # 隨機選擇一個反例型別
    negative_id = negative_cls*limit + np.random.choice(limit)  # 隨機選擇一個反例ID
    img_name = os.path.join(dir, img_name_file)
    # 獲取anchor, positive, negative 圖片地址
    fp = open(img_name, 'r')  # 這裡是open檔案
    img_filename = [x.strip() for x in fp]  # 返回照片檔名 這個是list格式
    fp.close()
    anchor_path = os.path.join(dir, img_filename[anchor_id])
    positive_path = os.path.join(dir, img_filename[positive_id])
    negative_path = os.path.join(dir, img_filename[negative_id])

    # 返回6元組
    return anchor_path, positive_path, negative_path, anchor_cls, anchor_cls, negative_cls
最終在主函式中呼叫如下:
for batch_idx, triplet_train_data in enumerate(train_loader, 0):#train_input, train_label, batch_ind 
    data, _ = triplet_train_data
    data1, data2, data3 = data

完美解決問題