Deep Spatio-Temporal Residual Networks for Citywide Crowd Flows Prediction
本博文是對鄭宇老師團隊所提出的STResNet網路的一個略微擴充說明。本人自己在看完這篇論文的時候,感覺就一個字‘懵’。你說不懂吧,好像又明白點,你說懂吧又感覺有好多細節還是不清楚。好在該論文開放了原始碼。經過對原始碼的一番剖析,總算是弄懂之前不明白的一些細節。不過該原始碼是基於Keras實現的,由於本人之前一直使用Tensorflow,所以又對其利用tf進行了重構,程式碼整體上看起也來更加簡潔,地址見文末。
1.背景
這是2017年發表在AAAI上論文,其研究目的是對某個地方下一時刻車輛進出流量進行預測。作者說到,按照這樣的思想利用論文中提出的模型同樣還可以對某個區域的人流量,外賣訂單量,快遞收發量進行預測
所謂車流量預測,指的是利用歷史資料對某個區域下一時間點的進/出車流量進行預測,也就是論文中所指的In-flow和Out-flow. 同時對於In-flow和Out-flow的統計定義如下所示:
其中圖p0063中公式所表示的含義如下圖所示:
上圖為一個
的區域,代表的是在某時間片
時兩個車輛的執行軌跡,則
。其統計規則如下:
在
時間時,車輛A的移動軌跡(藍色)歷經了4個區域(
);車輛B的移動軌跡歷經了3個區域。
計算in-flow:對於A來說(
)此時有:
;對於B來說(
),但此時不滿足公式,所以
計算out-flow:對於A來說(
)此時有:
;對於B來說(
)此時有:
,所以
不過這都不需要你來統計,論文中所提供的資料集都已統計好了,明白這個意思就好。
2.論文介紹
2.1 資料預處理
在理解模型前我們首先來看看餵給網路的資料都長什麼樣,這樣有助於理解。
如圖p0067所示,最原始的資料集已經將整個北京市劃分成了一個
的小區域,並且也已經統計出了每個小區域每隔半小時(一個時間片)的進出流量,即已經表示成了
的格式。同時論文在實現時候,採用的是用當前時刻的前3個時間片來模擬鄰近性(Closeness),用當前時刻前一天的相同時刻的一個時間片來模擬週期性(Period),用當前時刻前一週的相同時刻的一個時間片來模擬趨勢性(Trend),即程式碼中的len_closeness=3,len_period=1,len_trend=1
作為三個超引數。也就是用這三個部分來預測
時刻的流量。
同時除了車流量資料之外,論文中還引入了其它額外的氣象等資料,分別是:time_feature,holiday_feature,meteorol_feature
最終將這三個部分拼成一個向量meta_feature
.
對於每個時間片來說:
time_feature
有8維度,前面7個維度為one-hot
形式,最後以為表示當天是否為工作日;例如圖p0069中的含義為,該時間片對應為星期四且為工作日。
holiday_feature
有1個維度,0表示時間片所在的當天為工作日,1表示假期。
meteorol_feature
有19個維度,前面17個也為one-hot
形式,表示天氣型別中的一種,後面兩個維度分別表示風速和溫度
最後將這個三個向量拼接成了一個28維度的向量。也就是說,現在我們已經知道了整個網路輸入資料的形式了。對於資料預處理的這部分,直接呼叫下面函式即可獲取:
X_train, Y_train, X_test, Y_test, mmn, external_dim, timestamp_train, timestamp_test = \
load_data(len_closeness=3, len_period=1, len_trend=1, len_test=4*7* 48)
2.2 網路構建
首先定義了網路的輸入部分,筆者將其分成了5個placeholder
,其含義如變數名;然後接著就是定義網路的部分,即Closeness,Period,Trend這三個部分和天氣模組;最後就是評估和訓練模組。分別在下面這幾個方法中被定義:
def _build_placeholder(self):
def _build_stresnet(self, ):
def evaluate(self, mmn, x, y):
def train(self, x, y):
其它的部分在程式碼中都有詳細的註釋。
3.論文結果
筆者在用tersorflow重構的時候考慮到BN對實驗結果影響不大就沒有加上。但是即便如此,其它配置下效果依舊不如論文中的結果。筆記將引數初始化的方式等多個細節都同論文作者的程式碼進行過對比,都保持了一致,最終其它幾個引數配置的結果如下:
Config | epoches | RMSE |
---|---|---|
L2-E | 6082 | 22.32 |
後面的結果會不定期放在github上:
https://github.com/TolicWang/DeepST