1. 程式人生 > 實用技巧 >Alink漫談(八) : 二分類評估 AUC、K-S、PRC、Precision、Recall、LiftChart 如何實現

Alink漫談(八) : 二分類評估 AUC、K-S、PRC、Precision、Recall、LiftChart 如何實現

Alink漫談(八) : 二分類評估 AUC、K-S、PRC、Precision、Recall、LiftChart 如何實現

目錄

0x00 摘要

Alink 是阿里巴巴基於實時計算引擎 Flink 研發的新一代機器學習演演算法平臺,是業界首個同時支援批式演演算法、流式演演算法的機器學習平臺。二分類評估是對二分類演演算法的預測結果進行效果評估。本文將剖析Alink中對應程式碼實現。

0x01 相關概念

如果對本文某些概念有疑惑,可以參見之前文章 [白話解析] 通過例項來梳理概念 :準確率 (Accuracy)、精準率(Precision)、召回率(Recall) 和 F值(F-Measure)

0x02 示例程式碼

public class EvalBinaryClassExample {

    AlgoOperator getData(boolean isBatch) {
Row[] rows = new Row[]{
Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}"),
Row.of("prefix1", "{\"prefix1\": 0.8, \"prefix0\": 0.2}"),
Row.of("prefix1", "{\"prefix1\": 0.7, \"prefix0\": 0.3}"),
Row.of("prefix0", "{\"prefix1\": 0.75, \"prefix0\": 0.25}"),
Row.of("prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}")
}; String[] schema = new String[]{"label", "detailInput"}; if (isBatch) {
return new MemSourceBatchOp(rows, schema);
} else {
return new MemSourceStreamOp(rows, schema);
}
} public static void main(String[] args) throws Exception {
EvalBinaryClassExample test = new EvalBinaryClassExample();
BatchOperator batchData = (BatchOperator) test.getData(true); BinaryClassMetrics metrics = new EvalBinaryClassBatchOp()
.setLabelCol("label")
.setPredictionDetailCol("detailInput")
.linkFrom(batchData)
.collectMetrics(); System.out.println("RocCurve:" + metrics.getRocCurve());
System.out.println("AUC:" + metrics.getAuc());
System.out.println("KS:" + metrics.getKs());
System.out.println("PRC:" + metrics.getPrc());
System.out.println("Accuracy:" + metrics.getAccuracy());
System.out.println("Macro Precision:" + metrics.getMacroPrecision());
System.out.println("Micro Recall:" + metrics.getMicroRecall());
System.out.println("Weighted Sensitivity:" + metrics.getWeightedSensitivity());
}
}

程式輸出

RocCurve:([0.0, 0.0, 0.0, 0.5, 0.5, 1.0, 1.0],[0.0, 0.3333333333333333, 0.6666666666666666, 0.6666666666666666, 1.0, 1.0, 1.0])
AUC:0.8333333333333333
KS:0.6666666666666666
PRC:0.9027777777777777
Accuracy:0.6
Macro Precision:0.3
Micro Recall:0.6
Weighted Sensitivity:0.6

在 Alink 中,二分類評估有批處理,流處理兩種實現,下面一一為大家介紹( Alink 複雜之一在於大量精細的資料結構,所以下文會大量列印程式中變數以便大家理解)。

2.1 主要思路

  • 把 [0,1] 分成假設 100000個桶(bin)。所以得到positiveBin / negativeBin 兩個100000的陣列。

  • 根據輸入給positiveBin / negativeBin賦值。positiveBin就是 TP + FP,negativeBin就是 TN + FN。這些是後續計算的基礎。

  • 遍歷bins中每一個有意義的點,計算出totalTrue和totalFalse,並且在每一個點上計算該點的混淆矩陣,tpr,以及rocCurve,recallPrecisionCurve,liftChart在該點對應的資料;

  • 依據曲線內容計算並且儲存 AUC/PRC/KS

具體後續還有詳細呼叫關係綜述。

0x03 批處理

3.1 EvalBinaryClassBatchOp

EvalBinaryClassBatchOp是二分類評估的實現,功能是計算二分類的評估指標(evaluation metrics)。

輸入有兩種:

  • label column and predResult column
  • label column and predDetail column。如果有predDetail,則predResult被忽略

我們例子中 "prefix1" 就是 label,"{\"prefix1\": 0.9, \"prefix0\": 0.1}" 就是 predDetail

Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}")

具體類摘錄如下:

public class EvalBinaryClassBatchOp extends BaseEvalClassBatchOp<EvalBinaryClassBatchOp> implements BinaryEvaluationParams <EvalBinaryClassBatchOp>, EvaluationMetricsCollector<BinaryClassMetrics> {

	@Override
public BinaryClassMetrics collectMetrics() {
return new BinaryClassMetrics(this.collect().get(0));
}
}

可以看到,其主要工作都是在基類BaseEvalClassBatchOp中完成,所以我們會首先看BaseEvalClassBatchOp。

3.2 BaseEvalClassBatchOp

我們還是從 linkFrom 函式入手,其主要是做了幾件事:

  • 獲取配置資訊
  • 從輸入中提取某些列:"label","detailInput"
  • calLabelPredDetailLocal會按照partition分別計算evaluation metrics
  • 綜合reduce上述計算結果
  • SaveDataAsParams函式會把最終數值輸入到 output table

具體程式碼如下

@Override
public T linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
String labelColName = this.get(MultiEvaluationParams.LABEL_COL);
String positiveValue = this.get(BinaryEvaluationParams.POS_LABEL_VAL_STR); // Judge the evaluation type from params.
ClassificationEvaluationUtil.Type type = ClassificationEvaluationUtil.judgeEvaluationType(this.getParams()); DataSet<BaseMetricsSummary> res;
switch (type) {
case PRED_DETAIL: {
String predDetailColName = this.get(MultiEvaluationParams.PREDICTION_DETAIL_COL);
// 從輸入中提取某些列:"label","detailInput"
DataSet<Row> data = in.select(new String[] {labelColName, predDetailColName}).getDataSet();
// 按照partition分別計算evaluation metrics
res = calLabelPredDetailLocal(data, positiveValue, binary);
break;
}
......
} // 綜合reduce上述計算結果
DataSet<BaseMetricsSummary> metrics = res
.reduce(new EvaluationUtil.ReduceBaseMetrics()); // 把最終數值輸入到 output table
this.setOutput(metrics.flatMap(new EvaluationUtil.SaveDataAsParams()),
new String[] {DATA_OUTPUT}, new TypeInformation[] {Types.STRING}); return (T)this;
} // 執行中一些變數如下
labelColName = "label"
predDetailColName = "detailInput"
type = {[email protected]} "PRED_DETAIL"
binary = true
positiveValue = null

