1. 程式人生 > 其它 >[Deeplearning4j應用教程09]_基於DL4J的自動編碼器

[Deeplearning4j應用教程09]_基於DL4J的自動編碼器

技術標籤:DL4J人工智慧深度學習深度學習神經網路機器學習java資料探勘

基於DL4J的自動編碼器

一、簡介

為什麼要使用自動編碼器? 在實踐中,自動編碼器通常應用於資料的降噪和降維。 這對於表示學習非常有用,而對於資料壓縮則不太有用。
在深度學習中,自動編碼器是“嘗試”以重建其輸入的神經網路。 它可以用作特徵提取的一種形式,並且可以將自動編碼器堆疊起來以建立“深度”網路。 由自動編碼器生成的功能可以輸入到其他演算法中,以進行分類,聚類和異常檢測。

當原始輸入資料具有高維且無法輕鬆繪製時,自動編碼器還可用於資料視覺化。 通過降維,有時可以將輸出壓縮到2D或3D空間中,以進行更好的資料探索。
在實際應用當中,異常檢測能夠用於:網路入侵,欺詐檢測,系統監視,感測器網路事件檢測(IoT)和異常軌跡感測。

二、自編碼器的工作流程

自動編碼器包括:
1、編碼功能(“編碼器”)
2、解碼功能(“解碼器”)
3、距離函式(“損失函式”)
首先,輸入被饋入自動編碼器並轉換為壓縮表示。然後,解碼器學習如何從壓縮的表示中重建原始輸入,在無監督的訓練過程中,損失函式有助於糾正解碼器產生的錯誤。 此過程是自動的(因此稱為“自動”編碼器); 即不需要人工干預。
學習到現在,我們應該已經知道如何使用MultiLayerNetwork和ComputationGraph建立不同的網路配置了,現在,我們將構造一個“堆疊”自動編碼器,該編碼器對MNIST數字執行異常檢測而無需預先訓練。而目的是識別異常數字,即不尋常和不典型的數字。從給定資料集的規範中“脫穎而出”的內容,事件或觀察結果的識別被廣泛稱為異常檢測。異常檢測不需要標註的資料集,並且可以在無監督學習的情況下進行,這很有幫助,因為世界上大多數資料都沒有標註。

通常,異常檢測使用重構誤差來衡量解碼器的效能。正常的資料應具有較低的重構誤差,而異常值應具有較高的重構誤差。

三、基於DL4J的自編碼器實現

3.1、匯入需要的包

import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.SplitTestAndTrain; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.AdaGrad; import org.nd4j.linalg.lossfunctions.LossFunctions; import javax.swing.*; import java.awt.*; import java.awt.image.BufferedImage; import java.util.*; import java.util.List;

3.2、堆疊式自動編碼器

以下自動編碼器使用兩個堆疊的密集層進行編碼。 MNIST數字轉換為長度為784的平面一維陣列(MNIST影象為28x28畫素,當我們端對端放置它們時等於784)。在網路中,資料的大小變化情況如下:
784→250→10→250→784
程式碼如下:

//搭建模型
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(12345)
            .weightInit(WeightInit.XAVIER)
            .updater(new AdaGrad(0.05))
            .activation(Activation.RELU)
            .l2(0.0001)
            .list()
            .layer(new DenseLayer.Builder().nIn(784).nOut(250)
                .build())
            .layer(new DenseLayer.Builder().nIn(250).nOut(10)
                .build())
            .layer(new DenseLayer.Builder().nIn(10).nOut(250)
                .build())
            .layer(new OutputLayer.Builder().nIn(250).nOut(784)
                .activation(Activation.LEAKYRELU)
                .lossFunction(LossFunctions.LossFunction.MSE)
                .build())
            .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        //監聽器
        net.setListeners(Collections.singletonList(new ScoreIterationListener(10)));

3.3、使用MNIST迭代器

像Deeplearning4j的大多數內建迭代器一樣,MNIST迭代器擴充套件了DataSetIterator類。 該API允許簡單地例項化資料集並在後臺自動下載資料。
程式碼如下:

