1. 程式人生 > >weka 建立迴歸模型並輸出打分結果

weka 建立迴歸模型並輸出打分結果

例子是從資料庫中讀取資料,並拆分成建模資料,測試資料,建立多層感知器模型,並對測試資料進行打分,輸出打分結果

package weka.regression;

import java.io.File;

import org.apache.commons.compress.utils.IOUtils;
import org.apache.commons.io.FileUtils;

import weka.Utils;
import weka.classifiers.evaluation.Evaluation;
import weka.classifiers.functions.MultilayerPerceptron;
import
weka.core.Instance; import weka.core.Instances; public class RegressionTest { public static void main(String[] args) { RegressionTest regressionTest = new RegressionTest(); Instances modelData = regressionTest.loadData(); //資料拆分成建模資料和測試資料 int trainSize = (int) Math.round(modelData.numInstances() * 0.80
); int testSize = modelData.numInstances() - trainSize; modelData.setClassIndex(modelData.numAttributes() - 1); Instances train = new Instances(modelData, 0, trainSize); train.setClassIndex(train.numAttributes() - 1); MultilayerPerceptron model = regressionTest.trainModel(train); Instances test = new
Instances(modelData, trainSize, testSize); test.setClassIndex(test.numAttributes() - 1); regressionTest.evaluate(model, test); } public MultilayerPerceptron trainModel(Instances train){ MultilayerPerceptron model = new MultilayerPerceptron(); try { model.buildClassifier(train); } catch (Exception e) { e.printStackTrace(); } System.out.println(model); return model; } public Instances loadData(){ String sql = "select * from qy_car_model"; Instances modelData = null; try { modelData = Utils.loadDataSetFromOracle(sql); } catch (Exception e) { e.printStackTrace(); } return modelData; } public void evaluate(MultilayerPerceptron model,Instances testData){ Evaluation eval;// 構造評價器 try { eval = new Evaluation(testData); eval.evaluateModel(model, testData);// 用測試資料集來評價m_classifier double sum = testData.numInstances(); // 測試語料例項數 StringBuffer buf = new StringBuffer(); buf.append("SPD,ALPHA,PEDAL,BRAKE,RECOV,STEER,TEMP,真實值,預測值\n"); for (int i = 0; i < sum; i++) { Instance ins = testData.instance(i); for(int j=0;j<ins.numAttributes();j++){ buf.append(ins.value(j)).append(",");//輸出每條資料 } buf.append(model.classifyInstance(ins)); buf.append("\n"); } File file = new File("testApplyResult.csv"); FileUtils.writeStringToFile(file , buf.toString(),"UTF-8"); System.out.println(buf); } catch (Exception e) { e.printStackTrace(); } } }
public static Instances loadDataSetFromOracle(String sql) throws Exception{
        InstanceQuery query = new InstanceQuery();
        query.setUsername("ywf");
        query.setPassword("ywf");
        File file = new File("DatabaseUtils.props.oracle");
        query.initialize(file);
        query.setQuery(sql);
        Instances data = query.retrieveInstances();
       return data;
    }