1. 程式人生 > 實用技巧 >將資料分為訓練、驗證和測試集,考慮資料平衡問題和亂序,每個標籤下的資料隨機80%分到訓練集,10%到驗證和測試集

將資料分為訓練、驗證和測試集,考慮資料平衡問題和亂序,每個標籤下的資料隨機80%分到訓練集,10%到驗證和測試集

data = pd.read_excel("../data/dataset.xlsx")
list_label = []
train_list, dev_list, test_list = [],[],[]
data_value = data.values
for i in range(len(data_value)):
    if data_value[i][0] not in list_label:
        list_label.append(data_value[i][0])
for j in range(len(list_label)):
    s = data.loc[data["
c1"] == list_label[j]] s = s.loc[:,['c1','c2']] #c1為標籤,c2為具體的資料 s = s.ix[:].values s = s.tolist() s_random = random.sample(s,len(s)) train_list = train_list +s_random[:int(0.8*len(s_random))] dev_list = dev_list + s_random[int(0.8*len(s_random)):int(0.9*len(s_random))] test_list
= test_list + s_random[int(0.9*len(s_random)):] def write_file (s,f_file): s = random.sample(s,len(s)) f1 = open(f_file,'w',encoding='utf-8') sen_str = '' labe_cata = [] for i in range(len(s)): sen_str += '__label__' + s[i][0] labe_cata.append(s[i][0]) sen_str
+= "\t" for j in s[i][1]: sen_str += j + ' ' sen_str.strip() sen_str += '\n' print(len(set(labe_cata))) f1.write(sen_str) f1.close() write_file(train_list, '../data/train.txt') write_file(dev_list, '../data/val.txt') write_file(test_list, '../data/test.txt')