3.2.0 呼叫關係綜述

因為後續程式碼呼叫關係複雜,所以先給出一個呼叫關係

  • 從輸入中提取某些列:"label","detailInput",in.select(new String[] {labelColName, predDetailColName}).getDataSet()。因為可能輸入還有其他列,而只有某些列是我們計算需要的,所以只提取這些列。
  • 按照partition分別計算evaluation metrics,即呼叫 calLabelPredDetailLocal(data, positiveValue, binary);
    • flatMap會從label列和prediction列中,取出所有labels(注意是取出labels的名字 ),傳送給下游運算元。
    • reduceGroup主要功能是通過 buildLabelIndexLabelArray 去重 "labels名字",然後給每一個label一個ID,得到一個 <labels, ID>的map,最後返回是二元組(map, labels),即({prefix1=0, prefix0=1},[prefix1, prefix0])。從後文看,<labels, ID>Map看來是多分類才用到。二分類只用到了labels。
    • mapPartition 分割槽呼叫 CalLabelDetailLocal 來計算混淆矩陣,主要是分割槽呼叫getDetailStatistics,前文中得到的二元組(map, labels)會作為引數傳遞進來 。
      • getDetailStatistics 遍歷 rows 資料,提取每一個item(比如 "prefix1,{"prefix1": 0.8, "prefix0": 0.2}"),然後通過updateBinaryMetricsSummary累積計算混淆矩陣所需資料。

        • updateBinaryMetricsSummary 把 [0,1] 分成假設 100000個桶(bin)。所以得到positiveBin / negativeBin 兩個100000的陣列。positiveBin就是 TP + FP,negativeBin就是 TN + FN。

          • 如果某個 sample 為 正例 (positive value) 的概率是 p, 則該 sample 對應的 bin index 就是 p * 100000。如果 p 被預測為正例 (positive value) ,則positiveBin[index]++,
          • 否則就是被預測為負例(negative value) ,則negativeBin[index]++。
  • 綜合reduce上述計算結果,metrics = res.reduce(new EvaluationUtil.ReduceBaseMetrics());
    • 具體計算是在BinaryMetricsSummary.merge,其作用就是Merge the bins, and add the logLoss。
  • 把最終數值輸入到 output table,setOutput(metrics.flatMap(new EvaluationUtil.SaveDataAsParams()..);
    • 歸併所有BaseMetrics後,得到total BaseMetrics,計算indexes存入params。collector.collect(t.toMetrics().serialize());

      • 實際業務在BinaryMetricsSummary.toMetrics,即基於bin的資訊計算,然後儲存到params。

        • extractMatrixThreCurve函式取出非空的bins,據此計算出ConfusionMatrix array(混淆矩陣), threshold array, rocCurve/recallPrecisionCurve/LiftChart.

          • 遍歷bins中每一個有意義的點,計算出totalTrue和totalFalse,並且在每一個點上計算:
          • curTrue += positiveBin[index]; curFalse += negativeBin[index];
          • 得到該點的混淆矩陣 new ConfusionMatrix(new long[][] {{curTrue, curFalse}, {totalTrue - curTrue, totalFalse - curFalse}});
          • 得到 tpr = (totalTrue == 0 ? 1.0 : 1.0 * curTrue / totalTrue);
          • rocCurve,recallPrecisionCurve,liftChart在該點對應的資料;
        • 依據曲線內容計算並且儲存 AUC/PRC/KS
        • 對生成的rocCurve/recallPrecisionCurve/LiftChart輸出進行抽樣
        • 依據抽樣後的輸出儲存 RocCurve/RecallPrecisionCurve/LiftChar
        • 儲存正例樣本的度量指標
        • 儲存Logloss
        • Pick the middle point where threshold is 0.5.

3.2.1 calLabelPredDetailLocal

本函式按照partition分別計算評估指標 evaluation metrics。是的,這程式碼很短,但是有個地方需要注意。有時候越簡單的地方越容易疏漏。容易疏漏點是:

第一行程式碼的結果 labels 是第二行程式碼的引數,而並非第二行主體。第二行程式碼主體和第一行程式碼主體一樣,都是data。

private static DataSet<BaseMetricsSummary> calLabelPredDetailLocal(DataSet<Row> data, final String positiveValue, oolean binary) {

    DataSet<Tuple2<Map<String, Integer>, String[]>> labels = data.flatMap(new FlatMapFunction<Row, String>() {
@Override
public void flatMap(Row row, Collector<String> collector) {
TreeMap<String, Double> labelProbMap;
if (EvaluationUtil.checkRowFieldNotNull(row)) {
labelProbMap = EvaluationUtil.extractLabelProbMap(row);
labelProbMap.keySet().forEach(collector::collect);
collector.collect(row.getField(0).toString());
}
}
}).reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(binary, positiveValue)); return data
.rebalance()
.mapPartition(new CalLabelDetailLocal(binary))
.withBroadcastSet(labels, LABELS);
}

calLabelPredDetailLocal中具體分為三步驟:

  • 在flatMap會從label列和prediction列中,取出所有labels(注意是取出labels的名字 ),傳送給下游運算元。
  • reduceGroup的主要功能是去重 "labels名字",然後給每一個label一個ID,最後結果是一個<labels, ID>Map。
  • mapPartition 是分割槽呼叫 CalLabelDetailLocal 來計算混淆矩陣。

下面具體看看。

3.2.1.1 flatMap

在flatMap中,主要是從label列和prediction列中,取出所有labels(注意是取出labels的名字 ),傳送給下游運算元。

EvaluationUtil.extractLabelProbMap 作用就是解析輸入的json,獲得具體detailInput中的資訊。

下游運算元是reduceGroup,所以Flink runtime會對這些labels自動去重。如果對這部分有興趣,可以參見我之前介紹reduce的文章。CSDN : [原始碼解析] Flink的groupBy和reduce究竟做了什麼 部落格園 : [原始碼解析] Flink的groupBy和reduce究竟做了什麼

程式中變數如下

row = {[email protected]} "prefix1,{"prefix1": 0.9, "prefix0": 0.1}"
fields = {Object[2]@8925}
0 = "prefix1"
1 = "{"prefix1": 0.9, "prefix0": 0.1}" labelProbMap = {[email protected]} size = 2
"prefix0" -> {[email protected]} 0.1
"prefix1" -> {[email protected]} 0.9 labelProbMap.keySet().forEach(collector::collect); //這裡傳送 "prefix0", "prefix1"
collector.collect(row.getField(0).toString()); // 這裡傳送 "prefix1"
// 因為下一個操作是reduceGroup,所以這些label會被runtime去重
3.2.1.2 reduceGroup

主要功能是通過buildLabelIndexLabelArray去重labels,然後給每一個label一個ID,最後結果是一個<labels, ID>的Map。

reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(binary, positiveValue));

DistinctLabelIndexMap的作用是從label列和prediction列中,取出所有不同的labels,返回一個<labels, ID>的map,根據後續程式碼看,這個map是多分類才用到。Get all the distinct labels from label column and prediction column, and return the map of labels and their IDs.

前面已經提到,這裡的引數rows已經被自動去重。

public static class DistinctLabelIndexMap implements
GroupReduceFunction<String, Tuple2<Map<String, Integer>, String[]>> {
......
@Override
public void reduce(Iterable<String> rows, Collector<Tuple2<Map<String, Integer>, String[]>> collector) throws Exception {
HashSet<String> labels = new HashSet<>();
rows.forEach(labels::add);
collector.collect(buildLabelIndexLabelArray(labels, binary, positiveValue));
}
} // 變數為
labels = {[email protected]} size = 2
0 = "prefix1"
1 = "prefix0"
binary = true

buildLabelIndexLabelArray的作用是給每一個label一個ID,得到一個 <labels, ID>的map,最後返回是二元組(map, labels),即({prefix1=0, prefix0=1},[prefix1, prefix0])。

// Give each label an ID, return a map of label and ID.
public static Tuple2<Map<String, Integer>, String[]> buildLabelIndexLabelArray(HashSet<String> set,boolean binary, String positiveValue) {
String[] labels = set.toArray(new String[0]);
Arrays.sort(labels, Collections.reverseOrder()); Map<String, Integer> map = new HashMap<>(labels.length);
if (binary && null != positiveValue) {
if (labels[1].equals(positiveValue)) {
labels[1] = labels[0];
labels[0] = positiveValue;
}
map.put(labels[0], 0);
map.put(labels[1], 1);
} else {
for (int i = 0; i < labels.length; i++) {
map.put(labels[i], i);
}
}
return Tuple2.of(map, labels);
} // 程式變數如下
labels = {String[2]@9013}
0 = "prefix1"
1 = "prefix0"
map = {[email protected]} size = 2
"prefix1" -> {[email protected]} 0
"prefix0" -> {[email protected]} 1
3.2.1.3 mapPartition

這裡主要功能是分割槽呼叫 CalLabelDetailLocal 來為後來計算混淆矩陣做準備。

return data
.rebalance()
.mapPartition(new CalLabelDetailLocal(binary)) //這裡是業務所在
.withBroadcastSet(labels, LABELS);

具體工作是 CalLabelDetailLocal 完成的,其作用是分割槽呼叫getDetailStatistics

// Calculate the confusion matrix based on the label and predResult.
static class CalLabelDetailLocal extends RichMapPartitionFunction<Row, BaseMetricsSummary> {
private Tuple2<Map<String, Integer>, String[]> map;
private boolean binary; @Override
public void open(Configuration parameters) throws Exception {
List<Tuple2<Map<String, Integer>, String[]>> list = getRuntimeContext().getBroadcastVariable(LABELS);
this.map = list.get(0);// 前文生成的二元組(map, labels)
} @Override
public void mapPartition(Iterable<Row> rows, Collector<BaseMetricsSummary> collector) {
// 呼叫到了 getDetailStatistics
collector.collect(getDetailStatistics(rows, binary, map));
}
}

getDetailStatistics 的作用是:初始化分類評估的度量指標 base classification evaluation metrics,累積計算混淆矩陣需要的資料。主要就是遍歷 rows 資料,提取每一個item(比如 "prefix1,{"prefix1": 0.8, "prefix0": 0.2}"),然後累積計算混淆矩陣所需資料。

// Initialize the base classification evaluation metrics. There are two cases: BinaryClassMetrics and MultiClassMetrics.
private static BaseMetricsSummary getDetailStatistics(Iterable<Row> rows,
String positiveValue,
boolean binary,
Tuple2<Map<String, Integer>, String[]> tuple) {
BinaryMetricsSummary binaryMetricsSummary = null;
MultiMetricsSummary multiMetricsSummary = null;
Tuple2<Map<String, Integer>, String[]> labelIndexLabelArray = tuple; // 前文生成的二元組(map, labels) Iterator<Row> iterator = rows.iterator();
Row row = null;
while (iterator.hasNext() && !checkRowFieldNotNull(row)) {
row = iterator.next();
} Map<String, Integer> labelIndexMap = null;
if (binary) {
// 二分法在這裡
binaryMetricsSummary = new BinaryMetricsSummary(
new long[ClassificationEvaluationUtil.DETAIL_BIN_NUMBER],
new long[ClassificationEvaluationUtil.DETAIL_BIN_NUMBER],
labelIndexLabelArray.f1, 0.0, 0L);
} else {
//
labelIndexMap = labelIndexLabelArray.f0; // 前文生成的<labels, ID>Map看來是多分類才用到。
multiMetricsSummary = new MultiMetricsSummary(
new long[labelIndexMap.size()][labelIndexMap.size()],
labelIndexLabelArray.f1, 0.0, 0L);
} while (null != row) {
if (checkRowFieldNotNull(row)) {
TreeMap<String, Double> labelProbMap = extractLabelProbMap(row);
String label = row.getField(0).toString();
if (ArrayUtils.indexOf(labelIndexLabelArray.f1, label) >= 0) {
if (binary) {
// 二分法在這裡
updateBinaryMetricsSummary(labelProbMap, label, binaryMetricsSummary);
} else {
updateMultiMetricsSummary(labelProbMap, label, labelIndexMap, multiMetricsSummary);
}
}
}
row = iterator.hasNext() ? iterator.next() : null;
} return binary ? binaryMetricsSummary : multiMetricsSummary;
} //變數如下
tuple = {[email protected]} "({prefix1=0, prefix0=1},[prefix1, prefix0])"
f0 = {[email protected]} size = 2
"prefix1" -> {[email protected]} 0
"prefix0" -> {[email protected]} 1
f1 = {String[2]@9258}
0 = "prefix1"
1 = "prefix0" row = {[email protected]} "prefix1,{"prefix1": 0.8, "prefix0": 0.2}"
fields = {Object[2]@9276}
0 = "prefix1"
1 = "{"prefix1": 0.8, "prefix0": 0.2}" labelIndexLabelArray = {[email protected]} "({prefix1=0, prefix0=1},[prefix1, prefix0])"
f0 = {[email protected]} size = 2
"prefix1" -> {[email protected]} 0
"prefix0" -> {[email protected]} 1
f1 = {String[2]@9242}
0 = "prefix1"
1 = "prefix0" labelProbMap = {[email protected]} size = 2
"prefix0" -> {[email protected]} 0.1
"prefix1" -> {[email protected]} 0.9

先回憶下混淆矩陣:

預測值 0 預測值 1
真實值 0 TN FP
真實值 1 FN TP

