1. 程式人生 > >最佳閾值劃分問題

最佳閾值劃分問題

static int 有時 string ner 1.4 ble pro comparing

分類問題有時是個回歸問題。這就需要找到閾值來將樣本準確地劃分到類別。

例如一個文本情感分類問題:情感有0(消極)、1(中性)、2(積極)三種類別。回歸器返回的情感的分值分別為0.2,0.3,0.4,0.45,0.66,1.2,1.3,1.4,它們對應的類別分別為0,0,1,2,1,1,2,2,需要找到兩個閾值x,y,小於x的表示0類別,x和y之間的表示1類別,大於y的表示2類別。

如果尋找最佳答案,復雜度為O(樣本數^類別數)。

如果使用貪心法,問題復雜度可以降為O(樣本數)的復雜度。
把每次尋找閾值當做一個二分類問題。這種方法能夠犧牲準確性換來時間效率的提升。

import java.util.Arrays;
import java.util.Comparator;
import java.util.Random;

public class Main {
Random r = new Random();

class Sample {
    int score;//樣本的得分
    int type;//樣本的類別

    Sample(int score, int type) {
        this.score = score;
        this.type = type;
    }

    @Override
    public String toString() {
        return "(" + score + "," + type + ")";
    }
}

Sample[] generateProblem() {
    int n = r.nextInt(4) + 2;
    Sample[] a = new Sample[n];
    for (int i = 0; i < n; i++) {
        a[i] = new Sample(r.nextInt(50), r.nextInt(3));
    }
    Arrays.sort(a, Comparator.comparingInt(x -> x.score));
    return a;
}

int bruteforceScore(Sample[] a) {
    int bestI = 0, bestJ = 0;
    int bestScore = 0;
    for (int i = 0; i <= a.length; i++) {//第一個閾值
        for (int j = i; j <= a.length; j++) {//第二個閾值
            int score = getScore(a, i, j);
            if (score > bestScore) {
                bestScore = score;
                bestI = i;
                bestJ = j;
            }
        }
    }
    System.out.println("ans i: " + bestI + " ans j:" + bestJ);
    return bestScore;
}

int getScore(Sample[] a, int i, int j) {
    int rightCount = 0;
    for (int k = 0; k < a.length; k++) {
        if (k < i && a[k].type == 0) {
            rightCount++;
        } else if (k >= i && k < j && a[k].type == 1) {
            rightCount++;
        } else if (k >= j && a[k].type == 2) {
            rightCount++;
        }
    }
    return rightCount;
}

int mine(Sample[] a) {
    int bestI = 0;
    long bestRightCOunt = 0;
    long rightCount = Arrays.stream(a).filter(x -> x.type != 0).count();
    for (int i = 0; i < a.length; i++) {
        if (rightCount >= bestRightCOunt) {
            bestRightCOunt = rightCount;
            bestI = i;
        }
        if (a[i].type == 0) rightCount++;
        else rightCount--;
    }
    if (rightCount >= bestRightCOunt) {
        bestI = a.length;
    }
    bestRightCOunt = 0;
    final int goodI = bestI;
    if (goodI == a.length) {
        //全0的情況
        return getScore(a, a.length, a.length);
    }
    rightCount = Arrays.stream(a).filter(x -> x.score >= a[goodI].score && x.type == 2).count();
    int bestJ = 0;
    for (int i = bestI; i < a.length; i++) {
        if (rightCount > bestRightCOunt) {
            bestRightCOunt = rightCount;
            bestJ = i;
        }
        if (a[i].type == 1) rightCount++;
        else if (a[i].type == 2) rightCount--;
    }
    if (rightCount > bestRightCOunt) {
        bestJ = a.length;
    }
    System.out.println();
    System.out.println("bestI " + bestI + " bestJ " + bestJ);
    return getScore(a, bestI, bestJ);
}

Main() {
    while (true) {
        Sample a[] = generateProblem();
        for (Sample i : a) {
            System.out.print(i);
        }
        System.out.println();
        int x = bruteforceScore(a);
        int y = mine(a);
        System.out.println(x + " " + y);
        if (x != y) break;
    }
}

public static void main(String[] args) {
    new Main();
}
}

最佳閾值劃分問題