1. 程式人生 > 其它 >baselines演算法庫common/wrapper.py模組分析

baselines演算法庫common/wrapper.py模組分析

 

common/wrapper.py模組:

 

import gym

class TimeLimit(gym.Wrapper):
    def __init__(self, env, max_episode_steps=None):
        super(TimeLimit, self).__init__(env)
        self._max_episode_steps = max_episode_steps
        self._elapsed_steps = 0

    def step(self, ac):
        observation, reward, done, info 
= self.env.step(ac) self._elapsed_steps += 1 if self._elapsed_steps >= self._max_episode_steps: done = True info['TimeLimit.truncated'] = True return observation, reward, done, info def reset(self, **kwargs): self._elapsed_steps = 0
return self.env.reset(**kwargs) class ClipActionsWrapper(gym.Wrapper): def step(self, action): import numpy as np action = np.nan_to_num(action) action = np.clip(action, self.action_space.low, self.action_space.high) return self.env.step(action) def reset(self, **kwargs):
return self.env.reset(**kwargs)

 

兩個對gym環境類的包裝類,TimeLimit限制環境類env的最大step數,如果到達最大step數後遊戲還沒有終止則強制返回終止狀態done=True,並設定返回資訊:info['TimeLimit.truncated'] = True

 

類ClipActionsWrapper對輸入給gym環境的動作進行包裝,如果輸入的action(action為numpy向量)中含有np.nan則置為0,

如果action中的數值大小超過action_space.low和action_space.high則進行clip操作。

 

 

 

 

 

============================================