Java 機器學習庫Smile實戰(二)AdaBoost
1. AdaBoost演算法簡介
Boost 算法系列的起源來自於PAC Learnability(PAC 可學習性)。這套理論主要研究的是什麼時候一個問題是可被學習的,當然也會探討針對可學習的問題的具體的學習演算法。這套理論是由Valiant提出來的,也因此(還有其他貢獻哈)他獲得了2010年的圖靈獎。下面是他的照片:
PAC 定義了學習演算法的強弱:
弱學習演算法 — 識別錯誤率小於1/2(即準確率僅比隨機猜測略高的學習演算法);
強學習演算法 — 識別準確率很高並能在多項式時間內完成的學習演算法。
同時 ,Valiant和 Kearns首次提出了 PAC學習模型中弱學習演算法和強學習演算法的等價性問題,即任意給定僅比隨機猜測略好的弱學習演算法 ,是否可以將其提升為強學習演算法 ? 如果二者等價 ,那麼只需找到一個比隨機猜測略好的弱學習演算法就可以將其提升為強學習演算法 ,而不必尋找很難獲得的強學習演算法。 也就是這種猜測,讓無數牛人去設計演算法來驗證PAC理論的正確性。
不過很長一段時間都沒有一個切實可行的辦法來實現這個理想。細節決定成敗,再好的理論也需要有效的演算法來執行。終於功夫不負有心人, Schapire在1996年提出一個有效的演算法真正實現了這個夙願,它的名字叫AdaBoost。AdaBoost把多個不同的決策樹用一種非隨機的方式組合起來,表現出驚人的效能!第一,把決策樹的準確率大大提高,可以與SVM媲美。第二,速度快,且基本不用調引數。第三,幾乎不Overfitting。我估計當時Breiman和Friedman肯定高興壞了,因為眼看著他們提出的CART正在被SVM比下去的時候,AdaBoost讓決策樹起死回生!Breiman情不自禁地在他的論文裡讚揚AdaBoost是最好的現貨方法(off-the-shelf,即“拿下了就可以用”的意思)。
下面是Adaboost 演算法的實現思想:
接下來是Smile庫中的AdaBoost演算法的用法實戰程式:
import smile.classification.AdaBoost; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.util.ArrayList; import java.util.Arrays; import java.util.List; /** * Created by zhanghuayan on 2017/1/17. */ public class AdaBoostTest { public static void main(String[] args) throws Exception { List<List<Double>> datas = new ArrayList<List<Double>>(); List<Double> data = new ArrayList<Double>(); List<Integer> labels = new ArrayList<Integer>(); String line; List<String> lines; File file = new File("iris.txt"); BufferedReader reader = new BufferedReader(new FileReader(file)); while ((line = reader.readLine()) != null) { lines = Arrays.asList(line.trim().split("t")); for (int i = 0; i < lines.size() - 1; i++) { data.add(Double.parseDouble(lines.get(i))); } labels.add(Integer.parseInt(lines.get(lines.size() - 1))); datas.add(data); data = new ArrayList<Double>(); } //轉換label int[] label = new int[labels.size()]; for (int i = 0; i < label.length; i++) { label[i] = labels.get(i); } //轉換屬性 int rows = datas.size(); int cols = datas.get(0).size(); double[][] srcData = new double[rows][cols]; for (int i = 0; i < rows; i++) { for (int j = 0; j < cols; j++) { srcData[i][j] = datas.get(i).get(j); } } AdaBoost adaBoost = new AdaBoost(srcData, label, 4, 8); double right = 0; for (int i = 0; i < srcData.length; i++) { int tag = adaBoost.predict(srcData[i]); if (i % 10 == 0) System.out.println(); System.out.print(tag + " "); if (tag == label[i]) { right += 1; } } right = right / srcData.length; System.out.println("Accrurate: " + right * 100 + "%"); } }
資料集與Java機器學習庫Smile實戰(一)SVM中的iris資料集格式一致。下面是程式輸出的結果:
0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1
2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2
Accrurate: 100.0%
當然這裡的訓練資料和測試資料使用的是同一資料,所以可能出現過擬合現象。因此現實訓練的時候應採取例如交叉驗證等技術手段調整引數。
2. 引數解釋
第3個引數是子樹的個數,第4個引數是每個子樹最大葉子節點數:
AdaBoost adaBoost = new AdaBoost(srcData, label, 4, 8);
訓練好了之後就可以如下呼叫predict()方法:
int tag = adaBoost.predict(srcData[i]);