1. 程式人生 > >[RL學習篇][#3] 自動學習grid_mdp最佳的策略

[RL學習篇][#3] 自動學習grid_mdp最佳的策略

roo dep mina __init__ self targe upa num dom

本文修改 policy_iteration.py程式,讓他可以執行[#1]的程式,並找出最佳動作。

 1 # /bin/python
 2 import numpy;
 3 import random;
 4 import gym;
 5 #from grid_mdp import Grid_Mdp
 6 
 7 
 8 class Policy_Value:
 9     def __init__(self, grid_mdp):
10         self.v = [0.0 for i in range(len(grid_mdp.env.states) + 1)] # 初始變數v <-- 值函數
11 12 self.pi = dict() 13 for state in grid_mdp.env.states: 14 if state in grid_mdp.env.terminate_states: continue 15 self.pi[state] = grid_mdp.env.action_s[0] #初始pi <-- 策略pi 16 17 def policy_improve(self, grid_mdp): 18 19 for state in
grid_mdp.env.states: 20 grid_mdp.env.setAction(state) # upate state 21 if state in grid_mdp.env.terminate_states: continue 22 23 a1 = grid_mdp.env.action_s[0] 24 s, r, t, z = grid_mdp.env._step(a1) 25 v1 = r + grid_mdp.env.gamma * self.v[s]
26 27 for action in grid_mdp.env.action_s: 28 s, r, t, z = grid_mdp.env._step(action) 29 if v1 < r + grid_mdp.env.gamma * self.v[s]: # 當action有更好的值,則更新動作 30 a1 = action 31 v1 = r + grid_mdp.env.gamma * self.v[s] 32 33 self.pi[state] = a1 # 紀錄最佳動作 34 35 def policy_evaluate(self, grid_mdp): 36 for i in range(1000): 37 delta = 0.0 38 for state in grid_mdp.env.states: 39 grid_mdp.env.setAction(state) # upate state 40 if state in grid_mdp.env.terminate_states: continue 41 action = self.pi[state] 42 43 s, r, t, z = grid_mdp.env.step(action) 44 new_v = r + grid_mdp.env.gamma * self.v[s] 45 delta += abs(self.v[state] - new_v) 46 self.v[state] = new_v 47 48 if delta < 1e-6: 49 break; 50 51 def policy_iterate(self, grid_mdp): 52 for i in range(100): 53 self.policy_evaluate(grid_mdp); 54 self.policy_improve(grid_mdp); 55 56 57 if __name__ == "__main__": 58 #grid_mdp = Grid_Mdp() 59 env = gym.make(GridWorld-v0) 60 61 policy_value = Policy_Value(env) 62 policy_value.policy_iterate(env) 63 print("value:") 64 for i in range(1, 6): 65 print("%d:%f\t" % (i, policy_value.v[i]), ) 66 print("") 67 68 print("policy:") 69 for i in range(1, 6): 70 print("%d->%s\t" % (i, policy_value.pi[i]), ) 71 print("")

執行結果如下:

/home/lsa-dla/anaconda3/envs/tensorflow/bin/python /home/lsa-dla/PycharmProjects/grid_mdp/lsa_test2.py
WARN: Environment ‘<class ‘gym.envs.classic_control.grid_mdp.GridEnv‘>‘ has deprecated methods. Compatibility code invoked.
value:
1:0.640000
2:0.800000
3:1.000000
4:0.800000
5:0.640000

policy:
1->e
2->e
3->s
4->w
5->w


Process finished with exit code 0

reference:

[1] Reinforcement_Learning_Blog/2.強化學習系列之二:模型相關的強化學習/

[RL學習篇][#3] 自動學習grid_mdp最佳的策略