1. 程式人生 > 其它 >baselines演算法庫common/vec_env/util.py模組分析

baselines演算法庫common/vec_env/util.py模組分析

util.py模組程式碼:

"""
Helpers for dealing with vectorized environments.
"""

from collections import OrderedDict

import gym
import numpy as np


def copy_obs_dict(obs):
    """
    Deep-copy an observation dict.
    """
    return {k: np.copy(v) for k, v in obs.items()}


def dict_to_obs(obs_dict):
    
""" Convert an observation dict into a raw array if the original observation space was not a Dict space. """ if set(obs_dict.keys()) == {None}: return obs_dict[None] return obs_dict def obs_space_info(obs_space): """ Get dict-structured information about a gym.Space. Returns: A tuple (keys, shapes, dtypes): keys: a list of dict keys. shapes: a dict mapping keys to shapes. dtypes: a dict mapping keys to dtypes.
""" if isinstance(obs_space, gym.spaces.Dict): assert isinstance(obs_space.spaces, OrderedDict) subspaces = obs_space.spaces elif isinstance(obs_space, gym.spaces.Tuple): assert isinstance(obs_space.spaces, tuple) subspaces = {i: obs_space.spaces[i] for i in range(len(obs_space.spaces))}
else: subspaces = {None: obs_space} keys = [] shapes = {} dtypes = {} for key, box in subspaces.items(): keys.append(key) shapes[key] = box.shape dtypes[key] = box.dtype return keys, shapes, dtypes def obs_to_dict(obs): """ Convert an observation into a dict. """ if isinstance(obs, dict): return obs return {None: obs}

 

 

 

函式:

def copy_obs_dict(obs):
def obs_to_dict(obs_dict):

假設傳入的observation都是dict型別的。

在函式

obs_to_dict

中,如果傳入的observation不是dict型別的則將其轉為dict型別,此時的key值設定為None。

 

 

 

函式

def dict_to_obs(obs_dict)

假設輸入的是key為None的dict型別的observation,將其dict型別轉為np.array型別的observation。

如果輸入的不是key為None的dict型別的observation則直接將其返回。

 

 

 

 

函式

def obs_space_info(obs_space):

輸入引數為observation的spaces變數。

    if isinstance(obs_space, gym.spaces.Dict):
        assert isinstance(obs_space.spaces, OrderedDict)
        subspaces = obs_space.spaces
    elif isinstance(obs_space, gym.spaces.Tuple):
        assert isinstance(obs_space.spaces, tuple)
        subspaces = {i: obs_space.spaces[i] for i in range(len(obs_space.spaces))}
    else:
        subspaces = {None: obs_space}

首先將env.observation_sapce.spaces變數進行判斷,將其轉為dict型別。

 

 

對env.observation_space.spaces進行資訊提取,得到:

    Returns:
      A tuple (keys, shapes, dtypes):
        keys: a list of dict keys.
        shapes: a dict mapping keys to shapes.
        dtypes: a dict mapping keys to dtypes.

最後返回資訊的形式為tuple型別。

 

 

 

 

====================================================