1. 程式人生 > 其它 >Reinforcement Learning (DQN) 中經驗池詳細解釋

Reinforcement Learning (DQN) 中經驗池詳細解釋

技術標籤:python大資料多型強化學習

一般DQN中的經驗池類,都類似於下面這段程式碼。

import random
from collections import namedtuple, deque

Transition = namedtuple('Transition', ('state', 'next_state', 'action', 'reward'))

# 經驗池類
class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity        # 容量
        self.memory = []
        self.position = 0

    # 將四元組壓入經驗池
    def push(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    # 從經驗池中隨機壓出一個四元組
    def sample(self, batch_size):
        transitions = random.sample(self.memory, batch_size)
        batch = Transition(*zip(*transitions))
        return batch

    def __len__(self):
        return len(self.memory)

對Python不太熟悉的我裡邊就有兩點比較迷惑,一個是namedtuple()方法,一個是sample方法的倒數第二行,為什麼要這樣處理。

第一點,namedtuple()是繼承自tuple的子類,namedtuple()方法能夠建立一個和tuple類似的物件,而且物件擁有可訪問的屬性。

第二點,也就是sample方法中的倒數第二行,這裡進行了一個轉換, 將batch_size個四元組,轉換成,四個元祖,每個元祖一共有batch_size項,這裡放個程式解釋一下。

import random
from collections import namedtuple

if __name__ == '__main__':

    batch_size = 3
    Transition = namedtuple('Transition', ('state', 'next_state', 'action', 'reward'))

    a=Transition(state=1,next_state=2,action=3,reward=4)
    b=Transition(state=11,next_state=12,action=13,reward=14)
    c=Transition(state=21,next_state=22,action=23,reward=24)
    d=Transition(state=31,next_state=32,action=33,reward=34)
    e=Transition(state=41,next_state=42,action=43,reward=44)

    f=[a,b,c,d,e]

    # 從f中隨機抽取batch_size個數據
    t=random.sample(f,batch_size)

    print("隨機抽取的batch_size個四元祖是:")
    for i in range(batch_size):
        print(t[i])
    print()

    # 將t進行解壓操作
    print("將四元組進行解壓後是:")
    print(*zip(*t))
    print()

    # 將t進行解壓操作,再進行Transition轉換
    # 將batch_size個四元組,轉換成,四個元祖,每個元祖一共有batch_size項
    print("將四元組進行解壓後再進行Transition轉換後是:")
    batch=Transition(*zip(*t))
    print(batch)

輸出結果:

隨機抽取的batch_size個四元祖是:
Transition(state=21, next_state=22, action=23, reward=24)
Transition(state=11, next_state=12, action=13, reward=14)
Transition(state=41, next_state=42, action=43, reward=44)

將四元組進行解壓後是:
(21, 11, 41) (22, 12, 42) (23, 13, 43) (24, 14, 44)

將四元組進行解壓後再進行Transition轉換後是:
Transition(state=(21, 11, 41), next_state=(22, 12, 42), action=(23, 13, 43), reward=(24, 14, 44))