1. 程式人生 > >Neuroph感知機自我學習實現記憶邏輯與

Neuroph感知機自我學習實現記憶邏輯與

1.建立增加學習演算法的感知機

	/**
	 * @author Ragty
	 * @param  增加學習演算法的感知機(記憶邏輯與)
	 * @serialData 2018.4.22
	 * @param inputNeuralCount
	 */
	public void creatPerceptron(int inputNeuralCount){
		
		//設定型別為感知機
		this.setNetworkType(NeuralNetworkType.PERCEPTRON);
		
		//建立輸入神經元,表示輸入刺激
		NeuronProperties inputNeuronProperties = new NeuronProperties();
		inputNeuronProperties.setProperty("neuronType",InputNeuron.class);
		
		//建立輸入層
		Layer inputLayer = LayerFactory.createLayer(inputNeuralCount, inputNeuronProperties);
		this.addLayer(inputLayer);
		inputLayer.addNeuron(new BiasNeuron());
		
		//建立輸出神經元(傳輸函式為step)
		NeuronProperties outputNeuronProperties = new NeuronProperties();
		outputNeuronProperties.setProperty("transferFunction", TransferFunctionType.STEP);
		
		//建立輸出層
		Layer outputLayer = LayerFactory.createLayer(1, outputNeuronProperties);
		this.addLayer(outputLayer);
		
		//輸入層輸出層全連線
		ConnectionFactory.fullConnect(inputLayer, outputLayer);
		NeuralNetworkFactory.setDefaultIO(this);
		
		//設定感知機學習演算法
		this.setLearningRule(new perceptronLearningRule());
	}


2.建立學習演算法

public class perceptronLearningRule extends SupervisedLearning implements Serializable{

	private static final long serialVersionUID = 1L;

	public perceptronLearningRule() {

    }
	
	/**
	 * @author Ragty
	 * @param  迭代計算權值
	 * @serialData 2018.4.22
	 */
	@Override
    protected void updateNetworkWeights(double[] outputError) {
        int i = 0;
        for (Neuron neuron : neuralNetwork.getOutputNeurons()) {
            neuron.setError(outputError[i]); 
            double neuronError = neuron.getError();
            // 根據所有的神經元輸入 迭代學習
            for (Connection connection : neuron.getInputConnections()) {
                // 神經元的一個輸入
                double input = connection.getInput();
                // 計算權值的變更
                double weightChange =  neuronError * input;
                // 更新權值
                Weight weight = connection.getWeight();
                weight.weightChange = weightChange;                
                weight.value += weightChange;
            }

            i++;
        }
    }

	

}
3.訓練資料並測試
public class AndPerceptron implements LearningEventListener{

	public static void main(String[] args) {
		new AndPerceptron().run();
	}
	
	public void run(){
		
	   //給出學習的訓練資料(用於訓練神經網路)
	   //資料集有兩個輸入,一個輸出
	   //dataSetRow的建構函式接受兩個引數,第一個為輸入向量,第二個為期望值
	   DataSet trainningSet = new DataSet(2,1);	
	   trainningSet.addRow(new DataSetRow(new double[]{0,0},new double[]{0}));
	   trainningSet.addRow(new DataSetRow(new double[]{0,1},new double[]{0}));
	   trainningSet.addRow(new DataSetRow(new double[]{1,0},new double[]{0}));
	   trainningSet.addRow(new DataSetRow(new double[]{1,1},new double[]{1}));
	   
	   //建立一個只有兩個輸入節點的感知機
	   simplePerceptron andPerceptron = new simplePerceptron(2);
	   
	   //給學習過程增加事件監聽器(監督訓練)
	   perceptronLearningRule learningRule = (perceptronLearningRule) andPerceptron.getLearningRule();
	   learningRule.addListener(this);	
	   
	   //使用訓練資料訓練感知機(進行學習)
	   System.out.println("訓練開始");
	   andPerceptron.learn(trainningSet);
	   
	   //測試感知機是否能正確輸出
	   System.out.println("測試輸出");
	   testNeuralNetwork(andPerceptron, trainningSet);
	}
	
	
	/**
	 * @author Ragty
	 * @param  訓練之後對網路測試(測試感知機)
	 * @serialData 2018.4.22
	 * @param neuralNetwork
	 * @param data
	 */
	public static void testNeuralNetwork(NeuralNetwork neuralNetwork, DataSet testSet){
		
		for(DataSetRow testSetRow : testSet.getRows()){
			neuralNetwork.setInput(testSetRow.getInput());
			neuralNetwork.calculate();
			double[] networkOutput = neuralNetwork.getOutput();
			
			System.out.println("Input:"+Arrays.toString(testSetRow.getInput()));
			System.out.println("Output:"+Arrays.toString(networkOutput));
		}
		
	}

	
	//監督訓練過程
	@Override
	public void handleLearningEvent(LearningEvent event) {
		// TODO Auto-generated method stub
		//所有迭代學習演算法的基類, 它為它的所有子類提供迭代學習過程
		IterativeLearning bp = (IterativeLearning) event.getSource();
		System.out.println("iterate:"+bp.getCurrentIteration());
		System.out.println(Arrays.toString(bp.getNeuralNetwork().getWeights()));
	}
	
	
	
}
4.學習結果