1. 程式人生 > >7天微課程day6——用ARIMA模型進行時間序列預測

7天微課程day6——用ARIMA模型進行時間序列預測

宣告:

  1. 本文是系列課程的第6課
  2. 本文是對機器學習網站課程的翻譯
  3. 尊重原作者,尊重知識分享

用ARIMA模型進行時間序列預測

ARIMA(AutoRegressive Intergrated Moving Average)是一個非常非常流行的時間序列預測模型。

通過本文,你將瞭解:

  • 簡單瞭解ARIMA的原理
  • ARIMA如何學習時間序列並做出預測
  • ARIMA的調優

ARIMA

ARIMA可以看做幾個用於時間序列預測的模型的合併,它考慮了時間序列中集中標準的分佈模式,提供了簡單而高效的模型。ARIMA有以下幾部分組成:

  • AR: Autoregression.自迴歸模型建立當前預測與先前值的關係。
  • I: Intergrated. 表示加入了差分運算,用於表徵時間序列的趨勢。通過差分可以去除時間序列的上升或下降的趨勢,從而得到平穩序列。
  • MA: Moving Average. 移動平均模型,考慮的是當前值的預測與先前時刻的殘差有關。

這三個部分各有一個引數,通常表示為ARIMA(p, d, q). p, d, q取值皆為整數,三個引數的意義分別如下:

  • p: 對應AR,表示考慮的先前時刻的多少,稱為滯後值。
  • d: 對應I, 表示差分階數。
  • q: 對應MA,表示移動視窗的大小,也稱為移動視窗值。

Shampoo Sales Dataset

該資料集是3年內洗髮水的月銷量,詳見download

from pandas import read_csv
from pandas import datetime
from matplotlib import pyplot

def parser(x):
    return datetime.strptime('190'+x, '%Y-%m')

series = read_csv('shampoo-sales.csv', header=0, parse_dates=[0], index_col=0, squeeze=True, date_parser=parser)
print(series.head())
series.plot()
pyplot.show()
'''output
Month
1901-01-01 266.0
1901-02-01 145.9
1901-03-01 183.1
1901-04-01 119.3
1901-05-01 180.3
Name: Sales, dtype: float64
'''

可以看到洗髮水的銷量有一個明顯的上升趨勢。說明我們至少需要一階差分來使資料變平穩,即d=1.

我們在看一下時間序列的自相關性。

from pandas import read_csv
from pandas import datetime
from matplotlib import pyplot
from pandas.tools.plotting import autocorrelation_plot

def parser(x):
    return datetime.strptime('190'+x, '%Y-%m')

series = read_csv('shampoo-sales.csv', header=0, parse_dates=[0], index_col=0, squeeze=True, date_parser=parser)
autocorrelation_plot(series)
pyplot.show()

可以看到過去5天的自相關性較強,大於0.5,即p=5.

現在,我們初步確定了引數p和d的值,可以先構造一個p=5d=1的ARIMA,再觀察殘差項以確定引數q的值。

ARIMA實踐

Python中有statsmodel來建立ARIMA模型,具體步驟:

  1. ARIMA()定義模型,同時傳入引數p, d, q
  2. fit()訓練模型
  3. predict()預測

下面我們將訓練一個ARIMA模型,並觀察它的殘差已制定後面的引數q。經過前面視覺化分析,我們首先選用ARIMA(5, 1, 0)。

from pandas import read_csv
from pandas import datetime
from pandas import DataFrame
from statsmodels.tsa.arima_model import ARIMA
from matplotlib import pyplot

def parser(x):
    return datetime.strptime('190'+x, '%Y-%m')

series = read_csv('shampoo-sales.csv', header=0, parse_dates=[0], index_col=0, squeeze=True, date_parser=parser)

# fit model
model = ARIMA(series, order=(5, 1, 0))
model_fit = model.fit(disp=0)  # disp=0關閉對訓練資訊的列印
print(model_fit.summary())

