將資料分為訓練、驗證和測試集,考慮資料平衡問題和亂序,每個標籤下的資料隨機80%分到訓練集,10%到驗證和測試集
阿新 • • 發佈:2020-08-19
data = pd.read_excel("../data/dataset.xlsx") list_label = [] train_list, dev_list, test_list = [],[],[] data_value = data.values for i in range(len(data_value)): if data_value[i][0] not in list_label: list_label.append(data_value[i][0]) for j in range(len(list_label)): s = data.loc[data["c1"] == list_label[j]] s = s.loc[:,['c1','c2']] #c1為標籤,c2為具體的資料 s = s.ix[:].values s = s.tolist() s_random = random.sample(s,len(s)) train_list = train_list +s_random[:int(0.8*len(s_random))] dev_list = dev_list + s_random[int(0.8*len(s_random)):int(0.9*len(s_random))] test_list= test_list + s_random[int(0.9*len(s_random)):] def write_file (s,f_file): s = random.sample(s,len(s)) f1 = open(f_file,'w',encoding='utf-8') sen_str = '' labe_cata = [] for i in range(len(s)): sen_str += '__label__' + s[i][0] labe_cata.append(s[i][0]) sen_str+= "\t" for j in s[i][1]: sen_str += j + ' ' sen_str.strip() sen_str += '\n' print(len(set(labe_cata))) f1.write(sen_str) f1.close() write_file(train_list, '../data/train.txt') write_file(dev_list, '../data/val.txt') write_file(test_list, '../data/test.txt')