1. 程式人生 > 實用技巧 >對莫煩第一個強化學習例項分析

對莫煩第一個強化學習例項分析

對莫煩第一個強化學習例項分析

1 原始碼

連結如下所示:
小例子

2 分析

2.1 變數

N_STATES = 6   # the length of the 1 dimensional world
ACTIONS = ['left', 'right']     # available actions
EPSILON = 0.9   # greedy police
ALPHA = 0.1     # learning rate
GAMMA = 0.9    # discount factor
MAX_EPISODES = 13   # maximum episodes
FRESH_TIME = 0.3    # fresh time for one move

N_STATES:代表有多少個state,如下圖所示,在這個小例子中,它代表的是冒險者(圖中的O)所能到達的位置,一共有6個。

ACTIONS:冒險者所能採取的動作,即向左(left)或向右(right)。
EPSILON:貪婪度,實際上這是一個用來控制隨機探索概率的。什麼意思?就是說在每次選擇動作之前,都會獲得一個隨機的概率(這個概率大於0小於1),然後程式會用它和EPSILON進行一個比較,當這個概率小於EPSILON時,採用當前state下值最大的動作(也就是我們學習到的值),否則,隨機選擇一個動作。
ALPHA:學習率,個人理解是確定將每次學到的值保留多少的變數,比如說這一次學到了100,但我的學習率為0.1,那麼就只保留$100*0.1=10$
GAMMA:獎勵遞減值,這個值會導致離terminal越遠,獲得的值越小。
MAX_EPISODES:最大訓練輪次。

2.2 建立一個q表

def build_q_table(n_states, actions):
    table = pd.DataFrame(
        np.zeros((n_states, len(actions))),     # q_table initial values
        columns=actions,    # actions's name
    )
    #print(table)    # show table
    return table

上面的程式碼主要是用pandas庫來建立一個q表,q表示q_learning演算法裡面會用到的一個表格,它主要用來儲存state-action對應值(在每個state下每個action對應的值,它代表的是狀態state下采取動作action所能獲得收益的預期)
上述程式碼建立的表格格式如下:

它的行數為state數,列數為action數,q表的大小就是action數量*state數量。在現實的一些應用中,q表可能會非常非常巨大,所以大部分強化學習其實不會用q_learning演算法。

2.3 選擇動作

def choose_action(state, q_table):
    # This is how to choose an action
    state_actions = q_table.iloc[state, :]
    if (np.random.uniform() > EPSILON) or ((state_actions == 0).all()):  # act non-greedy or state-action have no value
        action_name = np.random.choice(ACTIONS)
    else:   # act greedy
        action_name = state_actions.idxmax()    # replace argmax to idxmax as argmax means a different function in newer version of pandas
    return action_name

這個函式的輸入為state,q_table,輸出則為action_name。
下面再來細看程式碼,首先是第一行

state_actions = q_table.iloc[state, :]

這一段程式碼的意思是取出當前q_table中的第state行,然後賦值給state_actions。
在下面的一個if...else判斷句,顯然就是根據隨機生成的一個位於0與1之間的數,來判斷是選擇q_table中值較大的那個動作,還是隨機選擇一個動作。

2.4 獲取環境反饋值

def get_env_feedback(S, A):
    # This is how agent will interact with the environment
    if A == 'right':    # move right
        if S == N_STATES - 2:   # terminate
            S_ = 'terminal'
            R = 1
        else:
            S_ = S + 1
            R = 0
    else:   # move left
        R = 0
        if S == 0:
            S_ = S  # reach the wall
        else:
            S_ = S - 1
    return S_, R

輸入值為S(state)和A(action),輸出值則為S_(下個state)和R(reward),這裡的下個其實理解成第n+1個比較好。
這段程式碼不用細講,其實就是隻有判斷S到達了terminal,R才有不為0的返回值。

2.5 強化學習主迴圈

def rl():
    # main part of RL loop
    q_table = build_q_table(N_STATES, ACTIONS)
    for episode in range(MAX_EPISODES):
        step_counter = 0
        S = 0
        is_terminated = False
        #update_env(S, episode, step_counter)
        while not is_terminated:

            A = choose_action(S, q_table)
            S_, R = get_env_feedback(S, A)  # take action & get next state and reward
            q_predict = q_table.loc[S, A]
            if S_ != 'terminal':
                q_target = R + GAMMA * q_table.iloc[S_, :].max()   # next state is not terminal
            else:
                q_target = R     # next state is terminal
                is_terminated = True    # terminate this episode

            q_table.loc[S, A] += ALPHA * (q_target - q_predict)  # update
            S = S_  # move to next state

            #update_env(S, episode, step_counter+1)
            step_counter += 1
    return q_table

直接看程式碼

q_table = build_q_table(N_STATES, ACTIONS)

這一段就是建立一個N_STATES行,ACTION列,初始值全為0的表格,如圖2所示。

for episode in range(MAX_EPISODES):
    step_counter = 0
    S = 0
    is_terminated = False

這一段代表要訓練多少輪次,並且每一輪次需要初始化的變數(注意,是每一輪次都要初始化,比如說輪次為13,那麼S=0就要初始化13次)

step_counter:每一輪次探索者到達終點需要的步數
S:探索者初始位置,S=0時探索者位於一維世界最左邊
is_terminated:用來判斷是否到達終點的布林值,為真時代表當前state位於終點

while not is_terminated:

            A = choose_action(S, q_table)
            S_, R = get_env_feedback(S, A)  # take action & get next state and reward

上述代表代表了每個輪次中,探索者是怎麼行動,程式又是怎樣更新q_table表格的。
第一行,第二行不用多說,主要就是獲取A,S_,R這三個值。

         q_predict = q_table.loc[S, A]

q_predict是q_table中的第S行,第A列的值

 if S_ != 'terminal':
    q_target = R + GAMMA * q_table.iloc[S_, :].max()   # next state is not terminal
else:
    q_target = R     # next state is terminal
    is_terminated = True    # terminate this episode

如果S_不是terminal,q_target則為q_table中第S_行的最大值乘上獎勵遞減值GAMMA再加上獎勵值R;否則,q_target賦值為R,並且將is_terminated設為真值。
根據q_target的更新規則,不難發現距離terminal越遠,q值便越小。
PS:q_target為實際q值,q_predict為預測q值。不管是q_predict還是q_target,兩者都會不斷更新。

下圖是訓練結束後的q_table

越是靠近terminal,q值就越大。