針對混淆矩陣,BinaryMetricsSummary 的作用是Save the evaluation data for binary classification。函式具體計算思路是:

  • 把 [0,1] 分成ClassificationEvaluationUtil.DETAIL_BIN_NUMBER(100000)這麼多桶(bin)。所以binaryMetricsSummary的positiveBin/negativeBin分別是兩個100000的陣列。如果某一個 sample 為 正例(positive value) 的概率是 p, 則該 sample 對應的 bin index 就是 p * 100000。如果 p 被預測為正例(positive value) ,則positiveBin[index]++,否則就是被預測為負例(negative value) ,則negativeBin[index]++。positiveBin就是 TP + FP,negativeBin就是 TN + FN。

  • 所以這裡會遍歷輸入,如果某一個輸入(以"prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}"為例),0.9 是prefix1(正例) 的概率,0.1 是為prefix0(負例) 的概率。

    • 既然這個演演算法選擇了 prefix1(正例) ,所以就說明此演演算法是判別成 positive 的,所以在 positiveBin 的 90000 處 + 1。
    • 假設這個演演算法選擇了 prefix0(負例) ,則說明此演演算法是判別成 negative 的,所以應該在 negativeBin 的 90000 處 + 1。

具體對應我們示例程式碼的5個取樣,分類如下:

Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}"),  positiveBin 90000處+1
Row.of("prefix1", "{\"prefix1\": 0.8, \"prefix0\": 0.2}"), positiveBin 80000處+1
Row.of("prefix1", "{\"prefix1\": 0.7, \"prefix0\": 0.3}"), positiveBin 70000處+1
Row.of("prefix0", "{\"prefix1\": 0.75, \"prefix0\": 0.25}"), negativeBin 75000處+1
Row.of("prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}") negativeBin 60000處+1

具體程式碼如下

public static void updateBinaryMetricsSummary(TreeMap<String, Double> labelProbMap,
String label,
BinaryMetricsSummary binaryMetricsSummary) {
binaryMetricsSummary.total++;
binaryMetricsSummary.logLoss += extractLogloss(labelProbMap, label); double d = labelProbMap.get(binaryMetricsSummary.labels[0]);
int idx = d == 1.0 ? ClassificationEvaluationUtil.DETAIL_BIN_NUMBER - 1 :
(int)Math.floor(d * ClassificationEvaluationUtil.DETAIL_BIN_NUMBER);
if (idx >= 0 && idx < ClassificationEvaluationUtil.DETAIL_BIN_NUMBER) {
if (label.equals(binaryMetricsSummary.labels[0])) {
binaryMetricsSummary.positiveBin[idx] += 1;
} else if (label.equals(binaryMetricsSummary.labels[1])) {
binaryMetricsSummary.negativeBin[idx] += 1;
} else {
.....
}
}
} private static double extractLogloss(TreeMap<String, Double> labelProbMap, String label) {
Double prob = labelProbMap.get(label);
prob = null == prob ? 0. : prob;
return -Math.log(Math.max(Math.min(prob, 1 - LOG_LOSS_EPS), LOG_LOSS_EPS));
} // 變數如下
ClassificationEvaluationUtil.DETAIL_BIN_NUMBER=100000 // 當 "prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}" 時候
labelProbMap = {[email protected]} size = 2
"prefix0" -> {[email protected]} 0.1
"prefix1" -> {[email protected]} 0.9 d = 0.9
idx = 90000
binaryMetricsSummary = {[email protected]}
labels = {String[2]@9242}
0 = "prefix1"
1 = "prefix0"
total = 1
positiveBin = {long[100000]@9263} // 90000處+1
negativeBin = {long[100000]@9264}
logLoss = 0.10536051565782628 // 當 "prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}" 時候
labelProbMap = {[email protected]} size = 2
"prefix0" -> {[email protected]} 0.4
"prefix1" -> {[email protected]} 0.6 d = 0.6
idx = 60000
binaryMetricsSummary = {[email protected]}
labels = {String[2]@9242}
0 = "prefix1"
1 = "prefix0"
total = 2
positiveBin = {long[100000]@9263}
negativeBin = {long[100000]@9264} // 60000處+1
logLoss = 1.0216512475319812

3.2.2 ReduceBaseMetrics

ReduceBaseMetrics作用是把區域性計算的 BaseMetrics 聚合起來。

DataSet<BaseMetricsSummary> metrics = res
.reduce(new EvaluationUtil.ReduceBaseMetrics());

ReduceBaseMetrics如下

public static class ReduceBaseMetrics implements ReduceFunction<BaseMetricsSummary> {
@Override
public BaseMetricsSummary reduce(BaseMetricsSummary t1, BaseMetricsSummary t2) throws Exception {
return null == t1 ? t2 : t1.merge(t2);
}
}

具體計算是在BinaryMetricsSummary.merge,其作用就是Merge the bins, and add the logLoss。

@Override
public BinaryMetricsSummary merge(BinaryMetricsSummary binaryClassMetrics) {
for (int i = 0; i < this.positiveBin.length; i++) {
this.positiveBin[i] += binaryClassMetrics.positiveBin[i];
}
for (int i = 0; i < this.negativeBin.length; i++) {
this.negativeBin[i] += binaryClassMetrics.negativeBin[i];
}
this.logLoss += binaryClassMetrics.logLoss;
this.total += binaryClassMetrics.total;
return this;
} // 程式變數是
this = {[email protected]}
labels = {String[2]@9322}
0 = "prefix1"
1 = "prefix0"
total = 2
positiveBin = {long[100000]@9320}
negativeBin = {long[100000]@9323}
logLoss = 1.742969305058623

3.2.3 SaveDataAsParams

this.setOutput(metrics.flatMap(new EvaluationUtil.SaveDataAsParams()),
new String[] {DATA_OUTPUT}, new TypeInformation[] {Types.STRING});

當歸並所有BaseMetrics之後,得到了total BaseMetrics,計算indexes,存入到params。

public static class SaveDataAsParams implements FlatMapFunction<BaseMetricsSummary, Row> {
@Override
public void flatMap(BaseMetricsSummary t, Collector<Row> collector) throws Exception {
collector.collect(t.toMetrics().serialize());
}
}

實際業務在BinaryMetricsSummary.toMetrics中完成,即基於bin的資訊計算,得到confusionMatrix array, threshold array, rocCurve/recallPrecisionCurve/LiftChart等等,然後儲存到params。

