1. 程式人生 > >OpenAI Gym構建自定義強化學習環境

OpenAI Gym構建自定義強化學習環境

OpenAI Gym是開發和比較強化學習演算法的工具包。

OpenAI Gym由兩部分組成:

  • gym開源庫:測試問題的集合。當你測試強化學習的時候,測試問題就是環境,比如機器人玩遊戲,環境的集合就是遊戲的畫面。這些環境有一個公共的介面,允許使用者設計通用的演算法。
  • OpenAI Gym服務。提供一個站點(比如對於遊戲cartpole-v0:https://gym.openai.com/envs/CartPole-v0)和api,允許使用者對他們的測試結果進行比較。

gym的核心介面是Env,作為統一的環境介面。Env包含下面幾個核心方法:

  • reset(self):重置環境的狀態,返回觀察。
  • step(self, action):推進一個時間步長,返回observation,reward,done,info
  • render(self, mode=’human’, close=False):重繪環境的一幀。預設模式一般比較友好,如彈出一個視窗。

自定義環境

背景

這裡寫圖片描述

機器人在一個二維迷宮中走動,迷宮中有火坑、石柱、鑽石。如果機器人掉到火坑中,遊戲結束,如果找到鑽石,可以得到獎勵,遊戲也結束!設計最佳的策略,讓機器人儘快地找到鑽石,獲得獎勵。

操作環境

Python環境:anaconda5.2

pip安裝gym

步驟

anaconda3/lib/python3.6/site-packages/gym/envs 下新建目錄 user ,用於存放自定義的強化學習環境。

user

目錄下新建環境 grid_mdp_v1.py

import logging
import random
import gym

logger = logging.getLogger(__name__)

class GridEnv1(gym.Env):
    metadata = {
        'render.modes': ['human', 'rgb_array'],
        'video.frames_per_second': 2
    }

    def __init__(self):

        self.states = range(1,17) #狀態空間
self.x=[150,250,350,450] * 4 self.y=[450] * 4 + [350] * 4 + [250] * 40 + [150] * 4 self.terminate_states = dict() #終止狀態為字典格式 self.terminate_states[11] = 1 self.terminate_states[12] = 1 self.terminate_states[15] = 1 self.actions = ['n','e','s','w'] self.rewards = dict(); #回報的資料結構為字典 self.rewards['8_s'] = -1.0 self.rewards['13_w'] = -1.0 self.rewards['7_s'] = -1.0 self.rewards['10_e'] = -1.0 self.rewards['14_4'] = 1.0 self.t = dict(); #狀態轉移的資料格式為字典 self.t['1_s'] = 5 self.t['1_e'] = 2 self.t['2_w'] = 1 self.t['2_e'] = 3 self.t['3_s'] = 6 self.t['3_w'] = 2 self.t['3_e'] = 4 self.t['4_w'] = 3 self.t['4_s'] = 7 self.t['5_s'] = 8 self.t['6_n'] = 3 self.t['6_s'] = 10 self.t['6_e'] = 7 self.t['7_w'] = 6 self.t['7_n'] = 4 self.t['7_s'] = 11 self.t['8_n'] = 5 self.t['8_e'] = 9 self.t['8_s'] = 12 self.t['9_w'] = 8 self.t['9_e'] = 10 self.t['9_s'] = 13 self.t['10_w'] = 9 self.t['10_n'] = 6 self.t['10_e'] = 11 self.t['10_s'] = 14 self.t['10_w'] = 9 self.t['13_n'] = 9 self.t['13_e'] = 14 self.t['13_w'] = 12 self.t['14_n'] = 10 self.t['14_e'] = 15 self.t['14_w'] = 13 self.gamma = 0.8 #折扣因子 self.viewer = None self.state = None def _seed(self, seed=None): self.np_random, seed = random.seeding.np_random(seed) return [seed] def getTerminal(self): return self.terminate_states def getGamma(self): return self.gamma def getStates(self): return self.states def getAction(self): return self.actions def getTerminate_states(self): return self.terminate_states def setAction(self,s): self.state=s def step(self, action): #系統當前狀態 state = self.state if state in self.terminate_states: return state, 0, True, {} key = "%d_%s"%(state, action) #將狀態和動作組成字典的鍵值 #狀態轉移 if key in self.t: next_state = self.t[key] else: next_state = state self.state = next_state is_terminal = False if next_state in self.terminate_states: is_terminal = True if key not in self.rewards: r = 0.0 else: r = self.rewards[key] return next_state, r, is_terminal,{} def reset(self): self.state = self.states[int(random.random() * len(self.states))] return self.state def render(self, mode='human'): from gym.envs.classic_control import rendering screen_width = 600 screen_height = 600 if self.viewer is None: self.viewer = rendering.Viewer(screen_width, screen_height) #建立網格世界 self.line1 = rendering.Line((100,100),(500,100)) self.line2 = rendering.Line((100, 200), (500, 200)) self.line3 = rendering.Line((100, 300), (500, 300)) self.line4 = rendering.Line((100, 400), (500, 400)) self.line5 = rendering.Line((100, 500), (500, 500)) self.line6 = rendering.Line((100, 100), (100, 500)) self.line7 = rendering.Line((200, 100), (200, 500)) self.line8 = rendering.Line((300, 100), (300, 500)) self.line9 = rendering.Line((400, 100), (400, 500)) self.line10 = rendering.Line((500, 100), (500, 500)) #建立石柱 self.shizhu = rendering.make_circle(40) self.circletrans = rendering.Transform(translation=(250,350)) self.shizhu.add_attr(self.circletrans) self.shizhu.set_color(0.8,0.6,0.4) #建立第一個火坑 self.fire1 = rendering.make_circle(40) self.circletrans = rendering.Transform(translation=(450, 250)) self.fire1.add_attr(self.circletrans) self.fire1.set_color(1, 0, 0) #建立第二個火坑 self.fire2 = rendering.make_circle(40) self.circletrans = rendering.Transform(translation=(150, 150)) self.fire2.add_attr(self.circletrans) self.fire2.set_color(1, 0, 0) #建立寶石 self.diamond = rendering.make_circle(40) self.circletrans = rendering.Transform(translation=(450, 150)) self.diamond.add_attr(self.circletrans) self.diamond.set_color(0, 0, 1) #建立機器人 self.robot= rendering.make_circle(30) self.robotrans = rendering.Transform() self.robot.add_attr(self.robotrans) self.robot.set_color(0, 1, 0) self.line1.set_color(0, 0, 0) self.line2.set_color(0, 0, 0) self.line3.set_color(0, 0, 0) self.line4.set_color(0, 0, 0) self.line5.set_color(0, 0, 0) self.line6.set_color(0, 0, 0) self.line7.set_color(0, 0, 0) self.line8.set_color(0, 0, 0) self.line9.set_color(0, 0, 0) self.line10.set_color(0, 0, 0) self.viewer.add_geom(self.line1) self.viewer.add_geom(self.line2) self.viewer.add_geom(self.line3) self.viewer.add_geom(self.line4) self.viewer.add_geom(self.line5) self.viewer.add_geom(self.line6) self.viewer.add_geom(self.line7) self.viewer.add_geom(self.line8) self.viewer.add_geom(self.line9) self.viewer.add_geom(self.line10) self.viewer.add_geom(self.shizhu) self.viewer.add_geom(self.fire1) self.viewer.add_geom(self.fire2) self.viewer.add_geom(self.diamond) self.viewer.add_geom(self.robot) if self.state is None: return None self.robotrans.set_translation(self.x[self.state-1], self.y[self.state- 1]) return self.viewer.render(return_rgb_array=mode == 'rgb_array') def close(self): if self.viewer: self.viewer.close()

user 目錄下新建 __init__.py

from gym.envs.user.grid_mdp_v1 import GridEnv1

anaconda3/lib/python3.6/site-packages/gym/envs/__init__.py 中進行註冊,在最後加入

register(
    id='GridWorld-v1',
    entry_point='gym.envs.user:GridEnv1',
    max_episode_steps=200,
    reward_threshold=100.0,
    )

測試

import gym

env = gym.make('GridWorld-v1')
env.reset()
env.render()
env.close()

這裡寫圖片描述