Java 機器學習庫Smile實戰(一)SVM
阿新 • • 發佈:2022-05-03
本文不會介紹SVM的基本原理,如果想了解SVM基本原理,請參閱相關書籍。
要使用Java機器學習庫Smile,需首先在專案的Maven配置檔案pom.xml中新增如下的maven依賴項:
<dependency>
<groupId>com.github.haifengl</groupId>
<artifactId>smile-core</artifactId>
<version>1.4.0</version>
</dependency>
Smile中的SVM是一個泛型類,他可以支援二分類和多分類兩種使用方法,而且這兩種使用方法差異較大,所以分開介紹。
1. 二分類
Smile 庫的SVM類是一個泛型型別,預設情況下進行二分類,選擇引數為核函式型別和懲罰項引數。
import smile.classification.SVM; import smile.math.kernel.GaussianKernel; public class Demo { public static void main(String[]args){ double gamma = 1.0; double C = 1.0; //通過某種方式獲取訓練資料及其類標 double[][] data = ... int[] label = ... SVM<double[]> svm = new SVM<double[]>( new GaussianKernel(gamma), C); svm.learn(data, label); //訓練模型 svm.finish(); //獲取測試資料 double[][] testData = ... int[] result = new int[testData.length]; for(int i=0; i < testData.length; i++){ result[i] = svm.predict(testData[i]); } } }
2. 多分類
接下來是我利用SVM對iris資料集進行分類的程式。首先我們將iris資料儲存iris.txt檔案,如下結構:
5.1 3.5 1.4 0.2 0
4.9 3 1.4 0.2 0
...
每一行代表一個測試資料項,前4列是屬性向量,最後一列是類標(在Smile中類標不能為負數,並且只能是從0開始的正整數,所以上述類標為:0、1、2)。檢測的完整的原始碼如下:
import smile.classification.SVM; import smile.math.kernel.GaussianKernel; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.util.ArrayList; import java.util.Arrays; import java.util.List; /** * Created by zhanghuayan on 2017/1/16. */ public class ClassificationTest { public static void main(String[] args) throws Exception { List<List<Double>> datas = new ArrayList<List<Double>>(); List<Double> data = new ArrayList<Double>(); List<Integer> labels = new ArrayList<Integer>(); String line; List<String> lines; File file = new File("iris.txt"); BufferedReader reader = new BufferedReader(new FileReader(file)); while ((line = reader.readLine()) != null) { lines = Arrays.asList(line.trim().split("t")); for (int i = 0; i < lines.size() - 1; i++) { data.add(Double.parseDouble(lines.get(i))); } labels.add(Integer.parseInt( lines.get(lines.size() - 1))); datas.add(data); data = new ArrayList<Double>(); } //轉換label int[] label = new int[labels.size()]; for (int i = 0; i < label.length; i++) { label[i] = labels.get(i); } //轉換屬性 int rows = datas.size(); int cols = datas.get(0).size(); double[][] srcData = new double[rows][cols]; for (int i = 0; i < rows; i++) { for (int j = 0; j < cols; j++) { srcData[i][j] = datas.get(i).get(j); } } SVM<double[]> svm = new SVM<double[]>( new GaussianKernel(1.0), 1.0, 3, SVM.Multiclass.ONE_VS_ALL); svm.learn(srcData, label); svm.finish(); double right = 0; for (int i = 0; i < srcData.length; i++) { int tag = svm.predict(srcData[i]); if (tag == label[i]) { right += 1; } } right = right / srcData.length; System.out.println( "Accrurate: " + right * 100 + "%"); } }