public BinaryClassMetrics toMetrics() {
Params params = new Params();
// 生成若干曲線,比如rocCurve/recallPrecisionCurve/LiftChart
Tuple3<ConfusionMatrix[], double[], EvaluationCurve[]> matrixThreCurve =
extractMatrixThreCurve(positiveBin, negativeBin, total); // 依據曲線內容計算並且儲存 AUC/PRC/KS
setCurveAreaParams(params, matrixThreCurve.f2); // 對生成的rocCurve/recallPrecisionCurve/LiftChart輸出進行抽樣
Tuple3<ConfusionMatrix[], double[], EvaluationCurve[]> sampledMatrixThreCurve = sample(
PROBABILITY_INTERVAL, matrixThreCurve); // 依據抽樣後的輸出儲存 RocCurve/RecallPrecisionCurve/LiftChar
setCurvePointsParams(params, sampledMatrixThreCurve);
ConfusionMatrix[] matrices = sampledMatrixThreCurve.f0; // 儲存正例樣本的度量指標
setComputationsArrayParams(params, sampledMatrixThreCurve.f1, sampledMatrixThreCurve.f0); // 儲存Logloss
setLoglossParams(params, logLoss, total); // Pick the middle point where threshold is 0.5.
int middleIndex = getMiddleThresholdIndex(sampledMatrixThreCurve.f1);
setMiddleThreParams(params, matrices[middleIndex], labels);
return new BinaryClassMetrics(params);
}

extractMatrixThreCurve是全文重點。這裡是 Extract the bins who are not empty, keep the middle threshold 0.5,然後初始化了 RocCurve, Recall-Precision Curve and Lift Curve,計算出ConfusionMatrix array(混淆矩陣), threshold array, rocCurve/recallPrecisionCurve/LiftChart.。

/**
* Extract the bins who are not empty, keep the middle threshold 0.5.
* Initialize the RocCurve, Recall-Precision Curve and Lift Curve.
* RocCurve: (FPR, TPR), starts with (0,0). Recall-Precision Curve: (recall, precision), starts with (0, p), p is the precision with the lowest. LiftChart: (TP+FP/total, TP), starts with (0,0). confusion matrix = [TP FP][FN * TN].
*
* @param positiveBin positiveBins.
* @param negativeBin negativeBins.
* @param total sample number
* @return ConfusionMatrix array, threshold array, rocCurve/recallPrecisionCurve/LiftChart.
*/
static Tuple3<ConfusionMatrix[], double[], EvaluationCurve[]> extractMatrixThreCurve(long[] positiveBin, long[] negativeBin, long total) {
ArrayList<Integer> effectiveIndices = new ArrayList<>();
long totalTrue = 0, totalFalse = 0; // 計算totalTrue,totalFalse,effectiveIndices
for (int i = 0; i < ClassificationEvaluationUtil.DETAIL_BIN_NUMBER; i++) {
if (0L != positiveBin[i] || 0L != negativeBin[i]
|| i == ClassificationEvaluationUtil.DETAIL_BIN_NUMBER / 2) {
effectiveIndices.add(i);
totalTrue += positiveBin[i];
totalFalse += negativeBin[i];
}
} // 以我們例子,得到
effectiveIndices = {[email protected]} size = 6
0 = {[email protected]} 50000 //這裡加入了中間點
1 = {[email protected]} 60000
2 = {[email protected]} 70000
3 = {[email protected]} 75000
4 = {[email protected]} 80000
5 = {[email protected]} 90000
totalTrue = 3
totalFalse = 2 // 繼續初始化,生成若干curve
final int length = effectiveIndices.size();
final int newLen = length + 1;
final double m = 1.0 / ClassificationEvaluationUtil.DETAIL_BIN_NUMBER;
EvaluationCurvePoint[] rocCurve = new EvaluationCurvePoint[newLen];
EvaluationCurvePoint[] recallPrecisionCurve = new EvaluationCurvePoint[newLen];
EvaluationCurvePoint[] liftChart = new EvaluationCurvePoint[newLen];
ConfusionMatrix[] data = new ConfusionMatrix[newLen];
double[] threshold = new double[newLen];
long curTrue = 0;
long curFalse = 0; // 以我們例子,得到
length = 6
newLen = 7
m = 1.0E-5 // 計算, 其中rocCurve,recallPrecisionCurve,liftChart 都可以從程式碼中看出
for (int i = 1; i < newLen; i++) {
int index = effectiveIndices.get(length - i);
curTrue += positiveBin[index];
curFalse += negativeBin[index];
threshold[i] = index * m;
// 計算出混淆矩陣
data[i] = new ConfusionMatrix(
new long[][] {{curTrue, curFalse}, {totalTrue - curTrue, totalFalse - curFalse}});
double tpr = (totalTrue == 0 ? 1.0 : 1.0 * curTrue / totalTrue);
// 比如當 90000 這點,得到 curTrue = 1 curFalse = 0 i = 1 index = 90000 tpr = 0.3333333333333333。totalTrue = 3 totalFalse = 2,
// 我們也知道,TPR = TP / (TP + FN) ,所以可以計算 tpr = 1 / 3
rocCurve[i] = new EvaluationCurvePoint(totalFalse == 0 ? 1.0 : 1.0 * curFalse / totalFalse, tpr, threshold[i]);
recallPrecisionCurve[i] = new EvaluationCurvePoint(tpr, curTrue + curTrue == 0 ? 1.0 : 1.0 * curTrue / (curTrue + curFalse), threshold[i]);
liftChart[i] = new EvaluationCurvePoint(1.0 * (curTrue + curFalse) / total, curTrue, threshold[i]);
} // 以我們例子,得到
curTrue = 3
curFalse = 2 threshold = {double[7]@9349}
0 = 0.0
1 = 0.9
2 = 0.8
3 = 0.7500000000000001
4 = 0.7000000000000001
5 = 0.6000000000000001
6 = 0.5 rocCurve = {EvaluationCurvePoint[7]@9315}
1 = {[email protected]}
x = 0.0
y = 0.3333333333333333
p = 0.9
2 = {[email protected]}
x = 0.0
y = 0.6666666666666666
p = 0.8
3 = {[email protected]}
x = 0.5
y = 0.6666666666666666
p = 0.7500000000000001
4 = {[email protected]}
x = 0.5
y = 1.0
p = 0.7000000000000001
5 = {[email protected]}
x = 1.0
y = 1.0
p = 0.6000000000000001
6 = {[email protected]}
x = 1.0
y = 1.0
p = 0.5 recallPrecisionCurve = {EvaluationCurvePoint[7]@9320}
1 = {[email protected]}
x = 0.3333333333333333
y = 1.0
p = 0.9
2 = {[email protected]}
x = 0.6666666666666666
y = 1.0
p = 0.8
3 = {[email protected]}
x = 0.6666666666666666
y = 0.6666666666666666
p = 0.7500000000000001
4 = {[email protected]}
x = 1.0
y = 0.75
p = 0.7000000000000001
5 = {[email protected]}
x = 1.0
y = 0.6
p = 0.6000000000000001
6 = {[email protected]}
x = 1.0
y = 0.6
p = 0.5 liftChart = {EvaluationCurvePoint[7]@9325}
1 = {[email protected]}
x = 0.2
y = 1.0
p = 0.9
2 = {[email protected]}
x = 0.4
y = 2.0
p = 0.8
3 = {[email protected]}
x = 0.6
y = 2.0
p = 0.7500000000000001
4 = {[email protected]}
x = 0.8
y = 3.0
p = 0.7000000000000001
5 = {[email protected]}
x = 1.0
y = 3.0
p = 0.6000000000000001
6 = {[email protected]}
x = 1.0
y = 3.0
p = 0.5 data = {ConfusionMatrix[7]@9339}
0 = {[email protected]}
longMatrix = {[email protected]}
matrix = {long[2][]@9491}
0 = {long[2]@9492}
0 = 0
1 = 0
1 = {long[2]@9493}
0 = 3
1 = 2
rowNum = 2
colNum = 2
labelCnt = 2
total = 5
actualLabelFrequency = {long[2]@9489}
0 = 3
1 = 2
predictLabelFrequency = {long[2]@9490}
0 = 0
1 = 5
tpCount = 2.0
tnCount = 2.0
fpCount = 3.0
fnCount = 3.0
1 = {[email protected]}
longMatrix = {[email protected]}
matrix = {long[2][]@9472}
0 = {long[2]@9474}
0 = 1
1 = 0
1 = {long[2]@9475}
0 = 2
1 = 2
rowNum = 2
colNum = 2
labelCnt = 2
total = 5
actualLabelFrequency = {long[2]@9470}
0 = 3
1 = 2
predictLabelFrequency = {long[2]@9471}
0 = 1
1 = 4
tpCount = 3.0
tnCount = 3.0
fpCount = 2.0
fnCount = 2.0
...... threshold[0] = 1.0;
data[0] = new ConfusionMatrix(new long[][] {{0, 0}, {totalTrue, totalFalse}});
rocCurve[0] = new EvaluationCurvePoint(0, 0, threshold[0]);
recallPrecisionCurve[0] = new EvaluationCurvePoint(0, recallPrecisionCurve[1].getY(), threshold[0]);
liftChart[0] = new EvaluationCurvePoint(0, 0, threshold[0]); return Tuple3.of(data, threshold, new EvaluationCurve[] {new EvaluationCurve(rocCurve),
new EvaluationCurve(recallPrecisionCurve), new EvaluationCurve(liftChart)});
}

