train set、dev set和test set_CodingPark程式設計公園
阿新 • • 發佈:2021-01-27
理論講解
-
train set:該集合是用於訓練模型的。
-
dev set:該集合是用於在訓練模型中評估模型,以促進模型優化的。
-
test set:該集合是用於測試訓練好的模型是否有效的。
你使用了train set訓練一個模型,這個模型有一個優化目標,利用dev set來評估你的模型,確定你模型離你的目標差距。在不斷迭代中不斷用train set訓練模型,dev set評估模型,不斷靠近你的目標直至最優。之後用test set來驗證模型效果。
⚠️注意:dev set 和 test set需要在同一分佈下
這裡是因為如果你dev set 和 test set屬於不同分佈,你在dev set訓練出的模型是不會符合test set的。
你的dev set+優化目標就是圖中靶子+靶心,你訓練的模型就是這這個靶子裡評估,不斷優化靠近靶心,但是如果你dev set 和 test set屬於不同分佈,你就會發現你最終驗證的是圖上紅色線畫的靶子,這顯而易見你訓練的模型完全不能很好符合test set。
所以要讓你的模型符合test set,必須要dev set 和 test set屬於同一分佈
程式碼實踐
class ChineseSentimentClsProcessor(DataProcessor):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
file_path = os.path.join(data_dir, 'train_sentiment.txt')
f = open(file_path, 'r', encoding='utf-8')
train_data = []
index = 0
for line in f.readlines():
guid = 'train-%d' % index # 引數guid是用來區分每個example的
line = line.replace('\n', '').split('\t')
text_a = tokenization.convert_to_unicode(str(line[1])) # 要分類的文字
label = str(line[2]) # 文字對應的情感類別
train_data.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) # 加入到InputExample列表中
index += 1
return train_data
def get_dev_examples(self, data_dir):
file_path = os.path.join(data_dir, 'test_sentiment.txt')
f = open(file_path, 'r', encoding='utf-8')
dev_data = []
index = 0
for line in f.readlines():
guid = 'dev-%d' % index
line = line.replace("\n", "").split("\t")
text_a = tokenization.convert_to_unicode(str(line[1]))
label = str(line[2])
dev_data.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
index += 1
return dev_data
def get_test_examples(self, data_dir):
file_path = os.path.join(data_dir, 'test_sentiment.txt')
f = open(file_path, 'r', encoding='utf-8')
test_data = []
index = 0
for line in f.readlines():
guid = 'dev-%d' % index
line = line.replace("\n", "").split("\t")
text_a = tokenization.convert_to_unicode(str(line[1]))
label = str(line[2])
test_data.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
index += 1
return test_data
def get_labels(self):
return ['0', '1', '2']