libsvm支援向量機迴歸示例
阿新 • • 發佈:2019-01-09
libsvm支援向量機演算法包的基本使用,此處演示的是支援向量迴歸機
import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.util.ArrayList; import java.util.List; import libsvm.svm; import libsvm.svm_model; import libsvm.svm_node; import libsvm.svm_parameter; import libsvm.svm_problem; public class SVM { public static void main(String[] args) { // 定義訓練集點a{10.0, 10.0} 和 點b{-10.0, -10.0},對應lable為{1.0, -1.0} List<Double> label = new ArrayList<Double>(); List<svm_node[]> nodeSet = new ArrayList<svm_node[]>(); getData(nodeSet, label, "file/train.txt"); int dataRange=nodeSet.get(0).length; svm_node[][] datas = new svm_node[nodeSet.size()][dataRange]; // 訓練集的向量表 for (int i = 0; i < datas.length; i++) { for (int j = 0; j < dataRange; j++) { datas[i][j] = nodeSet.get(i)[j]; } } double[] lables = new double[label.size()]; // a,b 對應的lable for (int i = 0; i < lables.length; i++) { lables[i] = label.get(i); } // 定義svm_problem物件 svm_problem problem = new svm_problem(); problem.l = nodeSet.size(); // 向量個數 problem.x = datas; // 訓練集向量表 problem.y = lables; // 對應的lable陣列 // 定義svm_parameter物件 svm_parameter param = new svm_parameter(); param.svm_type = svm_parameter.EPSILON_SVR; param.kernel_type = svm_parameter.LINEAR; param.cache_size = 100; param.eps = 0.00001; param.C = 1.9; // 訓練SVM分類模型 System.out.println(svm.svm_check_parameter(problem, param)); // 如果引數沒有問題,則svm.svm_check_parameter()函式返回null,否則返回error描述。 svm_model model = svm.svm_train(problem, param); // svm.svm_train()訓練出SVM分類模型 // 獲取測試資料 List<Double> testlabel = new ArrayList<Double>(); List<svm_node[]> testnodeSet = new ArrayList<svm_node[]>(); getData(testnodeSet, testlabel, "file/test.txt"); svm_node[][] testdatas = new svm_node[testnodeSet.size()][dataRange]; // 訓練集的向量表 for (int i = 0; i < testdatas.length; i++) { for (int j = 0; j < dataRange; j++) { testdatas[i][j] = testnodeSet.get(i)[j]; } } double[] testlables = new double[testlabel.size()]; // a,b 對應的lable for (int i = 0; i < testlables.length; i++) { testlables[i] = testlabel.get(i); } // 預測測試資料的lable double err = 0.0; for (int i = 0; i < testdatas.length; i++) { double truevalue = testlables[i]; System.out.print(truevalue + " "); double predictValue = svm.svm_predict(model, testdatas[i]); System.out.println(predictValue); err += Math.abs(predictValue - truevalue); } System.out.println("err=" + err / datas.length); } public static void getData(List<svm_node[]> nodeSet, List<Double> label, String filename) { try { FileReader fr = new FileReader(new File(filename)); BufferedReader br = new BufferedReader(fr); String line = null; while ((line = br.readLine()) != null) { String[] datas = line.split(","); svm_node[] vector = new svm_node[datas.length - 1]; for (int i = 0; i < datas.length - 1; i++) { svm_node node = new svm_node(); node.index = i + 1; node.value = Double.parseDouble(datas[i]); vector[i] = node; } nodeSet.add(vector); double lablevalue = Double.parseDouble(datas[datas.length - 1]); label.add(lablevalue); } } catch (Exception e) { e.printStackTrace(); } } }
訓練資料,最後一列為目標值
17.6,17.7,17.7,17.7,17.8 17.7,17.7,17.7,17.8,17.8 17.7,17.7,17.8,17.8,17.9 17.7,17.8,17.8,17.9,18 17.8,17.8,17.9,18,18.1 17.8,17.9,18,18.1,18.2 17.9,18,18.1,18.2,18.4 18,18.1,18.2,18.4,18.6 18.1,18.2,18.4,18.6,18.7 18.2,18.4,18.6,18.7,18.9 18.4,18.6,18.7,18.9,19.1 18.6,18.7,18.9,19.1,19.3
測試資料
18.7,18.9,19.1,19.3,19.6
18.9,19.1,19.3,19.6,19.9
19.1,19.3,19.6,19.9,20.2
19.3,19.6,19.9,20.2,20.6
19.6,19.9,20.2,20.6,21
19.9,20.2,20.6,21,21.5
20.2,20.6,21,21.5,22