Alink漫談(七) : 如何劃分訓練資料集和測試資料集
阿新 • • 發佈:2020-06-13
# Alink漫談(七) : 如何劃分訓練資料集和測試資料集
[TOC]
## 0x00 摘要
Alink 是阿里巴巴基於實時計算引擎 Flink 研發的新一代機器學習演算法平臺,是業界首個同時支援批式演算法、流式演算法的機器學習平臺。本文將為大家展現Alink如何劃分訓練資料集和測試資料集。
## 0x01 訓練資料集和測試資料集
**兩分法**
一般做預測分析時,會將資料分為兩大部分。一部分是訓練資料,用於構建模型,一部分是測試資料,用於檢驗模型。
**三分法**
但有時候模型的構建過程中也需要檢驗模型/輔助模型構建,這時會將訓練資料再分為兩個部分:1)訓練資料;2)驗證資料(Validation Data)。所以這種情況下會把資料分為三部分。
- 訓練資料(Train Data):用於模型構建。
- 驗證資料(Validation Data):可選,用於輔助模型構建,可以重複使用。
- 測試資料(Test Data):用於檢測模型構建,此資料只在模型檢驗時使用,用於評估模型的準確率。絕對不允許用於模型構建過程,否則會導致過渡擬合。
Training set是用來訓練模型或確定模型引數的,如ANN中權值等;
Validation set是用來做模型選擇(model selection),即做模型的最終優化及確定,如ANN的結構;
Test set則純粹是為了測試已經訓練好的模型的推廣能力。當然test set並不能保證模型的正確性,他只是說相似的資料用此模型會得出相似的結果。
**實際應用**
實際應用中,一般只將資料集分成兩類,即training set 和test set,大多數文章並不涉及validation set。我們這裡也不涉及。大家常用的sklearn的train_test_split函式就是將矩陣隨機劃分為訓練子集和測試子集,並返回劃分好的訓練集測試集樣本和訓練集測試集標籤。
## 0x02 Alink示例程式碼
首先我們給出示例程式碼,然後會深入剖析:
```java
public class SplitExample {
public static void main(String[] args) throws Exception {
String url = "iris.csv";
String schema = "sepal_length double, sepal_width double, petal_length double, petal_width double, category string";
//這裡是批處理
BatchOperator data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schema);
SplitBatchOp spliter = new SplitBatchOp().setFraction(0.8);
spliter.linkFrom(data);
BatchOperator trainData = spliter;
BatchOperator testData = spliter.getSideOutput(0);
// 這裡是流處理
CsvSourceStreamOp dataS = new CsvSourceStreamOp().setFilePath(url).setSchemaStr(schema);
SplitStreamOp spliterS = new SplitStreamOp().setFraction(0.4);
spliterS.linkFrom(dataS);
StreamOperator train_data = spliterS;
StreamOperator test_data = spliterS.getSideOutput(0);
}
}
```
## 0x03 批處理
SplitBatchOp是分割批處理的主要類,具體構建DAG的工作是在其linkFrom完成的。
總體思路比較簡單:
1. 假定有一個取樣比例 fraction
2. 將資料集分割槽,平行計算每個分割槽上的記錄數
3. 把每個分割槽上的記錄數累積,得到所有記錄總數 totCount
4. 從上而下計算出一個取樣總數:`numTarget = totCount * fraction`
5. 因為具體選擇元素是在每個分割槽上做的,所以在每個分割槽上,分別計算出來這個分割槽應該取樣的記錄數,比如第n個分割槽上應取樣記錄數:`task_n_count * fraction`
6. 把這些分割槽 "應該取樣的記錄數" 累積,得出來從下而上計算出的取樣總數: `totSelect = task_1_count * fraction + task_2_count * fraction + ... task_n_count * fraction`
7. numTarget 和 totSelect 可能不相等,所以隨機決定把多出來的 `numTarget - totSelect` 加入到某一個task中。
8. 在每個task上取樣得到具體的記錄。
### 3.1 得到記錄數
如果要分割資料,首先必須知道資料集的記錄數。比如這個DataSet的記錄是1萬個?還是十萬個?因為資料集可能會很大,所以這一步操作也使用了並行處理,即把資料分割槽,然後通過mapPartition操作得到每一個分割槽上元素的數目。
```java