# plot residual errors
residuals = DataFrame(model_fit.resid)
residuals.plot()
pyplot.show()
residuals.plot(kind='kde')
pyplot.show()
print(residuals.describe())
'''output
                             ARIMA Model Results
==============================================================================
Dep. Variable:                D.Sales   No. Observations:                   35
Model:                 ARIMA(5, 1, 0)   Log Likelihood                -196.170
Method:                       css-mle   S.D. of innovations             64.241
Date:                Mon, 12 Dec 2016   AIC                            406.340
Time:                        11:09:13   BIC                            417.227
Sample:                    02-01-1901   HQIC                           410.098
                         - 12-01-1903
=================================================================================
                    coef    std err          z      P>|z|      [95.0% Conf. Int.]
---------------------------------------------------------------------------------
const            12.0649      3.652      3.304      0.003         4.908    19.222
ar.L1.D.Sales    -1.1082      0.183     -6.063      0.000        -1.466    -0.750
ar.L2.D.Sales    -0.6203      0.282     -2.203      0.036        -1.172    -0.068
ar.L3.D.Sales    -0.3606      0.295     -1.222      0.231        -0.939     0.218
ar.L4.D.Sales    -0.1252      0.280     -0.447      0.658        -0.674     0.424
ar.L5.D.Sales     0.1289      0.191      0.673      0.506        -0.246     0.504
                                    Roots
=============================================================================
                 Real           Imaginary           Modulus         Frequency
-----------------------------------------------------------------------------
AR.1           -1.0617           -0.5064j            1.1763           -0.4292
AR.2           -1.0617           +0.5064j            1.1763            0.4292
AR.3            0.0816           -1.3804j            1.3828           -0.2406
AR.4            0.0816           +1.3804j            1.3828            0.2406
AR.5            2.9315           -0.0000j            2.9315           -0.0000
-----------------------------------------------------------------------------

count   35.000000
mean    -5.495213
std     68.132882
min   -133.296597
25%    -42.477935
50%     -7.186584
75%     24.748357
max    133.237980
'''

從表中可以看到ARIMA用於AR的常量係數和其他5個係數。
還有一個關於殘差的圖。


可以看到殘差還有一個明顯的趨勢,這正是MA模型要解決的問題。
殘差值的分佈圖如下:

在該圖中可以發現殘差並不是正態分佈,有一定的偏離,有一點不對稱,說明殘差中還有噪聲之外的資訊。這點從residuals.discribe()的輸出中也可以看得出來。

值得一提的是,在上面的例子中我們只是為了觀察residual的情況,所以沒有必要劃分訓練集和測試集。

ARIMA模型的預測

fit()後的ARIMA模型的返回值是ARIMAResults物件,對該物件呼叫predict()可以進行預測。predict()接受時刻t的序列作為引數。

例如,假設我們訓練集的時間序列有100個觀測值,若要預測下一時刻的輸出,則呼叫predict(start=101, end=101)會返回一個shape=(1,)的陣列。

predict()還有一個引數typ需要指定,若typ='linear',輸出差分值;若typ='levels,輸出預測值。Fortunately,在做一步預測時,我們可以用forecast()來代替predict(),就不需要那麼多引數要考慮了。

這裡採用的預測方式依然是walk-forward。walk-forward每一步都要加入新的觀測值,那我們就每一步都re-create ARIMA模型。

from pandas import read_csv
from pandas import datetime
from matplotlib import pyplot
from statsmodels.tsa.arima_model import ARIMA
from sklearn.metrics import mean_squared_error

def parser(x):
    return datetime.strptime('190'+x, '%Y-%m')

series = read_csv('shampoo-sales.csv', header=0, parse_dates=[0], index_col=0, squeeze=True, date_parser=parser)
X = series.values
size = int(len(X) * 0.66)
train, test = X[0:size], X[size:len(X)]
history = [x for x in train]
predictions = list()

for t in range(len(test)):
    model = ARIMA(history, order=(5, 1, 0))
    model_fit = model.fit(disp=0)
    output = model_fit.forecast()
    yhat = output[0]
    predictions.append(yhat)
    obs = test[t]
    history.append(obs)
    print('predicted=%f, expected=%f' % (yhat, obs))
error = mean_squared_error(test, predictions)
print('Test MSE: %.3f' % error)
# plot
pyplot.plot(test)
pyplot.plot(predictions, color='red')
pyplot.show()
'''output
predicted=349.117688, expected=342.300000
predicted=306.512968, expected=339.700000
predicted=387.376422, expected=440.400000
predicted=348.154111, expected=315.900000
predicted=386.308808, expected=439.300000
predicted=356.081996, expected=401.300000
predicted=446.379501, expected=437.400000
predicted=394.737286, expected=575.500000
predicted=434.915566, expected=407.600000
predicted=507.923407, expected=682.000000
predicted=435.483082, expected=475.300000
predicted=652.743772, expected=581.300000
predicted=546.343485, expected=646.900000
Test MSE: 6958.325
'''

ARIMA調優