weka 建立迴歸模型並輸出打分結果
阿新 • • 發佈:2019-02-17
例子是從資料庫中讀取資料,並拆分成建模資料,測試資料,建立多層感知器模型,並對測試資料進行打分,輸出打分結果
package weka.regression;
import java.io.File;
import org.apache.commons.compress.utils.IOUtils;
import org.apache.commons.io.FileUtils;
import weka.Utils;
import weka.classifiers.evaluation.Evaluation;
import weka.classifiers.functions.MultilayerPerceptron;
import weka.core.Instance;
import weka.core.Instances;
public class RegressionTest {
public static void main(String[] args) {
RegressionTest regressionTest = new RegressionTest();
Instances modelData = regressionTest.loadData();
//資料拆分成建模資料和測試資料
int trainSize = (int) Math.round(modelData.numInstances() * 0.80 );
int testSize = modelData.numInstances() - trainSize;
modelData.setClassIndex(modelData.numAttributes() - 1);
Instances train = new Instances(modelData, 0, trainSize);
train.setClassIndex(train.numAttributes() - 1);
MultilayerPerceptron model = regressionTest.trainModel(train);
Instances test = new Instances(modelData, trainSize, testSize);
test.setClassIndex(test.numAttributes() - 1);
regressionTest.evaluate(model, test);
}
public MultilayerPerceptron trainModel(Instances train){
MultilayerPerceptron model = new MultilayerPerceptron();
try {
model.buildClassifier(train);
} catch (Exception e) {
e.printStackTrace();
}
System.out.println(model);
return model;
}
public Instances loadData(){
String sql = "select * from qy_car_model";
Instances modelData = null;
try {
modelData = Utils.loadDataSetFromOracle(sql);
} catch (Exception e) {
e.printStackTrace();
}
return modelData;
}
public void evaluate(MultilayerPerceptron model,Instances testData){
Evaluation eval;// 構造評價器
try {
eval = new Evaluation(testData);
eval.evaluateModel(model, testData);// 用測試資料集來評價m_classifier
double sum = testData.numInstances(); // 測試語料例項數
StringBuffer buf = new StringBuffer();
buf.append("SPD,ALPHA,PEDAL,BRAKE,RECOV,STEER,TEMP,真實值,預測值\n");
for (int i = 0; i < sum; i++) {
Instance ins = testData.instance(i);
for(int j=0;j<ins.numAttributes();j++){
buf.append(ins.value(j)).append(",");//輸出每條資料
}
buf.append(model.classifyInstance(ins));
buf.append("\n");
}
File file = new File("testApplyResult.csv");
FileUtils.writeStringToFile(file , buf.toString(),"UTF-8");
System.out.println(buf);
} catch (Exception e) {
e.printStackTrace();
}
}
}
public static Instances loadDataSetFromOracle(String sql) throws Exception{
InstanceQuery query = new InstanceQuery();
query.setUsername("ywf");
query.setPassword("ywf");
File file = new File("DatabaseUtils.props.oracle");
query.initialize(file);
query.setQuery(sql);
Instances data = query.retrieveInstances();
return data;
}