強化學習DDPG的tensorflow程式碼(二)
我們有了像Q-learning這麼好的演算法,為什麼還要再折騰出一個Actor-Critic演算法呢?原來 Actor-Critic 的 Actor 的前生是 Policy Gradients, 這能讓它毫不費力地在連續動作中選取合適的動作,而Q-learning 做這件事會癱瘓。那為什麼不直接用 Policy Gradients呢? 原來Actor Critic中的Critic的前生是 Q-learning 之類以值為基礎的學習演算法,能進行單步更新,而傳統的 Policy Gradients 則是回合更新,這降低了學習效率。
上面的一段話不僅解釋了為什麼會有Actor-Critic這麼一個演算法,同時也告訴了我們,這個演算法具體是怎麼做的。既然Actor是一個策略網路(Policy Network),那麼他就需要獎懲資訊來進行調節不同狀態下采取各種動作的概率,在傳統的Policy Gradient演算法中,這種獎懲資訊是通過走完一個完整的episode來計算得到的,這導致了學習速率很慢。而既然Critic是一個以值為基礎的學習法,那麼他可以進行單步更新,計算每一步的獎懲值。那麼二者相結合,Actor來選擇動作,Critic來告訴Actor它選擇的動作是否合適。在這一過程中,Actor不斷迭代,得到每一個狀態下選擇每一動作的合理概率,Critic也不斷迭代,不斷完善每個狀態下選擇每一個動作的獎懲值。
1、Actor
1.1 定義Actor的初始化引數
在這裡,a_dim是動作的維度,action_bound是動作的範圍值:
def __init__(self, sess, action_dim, action_bound, learning_rate, replacement): self.sess = sess self.a_dim = action_dim self.action_bound = action_bound self.lr = learning_rate self.replacement = replacement self.t_replace_counter = 0 with tf.variable_scope('Actor'): # input s, output a self.a = self._build_net(S, scope='eval_net', trainable=True) # input s_, output a, get a_ for critic self.a_ = self._build_net(S_, scope='target_net', trainable=False) self.e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Actor/eval_net') self.t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Actor/target_net') if self.replacement['name'] == 'hard': self.t_replace_counter = 0 self.hard_replace = [tf.assign(t, e) for t, e in zip(self.t_params, self.e_params)] else: self.soft_replace = [tf.assign(t, (1 - self.replacement['tau']) * t + self.replacement['tau'] * e) for t, e in zip(self.t_params, self.e_params)]
replacement表示是"hard“,“soft“哪種引數值替代方法,hard是全盤直接複製,soft是每次修改一點點。a和a_分別是Actor的eval網路和target網路。
1.2 Actor的網路定義
Actor的神經網路結構,輸出是一個a_dim的動作維度:
def _build_net(self, s, scope, trainable): with tf.variable_scope(scope): init_w = tf.random_normal_initializer(0., 0.3) init_b = tf.constant_initializer(0.1) net = tf.layers.dense(s, 30, activation=tf.nn.relu, kernel_initializer=init_w, bias_initializer=init_b, name='l1', trainable=trainable) with tf.variable_scope('a'): actions = tf.layers.dense(net, self.a_dim, activation=tf.nn.tanh, kernel_initializer=init_w, bias_initializer=init_b, name='a', trainable=trainable) scaled_a = tf.multiply(actions, self.action_bound, name='scaled_a') # Scale output to -action_bound to action_bound return scaled_a
1.3 損失函式
這裡不是直接用損失來訓練,而是將Critic目標網路中的Q值對動作a的梯度傳遞過來,來疊加到Actor的各個引數對本地動作a的梯度上。
def add_grad_to_graph(self, a_grads):
with tf.variable_scope('policy_grads'):
# ys = policy;
# xs = policy's parameters;
# a_grads = the gradients of the policy to get more Q
# tf.gradients will calculate dys/dxs with a initial gradients for ys, so this is dq/da * da/dparams
self.policy_grads = tf.gradients(ys=self.a, xs=self.e_params, grad_ys=a_grads)
with tf.variable_scope('A_train'):
opt = tf.train.AdamOptimizer(-self.lr) # (- learning rate) for ascent policy
self.train_op = opt.apply_gradients(zip(self.policy_grads, self.e_params))
1.4 Actor訓練
Actor的訓練只需要將狀態,動作以及時間差分值餵給網路就可以。
def learn(self, s): # batch update
self.sess.run(self.train_op, feed_dict={S: s})
if self.replacement['name'] == 'soft':
self.sess.run(self.soft_replace)
else:
if self.t_replace_counter % self.replacement['rep_iter_a'] == 0:
self.sess.run(self.hard_replace)
self.t_replace_counter += 1
1.5 選擇動作
選擇動作和Policy Gradient一樣,根據計算出的softmax值來選擇動作
def choose_action(self, s):
s = s[np.newaxis, :] # single state
return self.sess.run(self.a, feed_dict={S: s})[0] # single action
2、critic
2.1 定義Critic的初始化過程
Critic網路的初始化過程和Actor網路非常相似。
def __init__(self, sess, state_dim, action_dim, learning_rate, gamma, replacement, a, a_):
self.sess = sess
self.s_dim = state_dim
self.a_dim = action_dim
self.lr = learning_rate
self.gamma = gamma
self.replacement = replacement
with tf.variable_scope('Critic'):
# Input (s, a), output q
self.a = tf.stop_gradient(a) # stop critic update flows to actor
self.q = self._build_net(S, self.a, 'eval_net', trainable=True)
# Input (s_, a_), output q_ for q_target
self.q_ = self._build_net(S_, a_, 'target_net', trainable=False) # target_q is based on a_ from Actor's target_net
self.e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Critic/eval_net')
self.t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Critic/target_net')
with tf.variable_scope('target_q'):
self.target_q = R + self.gamma * self.q_
with tf.variable_scope('TD_error'):
self.loss = tf.reduce_mean(tf.squared_difference(self.target_q, self.q))
with tf.variable_scope('C_train'):
self.train_op = tf.train.AdamOptimizer(self.lr).minimize(self.loss)
with tf.variable_scope('a_grad'):
self.a_grads = tf.gradients(self.q, a)[0] # tensor of gradients of each sample (None, a_dim)
if self.replacement['name'] == 'hard':
self.t_replace_counter = 0
self.hard_replacement = [tf.assign(t, e) for t, e in zip(self.t_params, self.e_params)]
else:
self.soft_replacement = [tf.assign(t, (1 - self.replacement['tau']) * t + self.replacement['tau'] * e)
for t, e in zip(self.t_params, self.e_params)]
Critic網路對a不求導,因為對於Critic來說動作a是確定的。損失函式是真實Q值和預測Q值之間的歐氏距離。
2.2 定義網路結構
我們的Critic網路輸出一個Q值。
def _build_net(self, s, a, scope, trainable):
with tf.variable_scope(scope):
init_w = tf.random_normal_initializer(0., 0.1)
init_b = tf.constant_initializer(0.1)
with tf.variable_scope('l1'):
n_l1 = 30
w1_s = tf.get_variable('w1_s', [self.s_dim, n_l1], initializer=init_w, trainable=trainable)
w1_a = tf.get_variable('w1_a', [self.a_dim, n_l1], initializer=init_w, trainable=trainable)
b1 = tf.get_variable('b1', [1, n_l1], initializer=init_b, trainable=trainable)
net = tf.nn.relu(tf.matmul(s, w1_s) + tf.matmul(a, w1_a) + b1)
with tf.variable_scope('q'):
q = tf.layers.dense(net, 1, kernel_initializer=init_w, bias_initializer=init_b, trainable=trainable) # Q(s,a)
return q
2.3 定義損失
Critic的損失定義為Q值的歐式距離:
self.loss = tf.reduce_mean(tf.squared_difference(self.target_q, self.q))
2.4 訓練Critic
Critic的任務就是使得預測的Q值儘可能準確:
def learn(self, s, a, r, s_):
self.sess.run(self.train_op, feed_dict={S: s, self.a: a, R: r, S_: s_})
if self.replacement['name'] == 'soft':
self.sess.run(self.soft_replacement)
else:
if self.t_replace_counter % self.replacement['rep_iter_c'] == 0:
self.sess.run(self.hard_replacement)
self.t_replace_counter += 1
3、經驗回放
經驗回放的程式碼都是類似的,通用的。
class Memory(object):
def __init__(self, capacity, dims):
self.capacity = capacity
self.data = np.zeros((capacity, dims))
self.pointer = 0
def store_transition(self, s, a, r, s_):
transition = np.hstack((s, a, [r], s_))
index = self.pointer % self.capacity # replace the old memory with new memory
self.data[index, :] = transition
self.pointer += 1
def sample(self, n):
assert self.pointer >= self.capacity, 'Memory has not been fulfilled'
indices = np.random.choice(self.capacity, size=n)
return self.data[indices, :]
4、整體模型訓練
有了Critic之後,Actor就可以進行單步訓練和更新了,完整的訓練程式碼如下:
env = gym.make(ENV_NAME)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
action_bound = env.action_space.high
# all placeholder for tf
with tf.name_scope('S'):
S = tf.placeholder(tf.float32, shape=[None, state_dim], name='s')
with tf.name_scope('R'):
R = tf.placeholder(tf.float32, [None, 1], name='r')
with tf.name_scope('S_'):
S_ = tf.placeholder(tf.float32, shape=[None, state_dim], name='s_')
sess = tf.Session()
actor = Actor(sess, action_dim, action_bound, LR_A, REPLACEMENT)
critic = Critic(sess, state_dim, action_dim, LR_C, GAMMA, REPLACEMENT, actor.a, actor.a_)
actor.add_grad_to_graph(critic.a_grads)
sess.run(tf.global_variables_initializer())
M = Memory(MEMORY_CAPACITY, dims=2 * state_dim + action_dim + 1)
for i in range(2000):
s = env.reset()
ep_reward = 0
for j in range(500):
a = actor.choose_action(s)
s_, r, done, info = env.step(a)
M.store_transition(s, a, r , s_)
if M.pointer > MEMORY_CAPACITY:
var *= .9995 # decay the action randomness
b_M = M.sample(BATCH_SIZE)
b_s = b_M[:, :state_dim]
b_a = b_M[:, state_dim: state_dim + action_dim]
b_r = b_M[:, -state_dim - 1: -state_dim]
b_s_ = b_M[:, -state_dim:]
critic.learn(b_s, b_a, b_r, b_s_)
actor.learn(b_s)
s = s_
ep_reward += r
if j == MAX_EP_STEPS-1:
print('Episode:', i, ' Reward: %i' % int(ep_reward), 'Explore: %.2f' % var, )
break