1. 程式人生 > >Deeplearning4j 實戰 (12):Mnist替代資料集Fashion Mnist在CNN上的實驗及結果

Deeplearning4j 實戰 (12):Mnist替代資料集Fashion Mnist在CNN上的實驗及結果

Mnist資料集的分類問題一直被認為是深度學習的Hello World。利用2層卷積網路,經過若干輪的訓練後,在相應測試集上的準確率可以達到95%以上。經過調參後,甚至可以達到99%以上。其實,即使不用用卷積層提取特徵,而是用傳統的全連線網路也同樣可以達到非常高的準確率。在Mnist資料集的官網上(http://yann.lecun.com/exdb/mnist/),除了基於神經網路的分類器,利用傳統的分類方法,如:KNN,SVM,也都可以獲得非常好的結果。下面就是部分模型分類效果的截圖:


從以上結果分析可以發現,無論是淺層模型還是深度學習,在Mnist上的分類問題上都可以達到很高的精度,因此從某種角度也可以說,Mnist資料集複雜度不夠,或者說Mnist分類問題並不是一個具有代表性的機器視覺問題。就這個問題,《Deep Learning》一書的作者Ian Goodfellow和著名開源專案Keras的作者Francois Chollet都有自己的評述,詳情可轉到下面兩個連結:

1.Ian Goodfellow Commnet On Mnist DataSet

2.Francois Chollet's Comment

雖然Mnist可能並不是最合適入門深度學習的資料集,但是鑑於長期以來開發人員的使用習慣,想要找到完全替代Mnist的開源資料集確實有點困難,但這個難題最近有了一個比較好的解答,就是類似Mnist的一個服裝影象資料集--Fashion Mnist

和Mnist資料集一樣,Fashion Mnist也是28*28的灰度圖。內容涵蓋了鞋、包、衣服、褲子。它的檔名稱和資料格式和Mnist一模一樣。換句話說,你完全不需要改動你之前在Mnist上的建模邏輯,只需要把相應的檔案替換掉,就可以對Fashion Mnist進行訓練和評估。不過唯一不同的是,Fashion Mnist的分類準確率遠沒有Mnist那麼高。目前在Fashion Mnist的github主頁上,最好的結果也僅僅是在95%左右。當然,如果你自己的網路有了好的結果,可以在主頁上提個issue,也作為是對這個資料集的一個貢獻。

下面主要介紹3個方面的內容:

1.Fashion Mnist基於CNN的建模分類與評估

2.與Mnist的比較

3.簡單的服裝分類應用

首先介紹第一部分的主要內容。對於Fashion Mnist資料集採用卷積神經網路進行分類建模,具體的網路結構是:2Conv+2FC。建模工具是Deeplearning4j。詳細的超引數配置見如下程式碼片段:

    public static MultiLayerNetwork getModel(){
        MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
                        .seed(12345)
                        .iterations(1)
                        //.regularization(true).l2(0.005)
                        .learningRate(0.01)
                        .learningRateScoreBasedDecayRate(0.5)
                        .weightInit(WeightInit.XAVIER)
                        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                        .updater(Updater.ADAM)
                        .list()
                        .layer(0, new ConvolutionLayer.Builder(5, 5)
                                .nIn(1)
                                .stride(1, 1)
                                .nOut(32)
                                .activation(Activation.LEAKYRELU)
                                .build())
                        .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                                .kernelSize(2,2)
                                .stride(2,2)
                                .build())
                        .layer(2, new ConvolutionLayer.Builder(5, 5)
                                .stride(1, 1)
                                .nOut(64)
                                .activation(Activation.LEAKYRELU)
                                .build())
                        .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                                .kernelSize(2,2)
                                .stride(2,2)
                                .build())
                        .layer(4, new DenseLayer.Builder().activation(Activation.LEAKYRELU)
                                .nOut(500).build())
                        .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                                .nOut(10)
                                .activation(Activation.SOFTMAX)
                                .build())
                        .backprop(true).pretrain(false)
                        .setInputType(InputType.convolutionalFlat(28, 28, 1));
        MultiLayerConfiguration conf = builder.build();
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        return model; 
    }
簡單解釋下部分超引數:

激勵函式部分主要用的是LeakeyRelu。

學習率用了Decay的策略。Decay的幅度是50%。

正則化項是可選的(經測試,正則化項在如上配置中,影響不大)