//載入資料,並且進行訓練集與測試集的劃分:40000訓練資料,10000測試資料
        DataSetIterator iter = new MnistDataSetIterator(100,50000,false);

        List<INDArray> featuresTrain = new ArrayList<>();
        List<INDArray> featuresTest = new ArrayList<>();
        List<INDArray> labelsTest = new ArrayList<>();

        Random r = new Random(12345);
        while(iter.hasNext()){
            DataSet ds = iter.next();
            SplitTestAndTrain split = ds.splitTestAndTrain(80, r);  //按照8:2的比例劃分資料集 (miniBatch = 100)
            featuresTrain.add(split.getTrain().getFeatures());
            DataSet dsTest = split.getTest();
            featuresTest.add(dsTest.getFeatures());
            INDArray indexes = Nd4j.argMax(dsTest.getLabels(),1); //進行獨熱編碼轉換: 表示 -> 索引
            labelsTest.add(indexes);
        }

3.4、無監督訓練

現在,我們已經設定了網路配置並與我們的MNIST測試/訓練迭代器一起例項化了,訓練只需要幾行程式碼。
之前,我們使用setListeners()方法將ScoreIterationListener附加到模型。根據用於執行此程式碼電腦的瀏覽器,可以開啟偵錯程式/檢查器以檢視偵聽器輸出。 由於Deeplearning4j的內部使用SL4J進行日誌記錄,因此此輸出重定向到控制檯,並且Zeppelin重定向了該輸出。 這有助於減少電腦的混亂情況。
程式碼如下:

//訓練模型
int nEpochs = 3;
for( int epoch=0; epoch<nEpochs; epoch++ ){
     for(INDArray data : featuresTrain){
           net.fit(data,data);
            }
        System.out.println("Epoch " + epoch + " complete");
     }

3.5、評估模型

現在,我們已經對自動編碼器進行了訓練,那麼,我們將根據測試資料來評估模型。每個示例將被單獨打分,並且將構成一個對映,該對映將每個數字與(得分,示例)對列表相關聯。
最後,我們將計算N個最佳分數和N個最差分數。
程式碼如下:

//根據測試資料評估模型
//分別對測試集中的每個樣本評分
//組成一個對映,將每個數字與(得分,樣本)對列表相關聯
//然後找到每位數中N個最佳分數和N個最差分數
Map<Integer,List<Pair<Double,INDArray>>> listsByDigit = new HashMap<>();
for( int i=0; i<10; i++ ) listsByDigit.put(i,new ArrayList<>());

    for( int i=0; i<featuresTest.size(); i++ ){
        INDArray testData = featuresTest.get(i);
        INDArray labels = labelsTest.get(i);
        int nRows = testData.rows();
        for( int j=0; j<nRows; j++){
            INDArray example = testData.getRow(j, true);
            int digit = (int)labels.getDouble(j);
            double score = net.score(new DataSet(example,example));
            // 將(得分,樣本)對新增到適當的列表
            List digitAllPairs = listsByDigit.get(digit);
            digitAllPairs.add(new ImmutablePair<>(score, example));
        }
    }

    //Sort each list in the map by score
    Comparator<Pair<Double, INDArray>> c = new Comparator<Pair<Double, INDArray>>() {
        @Override
        public int compare(Pair<Double, INDArray> o1, Pair<Double, INDArray> o2) {
            return Double.compare(o1.getLeft(),o2.getLeft());
        }
    };

    for(List<Pair<Double, INDArray>> digitAllPairs : listsByDigit.values()){
        Collections.sort(digitAllPairs, c);
    }

    //排序後,為每個數字選擇N個最佳分數和N個最差分數(根據重構誤差),其中N = 5
    List<INDArray> best = new ArrayList<>(50);
    List<INDArray> worst = new ArrayList<>(50);
    for( int i=0; i<10; i++ ){
        List<Pair<Double,INDArray>> list = listsByDigit.get(i);
        for( int j=0; j<5; j++ ){
            best.add(list.get(j).getRight());
            worst.add(list.get(list.size()-j-1).getRight());
        }
    }

3.6、結果視覺化

//預設視覺化
if (visualize) {
            //視覺化最好和最差的數字
            MNISTVisualizer bestVisualizer = new MNISTVisualizer(2.0, best, "Best (Low Rec. Error)");
            bestVisualizer.visualize();

            MNISTVisualizer worstVisualizer = new MNISTVisualizer(2.0, worst, "Worst (High Rec. Error)");
            worstVisualizer.visualize();
        }
