1. 程式人生 > >My Roadmap in Reinforcement Learning

My Roadmap in Reinforcement Learning

一、前言

前段時間接受導師的建議,學習了一些強化學習GANs的內容,第一週先看的強化學習,二三週看的GANs。強化學習(RL)是一個很有趣的領域,一直以來也是我很喜歡的一個AI的分支,被譽為是AI皇冠上的明珠,因為通過RL能很直觀地反映出“智慧”。第一週看完之後有不少收穫,當時想著要寫一篇部落格記錄下來,結果一拖再拖…
時至今日,已經是第四周了,本來給自己定的本週計劃是入門object tracking領域,幾天過去了,感覺tracking的入門門檻相對比較高,自己這幾天論文看得很艱難,有點迷失了,加上我導師找我談話催論文,弄得我心煩意亂無心學習。既然如此,不如利用現在的迷失時間來歸納一下RL方面的內容。閒話少說,現在開始進入RL時間。

================================================================

二、 從Q-Learning談起

要入門RL,首先的入門演算法就是Q-Learning了。Q-Learning如果換一個更大的名字,應該是基於值迭代的馬爾科夫鏈的一個求解演算法,認識到Q-Learning的這個名字有助於更深地把握其背後的數學思想,從而能夠將其用到求解其他基於馬爾科夫鏈建模的數學模型中去(至少我是這麼認為的。。)。
Q-Learning因為是基於值迭代的求解演算法(這個和後面會提到的基於策略迭代的演算法,比如policy gradient剛好對立),所以理解起來,它其實就是在玩一個Q-tabel,也就是儲存著Q

(state,action)的一個表格,表格的橫欄表示不同的state,縱欄表示不同的action。

Q-Learning演算法的目的就是不斷地迭代優化這個Q-tabel,直至其收斂以逼近理想化的Q-value(真正理想化的Q-value是得不到的,只能approximate),收斂之後每個state的決策就可以直接從Q-tabel中查詢Q-value最大的action作為當前state的決策了。

π(s)=argmaxaQ(s,a)

而Q-learning優化Q-tabel的方法基於以下幾個數學公式:

一是將discounted future reward定義為

Rt=rt+γrt

+1+γ2rt+2+γ3rt+3...+γntrn(其中γ是discount factor)

二是將Q-value(或者Q function)定義為:

the maximum discounted future reward when we perform action a in state s, and continue optimally from that point on
或者
the best possible score at the end of the game after performing action a in state s

從這個定義來看Q-value是一個很理想化的值,想真正地得到理想的Q-value是沒有辦法。但是,我們可以去逼近它,所以說Q-Learning整個演算法其實就是在為了逼近理想化的Q-value而不斷地“努力”。

Q-value(或者Q function)用數學表達就是:

Q(st,at)=maxRt
(注意:在intel的博文中寫成了Rt+1,我覺得應該有問題)

三是,拜一、二、所賜,就能順理成章地得到大名鼎鼎的Bellman Equation了:

Q(s,a)=r+γmaxaQ(s,a)

其中,r代表即時的reward(也就是執行action後立馬得到的reward),s是執行r之後轉到的下一個state,rs都是通過simulator(也就是game模擬器)觀測到的。

拿到了Bellman Equation,就可以快樂地迭代更新Q-table中的Q-value了,整個的演算法流程如下:

這裡寫圖片描述

上述演算法流程在“select and carry out an action a”的時候可以有不同的策略,比如可以是隨機的,也可以是以ϵ的概率的隨機選擇,1ϵ的概率選擇當前Q-value最高對應的action,這種選取方法被稱為ϵ-greedy exploration。

雖然Q-Learning演算法迭代之初得到的Q-value可能相比理想的Q-value差別很大,但是已經有理論證明,只要有足夠多的iterations,那麼Q-tabel最終會收斂並且能夠表示“理想的(卻是我們真正想要的)”Q-value。

————————————————————————————————————-

Sarsa和Sarsa-Lambda

