1. 程式人生 > 其它 >baselines演算法庫baselines/common/input.py模組分析

baselines演算法庫baselines/common/input.py模組分析

 

baselines演算法庫baselines/common/input.py模組程式碼:

import numpy as np
import tensorflow as tf
from gym.spaces import Discrete, Box, MultiDiscrete

def observation_placeholder(ob_space, batch_size=None, name='Ob'):
    '''
    Create placeholder to feed observations into of the size appropriate to the observation space

    Parameters:
    ----------

    ob_space: gym.Space     observation space

    batch_size: int         size of the batch to be fed into input. Can be left None in most cases.

    name: str               name of the placeholder

    Returns:
    -------

    tensorflow placeholder tensor
    
''' assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box) or isinstance(ob_space, MultiDiscrete), \ 'Can only deal with Discrete and Box observation spaces for now' dtype = ob_space.dtype if dtype == np.int8: dtype = np.uint8 return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name)
def observation_input(ob_space, batch_size=None, name='Ob'): ''' Create placeholder to feed observations into of the size appropriate to the observation space, and add input encoder of the appropriate type. ''' placeholder = observation_placeholder(ob_space, batch_size, name)
return placeholder, encode_observation(ob_space, placeholder) def encode_observation(ob_space, placeholder): ''' Encode input in the way that is appropriate to the observation space Parameters: ---------- ob_space: gym.Space observation space placeholder: tf.placeholder observation input placeholder ''' if isinstance(ob_space, Discrete): return tf.to_float(tf.one_hot(placeholder, ob_space.n)) elif isinstance(ob_space, Box): return tf.to_float(placeholder) elif isinstance(ob_space, MultiDiscrete): placeholder = tf.cast(placeholder, tf.int32) one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])] return tf.concat(one_hots, axis=-1) else: raise NotImplementedError

 

 

 

可以看到input.py模組中一共有三個函式,其中只有一個函式對外提供服務,也就是 observation_input

 

 

可以看到observation_placeholder函式和encode_observation函式都已經被observation_input函式包裝到了一起。

 

 

在observation_placeholder函式中