網路結構:2 Conv-with-MaxPooling + 2FC。卷積層中每層的featureMap的數量如上述所示。

除了建模的部分,資料的ETL部分同樣很重要。在具體實現中,我直接利用Deeplearning4j自帶的一個解析Mnist資料集的元件:MnistManager。它的主要功能就是讀取解壓後的二進位制Mnist資料集以及相應的分類標籤。由於Fashion Mnist和原始Mnist資料集在資料格式上完全相同,所以可以直接使用Mnist的元件進行解析。在讀取的時候,我們可以根據要求設定batchSize,一個batch的資料和標籤會封裝在一個DataSet物件中。由這些DataSet構成的迭代器即可作為最終訓練或者測試的資料。下面具體看下以上邏輯的實現:

    public static DataSet fetch(int batchSize , boolean binarize, MnistManager man, boolean save, boolean train) {        
        float[][] featureData = new float[batchSize][0];
        float[][] labelData = new float[batchSize][0];

        int actualExamples = 0;
        for (int i = 0; i < batchSize && cursor < totalExamples; i++, cursor++) {
            byte[] img = man.readImageUnsafe(order[cursor]);
            int label = man.readLabel(order[cursor]);
            
            float[] featureVec = new float[img.length];
            featureData[actualExamples] = featureVec;
            labelData[actualExamples] = new float[10];
            labelData[actualExamples][label] = 1.0f;

            for (int j = 0; j < img.length; j++) {
                float v = ((int) img[j]) & 0xFF; //byte is loaded as signed -> convert to unsigned
                if (binarize) {
                    if (v > 30.0f)
                        featureVec[j] = 1.0f;
                    else
                        featureVec[j] = 0.0f;
                } else {
                    featureVec[j] = v / 255.0f;
                }
            }
            if( save ){
                Mat mat = new Mat(28, 28, CV_8SC1, new BytePointer(img)); 
                
                if( train )
                    JavaCVUtil.imWrite(mat, "FashionMnist/trainData/" + label + "_" + cursor + ".jpg");
                else
                    JavaCVUtil.imWrite(mat, "FashionMnist/testData/" + label + "_" + cursor + ".jpg");
            }
            actualExamples++;
        }

        if (actualExamples < batchSize) {
            featureData = Arrays.copyOfRange(featureData, 0, actualExamples);
            labelData = Arrays.copyOfRange(labelData, 0, actualExamples);
        }

        INDArray features = Nd4j.create(featureData);
        INDArray labels = Nd4j.create(labelData);
        return new DataSet(features, labels);
    }
    
    public static DataSetIterator getData(String dir, boolean train , int batchSize, boolean save) throws IOException{
        String featureFileDir = dir;
        String labelFileDir = dir;
        cursor = 0;
        if( train ){
            featureFileDir += "train-images-idx3-ubyte";
            labelFileDir += "train-labels-idx1-ubyte";
            totalExamples = 60000;
            order = new int[totalExamples];
        }else{
            featureFileDir += "t10k-images-idx3-ubyte";
            labelFileDir += "t10k-labels-idx1-ubyte";
            totalExamples = 10000;
            order = new int[totalExamples];
        }
        for (int i = 0; i < order.length; i++)order[i] = i;
        MathUtils.shuffleArray(order, 123456L); //shuffle order
        MnistManager man = new MnistManager(featureFileDir, labelFileDir, train);
        List<DataSet> res = new LinkedList<DataSet>();
        while(cursor < totalExamples){
            res.add(fetch(batchSize, false, man, save, train));
        }
        ExistingDataSetIterator iter = new ExistingDataSetIterator(res);
        return iter;
    }

以上兩個靜態方法就是解析資料、讀取標籤、封裝資料並生成可迭代資料集的過程。其中,getData這個方法可以根據引數的不同,生成訓練或者測試資料集。在fetch這個方法裡,可以選擇是二值化還是正常歸一化。我這裡選擇的是正常歸一化。此外,為了方便看到Fashion Mnist的影象形式,可以選擇是否以圖片的形式生成這些圖片。如果生成圖片的話,則可以看到下面這些圖:



這些圖片我會上傳到CSDN上供大家下載。下載連線

從截圖可以看出是衣服和褲子兩個品類。檔名中的第一個數字是這張圖片的分類標籤。這樣方便直接從圖片進行建模。訓練集共6W張圖片,測試集共1W張圖片。

