【強化學習】python 實現 q-learning 例二
阿新 • • 發佈:2018-12-18
問題情境
一個2*2的迷宮,一個入口,一個出口,還有一個陷阱。如圖
這是一個二維的問題,不過我們可以把這個降維,變為一維的問題。
0.相關引數
epsilon = 0.9 # 貪婪度 greedy alpha = 0.1 # 學習率 gamma = 0.8 # 獎勵遞減值
1.狀態集
探索者的狀態,即其可到達的位置,有4個。所以定義
states = range(4) # 狀態集,從0到3
那麼,在某個狀態下執行某個動作之後,到達的下一個狀態如何確定呢?
def get_next_state(state, action):'''對狀態執行動作後,得到下一狀態''' #u,d,l,r,n = -2,+2,-1,+1,0 if state % 2 != 1 and action == 'r': # 除最後一列,皆可向右(+1) next_state = state + 1 elif state % 2 != 0 and action == 'l': # 除最前一列,皆可向左(-1) next_state = state -1 elif state // 2 != 1 and action == 'd': # 除最後一行,皆可向下(+2) next_state = state + 2 elifstate // 2 != 0 and action == 'u': # 除最前一行,皆可向上(-2) next_state = state - 2 else: next_state = state return next_state
2.動作集
探索者處於每個狀態時,可行的動作,只有上下左右4個。所以定義
actions = ['u', 'd', 'l', 'r'] # 動作集。上下左右,也可新增動作'n',表示停留
那麼,在某個給定的狀態(位置),其所有的合法動作如何確定呢?
def get_valid_actions(state):'''取當前狀態下的合法動作集合,與reward無關!''' global actions # ['u','d','l','r','n'] valid_actions = set(actions) if state % 2 == 1: # 最後一列,則 valid_actions = valid_actions - set(['r']) # 去掉向右的動作 if state % 2 == 0: # 最前一列,則 valid_actions = valid_actions - set(['l']) # 去掉向左 if state // 2 == 1: # 最後一行,則 valid_actions = valid_actions - set(['d']) # 去掉向下 if state // 2 == 0: # 最前一行,則 valid_actions = valid_actions - set(['u']) # 去掉向上 return list(valid_actions)
3.獎勵集
探索者到達每個狀態(位置)時,要有獎勵。所以定義
rewards = [0,0,-10,10] # 獎勵集。到達位置3(出口)獎勵10,位置2(陷阱)獎勵-10,其他皆為0
顯然,取得某狀態state下的獎勵就很簡單了:rewards[state] 。根據state,按圖索驥即可,無需額外定義一個函式。
4.Q table
最重要。Q table是一種記錄狀態-行為值 (Q value) 的表。常見的q-table都是二維的,基本長下面這樣:
(注意,也有3維的Q table)
所以定義
q_table = pd.DataFrame(data=[[0 for _ in actions] for _ in states], index=states, columns=actions)
5.Q-learning演算法
Q-learning演算法的虛擬碼
好吧,是時候實現它了:
# 總共探索300次 for i in range(300): # 0.從最左邊的位置開始(不是必要的) current_state = 0 #current_state = random.choice(states) while current_state != states[-1]: # 1.取當前狀態下的合法動作中,隨機(或貪婪)地選一個作為 當前動作 if (random.uniform(0,1) > epsilon) or ((q_table.ix[current_state] == 0).all()): # 探索 current_action = random.choice(get_valid_actions(current_state)) else: current_action = q_table.ix[current_state].idxmax() # 利用(貪婪) # 2.執行當前動作,得到下一個狀態(位置) next_state = get_next_state(current_state, current_action) # 3.取下一個狀態所有的Q value,待取其最大值 next_state_q_values = q_table.ix[next_state, get_valid_actions(next_state)] # 4.根據貝爾曼方程,更新 Q table 中當前狀態-動作對應的 Q value q_table.ix[current_state, current_action] += alpha * (rewards[next_state] + gamma * next_state_q_values.max() - q_table.ix[current_state, current_action]) # 5.進入下一個狀態(位置) current_state = next_state print('\nq_table:') print(q_table)
可以看到,與例一的程式碼一模一樣,不差一字!
6.環境及其更新
這裡的環境貌似必須用到GUI,有點麻煩;而在命令列下,我又不知如何實現。所以暫時算了,不搞了。
7.完整程式碼
''' 最簡單的四個格子的迷宮 --------------- | start | | --------------- | die | end | --------------- 每個格子是一個狀態,此時都有上下左右4個動作
作者:hhh5460
時間:20181217 ''' import pandas as pd import random epsilon = 0.9 # 貪婪度 greedy alpha = 0.1 # 學習率 gamma = 0.8 # 獎勵遞減值 states = range(4) # 0, 1, 2, 3 四個狀態 actions = list('udlr') # 上下左右 4個動作。還可新增動作'n',表示停留 rewards = [0,0,-10,10] # 獎勵集。到達位置3(出口)獎勵10,位置2(陷阱)獎勵-10,其他皆為0 q_table = pd.DataFrame(data=[[0 for _ in actions] for _ in states], index=states, columns=actions) def get_next_state(state, action): '''對狀態執行動作後,得到下一狀態''' #u,d,l,r,n = -2,+2,-1,+1,0 if state % 2 != 1 and action == 'r': # 除最後一列,皆可向右(+1) next_state = state + 1 elif state % 2 != 0 and action == 'l': # 除最前一列,皆可向左(-1) next_state = state -1 elif state // 2 != 1 and action == 'd': # 除最後一行,皆可向下(+2) next_state = state + 2 elif state // 2 != 0 and action == 'u': # 除最前一行,皆可向上(-2) next_state = state - 2 else: next_state = state return next_state def get_valid_actions(state): '''取當前狀態下的合法動作集合 global reward valid_actions = reward.ix[state, reward.ix[state]!=0].index return valid_actions ''' # 與reward無關! global actions valid_actions = set(actions) if state % 2 == 1: # 最後一列,則 valid_actions = valid_actions - set(['r']) # 無向右的動作 if state % 2 == 0: # 最前一列,則 valid_actions = valid_actions - set(['l']) # 無向左 if state // 2 == 1: # 最後一行,則 valid_actions = valid_actions - set(['d']) # 無向下 if state // 2 == 0: # 最前一行,則 valid_actions = valid_actions - set(['u']) # 無向上 return list(valid_actions) # 總共探索300次 for i in range(300): # 0.從最左邊的位置開始(不是必要的) current_state = 0 #current_state = random.choice(states) while current_state != states[-1]: # 1.取當前狀態下的合法動作中,隨機(或貪婪)地選一個作為 當前動作 if (random.uniform(0,1) > epsilon) or ((q_table.ix[current_state] == 0).all()): # 探索 current_action = random.choice(get_valid_actions(current_state)) else: current_action = q_table.ix[current_state].idxmax() # 利用(貪婪) # 2.執行當前動作,得到下一個狀態(位置) next_state = get_next_state(current_state, current_action) # 3.取下一個狀態所有的Q value,待取其最大值 next_state_q_values = q_table.ix[next_state, get_valid_actions(next_state)] # 4.根據貝爾曼方程,更新 Q table 中當前狀態-動作對應的 Q value q_table.ix[current_state, current_action] += alpha * (rewards[next_state] + gamma * next_state_q_values.max() - q_table.ix[current_state, current_action]) # 5.進入下一個狀態(位置) current_state = next_state print('\nq_table:') print(q_table)
8.效果圖
9.補充
又搞了一個numpy版本,比pandas版本的快了一個數量級!!程式碼如下
''' 最簡單的四個格子的迷宮 --------------- | start | | --------------- | die | end | --------------- 每個格子是一個狀態,此時都有上下左右停5個動作 ''' # 作者:hhh5460 # 時間:20181218 import numpy as np epsilon = 0.9 # 貪婪度 greedy alpha = 0.1 # 學習率 gamma = 0.8 # 獎勵遞減值 states = range(4) # 0, 1, 2, 3 四個狀態 actions = list('udlrn') # 上下左右停 五個動作 rewards = [0,0,-10,10] # 獎勵集。到達位置3(出口)獎勵10,位置2(陷阱)獎勵-10,其他皆為0 # 給numpy陣列的列加標籤,參考https://cloud.tencent.com/developer/ask/72790 q_table = np.zeros(shape=(4, ), # 坑二:這裡不能是(4,5)!! dtype=list(zip(actions, ['float']*5))) #dtype=[('u',float),('d',float),('l',float),('r',float),('n',float)]) #dtype={'names':actions, 'formats':[float]*5}) def get_next_state(state, action): '''對狀態執行動作後,得到下一狀態''' #u,d,l,r,n = -2,+2,-1,+1,0 if state % 2 != 1 and action == 'r': # 除最後一列,皆可向右(+1) next_state = state + 1 elif state % 2 != 0 and action == 'l': # 除最前一列,皆可向左(-1) next_state = state -1 elif state // 2 != 1 and action == 'd': # 除最後一行,皆可向下(+2) next_state = state + 2 elif state // 2 != 0 and action == 'u': # 除最前一行,皆可向上(-2) next_state = state - 2 else: next_state = state return next_state def get_valid_actions(state): '''取當前狀態下的合法動作集合,與reward無關!''' global actions # ['u','d','l','r','n'] valid_actions = set(actions) if state % 2 == 1: # 最後一列,則 valid_actions = valid_actions - set(['r']) # 去掉向右的動作 if state % 2 == 0: # 最前一列,則 valid_actions = valid_actions - set(['l']) # 去掉向左 if state // 2 == 1: # 最後一行,則 valid_actions = valid_actions - set(['d']) # 去掉向下 if state // 2 == 0: # 最前一行,則 valid_actions = valid_actions - set(['u']) # 去掉向上 return list(valid_actions) for i in range(1000): #current_state = states[0] # 固定 current_state = np.random.choice(states,1)[0] while current_state != 3: if (np.random.uniform() > epsilon) or ((np.array(list(q_table[current_state])) == 0).all()): # q_table[current_state]是numpy.void型別,只能這麼操作!! current_action = np.random.choice(get_valid_actions(current_state), 1)[0] else: current_action = actions[np.array(list(q_table[current_state])).argmax()] # q_table[current_state]是numpy.void型別 next_state = get_next_state(current_state, current_action) next_state_q_values = [q_table[next_state][action] for action in get_valid_actions(next_state)] q_table[current_state][current_action] = rewards[next_state] + gamma * max(next_state_q_values) current_state = next_state print('Final Q-table:') print(q_table)