1. 程式人生 > 其它 >python 拆分多類別資料集

python 拆分多類別資料集

原資料集形式,收集的資料來源包括兩個folder, 資料分為三類(class1-3)

希望得到的資料集形式:將資料集拆分為train和test兩部分,每部分都包含所有類別。

完整程式碼(已包含註釋,自測可用,參考文獻:資料集劃分、label生成及按label將圖片分類到不同資料夾):

 1 import os
 2 # import cv2
 3 import random
 4 import sys
 5 from random import randint
 6 import shutil
 7 
 8 def fileExist(path1):
 9     if os.path.exists(path1):
10 return 11 else: 12 try: 13 os.mkdir(path1) # 建立單層資料夾 14 except Exception as e: 15 os.makedirs(path1) # 建立多層資料夾 16 17 18 def split_dataset(root_path, new_path, ratio=0.7): # root: folder1: new_path: dataset1/folder1 按0.7的比例拆分,也可按其他比例 19 folder_list = os.listdir(root_path) #
folder1/[class1,class2...] 20 for folder in folder_list: # class1 21 train_path = os.path.join(new_path, "train", str(folder)) 22 test_path = os.path.join(new_path, "test", str(folder)) 23 origin_path = os.path.join(root_path, str(folder)) 24 img_list = os.listdir(origin_path)
25 26 img_num = len(img_list) 27 train_num = int(img_num * ratio) 28 train_sample = random.sample(img_list, train_num) 29 test_sample = list(set(img_list)-set(train_sample)) 30 31 for item in train_sample: 32 src_new = os.path.join(origin_path, str(item)) 33 dst_new = os.path.join(train_path, str(item)) 34 shutil.copy(src=src_new, dst = dst_new) 35 for item in test_sample: 36 src_new = os.path.join(origin_path, str(item)) 37 dst_new = os.path.join(test_path, str(item)) 38 shutil.copy(src=src_new, dst=dst_new) 39 40 41 if __name__ == '__main__': 42 root_path = "dataset" 43 new_path = "dataset1" 44 45 # 建立資料夾 46 for domain in os.listdir(root_path): 47 domain_path = os.path.join(root_path, str(domain)) 48 domain_new_path = os.path.join(new_path, str(domain)) 49 for folder in os.listdir(domain_path): # class1 50 train_path = os.path.join(domain_new_path, "train", str(folder)) 51 test_path = os.path.join(domain_new_path, "test", str(folder)) 52 fileExist(train_path) 53 fileExist(test_path) 54 55 # 拆分資料集到新的路徑 56 for domain in os.listdir(root_path): 57 domain_path = os.path.join(root_path, str(domain)) 58 domain_new_path = os.path.join(new_path, str(domain)) 59 split_dataset(domain_path,domain_new_path