到此的話,資料的ETL和建模的步驟都已經完成,下面就是對模型引數進行訓練。這裡我還是用的GPU來訓練模型。顯示卡是Telsa K80。單卡進行訓練。相應的CUDA版本是8.0。具體的訓練邏輯可見下面程式碼片段:

    public static void main(String[] args)throws IOException {
        DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE);  
        final int numEpochs = Integer.parseInt(args[0]);
        final int batchSize = Integer.parseInt(args[1]);
        final String modelSavePath = args[2];
        final String dataPath = args[3];
        CudaEnvironment.getInstance().getConfiguration()
                        // 是否允許多卡
                        .allowMultiGPU(false)
                        .useDevice(7)
                        // 視訊記憶體大小
                        .setMaximumDeviceCache(11L * 1024L * 1024L * 1024L)
                        // 是否允許多卡直接資料的直接訪問
                        .allowCrossDeviceAccess(false);
        DataSetIterator trainData = getData(dataPath+ "/", true, batchSize, false);
        DataSetIterator testData = getData(dataPath+ "/" , false, batchSize, false);
        MultiLayerNetwork model = getModel();
        for( int i = 0; i < numEpochs; ++i ){
            model.fit(trainData);
            System.out.println("Epoch :" + i + " Finish");
            System.out.println("Score: " + model.score());
            Evaluation eval = model.evaluate(testData);
            System.out.println(eval.stats());
            System.out.println();  
        }
        Evaluation eval = model.evaluate(testData);
        System.out.println(eval.stats());
        ModelSerializer.writeModel(model, modelSavePath, true);
    }
其中batchSize等可以通過args引數傳入設定。注意,最後我們把模型進行了儲存。在每一輪的訓練後,我們都列印了損失函式的值,並同時在測試集上評估了此時模型的準確性。我們一共訓練了100輪。下面給出部分訓練過程中的模型資訊:

Epoch :0 Finish
Score: 0.4417036887606827

==========================Scores========================================
 Accuracy:        0.7986
 Precision:       0.8004
 Recall:          0.7986
 F1 Score:        0.7995
========================================================================
第一輪的loss值和模型評估。可以說,效果不佳。和Mnist的第一輪相差甚遠(下面會有Mnist的相應訓練資訊)。

Epoch :99 Finish
Score: 0.020469321075697797

==========================Scores========================================
 Accuracy:        0.9072
 Precision:       0.9088
 Recall:          0.9072
 F1 Score:        0.908
========================================================================
100輪訓練完之後,勉強達到了90%左右。應該說,結果一般。

到這裡,我就沒有再訓練下去了。那麼到此,第一部分的主要工作就完成了。最終經過100輪的訓練,loss值達到0.02,模型的準確率在90%。

接著介紹下第二部分,也就是和Mnist比較的內容。

Mnist的訓練過程和上面的一模一樣,唯一不同的是,資料集換成Mnist的就可以了。同樣經過100輪的訓練,我們來看下對比結果。

  Mnist DataSet Fashion Mnist DataSet
Epoch 1
==========================Scores====================================
 Accuracy:        0.9545
 Precision:       0.955
 Recall:          0.954
 F1 Score:        0.9545
====================================================================
==========================Scores===================================
 Accuracy:        0.7986
 Precision:       0.8004
 Recall:          0.7986
 F1 Score:        0.7995
===================================================================
Epoch 100
==========================Scores====================================
 Accuracy:        0.9922
 Precision:       0.9921
 Recall:          0.9921
 F1 Score:        0.9921
====================================================================
==========================Scores=====================================
 Accuracy:        0.9072
 Precision:       0.9088
 Recall:          0.9072
 F1 Score:        0.908
=====================================================================
從表格裡就可以直觀的看出兩個資料集在同樣的模型、超引數配置下,最終評估效果的不同了。

Mnist資料集很容易就達到了95%的準確率,甚至最後達到了99.22%。然而Fashion Mnist最終也只有徘徊在90%上下。由此可見,Fashion Mnist資料集的分類問題更為複雜。2層卷積神經網路的效果可能也就是在90%左右了(PS:這個講述並沒有什麼理論依據,但從github主頁看到他人用Keras搭建類似結構的網路來訓練Fashion Mnist,也是在90%上下,所以作此推測,僅僅是實驗結果)。

最後一個部分介紹下基於剛才訓練的模型如何搭建一個Web應用。

