[Deeplearning4j應用教程09]_基於DL4J的自動編碼器
技術標籤:DL4J人工智慧深度學習深度學習神經網路機器學習java資料探勘
基於DL4J的自動編碼器
一、簡介
為什麼要使用自動編碼器? 在實踐中,自動編碼器通常應用於資料的降噪和降維。 這對於表示學習非常有用,而對於資料壓縮則不太有用。
在深度學習中,自動編碼器是“嘗試”以重建其輸入的神經網路。 它可以用作特徵提取的一種形式,並且可以將自動編碼器堆疊起來以建立“深度”網路。 由自動編碼器生成的功能可以輸入到其他演算法中,以進行分類,聚類和異常檢測。
在實際應用當中,異常檢測能夠用於:網路入侵,欺詐檢測,系統監視,感測器網路事件檢測(IoT)和異常軌跡感測。
二、自編碼器的工作流程
自動編碼器包括:
1、編碼功能(“編碼器”)
2、解碼功能(“解碼器”)
3、距離函式(“損失函式”)
首先,輸入被饋入自動編碼器並轉換為壓縮表示。然後,解碼器學習如何從壓縮的表示中重建原始輸入,在無監督的訓練過程中,損失函式有助於糾正解碼器產生的錯誤。 此過程是自動的(因此稱為“自動”編碼器); 即不需要人工干預。
學習到現在,我們應該已經知道如何使用MultiLayerNetwork和ComputationGraph建立不同的網路配置了,現在,我們將構造一個“堆疊”自動編碼器,該編碼器對MNIST數字執行異常檢測而無需預先訓練。而目的是識別異常數字,即不尋常和不典型的數字。從給定資料集的規範中“脫穎而出”的內容,事件或觀察結果的識別被廣泛稱為異常檢測。異常檢測不需要標註的資料集,並且可以在無監督學習的情況下進行,這很有幫助,因為世界上大多數資料都沒有標註。
三、基於DL4J的自編碼器實現
3.1、匯入需要的包
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.util.*;
import java.util.List;
3.2、堆疊式自動編碼器
以下自動編碼器使用兩個堆疊的密集層進行編碼。 MNIST數字轉換為長度為784的平面一維陣列(MNIST影象為28x28畫素,當我們端對端放置它們時等於784)。在網路中,資料的大小變化情況如下:
784→250→10→250→784
程式碼如下:
//搭建模型
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.weightInit(WeightInit.XAVIER)
.updater(new AdaGrad(0.05))
.activation(Activation.RELU)
.l2(0.0001)
.list()
.layer(new DenseLayer.Builder().nIn(784).nOut(250)
.build())
.layer(new DenseLayer.Builder().nIn(250).nOut(10)
.build())
.layer(new DenseLayer.Builder().nIn(10).nOut(250)
.build())
.layer(new OutputLayer.Builder().nIn(250).nOut(784)
.activation(Activation.LEAKYRELU)
.lossFunction(LossFunctions.LossFunction.MSE)
.build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
//監聽器
net.setListeners(Collections.singletonList(new ScoreIterationListener(10)));
3.3、使用MNIST迭代器
像Deeplearning4j的大多數內建迭代器一樣,MNIST迭代器擴充套件了DataSetIterator類。 該API允許簡單地例項化資料集並在後臺自動下載資料。
程式碼如下:
//載入資料,並且進行訓練集與測試集的劃分:40000訓練資料,10000測試資料
DataSetIterator iter = new MnistDataSetIterator(100,50000,false);
List<INDArray> featuresTrain = new ArrayList<>();
List<INDArray> featuresTest = new ArrayList<>();
List<INDArray> labelsTest = new ArrayList<>();
Random r = new Random(12345);
while(iter.hasNext()){
DataSet ds = iter.next();
SplitTestAndTrain split = ds.splitTestAndTrain(80, r); //按照8:2的比例劃分資料集 (miniBatch = 100)
featuresTrain.add(split.getTrain().getFeatures());
DataSet dsTest = split.getTest();
featuresTest.add(dsTest.getFeatures());
INDArray indexes = Nd4j.argMax(dsTest.getLabels(),1); //進行獨熱編碼轉換: 表示 -> 索引
labelsTest.add(indexes);
}
3.4、無監督訓練
現在,我們已經設定了網路配置並與我們的MNIST測試/訓練迭代器一起例項化了,訓練只需要幾行程式碼。
之前,我們使用setListeners()方法將ScoreIterationListener附加到模型。根據用於執行此程式碼電腦的瀏覽器,可以開啟偵錯程式/檢查器以檢視偵聽器輸出。 由於Deeplearning4j的內部使用SL4J進行日誌記錄,因此此輸出重定向到控制檯,並且Zeppelin重定向了該輸出。 這有助於減少電腦的混亂情況。
程式碼如下:
//訓練模型
int nEpochs = 3;
for( int epoch=0; epoch<nEpochs; epoch++ ){
for(INDArray data : featuresTrain){
net.fit(data,data);
}
System.out.println("Epoch " + epoch + " complete");
}
3.5、評估模型
現在,我們已經對自動編碼器進行了訓練,那麼,我們將根據測試資料來評估模型。每個示例將被單獨打分,並且將構成一個對映,該對映將每個數字與(得分,示例)對列表相關聯。
最後,我們將計算N個最佳分數和N個最差分數。
程式碼如下:
//根據測試資料評估模型
//分別對測試集中的每個樣本評分
//組成一個對映,將每個數字與(得分,樣本)對列表相關聯
//然後找到每位數中N個最佳分數和N個最差分數
Map<Integer,List<Pair<Double,INDArray>>> listsByDigit = new HashMap<>();
for( int i=0; i<10; i++ ) listsByDigit.put(i,new ArrayList<>());
for( int i=0; i<featuresTest.size(); i++ ){
INDArray testData = featuresTest.get(i);
INDArray labels = labelsTest.get(i);
int nRows = testData.rows();
for( int j=0; j<nRows; j++){
INDArray example = testData.getRow(j, true);
int digit = (int)labels.getDouble(j);
double score = net.score(new DataSet(example,example));
// 將(得分,樣本)對新增到適當的列表
List digitAllPairs = listsByDigit.get(digit);
digitAllPairs.add(new ImmutablePair<>(score, example));
}
}
//Sort each list in the map by score
Comparator<Pair<Double, INDArray>> c = new Comparator<Pair<Double, INDArray>>() {
@Override
public int compare(Pair<Double, INDArray> o1, Pair<Double, INDArray> o2) {
return Double.compare(o1.getLeft(),o2.getLeft());
}
};
for(List<Pair<Double, INDArray>> digitAllPairs : listsByDigit.values()){
Collections.sort(digitAllPairs, c);
}
//排序後,為每個數字選擇N個最佳分數和N個最差分數(根據重構誤差),其中N = 5
List<INDArray> best = new ArrayList<>(50);
List<INDArray> worst = new ArrayList<>(50);
for( int i=0; i<10; i++ ){
List<Pair<Double,INDArray>> list = listsByDigit.get(i);
for( int j=0; j<5; j++ ){
best.add(list.get(j).getRight());
worst.add(list.get(list.size()-j-1).getRight());
}
}
3.6、結果視覺化
//預設視覺化
if (visualize) {
//視覺化最好和最差的數字
MNISTVisualizer bestVisualizer = new MNISTVisualizer(2.0, best, "Best (Low Rec. Error)");
bestVisualizer.visualize();
MNISTVisualizer worstVisualizer = new MNISTVisualizer(2.0, worst, "Worst (High Rec. Error)");
worstVisualizer.visualize();
}
//視覺化方法
public static class MNISTVisualizer {
private double imageScale;
private List<INDArray> digits; //數字(作為行向量),每個INDArray一個
private String title;
private int gridWidth;
public MNISTVisualizer(double imageScale, List<INDArray> digits, String title ) {
this(imageScale, digits, title, 5);
}
public MNISTVisualizer(double imageScale, List<INDArray> digits, String title, int gridWidth ) {
this.imageScale = imageScale;
this.digits = digits;
this.title = title;
this.gridWidth = gridWidth;
}
public void visualize(){
JFrame frame = new JFrame();
frame.setTitle(title);
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
JPanel panel = new JPanel();
panel.setLayout(new GridLayout(0,gridWidth));
List<JLabel> list = getComponents();
for(JLabel image : list){
panel.add(image);
}
frame.add(panel);
frame.setVisible(true);
frame.pack();
}
private List<JLabel> getComponents(){
List<JLabel> images = new ArrayList<>();
for( INDArray arr : digits ){
BufferedImage bi = new BufferedImage(28,28,BufferedImage.TYPE_BYTE_GRAY);
for( int i=0; i<784; i++ ){
bi.getRaster().setSample(i % 28, i / 28, 0, (int)(255*arr.getDouble(i)));
}
ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((int)(imageScale*28),(int)(imageScale*28),Image.SCALE_REPLICATE);
ImageIcon scaled = new ImageIcon(imageScaled);
images.add(new JLabel(scaled));
}
return images;
}
}
最終結果如下圖所示:
最差的手寫數字:
最好的手寫數字:
完整程式碼如下:
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.util.*;
import java.util.List;
public class xx {
public static boolean visualize = true;
public static void main(String[] args) throws Exception {
//搭建模型. 784 輸入/輸出 (MNIST 圖片大小為 28x28).
//784 -> 250 -> 10 -> 250 -> 784
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.weightInit(WeightInit.XAVIER)
.updater(new AdaGrad(0.05))
.activation(Activation.RELU)
.l2(0.0001)
.list()
.layer(new DenseLayer.Builder().nIn(784).nOut(250)
.build())
.layer(new DenseLayer.Builder().nIn(250).nOut(10)
.build())
.layer(new DenseLayer.Builder().nIn(10).nOut(250)
.build())
.layer(new OutputLayer.Builder().nIn(250).nOut(784)
.activation(Activation.LEAKYRELU)
.lossFunction(LossFunctions.LossFunction.MSE)
.build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(Collections.singletonList(new ScoreIterationListener(10)));
//載入資料,並且進行訓練集與測試集的劃分:40000訓練資料,10000測試資料
DataSetIterator iter = new MnistDataSetIterator(100,50000,false);
List<INDArray> featuresTrain = new ArrayList<>();
List<INDArray> featuresTest = new ArrayList<>();
List<INDArray> labelsTest = new ArrayList<>();
Random r = new Random(12345);
while(iter.hasNext()){
DataSet ds = iter.next();
SplitTestAndTrain split = ds.splitTestAndTrain(80, r); //按照8:2的比例進行劃分(from miniBatch = 100)
featuresTrain.add(split.getTrain().getFeatures());
DataSet dsTest = split.getTest();
featuresTest.add(dsTest.getFeatures());
INDArray indexes = Nd4j.argMax(dsTest.getLabels(),1); //通過獨熱編碼將表示轉換為索引
labelsTest.add(indexes);
}
//訓練模型
int nEpochs = 3;
for( int epoch=0; epoch<nEpochs; epoch++ ){
for(INDArray data : featuresTrain){
net.fit(data,data);
}
System.out.println("Epoch " + epoch + " complete");
}
//根據測試資料評估模型
//分別對測試集中的每個樣本評分
//組成一個對映,將每個數字與(得分,樣本)對列表相關聯
//然後找到每位數中N個最佳分數和N個最差分數
Map<Integer,List<Pair<Double,INDArray>>> listsByDigit = new HashMap<>();
for( int i=0; i<10; i++ ) listsByDigit.put(i,new ArrayList<>());
for( int i=0; i<featuresTest.size(); i++ ){
INDArray testData = featuresTest.get(i);
INDArray labels = labelsTest.get(i);
int nRows = testData.rows();
for( int j=0; j<nRows; j++){
INDArray example = testData.getRow(j, true);
int digit = (int)labels.getDouble(j);
double score = net.score(new DataSet(example,example));
// 將(得分,樣本)對新增到適當的列表
List digitAllPairs = listsByDigit.get(digit);
digitAllPairs.add(new ImmutablePair<>(score, example));
}
}
//按分數對映對每個列表進行排序
Comparator<Pair<Double, INDArray>> c = new Comparator<Pair<Double, INDArray>>() {
@Override
public int compare(Pair<Double, INDArray> o1, Pair<Double, INDArray> o2) {
return Double.compare(o1.getLeft(),o2.getLeft());
}
};
for(List<Pair<Double, INDArray>> digitAllPairs : listsByDigit.values()){
Collections.sort(digitAllPairs, c);
}
排序後,為每個數字選擇N個最佳分數和N個最差分數(根據重構誤差),其中N = 5
List<INDArray> best = new ArrayList<>(50);
List<INDArray> worst = new ArrayList<>(50);
for( int i=0; i<10; i++ ){
List<Pair<Double,INDArray>> list = listsByDigit.get(i);
for( int j=0; j<5; j++ ){
best.add(list.get(j).getRight());
worst.add(list.get(list.size()-j-1).getRight());
}
}
//預設進行視覺化
if (visualize) {
//視覺化最好與最差的數字
MNISTVisualizer bestVisualizer = new MNISTVisualizer(2.0, best, "Best (Low Rec. Error)");
bestVisualizer.visualize();
MNISTVisualizer worstVisualizer = new MNISTVisualizer(2.0, worst, "Worst (High Rec. Error)");
worstVisualizer.visualize();
}
}
public static class MNISTVisualizer {
private double imageScale;
private List<INDArray> digits; //Digits (as row vectors), one per INDArray
private String title;
private int gridWidth;
public MNISTVisualizer(double imageScale, List<INDArray> digits, String title ) {
this(imageScale, digits, title, 5);
}
//設定視覺化的圖片大小、數字、標題與顯示圖片的網格大小
public MNISTVisualizer(double imageScale, List<INDArray> digits, String title, int gridWidth ) {
this.imageScale = imageScale;
this.digits = digits;
this.title = title;
this.gridWidth = gridWidth;
}
public void visualize(){
JFrame frame = new JFrame();
frame.setTitle(title);
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
JPanel panel = new JPanel();
panel.setLayout(new GridLayout(0,gridWidth));
List<JLabel> list = getComponents();
for(JLabel image : list){
panel.add(image);
}
frame.add(panel);
frame.setVisible(true);
frame.pack();
}
private List<JLabel> getComponents(){
List<JLabel> images = new ArrayList<>();
for( INDArray arr : digits ){
BufferedImage bi = new BufferedImage(28,28,BufferedImage.TYPE_BYTE_GRAY);
for( int i=0; i<784; i++ ){
bi.getRaster().setSample(i % 28, i / 28, 0, (int)(255*arr.getDouble(i)));
}
ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((int)(imageScale*28),(int)(imageScale*28),Image.SCALE_REPLICATE);
ImageIcon scaled = new ImageIcon(imageScaled);
images.add(new JLabel(scaled));
}
return images;
}
}
}