觀察上圖的Q-Learning演算法流程圖,雖然每次都是根據最大化的原則來選擇a所對應的Q(s,a)來更新Q(s,a),但是進入下一個iteration後,選擇的action未必就是a,因此這樣來看,傳統的Q-Learning看起來有點“不負責任”,盲目地追求用盡量大的值來更新Q(s,a)卻又不真正地執行a。由此有了Q-Learning的另一個版本,叫做SARSA,這個奇怪的名字其實就是(state-action-reward-sate-action)的首字母組合,從名字就能看出,SARSA是屬於“言出必行”型別的演算法,是一個“實踐派”,既然使用了最大的Q(s,a)來更新Q(s,a),那麼我就下一個iteration就執行a。相比之下,Q-Learning則有點冒險,因為它過度地去explore了,不像SARSA那麼保守務實,一步一一個腳印。

兩者演算法的對比見下圖,從Q(s,a)的更新來看,Sarsa言出必行。因此,Q-Learning是off-policy,而Sarsa是on-policy演算法:

另外還有一個Q-Learning的版本叫做Sarsa-Lambda,理解起來就是,Sarsa是屬於單步更新演算法,以“尋寶為例”,Sarsa只會給尋找到寶藏的前一步一個獎勵,而忽略了之前許多步的作用。因此引入了回合更新演算法,這樣就可以照顧到先前的步。

而lambda=1的時候相當於同等地看待先前的所有步,為了引入時間上的discount,所以lambda常常介於0,1之間(0對應的就是普通的Sarsa了)

為了更具體地瞭解一下Sarsa-lambda的思想和具體實現,可以看一下下面的程式碼語句塊:

class SarsaLambdaTable(RL): # 繼承 RL class
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, trace_decay=0.9):
        ...
    def check_state_exist(self, state):
        ...
    def learn(self, s, a, r, s_, a_):
        # 這部分和 Sarsa 一樣
        self.check_state_exist(s_)
        q_predict = self.q_table.ix[s, a]
        if s_ != 'terminal':
            q_target = r + self.gamma * self.q_table.ix[s_, a_]
        else:
            q_target = r
        error = q_target - q_predict

        # 這裡開始不同:
        # 對於經歷過的 state-action, 我們讓他+1, 證明他是得到 reward 路途中不可或缺的一環
        self.eligibility_trace.ix[s, a] += 1

        # Q table 更新
        self.q_table += self.lr * error * self.eligibility_trace

        # 隨著時間衰減 eligibility trace 的值, 離獲取 reward 越遠的步, 他的"不可或缺性"越小
        self.eligibility_trace *= self.gamma*self.lambda_

從程式碼中可以看到,借用eligibility_trace為“橋樑”,在每一步的Q tabel更新時可以更新整個Q table(或者先前所有步),並且eligibility_trace隨時間以gamma,隨步距以lambda_的速度衰減。

————————————————————————————————————-
以上就是初入RL需要掌握的三個經典演算法, Q-Learning(off-policy), Sarsa(on-policy), Sarsa-Lambda(on-policy)。其中Sarsa和Sarsa-Lambda主要參考了莫煩強化學習教程,在我看來應該是優於Q-Learning的演算法,尤其是Sarsa-Lambda的回合更新在reward的分配上是一個考慮相對更加周全的演算法,不過Sarsa, Sarsa-Lambda背後的主要思想都是繼承了Q-Learning,都是基於值迭代的馬爾科夫鏈的一個求解演算法!

================================================================

三、Deep Q Network

前邊提到的Q-Learning及其衍生演算法有一個弊病,那就是當state的數量很大時,比如在處理視訊遊戲(比如Atari遊戲)的時候,用畫素組合表徵的state是一個天文數字,這樣一來Q-Learning就有點捉襟見肘了。而神經網路對於建模高維結構化資料是一大利器,所以用Network去扮演高維Q table的角色,學習一個Q function成了一個自然的想法。

使用神經網路來approximate Qs,a有兩種可以選擇的結構:

其中,以state作為輸入,輸出各個action對應的Q-value只需要一次forward pass就能得到Q tabel的一行了,更加方便和容易建模。

需要一提的是,如果輸入的state是影象或者視訊,通常不會使用pooling層,因為pooling層會對transition不敏感,而視訊遊戲中對物體的位置資訊是需要保留的。

根據對Q function的理解,給定一個transition<s,a,r,s>,那麼DQN的損失函式定義為:

L=12[r+maxaQ(s,a)Q(s,a)]2

其中r+maxaQ(s,a)是target,Q(s,a)是prediction,ss的下一個state。

先前的Q-table更新演算法由此變更為:
————————————————————————-
1. Do a feedforward pass for the current st