多層前饋神經元網路
阿新 • • 發佈:2019-02-09
做了一個神經元網路分類器。開始步長設定為迭代次數的倒數,效果不好;後來調整到 0.2 效果比較好。測試一個拋物線邊界的例子,準確率大約 96% 以上。
public final class NeuroNetwork { private static class Neurode { double err; double output; double theta; } private static enum Status { NEW, TRAINED; } // status of this class, either NEW or TRAINED private Status status; // depth of network, layer 0 is input layer private int depth; // neurodes in each layer private Neurode[][] neurodes; // weights[i] is a two dimensional array, representing weights between layer i and layer 1+1 private double[][][] weights; // initialize the neuronetwork /** * Initialize the neuronetwork * * @param depth : the number of layers * @param numNeurodes : the number of neurodes in each layer */ public NeuroNetwork(int depth, int[] numNeurodes) { this.depth = depth; // create and initialize neurodes neurodes = new Neurode[depth][]; for ( int d=0; d<depth; d++ ) { neurodes[d] = new Neurode[numNeurodes[d]]; for ( int i=0; i<numNeurodes[d]; i++) { neurodes[d][i] = new Neurode(); neurodes[d][i].theta = Math.random(); } } // initialize weights weights = new double[depth][][]; for ( int d=0; d<depth-1; d++ ) { weights[d] = new double[numNeurodes[d]][numNeurodes[d+1]]; for ( int i=0; i<numNeurodes[d]; i++) { for ( int j=0; j<numNeurodes[d+1]; j++ ) { weights[d][i][j] = Math.random(); } } } status = Status.NEW; } /** * Calculate output given a input * * @param data : an vector representing input */ private void calculateOutput(double[] data) { // initial output of layer 0 for (int i=0; i<neurodes[0].length; i++ ) { neurodes[0][i].output = data[i]; } // calculate output for each output layer for ( int d=1; d<depth; d++ ) { for ( int j=0; j<neurodes[d].length; j++) { double input = 0.0; for ( int i=0; i<neurodes[d-1].length; i++ ) { input += neurodes[d-1][i].output*weights[d-1][i][j]; } input += neurodes[d][j].theta; neurodes[d][j].output = 1.0/(1.0+Math.exp(0.0-input)); } } } /** * Classify and predict * * @param data : an vector represent one entry of taining sample * @param target : an vector represent class label of the training sample */ public int predict(double[] data, double[] output) { if ( data.length != neurodes[0].length || output.length != neurodes[depth-1].length ) { throw new IllegalArgumentException(); } calculateOutput(data); double x = neurodes[depth-1][0].output; int label = 0; for ( int i=0; i<neurodes[depth-1].length; i++ ) { output[i] = neurodes[depth-1][i].output; if ( x < output[i] ) { x = output[i]; label = i; } } return label; } /** * Train the neuronetwork * * @param data : input matrix of train data, with data[i] represents the ith sample * @param target : input matrix of train label, with target[i] represents the ith label * @param maxIteration : maximum times of interation * @param threshold : threshold of weights update * @param errorRate : threshold for error rate * @return */ public boolean train(double[][] data, double target[][], int maxIteration, double threshold, double errorRate) { // check status if ( status == Status.TRAINED ){ throw new IllegalStateException(); } // check input arguments and input parameters if ( data.length <=0 || data[0].length != neurodes[0].length || target.length == 0 || target[0].length != neurodes[depth-1].length ) { throw new IllegalArgumentException(); } int round = 1; boolean convergence = false; while ( round <= maxIteration && ! convergence ) { double rate = 0.2;//1.0/round; // learn rate double delta = 0.0; for ( int r=0; r<data.length; r++) { double res = trainWithOneSample(data[r], target[r], rate); delta = (delta<res)?res:delta; } convergence = (delta<threshold); round++; System.out.printf(" %d round of train, delta is %f %n", round-1, delta); } return true; } /** * Train the neuronetwork with one entry of sample data * * @param data : an vector represent one entry of taining sample * @param target : an vector represent class label of the training sample * @param rate : learn rate * @return : maximum detla of weights */ private double trainWithOneSample(double[] data, double[] target, double rate) { calculateOutput(data); // calculate error for layer n-1 for ( int j=0; j<neurodes[depth-1].length; j++ ) { double output = neurodes[depth-1][j].output; neurodes[depth-1][j].err = output*(1-output)*(target[j]-output); } // calculate error for hidden layers n-2 ... 1 for ( int d=depth-2; d>0; d-- ) { for ( int j=0; j<neurodes[d].length; j++ ) { double error = 0.0; for ( int k=0; k<neurodes[d+1].length; k++ ) { error += neurodes[d+1][k].err*weights[d][j][k]; } double output = neurodes[d][j].output; neurodes[d][j].err = output*(1-output)*error; } } double maxDelta = 0.0; // update weights for ( int d=0; d<depth-1; d++ ) { for ( int i=0; i<neurodes[d].length; i++ ) { for ( int j=0; j<neurodes[d+1].length; j++ ) { double delta = neurodes[d][i].output*neurodes[d+1][j].err; weights[d][i][j] += rate*delta; if ( maxDelta < Math.abs(delta) ) { maxDelta = Math.abs(delta); } } } } // update theta for ( int d=1; d<depth; d++ ) { for ( int j=0; j<neurodes[d].length; j++ ) { neurodes[d][j].theta += rate*neurodes[d][j].err; } } return maxDelta; } }
測試:
public class TestMain { public static double[][][] generateData(int m) { double[][][] res = new double[2][][]; double[][] data = new double[m*m][2]; double[][] label = new double[m*m][3]; for ( int i=0; i<m; i++ ) { double x = i/(m-1.0); for ( int j=0; j<m; j++ ) { double y = j/(m-1.0); data[i*m+j][0] = x; data[i*m+j][1] = y; label[i*m+j][0] = label[i*m+j][1] = label[i*m+j][2] = 0; if ( y > 4.0*(x-0.5)*(x-0.5) ) { label[i*m+j][0] = 1; } else if ( x < 0.5 ) { label[i*m+j][1] = 1; } else { label[i*m+j][2] = 1; } } } res[0] = data; res[1] = label; return res; } public static int calculateLabel(double x, double y) { if ( y > 4.0*(x-0.5)*(x-0.5) ) { return 0; } else if ( x < 0.5 ) { return 1; } else { return 2; } } /** * @param args */ public static void main(String[] args) { int[] num = { 2, 3, 3 }; int m = 10, n = 3; NeuroNetwork inst = new NeuroNetwork(num.length, num); double[][][] trainData = generateData(m); inst.train(trainData[0], trainData[1], 1000000, 0.001, 0.8); int t=50, success = 0; double[][][] testData = generateData(t); for ( int i=0; i<t*t; i++ ) { int res = inst.predict(testData[0][i], testData[1][i]); int ans = calculateLabel(testData[0][i][0], testData[0][i][1]); if ( res == ans ) { success ++; } System.out.printf("<%f, %f> : %d %b%n",testData[0][i][0],testData[0][i][1],res,res==ans); } System.out.printf("Accuracy rate is %f%n", (success+0.0)/(t*t)); } }