深度學習-基於spark的多層神經網路
阿新 • • 發佈:2019-01-01
最後我們再寫3篇基於spark的深度學習,這篇是手寫識別的,用的是spark的local模式,如果想用叢集模式在submit的時候設定-useSparkLocal false,或者在程式中設定useSparkLocal=false,程式碼如下
public class MnistMLPExample { private static final Logger log = LoggerFactory.getLogger(MnistMLPExample.class); @Parameter(names = "-useSparkLocal", description = "Use spark local (helper for testing/running without spark submit)".averagingFrequency(5)//該專案控制引數平均化和再分發的頻率,按大小等於batchSizePerWorker的微批次的數量計算。總體上的規則是:, arity = 1)//@Parameter這個是引數的意思,引數名是sparklocal,下面的true代表使用local模式,arity=1代表這個引數消費一個引數值 private boolean useSparkLocal = true; @Parameter(names = "-batchSizePerWorker", description = "Number of examples to fit each worker with")//每個worker訓練多少個例子,下面指定了16 private int batchSizePerWorker = 16; @Parameter(names = "-numEpochs", description = "Number of epochs for training")//訓練的步數,下面指定了15 private int numEpochs = 15; public static void main(String[] args) throws Exception { new MnistMLPExample().entryPoint(args);//呼叫入口方法 } protected void entryPoint(String[] args) throws Exception {//Handle command line arguments JCommander jcmdr = new JCommander(this);//使用JCommander類處理命令列引數,這個好高階,之前spark也沒見過 try { jcmdr.parse(args);//解析引數 } catch (ParameterException e) { //User provides invalid input -> print the usage info jcmdr.usage();//如果使用者提供無效的輸入,提示使用方法 try { Thread.sleep(500); } catch (Exception e2) { } throw e; } SparkConf sparkConf = new SparkConf();//終於看到了sparkConf if (useSparkLocal) {//如果使用local模式,使用所有的核 sparkConf.setMaster("local[*]"); } sparkConf.setAppName("DL4J Spark MLP Example");//設定任務名 JavaSparkContext sc = new JavaSparkContext(sparkConf);//建立上下文環境 //Load the data into memory then parallelize//把資料並行載入記憶體 //This isn't a good approach in general - but is simple to use for this example DataSetIterator iterTrain = new MnistDataSetIterator(batchSizePerWorker, true, 12345);//建立繼承自BaseDatasetIterator的手寫資料迭代器,放入每個worker批大小,是否是訓練資料,種子三個引數DataSetIterator iterTest = new MnistDataSetIterator(batchSizePerWorker, true, 12345);//同樣搞一個測試迭代器,都設定為true的意思是訓練和測試使用的是同一批資料 List<DataSet> trainDataList = new ArrayList<>();//建立訓練和測試的陣列並把迭代器內容裝入陣列 List<DataSet> testDataList = new ArrayList<>(); while (iterTrain.hasNext()) { trainDataList.add(iterTrain.next()); } while (iterTest.hasNext()) { testDataList.add(iterTest.next()); } JavaRDD<DataSet> trainData = sc.parallelize(trainDataList);//訓練測試資料並行化,變成RDD JavaRDD<DataSet> testData = sc.parallelize(testDataList); //---------------------------------- //Create network configuration and conduct network training MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()//下面的和之前一樣了 .seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1) .activation("leakyrelu") .weightInit(WeightInit.XAVIER) .learningRate(0.02) .updater(Updater.NESTEROVS).momentum(0.9) .regularization(true).l2(1e-4) .list() .layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(500).build()) .layer(1, new DenseLayer.Builder().nIn(500).nOut(100).build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .activation("softmax").nIn(100).nOut(10).build()) .pretrain(false).backprop(true) .build(); //Configuration for Spark training: see http://deeplearning4j.org/spark for explanation of these configuration options//從這個連結可以看spark的相關配置https://deeplearning4j.org/cn/spark TrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(batchSizePerWorker) //Each DataSet object: contains (by default) 32 examples//ParameterAveragingTrainingMaster提供了一系列叢集執行的配置,上面的連結裡有詳細說明,建議想用spark的都要通讀
- 平均化週期太短(例如averagingFrequency=1)可能效率不佳(相對於計算量而言,網路通訊和初始化開銷太多)
- 平均化週期太長(例如averagingFrequency=200)可能會導致表現不佳(不同工作節點例項的引數可能出現很大差異)
- 通常將平均化週期設在5~10個微批次的範圍內比較保險
- 將該項的值設定為0會禁用預提取。
- 比較合理的預設值通常是2。過高的值在許多情況下並無幫助(但會使用更多的記憶體)