服裝的分類場景在各大電商企業中有很多應用。雖然不一定需要準確區分運動鞋和休閒鞋,但是區分衣服、褲子、包、鞋還是很有必要的。這個場景在影象檢索等應用方面有著類似文字檢索中Query分析的作用,最終可以減少索引的查詢量。這裡就直接利用這樣的一個開源資料集搭建一個Web服務,用於識別圖片中物品的所屬品類。涉及到的工具有Spring、Tomcat,JSP,還有之前提到的Deeplearning4j和Nd4j。

我在本地的Eclipse中配置了Tomcat的外掛、服務的埠號、上下文的根路徑等。在POM檔案中引入了Spring和Deeplearning4j的相關依賴。最後前端頁面上做了個簡單的上傳圖片的按鈕,最後的模型分類結果會和圖片一起在頁面上做展示。由於這裡面涉及了關於J2EE開發的諸多細節,和主要介紹的內容有些偏離,所以這裡僅僅介紹主要的思路。在後面的文章中,如果有機會的話會詳細介紹Deeplearning4j訓練的模型上線部署的一些方式,當然也包括一些採坑的地方。下面就給出一些示例結果:

這些服裝類的圖片是從蘇寧易購的網站上面下載下來的,而且都是一些不需要做主體檢測的、內容比較明確的圖片。從實際的效果來看,確實可以對這些圖片的品類做相對準確的識別。不過,其中也有誤判的場景,比如長袖襯衫那個場景被預測成了外套。當然這只是一個demo,並不是最終可以達到產品效果的服務,而且在實際的應用中,像襯衫和外套一般並不會要求嚴格區分,畢竟單純靠一張正面的圖片就區分兩個非常相似的品類是非常困難的,雖然並非一定不可以做到,但準確率未必可以保證。

這裡有個地方需要注意:Fashion Mnist是28*28的灰度圖。在做這個實際應用場景的時候,我同樣對這些彩色圖片做了灰度化以及resize的處理。換句話說,和訓練資料保持一致對預測結果也同樣重要。

最後對上面的內容做下小結。Fashion Mnist作為Mnist的替代資料集,無論在資料格式還是檔名稱上都和原始的Mnist保持了高度一致,從而方便研發人員遷移之前的工作。但是,Fashion Mnist的分類比Mnist更有挑戰性,至少從目前github主頁上最優結果以及我自己的實驗來看,很難達到和Mnist一樣的準確性。原因的話,像外套、襯衫;靴子、運動鞋;難免存在外形極其相似的情況。因此,誤判的情況會比較多。不過,從另一個角度說,這也說明相比Mnist,Fashion Mnist資料集的分類問題更為複雜。此外,Fashion Mnist也可以作為諸多電商企業商品圖片分類的一個demo級別的測試資料集。通過做服裝商品分類這樣一個應用,可以對深度學習在產品級別應用的問題上有感性的認識,更重要的可能是發現深度神經網路的侷限性,並非是萬能的。這可能也是Fashion Mnist相比於原始Mnist資料集的價值所在,讓大家對深度學習有理性的認識(原始Mnist很容易達到98%-99%的準確率,容易誤導大家覺得深度學習就是這樣準確,其實資料集本身也有非常大的關係,不能僅僅依靠模型)。

--------------以下更新自2018/3/21

在Deeplearning4j的QQ群裡還有這篇文章的留言區有同學希望我補充下Web部分的程式碼,特此在這裡做些補充。

首先,Web容器我用的是Tomcat-Eclipse的外掛,在pom裡的配置如下:

  <build> 
    <finalName>dl-webapp</finalName>  
    <plugins> 
      <plugin> 
        <groupId>org.apache.tomcat.maven</groupId>  
        <artifactId>tomcat7-maven-plugin</artifactId>  
        <version>2.2</version>  
        <configuration> 
          <port>8080</port>  
          <path>/maven-web-demo</path>  
          <uriEncoding>UTF-8</uriEncoding>  
          <finalName>maven-web-demo</finalName>  
          <server>tomcat7</server> 
        </configuration>  
        <executions> 
          <!-- 打包成功後即開始執行web容器 -->  
          <execution> 
            <phase>package</phase>  
            <goals> 
              <goal>run</goal> 
            </goals> 
          </execution> 
        </executions> 
      </plugin> 
    </plugins> 
  </build>

其次,整個Web工程的編譯目錄結構如下(工程名:DL):