1. 程式人生 > >LeetCode 847. Shortest Path Visiting All Nodes的強化學習解法

LeetCode 847. Shortest Path Visiting All Nodes的強化學習解法

elf delta port true lse action amp ota before

這題的本意不是要考機器學習的,而且模型已知情況下,可以直接求解,不需要用MC、TD等方式。使用這個代碼,即使得到解,也不能通過本題測試。可以初步練習下調參,比如設置不同的獎勵,探索衰減等,看看機器找路的特點。本題是要考遍歷和找字串等技能。

代碼如下:

import random as rnd

class Env(object):
    def __init__(self):
        self.Connection = []
        self.Visited = []
        self.NodeIndex = -1

    def setConnection(self, con):
        self.Connection 
= con def reset(self): self.Visited = [] for node in self.Connection: self.Visited.append(0) self.NodeIndex = -1 return self.NodeIndex def stateSpace(self): return self.Connection def actionSpace(self, node): return self.Connection[node]
def actionSample(self, state): Sample = -1 if -1 == state: Sample = rnd.randint(0, len(self.Connection)-1) else: Sample = self.Connection[state][rnd.randint(0, len(self.Connection[state])-1)] return Sample def isDone(self): is_done = True
for visited in self.Visited: if 0 == visited: is_done = False break return is_done def step(self, action): self.NodeIndex = action reward = -1 * self.Visited[self.NodeIndex] self.Visited[self.NodeIndex] += 1 is_done = self.isDone() if is_done: reward = 1 return self.NodeIndex, reward, is_done class SarsaAgent(object): def __init__(self, env:Env): self.env = env self.Q = {} self.E = {} self.initAgent() def initStateValues(self, randomized = True): self.Q, self.E = {}, {} self.Q[-1], self.E[-1] = {}, {} actionIndex = 0 for state in self.env.stateSpace(): default_v = rnd.random() / 10 if randomized is True else 0.0 self.Q[-1][actionIndex] = default_v self.E[-1][actionIndex] = 0.0 actionIndex += 1 stateIndex = 0 for state in self.env.stateSpace(): self.Q[stateIndex], self.E[stateIndex] = {}, {} for action in self.env.actionSpace(stateIndex): default_v = random() / 10 if randomized is True else 0.0 self.Q[stateIndex][action] = default_v self.E[stateIndex][action] = 0.0 stateIndex += 1 def get(self, QorE, s, a): return QorE[s][a] def set(self, QorE, s, a, value): QorE[s][a] = value def resetEValue(self): self.E = {} self.E[-1] = {} actionIndex = 0 for state in self.env.stateSpace(): self.E[-1][actionIndex] = 0.0 actionIndex += 1 stateIndex = 0 for state in self.env.stateSpace(): self.E[stateIndex] = {} for action in self.env.actionSpace(stateIndex): self.E[stateIndex][action] = 0.0 stateIndex += 1 def initAgent(self): self.state = self.env.reset() self.initStateValues(randomized=False) # using simple decaying epsilon greedy exploration def curPolicy(self, s, episode_num, use_epsilon): epsilon = 1.00 / (episode_num+1) rand_value = rnd.random() action = None if use_epsilon and rand_value < epsilon: action = self.env.actionSample(s) else: Q_s = self.Q[s] action = max(Q_s, key=Q_s.get) return action # Agent依據當前策略和狀態決定下一步的動作 def performPolicy(self, s, episode_num, use_epsilon=False): return self.curPolicy(s, episode_num, use_epsilon) def act(self, a): return self.env.step(a) # SARSA(λ) learning def learning(self, lambda_, gamma, alpha, max_episode_num): total_time = 0 time_in_episode = 0 num_episode = 1 while num_episode <= max_episode_num: self.state = self.env.reset() self.start = self.state self.resetEValue() s0 = self.state a0 = self.performPolicy(s0, num_episode) time_in_episode = 0 is_done = False while not is_done: s1, r1, is_done = self.act(a0) print(a0, end="") a1= self.performPolicy(s1, num_episode) q = self.get(self.Q, s0, a0) q_prime = self.get(self.Q, s1, a1) delta = r1 + gamma * q_prime - q e = self.get(self.E, s0, a0) e = e + 1 self.set(self.E, s0, a0, e) # set E before update E state_action_list = list(zip(self.E.keys(),self.E.values())) for s, a_es in state_action_list: for a in self.env.actionSpace(s): e_value = a_es[a] old_q = self.get(self.Q, s, a) new_q = old_q + alpha * delta * e_value new_e = gamma * lambda_ * e_value self.set(self.Q, s, a, new_q) self.set(self.E, s, a, new_e) s0, a0 = s1, a1 time_in_episode += 1 print("\nEpisode {0} takes {1} steps.".format( num_episode, time_in_episode)) total_time += time_in_episode num_episode += 1 return def main(): env = Env() env.setConnection([[1],[0,2,4],[1,3,4],[2],[1,2]]) agent = SarsaAgent(env) env.reset() print("Learning...") agent.learning(lambda_ = 0.01, gamma=1.0, alpha=0.1, max_episode_num=50) if __name__ == "__main__": main()

LeetCode 847. Shortest Path Visiting All Nodes的強化學習解法