3.2.4 計算混淆矩陣

這裡再給大家講講混淆矩陣如何計算,這裡思路比較繞。

3.2.4.1 原始矩陣

呼叫之處是:

// 呼叫之處
data[i] = new ConfusionMatrix(
new long[][] {{curTrue, curFalse}, {totalTrue - curTrue, totalFalse - curFalse}});
// 呼叫時候各種賦值
i = 1
index = 90000
totalTrue = 3
totalFalse = 2
curTrue = 1
curFalse = 0

得到原始矩陣,以下都有cur,說明只針對當前點來說

curTrue = 1 curFalse = 0
totalTrue - curTrue = 2 totalFalse - curFalse = 2
3.2.4.2 計算標籤

後續ConfusionMatrix計算中,由此可以得到

actualLabelFrequency = longMatrix.getColSums();
predictLabelFrequency = longMatrix.getRowSums(); actualLabelFrequency = {long[2]@9322}
0 = 3
1 = 2
predictLabelFrequency = {long[2]@9323}
0 = 1
1 = 4

可以看出來,Alink演演算法認為:每列的sum和實際標籤有關;每行sum和預測標籤有關。

得到新矩陣如下

predictLabelFrequency
curTrue = 1 curFalse = 0 1 = curTrue + curFalse
totalTrue - curTrue = 2 totalFalse - curFalse = 2 4 = total - curTrue - curFalse
actualLabelFrequency 3 = totalTrue 2 = totalFalse

後續計算將要基於這些來計算:

計算中就用到longMatrix 對角線上的資料,即longMatrix(0)(0)和 longMatrix(1)(1)。一定要注意,這裡考慮的都是 當前狀態 (畫重點強調)

longMatrix(0)(0) :curTrue

longMatrix(1)(1) :totalFalse - curFalse

totalFalse :( TN + FN )

totalTrue :( TP + FP )

double numTrueNegative(Integer labelIndex) {
// labelIndex為 0 時候,return 1 + 5 - 1 - 3 = 2;
// labelIndex為 1 時候,return 2 + 5 - 4 - 2 = 1;
return null == labelIndex ? tnCount : longMatrix.getValue(labelIndex, labelIndex) + total - predictLabelFrequency[labelIndex] - actualLabelFrequency[labelIndex];
} double numTruePositive(Integer labelIndex) {
// labelIndex為 0 時候,return 1; 這個是 curTrue,就是真實標籤是True,判別也是True。是TP
// labelIndex為 1 時候,return 2; 這個是 totalFalse - curFalse,總判別錯 - 當前判別錯。這就意味著“本來判別錯了但是當前沒有發現”,所以認為在當前狀態下,這也算是TP
return null == labelIndex ? tpCount : longMatrix.getValue(labelIndex, labelIndex);
} double numFalseNegative(Integer labelIndex) {
// labelIndex為 0 時候,return 3 - 1;
// actualLabelFrequency[0] = totalTrue。所以return totalTrue - curTrue,即當前“全部正確”中沒有“判別為正確”,這個就可以認為是“判別錯了且判別為負”
// labelIndex為 1 時候,return 2 - 2;
// actualLabelFrequency[1] = totalFalse。所以return totalFalse - ( totalFalse - curFalse ) = curFalse
return null == labelIndex ? fnCount : actualLabelFrequency[labelIndex] - longMatrix.getValue(labelIndex, labelIndex);
} double numFalsePositive(Integer labelIndex) {
// labelIndex為 0 時候,return 1 - 1;
// predictLabelFrequency[0] = curTrue + curFalse。
// 所以 return = curTrue + curFalse - curTrue = curFalse = current( TN + FN ) 這可以認為是判斷錯了實際是正確標籤
// labelIndex為 1 時候,return 4 - 2;
// predictLabelFrequency[1] = total - curTrue - curFalse。
// 所以 return = total - curTrue - curFalse - (totalFalse - curFalse) = totalTrue - curTrue = ( TP + FP ) - currentTP = currentFP
return null == labelIndex ? fpCount : predictLabelFrequency[labelIndex] - longMatrix.getValue(labelIndex, labelIndex);
} // 最後得到
tpCount = 3.0
tnCount = 3.0
fpCount = 2.0
fnCount = 2.0
3.2.4.3 具體程式碼
// 具體計算
public ConfusionMatrix(LongMatrix longMatrix) { longMatrix = {[email protected]}
0 = {long[2]@9324}
0 = 1
1 = 0
1 = {long[2]@9325}
0 = 2
1 = 2 this.longMatrix = longMatrix;
labelCnt = this.longMatrix.getRowNum();
// 這裡就是計算
actualLabelFrequency = longMatrix.getColSums();
predictLabelFrequency = longMatrix.getRowSums(); actualLabelFrequency = {long[2]@9322}
0 = 3
1 = 2
predictLabelFrequency = {long[2]@9323}
0 = 1
1 = 4
labelCnt = 2
total = 5 total = longMatrix.getTotal();
for (int i = 0; i < labelCnt; i++) {
tnCount += numTrueNegative(i);
tpCount += numTruePositive(i);
fnCount += numFalseNegative(i);
fpCount += numFalsePositive(i);
}
}