//視覺化方法
public static class MNISTVisualizer {
        private double imageScale;
        private List<INDArray> digits;  //數字(作為行向量),每個INDArray一個
        private String title;
        private int gridWidth;

        public MNISTVisualizer(double imageScale, List<INDArray> digits, String title ) {
            this(imageScale, digits, title, 5);
        }
        
        public MNISTVisualizer(double imageScale, List<INDArray> digits, String title, int gridWidth ) {
            this.imageScale = imageScale;
            this.digits = digits;
            this.title = title;
            this.gridWidth = gridWidth;
        }

        public void visualize(){
            JFrame frame = new JFrame();
            frame.setTitle(title);
            frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);

            JPanel panel = new JPanel();
            panel.setLayout(new GridLayout(0,gridWidth));

            List<JLabel> list = getComponents();
            for(JLabel image : list){
                panel.add(image);
            }

            frame.add(panel);
            frame.setVisible(true);
            frame.pack();
        }

        private List<JLabel> getComponents(){
            List<JLabel> images = new ArrayList<>();
            for( INDArray arr : digits ){
                BufferedImage bi = new BufferedImage(28,28,BufferedImage.TYPE_BYTE_GRAY);
                for( int i=0; i<784; i++ ){
                    bi.getRaster().setSample(i % 28, i / 28, 0, (int)(255*arr.getDouble(i)));
                }
                ImageIcon orig = new ImageIcon(bi);
                Image imageScaled = orig.getImage().getScaledInstance((int)(imageScale*28),(int)(imageScale*28),Image.SCALE_REPLICATE); 
                ImageIcon scaled = new ImageIcon(imageScaled);
                images.add(new JLabel(scaled));
            }
            return images;
        }
    }

最終結果如下圖所示:
在這裡插入圖片描述
最差的手寫數字:
在這裡插入圖片描述
最好的手寫數字:
在這裡插入圖片描述
完整程式碼如下:

import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.util.*;
import java.util.List;
public class xx {

    public static boolean visualize = true;

