1. 程式人生 > >【神經網路入門】用JAVA實現感知器演算法

【神經網路入門】用JAVA實現感知器演算法

簡述
隨著網際網路的高速發展,A(AI)B(BigData)C(Cloud)已經成為當下的核心發展方向,假如三者深度結合的話,AI是其中最核心的部分。所以如果說在未來社會,每個人都必須要學會程式設計的話,那麼對於程式設計師來說,人工智慧則是他們所必須掌握的技術(科技發展真tm快)。

這篇文章介紹並用JAVA實現了一種最簡單的感知器網路,不糾結於公式的推導,旨在給大家提供一下學習神經網路的思路,對神經網路有一個大概的認識。

感知器網路模型分析
首先看一張圖

如果稍微對神經網路感興趣的一定對這張圖不陌生,這張圖是神經元的結構圖 
X1~Xm表示輸入,W1~Wm表示突觸權值,Σ表示求和結點,Activation function表示啟用函式,之後輸出一個結果,具體的流程是

神經元接收到輸入,每個輸入都會與其相對路徑上的權值相乘,到了求和結點進行求和,這裡把求和結點的結果設為z :

z = X1 * W1 + X2 * W2 + X3 * W3 + ...... + Xm * Wm

之後將 z 傳入到啟用函式(這裡我們稱啟用函式為 f)進行二分類模式識別 :

if f(x) > e, y = 1
else         y = -1

e 為閾值
y 為分類結果

這裡可以看出,如果 f(x) 的值大於閾值,得到分類 y = 1,反之 y = -1 
注:相對於生物神經元受到刺激表示的反應,如果刺激在可接受範圍之內,則神經元會抑制刺激(y = -1),如果超過範圍則會興奮(y = 1),而這個範圍的分水嶺就是閾值(e)

學習
我們發現,如果權值和閾值都固定的話,那麼這個神經網路就沒有存在的意義了,所以我們引入學習的概念,通過學習,讓神經網路去修改權值和閾值,從而可以動態的修正模式識別的正確率,這才是機器學習的本質。

那麼如何學習呢?當我們在使用之前我們需要提供給此網路一組樣本資料(這裡採取的是有教師模式學習),樣本資料包括輸入資料x和正確的識別結果y’。 
當我們輸入訓練資料x得到模式識別y之後進行判斷,如果 y != y’ ,則會去調整此網路的權值和閾值,調整請看公式,μ 表示學習率(修正率),update 表示需要修正值:

update = μ * (yi - y')
update = (f(x) - y')

m
Σ    Wi += update * Xi
i=1

e += update

當感知器分類結果等於正確分類,update = 0,不調整網路;如果不等於正確分類,則會調整全部的權值(w)與閾值(e)

以上就是我所介紹的感知器最簡單的學習流程:

輸入資料->求和得到z->通過啟用函式等到分類結果->分類結果與正確結果不符則調整網路

下面就讓我們來實現這個簡單的神經網路吧

Java程式碼實現
這裡我所實現的是通過神經網路學習識別整數的正負 
首先定義一個感知器的類

 /**
 * Created by CimZzz on 12/2/17.
 *
 */
public class Perceptron {
    /**
     * 學習率
     */
    private final float learnRate;

    /**
     * 學習次數
     */
    private final int studyCount;

    /**
     * 閾值
     */
    private float e;

    /**
     * 權值
     * 因為判斷整數正負只需要一條輸入,所以這裡只有一個權值,多條輸入可以設定為陣列
     */
    private float w;

    /**
     * 每次學習的正確率
     */
    private float[] correctRate;

    //

    /**
     * 建構函式初始化學習率,學習次數,權值、閾值初始化為0
     * @param learnRate 學習率(取值範圍 0 < learnRate < 1)
     * @param studyCount 學習次數
     */
    public Perceptron(float learnRate, int studyCount) {
        this.learnRate = learnRate;
        this.studyCount = studyCount;

        this.e = 0;
        this.w = 0;

        this.correctRate = new float[studyCount];
    }


    /**
     * 學習函式,samples 是一個包含輸入資料和分類結果的二維陣列,
     * samples[][0] 表示輸入資料
     * samples[][1] 表示正確的分類結果
     * @param samples 訓練資料
     */
    public void fit(int[][] samples) {
        int sampleLength = samples.length;

        for(int i = 0 ; i < studyCount ; i ++) {
            int errorCount = 0;

            for (int[] sample : samples) {
                float update = learnRate * (sample[1]-predict(sample[0]));

                //更新權值、閾值
                w += update * sample[0];
                e += update;

                //計算錯誤次數
                if (update != 0)
                    errorCount++;
            }

            //計算此次學習的正確率
            correctRate[i] = 1 - errorCount * 1.0f / sampleLength;
        }
    }

    /**
     * 求和函式,模擬求和結點操作 輸入資料 * 權值
     * @param num 輸入資料
     * @return 求和結果 z
     */
    private float sum(int num) {
        return num * w + e;
    }

    /**
     * 啟用函式,通過求和結果 z 和閾值 e 進行判斷
     * @param num 輸入資料
     * @return 分類結果
     */
    public int predict(int num) {
        return sum(num) >= 0 ? 1 : -1;
    }

    /**
     * 列印正確率
     */
    public void printCorrectRate() {
        for (int i = 0 ; i < studyCount ; i ++)
            System.out.printf("第%d次學習的正確率 -> %.2f%%\n",i + 1,correctRate[i] * 100);
    }
}

然後寫生成訓練資料的函式

    /**
     * 生成訓練資料
     * @return 訓練資料
     */
    private static int[][] genStudyData() {
        //這裡我們取 -100 ~ 100 之間的整數,大於0的設為模式 y = 1,反之為 y = -1
        int[][] data = new int[201][2];

        for(int i = -100 , j = 0; i <= 100 ; i ++ , j ++) {
            data[j][0] = i;
            data[j][1] = i >= 0 ? 1 : -1;
        }

        return data;
    }

    /**
     * 生成訓練資料
     * @return 訓練資料
     */
    private static int[][] genStudyData2() {
        //這裡我們取 1~250 之間的整數,大於125的設為模式 y = 1,反之為 y = -1
        int[][] data = new int[250][2];

        for(int i = 1 , j = 0; i <= 250 ; i ++ , j ++) {
            data[j][0] = i;
            data[j][1] = i >= 125 ? 1 : -1;
        }

        return data;
    }

最後是主函式


    public static void main(String[] args) {
        //這裡的學習率和訓練次數可以根據情況人為調整
        Perceptron perceptron = new Perceptron(0.4f,500);

        perceptron.fit(genStudyData());
        perceptron.printCorrectRate();

        System.out.println(perceptron.predict(-1));
        System.out.println(perceptron.predict(126));
    }

大家可以測試一下

侷限性
這個感知器神經網路比較簡單,是適用於可線性劃分的資料,比如一維的話正數和負數,二維的座標象限分類;對於不可線性劃分的資料無法進行正確的分類,如尋找質數等


csdn.net/wang19950207/article/details/78693107