0x04 流處理

4.1 示例

Alink原有python示例程式碼中,Stream部分是沒有輸出的,因為MemSourceStreamOp沒有和時間相關聯,而Alink中沒有提供基於時間的StreamOperator,所以只能自己仿照MemSourceBatchOp寫了一個。雖然程式碼有些醜,但是至少可以提供輸出,這樣就能夠除錯。

4.1.1 主類

public class EvalBinaryClassExampleStream {

    AlgoOperator getData(boolean isBatch) {
Row[] rows = new Row[]{
Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}")
};
String[] schema = new String[]{"label", "detailInput"};
if (isBatch) {
return new MemSourceBatchOp(rows, schema);
} else {
return new TimeMemSourceStreamOp(rows, schema, new EvalBinaryStreamSource());
}
} public static void main(String[] args) throws Exception {
EvalBinaryClassExampleStream test = new EvalBinaryClassExampleStream();
StreamOperator streamData = (StreamOperator) test.getData(false);
StreamOperator sOp = new EvalBinaryClassStreamOp()
.setLabelCol("label")
.setPredictionDetailCol("detailInput")
.setTimeInterval(1)
.linkFrom(streamData);
sOp.print();
StreamOperator.execute();
}
}

4.1.2 TimeMemSourceStreamOp

這個是我自己炮製的。借鑑了MemSourceStreamOp。

public final class TimeMemSourceStreamOp extends StreamOperator<TimeMemSourceStreamOp> {

    public TimeMemSourceStreamOp(Row[] rows, String[] colNames, EvalBinaryStrSource source) {
super(null);
init(source, Arrays.asList(rows), colNames);
} private void init(EvalBinaryStreamSource source, List <Row> rows, String[] colNames) {
Row first = rows.iterator().next();
int arity = first.getArity();
TypeInformation <?>[] types = new TypeInformation[arity]; for (int i = 0; i < arity; ++i) {
types[i] = TypeExtractor.getForObject(first.getField(i));
} init(source, colNames, types);
} private void init(EvalBinaryStreamSource source, String[] colNames, TypeInformation <?>[] colTypes) {
DataStream <Row> dastr = MLEnvironmentFactory.get(getMLEnvironmentId())
.getStreamExecutionEnvironment().addSource(source);
StringBuilder sbd = new StringBuilder();
sbd.append(colNames[0]); for (int i = 1; i < colNames.length; i++) {
sbd.append(",").append(colNames[i]);
}
this.setOutput(dastr, colNames, colTypes);
} @Override
public TimeMemSourceStreamOp linkFrom(StreamOperator<?>... inputs) {
return null;
}
}

4.1.3 Source

定時提供Row,加入了隨機數,讓概率有變化。

class EvalBinaryStreamSource extends RichSourceFunction[Row] {

  override def run(ctx: SourceFunction.SourceContext[Row]) = {
while (true) {
val rdm = Math.random() // 這裡加入了隨機數,讓概率有變化
val rows: Array[Row] = Array[Row](
Row.of("prefix1", "{\"prefix1\": " + rdm + ", \"prefix0\": " + (1-rdm) + "}"),
Row.of("prefix1", "{\"prefix1\": 0.8, \"prefix0\": 0.2}"),
Row.of("prefix1", "{\"prefix1\": 0.7, \"prefix0\": 0.3}"),
Row.of("prefix0", "{\"prefix1\": 0.75, \"prefix0\": 0.25}"),
Row.of("prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}"))
for(row <- rows) {
println(s"當前值:$row")
ctx.collect(row)
}
Thread.sleep(1000)
}
} override def cancel() = ???
}

4.2 BaseEvalClassStreamOp

Alink流處理類是 EvalBinaryClassStreamOp,主要工作在其基類 BaseEvalClassStreamOp,所以我們重點看後者。

public class BaseEvalClassStreamOp<T extends BaseEvalClassStreamOp<T>> extends StreamOperator<T> {
@Override
public T linkFrom(StreamOperator<?>... inputs) {
StreamOperator<?> in = checkAndGetFirst(inputs);
String labelColName = this.get(MultiEvaluationStreamParams.LABEL_COL);
String positiveValue = this.get(BinaryEvaluationStreamParams.POS_LABEL_VAL_STR);
Integer timeInterval = this.get(MultiEvaluationStreamParams.TIME_INTERVAL); ClassificationEvaluationUtil.Type type = ClassificationEvaluationUtil.judgeEvaluationType(this.getParams()); DataStream<BaseMetricsSummary> statistics; switch (type) {
case PRED_RESULT: {
......
}
case PRED_DETAIL: {
String predDetailColName = this.get(MultiEvaluationStreamParams.PREDICTION_DETAIL_COL);
//
PredDetailLabel eval = new PredDetailLabel(positiveValue, binary);
// 獲取輸入資料,重點是timeWindowAll
statistics = in.select(new String[] {labelColName, predDetailColName})
.getDataStream()
.timeWindowAll(Time.of(timeInterval, TimeUnit.SECONDS))
.apply(eval);
break;
}
}
// 把各個視窗的資料累積到 totalStatistics,注意,這裡是新變量了。
DataStream<BaseMetricsSummary> totalStatistics = statistics
.map(new EvaluationUtil.AllDataMerge())
.setParallelism(1); // 並行度設定為1 // 基於兩種 bins 計算&序列化,得到當前的 statistics
DataStream<Row> windowOutput = statistics.map(
new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.WINDOW.f0));
// 基於bins計算&序列化,得到累積的 totalStatistics
DataStream<Row> allOutput = totalStatistics.map(
new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.ALL.f0)); // "當前" 和 "累積" 做聯合,最終返回
DataStream<Row> union = windowOutput.union(allOutput); this.setOutput(union,
new String[] {ClassificationEvaluationUtil.STATISTICS_OUTPUT, DATA_OUTPUT},
new TypeInformation[] {Types.STRING, Types.STRING}); return (T)this;
}
}

