【強化學習】python 實現 q-learning 迷宮通用模板
阿新 • • 發佈:2018-12-19
本文作者:hhh5460
本文地址:https://www.cnblogs.com/hhh5460/p/10145797.html
0.說明
這裡提供了二維迷宮問題的一個比較通用的模板,拿到後需要修改的地方非常少。
對於任意的二維迷宮的 class Agent,只需修改三個地方:MAZE_R, MAZE_R, rewards,其他的不要動!如下所示:
class Agent(object): '''個體類''' MAZE_R = 6 # 迷宮行數 MAZE_C = 6 # 迷宮列數 def __init__(self, alpha=0.1, gamma=0.9):'''初始化''' # ... ... self.rewards = [0,-10,0, 0, 0, 0, 0,-10,0, 0,-10, 0, 0,-10,0,-10, 0, 0, 0,-10,0,-10, 0, 0, 0,-10,0,-10, 1, 0, 0, 0,0,-10, 0,10,] # 獎勵集。出口獎勵10,陷阱獎勵-10,元寶獎勵1 # ... ...
1.完整程式碼
import pandas as pd import random import time import pickle import pathlib import os import tkinter as tk ''' 6*6 的迷宮: ------------------------------------------- | 入口 | 陷阱 | | | | | ------------------------------------------- | | 陷阱 | | | 陷阱 | | ------------------------------------------- | | 陷阱 | | 陷阱 | | | ------------------------------------------- | | 陷阱 | | 陷阱 | | | ------------------------------------------- | | 陷阱 | | 陷阱 | 元寶 | | ------------------------------------------- | | | | 陷阱 | | 出口 | ------------------------------------------- 作者:hhh5460 時間:20181219 地點:Tai Zi Miao''' class Maze(tk.Tk): '''環境類(GUI)''' UNIT = 40 # pixels MAZE_R = 6 # grid row MAZE_C = 6 # grid column def __init__(self): '''初始化''' super().__init__() self.title('迷宮') h = self.MAZE_R * self.UNIT w = self.MAZE_C * self.UNIT self.geometry('{0}x{1}'.format(h, w)) #視窗大小 self.canvas = tk.Canvas(self, bg='white', height=h, width=w) # 畫網格 for c in range(1, self.MAZE_C): self.canvas.create_line(c * self.UNIT, 0, c * self.UNIT, h) for r in range(1, self.MAZE_R): self.canvas.create_line(0, r * self.UNIT, w, r * self.UNIT) # 畫陷阱 self._draw_rect(1, 0, 'black') # 在1列、0行處,下同 self._draw_rect(1, 1, 'black') self._draw_rect(1, 2, 'black') self._draw_rect(1, 3, 'black') self._draw_rect(1, 4, 'black') self._draw_rect(3, 2, 'black') self._draw_rect(3, 3, 'black') self._draw_rect(3, 4, 'black') self._draw_rect(3, 5, 'black') self._draw_rect(4, 1, 'black') # 畫獎勵 self._draw_rect(4, 4, 'yellow') # 畫玩家(儲存!!) self.rect = self._draw_rect(0, 0, 'red') self.canvas.pack() # 顯示畫作! def _draw_rect(self, x, y, color): '''畫矩形, x,y表示橫,豎第幾個格子''' padding = 5 # 內邊距5px,參見CSS coor = [self.UNIT * x + padding, self.UNIT * y + padding, self.UNIT * (x+1) - padding, self.UNIT * (y+1) - padding] return self.canvas.create_rectangle(*coor, fill = color) def move_agent_to(self, state, step_time=0.01): '''移動玩家到新位置,根據傳入的狀態''' coor_old = self.canvas.coords(self.rect) # 形如[5.0, 5.0, 35.0, 35.0](第一個格子左上、右下座標) x, y = state % 6, state // 6 #橫豎第幾個格子 padding = 5 # 內邊距5px,參見CSS coor_new = [self.UNIT * x + padding, self.UNIT * y + padding, self.UNIT * (x+1) - padding, self.UNIT * (y+1) - padding] dx_pixels, dy_pixels = coor_new[0] - coor_old[0], coor_new[1] - coor_old[1] # 左上角頂點座標之差 self.canvas.move(self.rect, dx_pixels, dy_pixels) self.update() # tkinter內建的update! time.sleep(step_time) class Agent(object): '''個體類''' MAZE_R = 6 # 迷宮行數 MAZE_C = 6 # 迷宮列數 def __init__(self, alpha=0.1, gamma=0.9): '''初始化''' self.states = range(self.MAZE_R * self.MAZE_C) # 狀態集。0~35 共36個狀態 self.actions = list('udlr') # 動作集。上下左右 4個動作 ↑↓←→ ←↑→↓↖↗↘↙ self.rewards = [0,-10,0, 0, 0, 0, 0,-10,0, 0,-10, 0, 0,-10,0,-10, 0, 0, 0,-10,0,-10, 0, 0, 0,-10,0,-10, 1, 0, 0, 0,0,-10, 0,10,] # 獎勵集。出口獎勵10,陷阱獎勵-10,元寶獎勵5 #self.hell_states = [1,7,13,19,25,15,31,37,43,10] # 陷阱位置 self.alpha = alpha self.gamma = gamma self.q_table = pd.DataFrame(data=[[0 for _ in self.actions] for _ in self.states], index=self.states, columns=self.actions) def save_policy(self): '''儲存Q table''' with open('q_table.pickle', 'wb') as f: pickle.dump(self.q_table, f, pickle.HIGHEST_PROTOCOL) def load_policy(self): '''匯入Q table''' with open('q_table.pickle', 'rb') as f: self.q_table = pickle.load(f) def choose_action(self, state, epsilon=0.8): '''選擇相應的動作。根據當前狀態,隨機或貪婪,按照引數epsilon''' #if (random.uniform(0,1) > epsilon) or ((self.q_table.ix[state] == 0).all()): # 探索 if random.uniform(0,1) > epsilon: # 探索 action = random.choice(self.get_valid_actions(state)) else: #action = self.q_table.ix[state].idxmax() # 利用 當有多個最大值時,會鎖死第一個! #action = self.q_table.ix[state].filter(items=self.get_valid_actions(state)).idxmax() # 重大改進!然鵝與上面一樣 s = self.q_table.ix[state].filter(items=self.get_valid_actions(state)) action = random.choice(s[s==s.max()].index) # 從可能有多個的最大值裡面隨機選擇一個! return action def get_q_values(self, state): '''取給定狀態state的所有Q value''' q_values = self.q_table.ix[state, self.get_valid_actions(state)] return q_values def update_q_value(self, state, action, next_state_reward, next_state_q_values): '''更新Q value,根據貝爾曼方程''' self.q_table.ix[state, action] += self.alpha * (next_state_reward + self.gamma * next_state_q_values.max() - self.q_table.ix[state, action]) def get_valid_actions(self, state): '''取當前狀態下所有的合法動作''' valid_actions = set(self.actions) if state // self.MAZE_C == 0: # 首行,則 不能向上 valid_actions -= set(['u']) elif state // self.MAZE_C == self.MAZE_R - 1: # 末行,則 不能向下 valid_actions -= set(['d']) if state % self.MAZE_C == 0: # 首列,則 不能向左 valid_actions -= set(['l']) elif state % self.MAZE_C == self.MAZE_C - 1: # 末列,則 不能向右 valid_actions -= set(['r']) return list(valid_actions) def get_next_state(self, state, action): '''對狀態執行動作後,得到下一狀態''' #u,d,l,r,n = -6,+6,-1,+1,0 if action == 'u' and state // self.MAZE_C != 0: # 除首行外,向上-MAZE_C next_state = state - self.MAZE_C elif action == 'd' and state // self.MAZE_C != self.MAZE_R - 1: # 除末行外,向下+MAZE_C next_state = state + self.MAZE_C elif action == 'l' and state % self.MAZE_C != 0: # 除首列外,向左-1 next_state = state - 1 elif action == 'r' and state % self.MAZE_C != self.MAZE_C - 1: # 除末列外,向右+1 next_state = state + 1 else: next_state = state return next_state def learn(self, env=None, episode=1000, epsilon=0.8): '''q-learning演算法''' print('Agent is learning...') for i in range(episode): current_state = self.states[0] if env is not None: # 若提供了環境,則重置之! env.move_agent_to(current_state) while current_state != self.states[-1]: current_action = self.choose_action(current_state, epsilon) # 按一定概率,隨機或貪婪地選擇 next_state = self.get_next_state(current_state, current_action) next_state_reward = self.rewards[next_state] next_state_q_values = self.get_q_values(next_state) self.update_q_value(current_state, current_action, next_state_reward, next_state_q_values) current_state = next_state #if next_state not in self.hell_states: # 非陷阱,則往前;否則待在原位 # current_state = next_state if env is not None: # 若提供了環境,則更新之! env.move_agent_to(current_state) print(i) print('\nok') def test(self): '''測試agent是否已具有智慧''' count = 0 current_state = self.states[0] while current_state != self.states[-1]: current_action = self.choose_action(current_state, 1.) # 1., 100%貪婪 next_state = self.get_next_state(current_state, current_action) current_state = next_state count += 1 if count > self.MAZE_R * self.MAZE_C: # 沒有在36步之內走出迷宮,則 return False # 無智慧 return True # 有智慧 def play(self, env=None, step_time=0.5): '''玩遊戲,使用策略''' assert env != None, 'Env must be not None!' if not self.test(): # 若尚無智慧,則 if pathlib.Path("q_table.pickle").exists(): self.load_policy() else: print("I need to learn before playing this game.") self.learn(env, episode=1000, epsilon=0.5) self.save_policy() print('Agent is playing...') current_state = self.states[0] env.move_agent_to(current_state, step_time) while current_state != self.states[-1]: current_action = self.choose_action(current_state, 1.) # 1., 100%貪婪 next_state = self.get_next_state(current_state, current_action) current_state = next_state env.move_agent_to(current_state, step_time) print('\nCongratulations, Agent got it!') if __name__ == '__main__': env = Maze() # 環境 agent = Agent() # 個體(智慧體) agent.learn(env, episode=1000, epsilon=0.6) # 先學習 #agent.save_policy() #agent.load_policy() agent.play(env) # 再玩耍 #env.after(0, agent.learn, env, 1000, 0.8) # 先學 #env.after(0, agent.save_policy) # 儲存所學 #env.after(0, agent.load_policy) # 匯入所學 #env.after(0, agent.play, env) # 再玩 env.mainloop()
Just enjoy it!