    public static void main(String[] args) throws Exception {

        //搭建模型. 784 輸入/輸出 (MNIST 圖片大小為 28x28).
        //784 -> 250 -> 10 -> 250 -> 784
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(12345)
            .weightInit(WeightInit.XAVIER)
            .updater(new AdaGrad(0.05))
            .activation(Activation.RELU)
            .l2(0.0001)
            .list()
            .layer(new DenseLayer.Builder().nIn(784).nOut(250)
                .build())
            .layer(new DenseLayer.Builder().nIn(250).nOut(10)
                .build())
            .layer(new DenseLayer.Builder().nIn(10).nOut(250)
                .build())
            .layer(new OutputLayer.Builder().nIn(250).nOut(784)
                .activation(Activation.LEAKYRELU)
                .lossFunction(LossFunctions.LossFunction.MSE)
                .build())
            .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.setListeners(Collections.singletonList(new ScoreIterationListener(10)));

       //載入資料,並且進行訓練集與測試集的劃分:40000訓練資料,10000測試資料
        DataSetIterator iter = new MnistDataSetIterator(100,50000,false);

        List<INDArray> featuresTrain = new ArrayList<>();
        List<INDArray> featuresTest = new ArrayList<>();
        List<INDArray> labelsTest = new ArrayList<>();

        Random r = new Random(12345);
        while(iter.hasNext()){
            DataSet ds = iter.next();
            SplitTestAndTrain split = ds.splitTestAndTrain(80, r);  //按照8:2的比例進行劃分(from miniBatch = 100)
            featuresTrain.add(split.getTrain().getFeatures());
            DataSet dsTest = split.getTest();
            featuresTest.add(dsTest.getFeatures());
            INDArray indexes = Nd4j.argMax(dsTest.getLabels(),1); //通過獨熱編碼將表示轉換為索引
            labelsTest.add(indexes);
        }

        //訓練模型
        int nEpochs = 3;
        for( int epoch=0; epoch<nEpochs; epoch++ ){
            for(INDArray data : featuresTrain){
                net.fit(data,data);
            }
            System.out.println("Epoch " + epoch + " complete");
        }

        //根據測試資料評估模型
//分別對測試集中的每個樣本評分
//組成一個對映,將每個數字與(得分,樣本)對列表相關聯
//然後找到每位數中N個最佳分數和N個最差分數
 Map<Integer,List<Pair<Double,INDArray>>> listsByDigit = new HashMap<>();
        for( int i=0; i<10; i++ ) listsByDigit.put(i,new ArrayList<>());

        for( int i=0; i<featuresTest.size(); i++ ){
            INDArray testData = featuresTest.get(i);
            INDArray labels = labelsTest.get(i);
            int nRows = testData.rows();
            for( int j=0; j<nRows; j++){
                INDArray example = testData.getRow(j, true);
                int digit = (int)labels.getDouble(j);
                double score = net.score(new DataSet(example,example));
                // 將(得分,樣本)對新增到適當的列表
                List digitAllPairs = listsByDigit.get(digit);
                digitAllPairs.add(new ImmutablePair<>(score, example));
            }
        }

        //按分數對映對每個列表進行排序
        Comparator<Pair<Double, INDArray>> c = new Comparator<Pair<Double, INDArray>>() {
            @Override
            public int compare(Pair<Double, INDArray> o1, Pair<Double, INDArray> o2) {
                return Double.compare(o1.getLeft(),o2.getLeft());
            }
        };

        for(List<Pair<Double, INDArray>> digitAllPairs : listsByDigit.values()){
            Collections.sort(digitAllPairs, c);
        }

        排序後,為每個數字選擇N個最佳分數和N個最差分數(根據重構誤差),其中N = 5
        List<INDArray> best = new ArrayList<>(50);
        List<INDArray> worst = new ArrayList<>(50);
        for( int i=0; i<10; i++ ){
            List<Pair<Double,INDArray>> list = listsByDigit.get(i);
            for( int j=0; j<5; j++ ){
                best.add(list.get(j).getRight());
                worst.add(list.get(list.size()-j-1).getRight());
            }
        }

        //預設進行視覺化
        if (visualize) {
            //視覺化最好與最差的數字
            MNISTVisualizer bestVisualizer = new MNISTVisualizer(2.0, best, "Best (Low Rec. Error)");
            bestVisualizer.visualize();

            MNISTVisualizer worstVisualizer = new MNISTVisualizer(2.0, worst, "Worst (High Rec. Error)");
            worstVisualizer.visualize();
        }
    }

    public static class MNISTVisualizer {
        private double imageScale;
        private List<INDArray> digits;  //Digits (as row vectors), one per INDArray
        private String title;
        private int gridWidth;

        public MNISTVisualizer(double imageScale, List<INDArray> digits, String title ) {
            this(imageScale, digits, title, 5);
        }
        //設定視覺化的圖片大小、數字、標題與顯示圖片的網格大小
        public MNISTVisualizer(double imageScale, List<INDArray> digits, String title, int gridWidth ) {
            this.imageScale = imageScale;
            this.digits = digits;
            this.title = title;
            this.gridWidth = gridWidth;
        }

        public void visualize(){
            JFrame frame = new JFrame();
            frame.setTitle(title);
            frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);

            JPanel panel = new JPanel();
            panel.setLayout(new GridLayout(0,gridWidth));

            List<JLabel> list = getComponents();
            for(JLabel image : list){
                panel.add(image);
            }

            frame.add(panel);
            frame.setVisible(true);
            frame.pack();
        }

        private List<JLabel> getComponents(){
            List<JLabel> images = new ArrayList<>();
            for( INDArray arr : digits ){
                BufferedImage bi = new BufferedImage(28,28,BufferedImage.TYPE_BYTE_GRAY);
                for( int i=0; i<784; i++ ){
                    bi.getRaster().setSample(i % 28, i / 28, 0, (int)(255*arr.getDouble(i)));
                }
                ImageIcon orig = new ImageIcon(bi);
                Image imageScaled = orig.getImage().getScaledInstance((int)(imageScale*28),(int)(imageScale*28),Image.SCALE_REPLICATE);
                ImageIcon scaled = new ImageIcon(imageScaled);
                images.add(new JLabel(scaled));
            }
            return images;
        }
    }
}