具體業務是:

  • PredDetailLabel 會進行去重標籤名字 和 累積計算混淆矩陣所需資料

    • buildLabelIndexLabelArray 去重 "labels名字",然後給每一個label一個ID,最後結果是一個<labels, ID>Map。
    • getDetailStatistics 遍歷 rows 資料,提取每一個item(比如 "prefix1,{"prefix1": 0.8, "prefix0": 0.2}"),然後通過updateBinaryMetricsSummary累積計算混淆矩陣所需資料。
  • 根據標籤從Window中獲取資料 statistics = in.select().getDataStream().timeWindowAll() .apply(eval);
  • EvaluationUtil.AllDataMerge 把各個視窗的資料累積到 totalStatistics 。
  • 得到windowOutput -------- EvaluationUtil.SaveDataStream,對"當前資料statistics"做處理。實際業務在BinaryMetricsSummary.toMetrics,即基於bin的資訊計算,然後儲存到params,並序列化返回Row。
    • extractMatrixThreCurve函式取出非空的bins,據此計算出ConfusionMatrix array(混淆矩陣), threshold array, rocCurve/recallPrecisionCurve/LiftChart.
    • 依據曲線內容計算並且儲存 AUC/PRC/KS
    • 對生成的rocCurve/recallPrecisionCurve/LiftChart輸出進行抽樣
    • 依據抽樣後的輸出儲存 RocCurve/RecallPrecisionCurve/LiftChar
    • 儲存正例樣本的度量指標
    • 儲存Logloss
    • Pick the middle point where threshold is 0.5.
  • 得到allOutput -------- EvaluationUtil.SaveDataStream , 對"累積資料totalStatistics"做處理。
    • 詳細處理流程同windowOutput。
  • windowOutput 和 allOutput 做聯合。最終返回 DataStream union = windowOutput.union(allOutput);

4.2.1 PredDetailLabel

static class PredDetailLabel implements AllWindowFunction<Row, BaseMetricsSummary, TimeWindow> {
@Override
public void apply(TimeWindow timeWindow, Iterable<Row> rows, Collector<BaseMetricsSummary> collector) throws Exception {
HashSet<String> labels = new HashSet<>();
// 首先還是獲取 labels 名字
for (Row row : rows) {
if (EvaluationUtil.checkRowFieldNotNull(row)) {
labels.addAll(EvaluationUtil.extractLabelProbMap(row).keySet());
labels.add(row.getField(0).toString());
}
}
labels = {[email protected]} size = 2
0 = "prefix1"
1 = "prefix0"
// 之前介紹過,buildLabelIndexLabelArray 去重 "labels名字",然後給每一個label一個ID,最後結果是一個<labels, ID>Map。
// getDetailStatistics 遍歷 rows 資料,累積計算混淆矩陣所需資料( "TP + FN" / "TN + FP")。
if (labels.size() > 0) {
collector.collect(
getDetailStatistics(rows, binary, buildLabelIndexLabelArray(labels, binary, positiveValue)));
}
}
}

4.2.2 AllDataMerge

EvaluationUtil.AllDataMerge 把各個視窗的資料累積

/**
* Merge data from different windows.
*/
public static class AllDataMerge implements MapFunction<BaseMetricsSummary, BaseMetricsSummary> {
private BaseMetricsSummary statistics;
@Override
public BaseMetricsSummary map(BaseMetricsSummary value) {
this.statistics = (null == this.statistics ? value : this.statistics.merge(value));
return this.statistics;
}
}

4.2.3 SaveDataStream

SaveDataStream具體呼叫的函式之前批處理介紹過,實際業務在BinaryMetricsSummary.toMetrics,即基於bin的資訊計算,儲存到params。

這裡與批處理不同的是直接就把"構建出的度量資訊“返回給使用者。

public static class SaveDataStream implements MapFunction<BaseMetricsSummary, Row> {
@Override
public Row map(BaseMetricsSummary baseMetricsSummary) throws Exception {
BaseMetricsSummary metrics = baseMetricsSummary;
BaseMetrics baseMetrics = metrics.toMetrics();
Row row = baseMetrics.serialize();
return Row.of(funtionName, row.getField(0));
}
} // 最後得到的 row 其實就是最終返回給使用者的度量資訊
row = {[email protected]} "{"PRC":"0.9164636268708667","SensitivityArray":"[0.38461538461538464,0.6923076923076923,0.6923076923076923,1.0,1.0,1.0]","ConfusionMatrix":"[[13,8],[0,0]]","MacroRecall":"0.5","MacroSpecificity":"0.5","FalsePositiveRateArray":"[0.0,0.0,0.5,0.5,1.0,1.0]" ...... 還有很多其他的

4.2.4 Union

DataStream<Row> windowOutput = statistics.map(
new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.WINDOW.f0));
DataStream<Row> allOutput = totalStatistics.map(
new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.ALL.f0)); DataStream<Row> union = windowOutput.union(allOutput);

最後返回兩種統計資料

4.2.4.1 allOutput
all|{"PRC":"0.7341146115890359","SensitivityArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,0.7333333333333333,0.8,0.8,0.8666666666666667,0.8666666666666667,0.9333333333333333,1.0]","ConfusionMatrix":"[[13,10],[2,0]]","MacroRecall":"0.43333333333333335","MacroSpecificity":"0.43333333333333335","FalsePositiveRateArray":"[0.0,0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.0]","TruePositiveRateArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,0.7333333333333333,0.8,0.8,0.8666666666666667,0.8666666666666667,0.9333333333333333,1.0]","AUC":"0.5666666666666667","MacroAccuracy":"0.52", ......

4.2.4.2 windowOutput

window|{"PRC":"0.7638888888888888","SensitivityArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,1.0,1.0,1.0]","ConfusionMatrix":"[[3,2],[0,0]]","MacroRecall":"0.5","MacroSpecificity":"0.5","FalsePositiveRateArray":"[0.0,0.5,0.5,0.5,1.0,1.0]","TruePositiveRateArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,1.0,1.0,1.0]","AUC":"0.6666666666666666","MacroAccuracy":"0.6","RecallArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,1.0,1.0,1.0]","KappaArray":"[0.28571428571428564,-0.15384615384615377,0.1666666666666666,0.5454545454545455,0.0,0.0]","MicroFalseNegativeRate":"0.4","WeightedRecall":"0.6","WeightedPrecision":"0.36","Recall":"1.0","MacroPrecision":"0.3",......

0xFF 參考

[[白話解析] 通過例項來梳理概念 :準確率 (Accuracy)、精準率(Precision)、召回率(Recall) 和 F值(F-Measure)](