1. 程式人生 > >深度學習-基於spark的多層神經網路

深度學習-基於spark的多層神經網路

最後我們再寫3篇基於spark的深度學習,這篇是手寫識別的,用的是spark的local模式,如果想用叢集模式在submit的時候設定-useSparkLocal false,或者在程式中設定useSparkLocal=false,程式碼如下

public class MnistMLPExample {
    private static final Logger log = LoggerFactory.getLogger(MnistMLPExample.class);
@Parameter(names = "-useSparkLocal", description = "Use spark local (helper for testing/running without spark submit)"
, arity = 1)//@Parameter這個是引數的意思,引數名是sparklocal,下面的true代表使用local模式,arity=1代表這個引數消費一個引數值 private boolean useSparkLocal = true; @Parameter(names = "-batchSizePerWorker", description = "Number of examples to fit each worker with")//每個worker訓練多少個例子,下面指定了16 private int batchSizePerWorker = 16; @Parameter
(names = "-numEpochs", description = "Number of epochs for training")//訓練的步數,下面指定了15 private int numEpochs = 15; public static void main(String[] args) throws Exception { new MnistMLPExample().entryPoint(args);//呼叫入口方法 } protected void entryPoint(String[] args) throws Exception {
//Handle command line arguments JCommander jcmdr = new JCommander(this);//使用JCommander類處理命令列引數,這個好高階,之前spark也沒見過 try { jcmdr.parse(args);//解析引數 } catch (ParameterException e) { //User provides invalid input -> print the usage info jcmdr.usage();//如果使用者提供無效的輸入,提示使用方法 try { Thread.sleep(500); } catch (Exception e2) { } throw e; } SparkConf sparkConf = new SparkConf();//終於看到了sparkConf if (useSparkLocal) {//如果使用local模式,使用所有的核 sparkConf.setMaster("local[*]"); } sparkConf.setAppName("DL4J Spark MLP Example");//設定任務名 JavaSparkContext sc = new JavaSparkContext(sparkConf);//建立上下文環境 //Load the data into memory then parallelize//把資料並行載入記憶體 //This isn't a good approach in general - but is simple to use for this example DataSetIterator iterTrain = new MnistDataSetIterator(batchSizePerWorker, true, 12345);//建立繼承自BaseDatasetIterator的手寫資料迭代器,放入每個worker批大小,是否是訓練資料,種子三個引數DataSetIterator iterTest = new MnistDataSetIterator(batchSizePerWorker, true, 12345);//同樣搞一個測試迭代器,都設定為true的意思是訓練和測試使用的是同一批資料 List<DataSet> trainDataList = new ArrayList<>();//建立訓練和測試的陣列並把迭代器內容裝入陣列 List<DataSet> testDataList = new ArrayList<>(); while (iterTrain.hasNext()) { trainDataList.add(iterTrain.next()); } while (iterTest.hasNext()) { testDataList.add(iterTest.next()); } JavaRDD<DataSet> trainData = sc.parallelize(trainDataList);//訓練測試資料並行化,變成RDD JavaRDD<DataSet> testData = sc.parallelize(testDataList); //---------------------------------- //Create network configuration and conduct network training MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()//下面的和之前一樣了 .seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1) .activation("leakyrelu") .weightInit(WeightInit.XAVIER) .learningRate(0.02) .updater(Updater.NESTEROVS).momentum(0.9) .regularization(true).l2(1e-4) .list() .layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(500).build()) .layer(1, new DenseLayer.Builder().nIn(500).nOut(100).build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .activation("softmax").nIn(100).nOut(10).build()) .pretrain(false).backprop(true) .build(); //Configuration for Spark training: see http://deeplearning4j.org/spark for explanation of these configuration options//從這個連結可以看spark的相關配置https://deeplearning4j.org/cn/spark TrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(batchSizePerWorker) //Each DataSet object: contains (by default) 32 examples//ParameterAveragingTrainingMaster提供了一系列叢集執行的配置,上面的連結裡有詳細說明,建議想用spark的都要通讀
.averagingFrequency(5)//該專案控制引數平均化和再分發的頻率,按大小等於batchSizePerWorker的微批次的數量計算。總體上的規則是:
  • 平均化週期太短(例如averagingFrequency=1)可能效率不佳(相對於計算量而言,網路通訊和初始化開銷太多)
  • 平均化週期太長(例如averagingFrequency=200)可能會導致表現不佳(不同工作節點例項的引數可能出現很大差異)
  • 通常將平均化週期設在5~10個微批次的範圍內比較保險
.workerPrefetchNumBatches(2) //Async prefetching: 2 examples per workerSpark工作節點能夠以非同步方式預抓取一定數量的微批次(DataSet物件),從而避免資料載入時的等待。
  • 將該項的值設定為0會禁用預提取。
  • 比較合理的預設值通常是2。過高的值在許多情況下並無幫助(但會使用更多的記憶體)
.batchSizePerWorker(batchSizePerWorker)//該專案控制每個工作節點的微批次大小。這與單機定型中的微批次大小設定相仿。換言之,這是每個工作節點中每次引數更新所使用的樣例數量 .build();//Create the Spark networkSparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf, tm);//spark配置,網路配置,叢集執行配置放入SparkDl4jMultiLayer這個類,獲取一個spark網路,感覺已經封裝的很高階了//Execute training:for (int i = 0; i < numEpochs; i++) {//按步數訓練,列印步數 sparkNet.fit(trainData);log.info("Completed Epoch {}", i);} //Perform evaluation (distributed)Evaluation evaluation = sparkNet.evaluate(testData);//評估測試資料,並列印log.info("***** Evaluation *****");log.info(evaluation.stats());//Delete the temp training files, now that we are done with themtm.deleteTempFiles(sc);//刪除臨時的訓練資料檔案log.info("***** Example Complete *****");}}