深度學習 之 DeepLearning4j 預測股市走向
阿新 • • 發佈:2018-05-23
深度學習 之 DeepLearning4上一篇,預測花的類型,是沒有用到中間件的,實際情況是,數據量是非常大的,所以不實用,這次使用DeepLearning4j來預測股市走向,後續加上spark。代碼如下:
public class DailyData { //開盤價 private double openPrice; //收盤價 private double closeprice; //最高價 private double maxPrice; //最低價 private double minPrice; //成交量 private double turnover; //成交額 private double volume; public double getTurnover() { return turnover; } public double getVolume() { return volume; } public DailyData(){ } public double getOpenPrice() { return openPrice; } public double getCloseprice() { return closeprice; } public double getMaxPrice() { return maxPrice; } public double getMinPrice() { return minPrice; } public void setOpenPrice(double openPrice) { this.openPrice = openPrice; } public void setCloseprice(double closeprice) { this.closeprice = closeprice; } public void setMaxPrice(double maxPrice) { this.maxPrice = maxPrice; } public void setMinPrice(double minPrice) { this.minPrice = minPrice; } public void setTurnover(double turnover) { this.turnover = turnover; } public void setVolume(double volume) { this.volume = volume; } @Override public String toString(){ StringBuilder builder = new StringBuilder(); builder.append("開盤價="+this.openPrice+", "); builder.append("收盤價="+this.closeprice+", "); builder.append("最高價="+this.maxPrice+", "); builder.append("最低價="+this.minPrice+", "); builder.append("成交量="+this.turnover+", "); builder.append("成交額="+this.volume); return builder.toString(); } }
public class StockDataIterator implements DataSetIterator { /** * */ private static final long serialVersionUID = 1L; private static final int VECTOR_SIZE = 6; //每批次的訓練數據組數 private int batchNum; //每組訓練數據長度(DailyData的個數) private int exampleLength; //數據集 private List<DailyData> dataList; //存放剩余數據組的index信息 private List<Integer> dataRecord; private double[] maxNum; /** * 構造方法 * */ public StockDataIterator(){ dataRecord = new ArrayList<>(); } /** * 加載數據並初始化 * */ public boolean loadData(String fileName, int batchNum, int exampleLength){ this.batchNum = batchNum; this.exampleLength = exampleLength; maxNum = new double[6]; //加載文件中的股票數據 try { readDataFromFile(fileName); }catch (Exception e){ e.printStackTrace(); return false; } //重置訓練批次列表 resetDataRecord(); return true; } /** * 重置訓練批次列表 * */ private void resetDataRecord(){ dataRecord.clear(); int total = dataList.size()/exampleLength+1; for( int i=0; i<total; i++ ){ dataRecord.add(i * exampleLength); } } /** * 從文件中讀取股票數據 * */ public List<DailyData> readDataFromFile(String fileName) throws IOException{ dataList = new ArrayList<>(); BufferedReader in = new BufferedReader(new InputStreamReader(StockDataIterator.class.getResourceAsStream(fileName) ,"UTF-8")); String line = in.readLine(); for(int i=0;i<maxNum.length;i++){ maxNum[i] = 0; } System.out.println("讀取數據.."); while(line!=null){ String[] strArr = line.split(","); if(strArr.length>=7) { DailyData data = new DailyData(); //獲得最大值信息,用於歸一化 double[] nums = new double[6]; for(int j=0;j<6;j++){ nums[j] = Double.valueOf(strArr[j+2]); if( nums[j]>maxNum[j] ){ maxNum[j] = nums[j]; } } //構造data對象 data.setOpenPrice(Double.valueOf(nums[0])); data.setCloseprice(Double.valueOf(nums[1])); data.setMaxPrice(Double.valueOf(nums[2])); data.setMinPrice(Double.valueOf(nums[3])); data.setTurnover(Double.valueOf(nums[4])); data.setVolume(Double.valueOf(nums[5])); dataList.add(data); } line = in.readLine(); } in.close(); System.out.println("反轉list..."); Collections.reverse(dataList); return dataList; } public double[] getMaxArr(){ return this.maxNum; } public void reset(){ resetDataRecord(); } public boolean hasNext(){ return dataRecord.size() > 0; } public DataSet next(){ return next(batchNum); } /** * 獲得接下來一次的訓練數據集 * */ public DataSet next(int num){ if( dataRecord.size() <= 0 ) { throw new NoSuchElementException(); } int actualBatchSize = Math.min(num, dataRecord.size()); int actualLength = Math.min(exampleLength,dataList.size()-dataRecord.get(0)-1); INDArray input = Nd4j.create(new int[]{actualBatchSize,VECTOR_SIZE,actualLength}, ‘f‘); INDArray label = Nd4j.create(new int[]{actualBatchSize,1,actualLength}, ‘f‘); DailyData nextData = null,curData = null; //獲取每批次的訓練數據和標簽數據 for(int i=0;i<actualBatchSize;i++){ int index = dataRecord.remove(0); int endIndex = Math.min(index+exampleLength,dataList.size()-1); curData = dataList.get(index); for(int j=index;j<endIndex;j++){ //獲取數據信息 nextData = dataList.get(j+1); //構造訓練向量 int c = endIndex-j-1; input.putScalar(new int[]{i, 0, c}, curData.getOpenPrice()/maxNum[0]); input.putScalar(new int[]{i, 1, c}, curData.getCloseprice()/maxNum[1]); input.putScalar(new int[]{i, 2, c}, curData.getMaxPrice()/maxNum[2]); input.putScalar(new int[]{i, 3, c}, curData.getMinPrice()/maxNum[3]); input.putScalar(new int[]{i, 4, c}, curData.getTurnover()/maxNum[4]); input.putScalar(new int[]{i, 5, c}, curData.getVolume()/maxNum[5]); //構造label向量 label.putScalar(new int[]{i, 0, c}, nextData.getCloseprice()/maxNum[1]); curData = nextData; } if(dataRecord.size()<=0) { break; } } return new DataSet(input, label); } public int batch() { return batchNum; } public int cursor() { return totalExamples() - dataRecord.size(); } public int numExamples() { return totalExamples(); } public void setPreProcessor(DataSetPreProcessor preProcessor) { throw new UnsupportedOperationException("Not implemented"); } public int totalExamples() { return (dataList.size()) / exampleLength; } public int inputColumns() { return dataList.size(); } public int totalOutcomes() { return 1; } @Override public List<String> getLabels() { throw new UnsupportedOperationException("Not implemented"); } @Override public void remove() { throw new UnsupportedOperationException(); } @Override public boolean resetSupported() { // TODO Auto-generated method stub return false; } @Override public boolean asyncSupported() { // TODO Auto-generated method stub return false; } @Override public DataSetPreProcessor getPreProcessor() { // TODO Auto-generated method stub return null; } }
public class Dtest { private static final int IN_NUM = 6; private static final int OUT_NUM = 1; private static final int Epochs = 1; private static final int lstmLayer1Size = 50; private static final int lstmLayer2Size = 100; public static MultiLayerNetwork getNetModel(int nIn,int nOut){ MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .seed(12345) .l2(0.001) .updater(Updater.RMSPROP) .list() .layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(lstmLayer1Size) .activation(Activation.TANH).build()) .layer(1, new GravesLSTM.Builder().nIn(lstmLayer1Size).nOut(lstmLayer2Size) .activation(Activation.TANH).build()) .layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY) .nIn(lstmLayer2Size).nOut(nOut).build()) .pretrain(false).backprop(true) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); net.setListeners(new ScoreIterationListener(1)); return net; } public static void train(MultiLayerNetwork net,StockDataIterator iterator){ //叠代訓練 for(int i=0;i<Epochs;i++) { DataSet dataSet = null; while (iterator.hasNext()) { dataSet = iterator.next(); net.fit(dataSet); } iterator.reset(); System.out.println(); System.out.println("=================>完成第"+i+"次完整訓練"); INDArray initArray = getInitArray(iterator); System.out.println("預測結果:"); for(int j=0;j<20;j++) { INDArray output = net.rnnTimeStep(initArray); System.out.print(output.getDouble(0)*iterator.getMaxArr()[1]+" "); } System.out.println(); net.rnnClearPreviousState(); } } private static INDArray getInitArray(StockDataIterator iter){ double[] maxNums = iter.getMaxArr(); INDArray initArray = Nd4j.zeros(1, 6, 1); initArray.putScalar(new int[]{0,0,0}, 3433.85/maxNums[0]); initArray.putScalar(new int[]{0,1,0}, 3445.41/maxNums[1]); initArray.putScalar(new int[]{0,2,0}, 3327.81/maxNums[2]); initArray.putScalar(new int[]{0,3,0}, 3470.37/maxNums[3]); initArray.putScalar(new int[]{0,4,0}, 304197903.0/maxNums[4]); initArray.putScalar(new int[]{0,5,0}, 3.8750365e+11/maxNums[5]); return initArray; } public static void main(String[] args) { String inputFile = "sz399905.csv"; int batchSize = 1; int exampleLength = 30; //初始化深度神經網絡 StockDataIterator iterator = new StockDataIterator(); iterator.loadData(inputFile,batchSize,exampleLength); MultiLayerNetwork net = getNetModel(IN_NUM,OUT_NUM); train(net, iterator); } }
數據格式如下:
sz399905 2015/12/11 7320.16 7290.7 7253.84 7347.36 72132287 1.12E+11 -0.008096367
sz399905 2015/12/10 7374.35 7350.21 7332.98 7437.71 78990424 1.30E+11 -0.003262696
sz399905 2015/12/9 7369.11 7374.27 7322.87 7431.04 83299991 1.32E+11 -0.004034229
sz399905 2015/12/8 7555.46 7404.14 7398.56 7555.46 94938823 1.47E+11 -0.026056828
sz399905 2015/12/7 7526.22 7602.23 7476.19 7602.77 92881296 1.47E+11 0.012055908
sz399905 2015/12/4 7533.61 7511.67 7464.28 7600.34 101362535 1.55E+11 -0.007772264
sz399905 2015/12/3 7413.22 7570.51 7412.65 7571.45 95329412 1.43E+11 0.022232394
sz399905 2015/12/2 7423.5 7405.86 7201.66 7444.22 102647475 1.50E+11 -0.005115571
sz399905 2015/12/1 7403.94 7443.94 7358.37 7519.94 113008679 1.73E+11 0.004797257
sz399905 2015/11/30 7388.28 7408.4 7035.55 7467.47 129234023 1.97E+11 0.004376285
sz399905 2015/11/27 7839.31 7376.12 7317.65 7852 152970489 2.34E+11 -0.063240404
sz399905 2015/11/26 7962.17 7874.08 7859.63 7974.73 140404615 2.29E+11 -0.006096653
sz399905 2015/11/25 7803.29 7922.38 7795.16 7925.54 124435501 2.07E+11 0.015885106
sz399905 2015/11/24 7739.09 7798.5 7635.78 7799.01 110258558 1.69E+11 0.0070143
參考文章:
https://blog.csdn.net/a398942089/article/details/52294082
深度學習 之 DeepLearning4j 預測股市走向