強化學習Sarsa演算法走迷宮小例子
阿新 • • 發佈:2018-11-26
Sarsa演算法:
Sarsa演算法與Q-learing演算法的不同之處是什麼?
一個簡單的解釋,引用莫凡大神的話:
- 他在當前
state
已經想好了state
對應的action
, 而且想好了 下一個state_
和下一個action_
(Qlearning 還沒有想好下一個action_
) - 更新
Q(s,a)
的時候基於的是下一個Q(s_, a_)
(Qlearning 是基於maxQ(s_)
)
對於第二句話,可以從走迷宮的程式碼中只管體現出來:(程式碼來自於莫凡大神編寫地址:https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/blob/master/contents/3_Sarsa_maze/RL_brain.py)
# off-policy class QLearningTable(RL): def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9): super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy) def learn(self, s, a, r, s_): self.check_state_exist(s_) q_predict = self.q_table.loc[s, a] if s_ != 'terminal': q_target = r + self.gamma * self.q_table.loc[s_, :].max() # next state is not terminal else: q_target = r # next state is terminal self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # update # on-policy class SarsaTable(RL): def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9): super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy) def learn(self, s, a, r, s_, a_): self.check_state_exist(s_) q_predict = self.q_table.loc[s, a] if s_ != 'terminal': q_target = r + self.gamma * self.q_table.loc[s_, a_] # next state is not terminal else: q_target = r # next state is terminal self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # update
可以看出二者的q_target不同,Q-learing取得是最大值,但是實際不一定會選,而Sarsa則是直接取到下一個a_,也就是下一個狀態的動作,這個動作是下一次一定要做的。