Neuroph感知機自我學習實現記憶邏輯與
阿新 • • 發佈:2019-02-03
1.建立增加學習演算法的感知機
/** * @author Ragty * @param 增加學習演算法的感知機(記憶邏輯與) * @serialData 2018.4.22 * @param inputNeuralCount */ public void creatPerceptron(int inputNeuralCount){ //設定型別為感知機 this.setNetworkType(NeuralNetworkType.PERCEPTRON); //建立輸入神經元,表示輸入刺激 NeuronProperties inputNeuronProperties = new NeuronProperties(); inputNeuronProperties.setProperty("neuronType",InputNeuron.class); //建立輸入層 Layer inputLayer = LayerFactory.createLayer(inputNeuralCount, inputNeuronProperties); this.addLayer(inputLayer); inputLayer.addNeuron(new BiasNeuron()); //建立輸出神經元(傳輸函式為step) NeuronProperties outputNeuronProperties = new NeuronProperties(); outputNeuronProperties.setProperty("transferFunction", TransferFunctionType.STEP); //建立輸出層 Layer outputLayer = LayerFactory.createLayer(1, outputNeuronProperties); this.addLayer(outputLayer); //輸入層輸出層全連線 ConnectionFactory.fullConnect(inputLayer, outputLayer); NeuralNetworkFactory.setDefaultIO(this); //設定感知機學習演算法 this.setLearningRule(new perceptronLearningRule()); }
2.建立學習演算法
3.訓練資料並測試public class perceptronLearningRule extends SupervisedLearning implements Serializable{ private static final long serialVersionUID = 1L; public perceptronLearningRule() { } /** * @author Ragty * @param 迭代計算權值 * @serialData 2018.4.22 */ @Override protected void updateNetworkWeights(double[] outputError) { int i = 0; for (Neuron neuron : neuralNetwork.getOutputNeurons()) { neuron.setError(outputError[i]); double neuronError = neuron.getError(); // 根據所有的神經元輸入 迭代學習 for (Connection connection : neuron.getInputConnections()) { // 神經元的一個輸入 double input = connection.getInput(); // 計算權值的變更 double weightChange = neuronError * input; // 更新權值 Weight weight = connection.getWeight(); weight.weightChange = weightChange; weight.value += weightChange; } i++; } } }
4.學習結果public class AndPerceptron implements LearningEventListener{ public static void main(String[] args) { new AndPerceptron().run(); } public void run(){ //給出學習的訓練資料(用於訓練神經網路) //資料集有兩個輸入,一個輸出 //dataSetRow的建構函式接受兩個引數,第一個為輸入向量,第二個為期望值 DataSet trainningSet = new DataSet(2,1); trainningSet.addRow(new DataSetRow(new double[]{0,0},new double[]{0})); trainningSet.addRow(new DataSetRow(new double[]{0,1},new double[]{0})); trainningSet.addRow(new DataSetRow(new double[]{1,0},new double[]{0})); trainningSet.addRow(new DataSetRow(new double[]{1,1},new double[]{1})); //建立一個只有兩個輸入節點的感知機 simplePerceptron andPerceptron = new simplePerceptron(2); //給學習過程增加事件監聽器(監督訓練) perceptronLearningRule learningRule = (perceptronLearningRule) andPerceptron.getLearningRule(); learningRule.addListener(this); //使用訓練資料訓練感知機(進行學習) System.out.println("訓練開始"); andPerceptron.learn(trainningSet); //測試感知機是否能正確輸出 System.out.println("測試輸出"); testNeuralNetwork(andPerceptron, trainningSet); } /** * @author Ragty * @param 訓練之後對網路測試(測試感知機) * @serialData 2018.4.22 * @param neuralNetwork * @param data */ public static void testNeuralNetwork(NeuralNetwork neuralNetwork, DataSet testSet){ for(DataSetRow testSetRow : testSet.getRows()){ neuralNetwork.setInput(testSetRow.getInput()); neuralNetwork.calculate(); double[] networkOutput = neuralNetwork.getOutput(); System.out.println("Input:"+Arrays.toString(testSetRow.getInput())); System.out.println("Output:"+Arrays.toString(networkOutput)); } } //監督訓練過程 @Override public void handleLearningEvent(LearningEvent event) { // TODO Auto-generated method stub //所有迭代學習演算法的基類, 它為它的所有子類提供迭代學習過程 IterativeLearning bp = (IterativeLearning) event.getSource(); System.out.println("iterate:"+bp.getCurrentIteration()); System.out.println(Arrays.toString(bp.getNeuralNetwork().getWeights())); } }