1. 程式人生 > 其它 >train set、dev set和test set_CodingPark程式設計公園

train set、dev set和test set_CodingPark程式設計公園

技術標籤:基本功機器學習

理論講解

  • train set:該集合是用於訓練模型的。

  • dev set:該集合是用於在訓練模型中評估模型,以促進模型優化的。

  • test set:該集合是用於測試訓練好的模型是否有效的。

你使用了train set訓練一個模型,這個模型有一個優化目標,利用dev set來評估你的模型,確定你模型離你的目標差距。在不斷迭代中不斷用train set訓練模型,dev set評估模型,不斷靠近你的目標直至最優。之後用test set來驗證模型效果。

⚠️注意:dev set 和 test set需要在同一分佈下

這裡是因為如果你dev set 和 test set屬於不同分佈,你在dev set訓練出的模型是不會符合test set的。

你的dev set+優化目標就是圖中靶子+靶心,你訓練的模型就是這這個靶子裡評估,不斷優化靠近靶心,但是如果你dev set 和 test set屬於不同分佈,你就會發現你最終驗證的是圖上紅色線畫的靶子,這顯而易見你訓練的模型完全不能很好符合test set。

所以要讓你的模型符合test set,必須要dev set 和 test set屬於同一分佈


程式碼實踐

class ChineseSentimentClsProcessor(DataProcessor):
  """Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir): file_path = os.path.join(data_dir, 'train_sentiment.txt') f = open(file_path, 'r', encoding='utf-8') train_data = [] index = 0 for line in f.readlines(): guid = 'train-%d' % index # 引數guid是用來區分每個example的 line =
line.replace('\n', '').split('\t') text_a = tokenization.convert_to_unicode(str(line[1])) # 要分類的文字 label = str(line[2]) # 文字對應的情感類別 train_data.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) # 加入到InputExample列表中 index += 1 return train_data def get_dev_examples(self, data_dir): file_path = os.path.join(data_dir, 'test_sentiment.txt') f = open(file_path, 'r', encoding='utf-8') dev_data = [] index = 0 for line in f.readlines(): guid = 'dev-%d' % index line = line.replace("\n", "").split("\t") text_a = tokenization.convert_to_unicode(str(line[1])) label = str(line[2]) dev_data.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) index += 1 return dev_data def get_test_examples(self, data_dir): file_path = os.path.join(data_dir, 'test_sentiment.txt') f = open(file_path, 'r', encoding='utf-8') test_data = [] index = 0 for line in f.readlines(): guid = 'dev-%d' % index line = line.replace("\n", "").split("\t") text_a = tokenization.convert_to_unicode(str(line[1])) label = str(line[2]) test_data.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) index += 1 return test_data def get_labels(self): return ['0', '1', '2']

在這裡插入圖片描述