Keras LSTM 時間序列預測
阿新 • • 發佈:2018-11-25
Keras LSTM 時間序列預測
international-airline-passengers.csv資料記錄:
time,passengers "1949-01",112 "1949-02",118 "1949-03",132 "1949-04",129 "1949-05",121 "1949-06",135 "1949-07",148 "1949-08",148 "1949-09",136 "1949-10",119 "1949-11",104 "1949-12",118 "1950-01",115 "1950-02",126 "1950-03",141 "1950-04",135 "1950-05",125 "1950-06",149 "1950-07",170 "1950-08",170 "1950-09",158 "1950-10",133 "1950-11",114 "1950-12",140 "1951-01",145 "1951-02",150 "1951-03",178 "1951-04",163 "1951-05",172 "1951-06",178 "1951-07",199 "1951-08",199 "1951-09",184 "1951-10",162 "1951-11",146 "1951-12",166 "1952-01",171 "1952-02",180 "1952-03",193 "1952-04",181 "1952-05",183 "1952-06",218 "1952-07",230 "1952-08",242 "1952-09",209 "1952-10",191 "1952-11",172 "1952-12",194 "1953-01",196 "1953-02",196 "1953-03",236 "1953-04",235 "1953-05",229 "1953-06",243 "1953-07",264 "1953-08",272 "1953-09",237 "1953-10",211 "1953-11",180 "1953-12",201 "1954-01",204 "1954-02",188 "1954-03",235 "1954-04",227 "1954-05",234 "1954-06",264 "1954-07",302 "1954-08",293 "1954-09",259 "1954-10",229 "1954-11",203 "1954-12",229 "1955-01",242 "1955-02",233 "1955-03",267 "1955-04",269 "1955-05",270 "1955-06",315 "1955-07",364 "1955-08",347 "1955-09",312 "1955-10",274 "1955-11",237 "1955-12",278 "1956-01",284 "1956-02",277 "1956-03",317 "1956-04",313 "1956-05",318 "1956-06",374 "1956-07",413 "1956-08",405 "1956-09",355 "1956-10",306 "1956-11",271 "1956-12",306 "1957-01",315 "1957-02",301 "1957-03",356 "1957-04",348 "1957-05",355 "1957-06",422 "1957-07",465 "1957-08",467 "1957-09",404 "1957-10",347 "1957-11",305 "1957-12",336 "1958-01",340 "1958-02",318 "1958-03",362 "1958-04",348 "1958-05",363 "1958-06",435 "1958-07",491 "1958-08",505 "1958-09",404 "1958-10",359 "1958-11",310 "1958-12",337 "1959-01",360 "1959-02",342 "1959-03",406 "1959-04",396 "1959-05",420 "1959-06",472 "1959-07",548 "1959-08",559 "1959-09",463 "1959-10",407 "1959-11",362 "1959-12",405 "1960-01",417 "1960-02",391 "1960-03",419 "1960-04",461 "1960-05",472 "1960-06",535 "1960-07",622 "1960-08",606 "1960-09",508 "1960-10",461 "1960-11",390 "1960-12",432
Keras LSTM時間序列lstm_airline_predict.py:
import pandas as pd import numpy as np import matplotlib.pyplot as plt from sklearn.preprocessing import MinMaxScaler from keras.models import Sequential from keras.layers import LSTM, Dense, Activation def load_data(file_name, sequence_length=10, split=0.8): df = pd.read_csv(file_name, sep=',', usecols=[1]) data_all = np.array(df).astype(float) scaler = MinMaxScaler() data_all = scaler.fit_transform(data_all) data = [] for i in range(len(data_all) - sequence_length - 1): data.append(data_all[i: i + sequence_length + 1]) reshaped_data = np.array(data).astype('float64') np.random.shuffle(reshaped_data) # 對x進行統一歸一化,而y則不歸一化 x = reshaped_data[:, :-1] y = reshaped_data[:, -1] split_boundary = int(reshaped_data.shape[0] * split) train_x = x[: split_boundary] test_x = x[split_boundary:] train_y = y[: split_boundary] test_y = y[split_boundary:] return train_x, train_y, test_x, test_y, scaler def build_model(): # input_dim是輸入的train_x的最後一個維度,train_x的維度為(n_samples, time_steps, input_dim) model = Sequential() model.add(LSTM(input_dim=1, output_dim=50, return_sequences=True)) print(model.layers) model.add(LSTM(100, return_sequences=False)) model.add(Dense(output_dim=1)) model.add(Activation('linear')) model.compile(loss='mse', optimizer='rmsprop') return model def train_model(train_x, train_y, test_x, test_y): model = build_model() try: model.fit(train_x, train_y, batch_size=512, nb_epoch=30, validation_split=0.1) predict = model.predict(test_x) predict = np.reshape(predict, (predict.size, )) except KeyboardInterrupt: print(predict) print(test_y) print(predict) print(test_y) try: fig = plt.figure(1) plt.plot(predict, 'r:') plt.plot(test_y, 'g-') plt.legend(['predict', 'true']) except Exception as e: print(e) return predict, test_y if __name__ == '__main__': train_x, train_y, test_x, test_y, scaler = load_data('international-airline-passengers.csv') train_x = np.reshape(train_x, (train_x.shape[0], train_x.shape[1], 1)) test_x = np.reshape(test_x, (test_x.shape[0], test_x.shape[1], 1)) predict_y, test_y = train_model(train_x, train_y, test_x, test_y) predict_y = scaler.inverse_transform([[i] for i in predict_y]) test_y = scaler.inverse_transform(test_y) fig2 = plt.figure(2) plt.plot(predict_y, 'g:') plt.plot(test_y, 'r-') plt.show()
執行結果:
Epoch 1/30 95/95 [==============================] - 5s 53ms/step - loss: 0.1793 - val_loss: 0.1028 Epoch 2/30 95/95 [==============================] - 0s 412us/step - loss: 0.1015 - val_loss: 0.0528 Epoch 3/30 95/95 [==============================] - 0s 353us/step - loss: 0.0532 - val_loss: 0.0183 Epoch 4/30 95/95 [==============================] - 0s 359us/step - loss: 0.0204 - val_loss: 0.0113 Epoch 5/30 95/95 [==============================] - 0s 448us/step - loss: 0.0145 - val_loss: 0.0119 Epoch 6/30 95/95 [==============================] - 0s 507us/step - loss: 0.0140 - val_loss: 0.0114 Epoch 7/30 95/95 [==============================] - 0s 439us/step - loss: 0.0135 - val_loss: 0.0120 Epoch 8/30 95/95 [==============================] - 0s 373us/step - loss: 0.0132 - val_loss: 0.0118 Epoch 9/30 95/95 [==============================] - 0s 454us/step - loss: 0.0129 - val_loss: 0.0127 Epoch 10/30 95/95 [==============================] - 0s 413us/step - loss: 0.0129 - val_loss: 0.0127 Epoch 11/30 95/95 [==============================] - 0s 418us/step - loss: 0.0129 - val_loss: 0.0147 Epoch 12/30 95/95 [==============================] - 0s 369us/step - loss: 0.0139 - val_loss: 0.0145 Epoch 13/30 95/95 [==============================] - 0s 485us/step - loss: 0.0141 - val_loss: 0.0182 Epoch 14/30 95/95 [==============================] - 0s 459us/step - loss: 0.0166 - val_loss: 0.0146 Epoch 15/30 95/95 [==============================] - 0s 549us/step - loss: 0.0138 - val_loss: 0.0168 Epoch 16/30 95/95 [==============================] - 0s 423us/step - loss: 0.0149 - val_loss: 0.0141 Epoch 17/30 95/95 [==============================] - 0s 401us/step - loss: 0.0129 - val_loss: 0.0155 Epoch 18/30 95/95 [==============================] - 0s 383us/step - loss: 0.0134 - val_loss: 0.0141 Epoch 19/30 95/95 [==============================] - 0s 328us/step - loss: 0.0125 - val_loss: 0.0154 Epoch 20/30 95/95 [==============================] - 0s 401us/step - loss: 0.0130 - val_loss: 0.0144 Epoch 21/30 95/95 [==============================] - 0s 338us/step - loss: 0.0124 - val_loss: 0.0158 Epoch 22/30 95/95 [==============================] - 0s 359us/step - loss: 0.0131 - val_loss: 0.0148 Epoch 23/30 95/95 [==============================] - 0s 338us/step - loss: 0.0126 - val_loss: 0.0164 Epoch 24/30 95/95 [==============================] - 0s 380us/step - loss: 0.0135 - val_loss: 0.0150 Epoch 25/30 95/95 [==============================] - 0s 378us/step - loss: 0.0127 - val_loss: 0.0167 Epoch 26/30 95/95 [==============================] - 0s 541us/step - loss: 0.0137 - val_loss: 0.0151 Epoch 27/30 95/95 [==============================] - 0s 528us/step - loss: 0.0127 - val_loss: 0.0166 Epoch 28/30 95/95 [==============================] - 0s 423us/step - loss: 0.0134 - val_loss: 0.0150 Epoch 29/30 95/95 [==============================] - 0s 515us/step - loss: 0.0125 - val_loss: 0.0164 Epoch 30/30 95/95 [==============================] - 0s 457us/step - loss: 0.0131 - val_loss: 0.0150 [0.6991743 0.4155811 0.43763575 0.1943914 0.24489456 0.43544254 0.728908 0.27704275 0.7644203 0.24740852 0.58411294 0.33986062 0.28997922 0.13274276 0.74714196 0.5237809 0.36774576 0.5282971 0.23951268 0.6239692 0.15398878 0.4958876 0.10568523 0.55706674 0.32880494 0.60746497 0.294434 ] [[1. ] [0.25675676] [0.4034749 ] [0.11969112] [0.17374517] [0.58108108] [0.4980695 ] [0.25675676] [0.55405405] [0.17760618] [0.5 ] [0.31853282] [0.2992278 ] [0.01930502] [0.58108108] [0.48648649] [0.4015444 ] [0.38030888] [0.13127413] [0.61003861] [0.18339768] [0.38996139] [0.12741313] [0.63899614] [0.40733591] [0.87837838] [0.20656371]]