Deeplearning4j 實戰(2):Deeplearning4j 手寫體數字識別Spark實現【轉】
阿新 • • 發佈:2019-01-28
from:http://blog.csdn.net/wangongxi/article/details/54616842
在前兩天的部落格中,我們用Deeplearning4j做了Mnist資料集的分類。算是第一個深度學習的應用。像Mnist資料集這樣圖片尺寸不大,而且是黑白的開源圖片集在本地完成訓練是可以的,畢竟我們用了Lenet這樣相對簡單的網路結構,而且本地的機器配置也有8G左右的記憶體。但實際生產中,圖片的數量要多得多,尺寸也大得多,用的網路也會是AlexNet、GoogLenet這樣更多層數的網路,所以往往我們需要用叢集來解決計算資源的問題。由於Deeplearning4j本身基於Spark
我們還是以Mnist資料集為例來做Deeplearning4j的第一個Spark版本的應用。首先需要在上一篇部落格的基礎上,在pom裡面加入新的依賴:
- <dependency>
- <groupId>org.nd4j</groupId>
- <artifactId>nd4j-kryo_${scala.binary.version}</artifactId>
- <version>${nd4j.version}</version>
-
</
接著程式碼分為2個部分,一個部分是將Mnist資料集在本地以JavaRDD<DataSet>的形式存到磁碟並最終推到HDFS上作為Spark job的輸入資料來源。另一個部分則是模型的訓練和儲存。
第一部分的邏輯大致如下:本地建立Spark任務-->獲取所有Mnist圖片的路徑-->讀取圖片並提取特徵,打上標註,以DataSet的形式作為一張圖片的wrapper-->將所有圖片構成的JavaRDD<DataSet>儲存下來。
這裡原始的Mnist資料集是以圖片形式存在,不再是二進位制格式的資料。這個例子這樣處理,也是方便日後用同樣的方式讀取一般的圖片。Mnist的圖片如下:
- SparkConf conf = new SparkConf()
- .setMaster("local[*]") //local mode
- .set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator")
- .setAppName("Mnist Java Spark (Java)");
- JavaSparkContext jsc = new JavaSparkContext(conf);
- final List<String> lstLabelNames = Arrays.asList("零","一","二","三","四","五","六","七","八","九"); //Chinese Label
- final ImageLoader imageLoader = new ImageLoader(28, 28, 1); //Load Image
- final DataNormalization scaler = new ImagePreProcessingScaler(0, 1); //Normalize
- String srcPath = args[0];
- FileSystem hdfs = FileSystem.get(URI.create(srcPath),jsc.hadoopConfiguration()); //hdfs read local file system
- FileStatus[] fileList = hdfs.listStatus(new Path(srcPath));
- List<String> lstFilePath = new ArrayList<>();
- for( FileStatus fileStatus : fileList){
- lstFilePath.add(srcPath + "/" + fileStatus.getPath().getName());
- }
- JavaRDD<String> javaRDDImagePath = jsc.parallelize(lstFilePath);
- JavaRDD<DataSet> javaRDDImageTrain = javaRDDImagePath.map(new Function<String, DataSet>() {
- @Override
- public DataSet call(String imagePath) throws Exception {
- FileSystem fs = FileSystem.get(new Configuration());
- DataInputStream in = fs.open(new Path(imagePath));
- INDArray features = imageLoader.asRowVector(in); //features tensor
- String[] tokens = imagePath.split("\\/");
- String label = tokens[tokens.length-1].split("\\.")[0];
- int intLabel = Integer.parseInt(label);
- INDArray labels = Nd4j.zeros(10); //labels tensor
- labels.putScalar(0, intLabel, 1.0);
- DataSet trainData = new DataSet(features, labels); //DataSet, wrapper of features and labels
- trainData.setLabelNames(lstLabelNames);
- scaler.preProcess(trainData); //normalize
- fs.close();
- return trainData;
- }
- });
- javaRDDImageTrain.saveAsObjectFile("mnistNorm.dat"); //save training data
1.用hdfs.filesystem來獲取檔案。用Java原生態的File來操作也是完全可以的。只不過,這樣讀取檔案的方式,同時適用於讀取本地和HDFS上的檔案。
2.ImageLoader類。這個類是用來讀取圖片檔案的。類似的還有一個類,叫NativeImageLoader。不同的在於,NativeImageLoader是呼叫了OpenCV的相關方法來對圖片做處理的,效率更高,因此推薦使用NativeImageLoader
儲存的RDD的形式如下圖:
然後,講下模型訓練任務的邏輯。讀取HDFS上的以DataSet形式儲存的Mnist檔案-->定義引數中心服務-->定義神經網路結構(Lenet)--> 訓練網路-->儲存訓練好的模型。首先看前兩步的操作:
- SparkConf conf = new SparkConf()
- .set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator") //register kryo for nd4j
- .setAppName("Mnist Java Spark (Java)");
- final String imageFilePath = args[0];
- finalint numEpochs = Integer.parseInt(args[1]);
- final String modelPath = args[2];
- finalint numBatch = Integer.parseInt(args[3]);
- //
- JavaSparkContext jsc = new JavaSparkContext(conf);
- //
- JavaRDD<DataSet> javaRDDImageTrain = jsc.objectFile(imageFilePath); //load image data from hdfs
- ParameterAveragingTrainingMaster trainMaster = new ParameterAveragingTrainingMaster.Builder(numBatch) //weight average service
- .workerPrefetchNumBatches(0)
- .saveUpdater(true)
- .averagingFrequency(5)
- .batchSizePerWorker(numBatch)
這裡需要解釋的是ParameterAveragingTrainingMaster這個類。這個類的作用是用於將spark worker節點上各自計算的權重收回到driver節點上進行加權平均,並將最新的權重廣播到worker節點上。也即為:將各個工作節點的引數的均值作為全域性引數值。這種分散式機器學習中,資料並行化的一種操作。
下面是定義神經網路結構和訓練網路:
- int nChannels = 1;
- int outputNum = 10;
- int iterations = 1;
- int seed = 123;
- MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() //define lenent
- .seed(seed)
- .iterations(iterations)
- .regularization(true).l2(0.0005)
- .learningRate(0.1)
- .learningRateScoreBasedDecayRate(0.5)
- .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
- .updater(Updater.ADAM)
- .list()
- .layer(0, new ConvolutionLayer.Builder(5, 5)
- .nIn(nChannels)
- .stride(1, 1)
- .nOut(20)
- .weightInit(WeightInit.XAVIER)