1. 程式人生 > 其它 >Java 機器學習庫Smile實戰(一)SVM

Java 機器學習庫Smile實戰(一)SVM

本文不會介紹SVM的基本原理,如果想了解SVM基本原理,請參閱相關書籍。

要使用Java機器學習庫Smile,需首先在專案的Maven配置檔案pom.xml中新增如下的maven依賴項:

<dependency>
    <groupId>com.github.haifengl</groupId>
    <artifactId>smile-core</artifactId>
    <version>1.4.0</version>
</dependency>

Smile中的SVM是一個泛型類,他可以支援二分類和多分類兩種使用方法,而且這兩種使用方法差異較大,所以分開介紹。

1. 二分類

Smile 庫的SVM類是一個泛型型別,預設情況下進行二分類,選擇引數為核函式型別和懲罰項引數。

import smile.classification.SVM;
import smile.math.kernel.GaussianKernel;

public class Demo {
    public static void main(String[]args){

        double gamma = 1.0;
        double C = 1.0;

        //通過某種方式獲取訓練資料及其類標
        double[][] data = ...
        int[] label = ...

        SVM<double[]> svm = new SVM<double[]>(
           new GaussianKernel(gamma), C);
        
        svm.learn(data, label); //訓練模型
        svm.finish();

        //獲取測試資料
        double[][] testData = ...
        int[] result = new int[testData.length];
        for(int i=0; i < testData.length; i++){
            result[i] = svm.predict(testData[i]);
        }
    }
}

2. 多分類

接下來是我利用SVM對iris資料集進行分類的程式。首先我們將iris資料儲存iris.txt檔案,如下結構:

5.1 3.5 1.4 0.2 0

4.9 3 1.4 0.2 0

...

每一行代表一個測試資料項,前4列是屬性向量,最後一列是類標(在Smile中類標不能為負數,並且只能是從0開始的正整數,所以上述類標為:0、1、2)。檢測的完整的原始碼如下:

import smile.classification.SVM;
import smile.math.kernel.GaussianKernel;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * Created by zhanghuayan on 2017/1/16.
 */
public class ClassificationTest {

    public static void main(String[] args) throws Exception {

        List<List<Double>> datas = 
            new ArrayList<List<Double>>();
        List<Double> data = new ArrayList<Double>();
        List<Integer> labels = new ArrayList<Integer>();

        String line;
        List<String> lines;
        File file = new File("iris.txt");
        BufferedReader reader = 
        new BufferedReader(new FileReader(file));
        
        while ((line = reader.readLine()) != null) {
            lines = Arrays.asList(line.trim().split("t"));
            for (int i = 0; i < lines.size() - 1; i++) {
                data.add(Double.parseDouble(lines.get(i)));
            }
            labels.add(Integer.parseInt(
                  lines.get(lines.size() - 1)));

            datas.add(data);
            data = new ArrayList<Double>();

        }

        //轉換label
        int[] label = new int[labels.size()];
        for (int i = 0; i < label.length; i++) {
            label[i] = labels.get(i);
        }

        //轉換屬性
        int rows = datas.size();
        int cols = datas.get(0).size();
        double[][] srcData = new double[rows][cols];
        for (int i = 0; i < rows; i++) {
            for (int j = 0; j < cols; j++) {
                srcData[i][j] = datas.get(i).get(j);
            }
        }

        SVM<double[]> svm = new SVM<double[]>(
           new GaussianKernel(1.0), 1.0, 3, 
           SVM.Multiclass.ONE_VS_ALL);

        svm.learn(srcData, label);
        svm.finish();

        double right = 0;
        for (int i = 0; i < srcData.length; i++) {
            int tag = svm.predict(srcData[i]);
            if (tag == label[i]) {
                right += 1;
            }
        }
        right = right / srcData.length;

        System.out.println(
        "Accrurate: " + right * 100 + "%");
    }
}