1. 程式人生 > >深度有趣 | 21 從FlappyBird到DQN

深度有趣 | 21 從FlappyBird到DQN

簡介

介紹強化學習(Reinforcement Learning,RL)的概念,並用DQN訓練一個會玩FlappyBird的模型

FlappyBird

如果沒有pygame則安裝

pip install pygame

執行flappy.py即可開始遊戲,如果出現按鍵無法控制的情況,用pythonw執行程式碼即可

pythonw flappy.py

FlappyBird遊戲截圖

原理

無監督學習沒有標籤,例如聚類;有監督學習有標籤,例如分類;而強化學習介於兩者之間,標籤是通過不斷嘗試積累的

RL包括幾個組成部分:

  • State(S):環境的狀態,例如FlappyBird中的當前遊戲介面,可以用一張圖片來表示
  • Action(A):每個S下可採取的行動集合,例如在FlappyBird中可選擇兩個A,“跳一下”或者“什麼都不做”
  • Reward(R):在某個S下執行某個A之後得到的回報,例如在FlappyBird中,可以是成功跳過一根水管(正回報),撞到水管或者掉到地上(負回報)

這樣一來,遊戲的進行過程,無非是從一個初始S開始,執行A、得到R、進入下一個S,如此往復,直到進入一個終止S

s0,a0,r1,s1,a1,r2,s2,...,sn1,an1,rn,sn s_0,a_0,r_1,s_1,a_1,r_2,s_2,...,s_{n-1},a_{n-1},r_n,s_n

定義一個函式,用來計算遊戲過程中回報的總和

R=

r1+r2+r3+...+rn R=r_1+r_2+r_3+...+r_n

以及從某個時刻開始之後的回報總和

Rt=rt+rt+1+rt+2+...+rn R_t=r_t+r_{t+1}+r_{t+2}+...+r_n

但我們對未來每一步能獲取的回報並不是完全肯定的,所以不妨乘上一個0到1之間的衰減係數

Rt=rt+γrt+1+γ2rt+2+...+γntrn R_t=r_t+\gamma r_{t+1}+\gamma^2 r_{t+2}+...+\gamma^{n-t} r_n

+γrt+1+γ2rt+2+...+γntrn

這樣一來,可以得到相鄰兩步總回報之間的遞推關係

Rt=rt+γRt+1 R_t=r_t+\gamma R_{t+1}

DQN是強化學習中的一種常用演算法,主要是引入了Q函式(Quality,價值函式),用於計算在某個S下執行某個A可以得到的最大總回報

Q(st,at)=maxRt+1 Q(s_t,a_t)=\max R_{t+1}

有了Q函式之後,對於當前狀態S,只需要計算每一個A對應的Q值,然後選擇Q值最大的一個A,便是最優的行動策略(策略函式)

π(s)=argmaxaQ(s,a) \pi(s)=argmax_a Q(s,a)

當Q函式收斂後,還可以得到Q函式的遞推公式

Q(st,at)=rt+γmaxQ(st+1,at+1) Q(s_t,a_t)=r_t+\gamma \max Q(s_{t+1},a_{t+1})

可以使用神經網路實現Q函式並訓練:

  • 定義神經網路的結構並隨機初始化,輸入為S,輸出的個數和行動集合的大小一樣
  • 每次以一定概率隨機選擇A,否則使用策略函式選擇最優的A,即隨機探索和有向策略相結合
  • 維護一個記憶模組,用於積累遊戲過程中產生的資料
  • 預熱期:不訓練,主要是為了讓記憶模組先積累一定資料
  • 探索期:逐漸降低隨機概率,從隨機探索過渡到有向策略,並且每次從記憶模組中取出一些資料訓練模型
  • 訓練期:固定隨機概率,進一步訓練模型,使得Q函式進一步收斂

實現

game中的程式碼對之前的flappy.py進行了簡化和修改,去掉了背景圖並固定角色和水管顏色,遊戲會自動開始,掛掉之後也會自動繼續,主要是便於模型自動進行和採集資料

載入庫

# -*- coding: utf-8 -*-

import tensorflow as tf
import numpy as np
import random
import cv2
import sys
sys.path.append('game/')
import wrapped_flappy_bird as fb
from collections import deque

定義一些引數

ACTIONS = 2
GAMMA = 0.99
OBSERVE = 10000
EXPLORE = 3000000
INITIAL_EPSILON = 0.1
FINAL_EPSILON = 0.0001
REPLAY_MEMORY = 50000
BATCH = 32
IMAGE_SIZE = 80

定義一些網路輸入和輔助函式,每一個S由連續的四幀遊戲截圖組成

S = tf.placeholder(dtype=tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 4], name='S')
A = tf.placeholder(dtype=tf.float32, shape=[None, ACTIONS], name='A')
Y = tf.placeholder(dtype=tf.float32, shape=[None], name='Y')
k_initializer = tf.truncated_normal_initializer(0, 0.01)
b_initializer = tf.constant_initializer(0.01)

def conv2d(inputs, kernel_size, filters, strides):
    return tf.layers.conv2d(inputs, kernel_size=kernel_size, filters=filters, strides=strides, padding='same', kernel_initializer=k_initializer, bias_initializer=b_initializer)

def max_pool(inputs):
    return tf.layers.max_pooling2d(inputs, pool_size=2, strides=2, padding='same')

def relu(inputs):
    return tf.nn.relu(inputs)

定義網路結構,典型的卷積、池化、全連線層結構

h0 = max_pool(relu(conv2d(S, 8, 32, 4)))
h0 = relu(conv2d(h0, 4, 64, 2))
h0 = relu(conv2d(h0, 3, 64, 1))
h0 = tf.contrib.layers.flatten(h0)
h0 = tf.layers.dense(h0, units=512, activation=tf.nn.relu, bias_initializer=b_initializer)

Q = tf.layers.dense(h0, units=ACTIONS, bias_initializer=b_initializer, name='Q')
Q_ = tf.reduce_sum(tf.multiply(Q, A), axis=1)
loss = tf.losses.mean_squared_error(Y, Q_)
optimizer = tf.train.AdamOptimizer(1e-6).minimize(loss)

用一個佇列實現記憶模組,開始遊戲,對於初始狀態選擇什麼都不做

game_state = fb.GameState()
D = deque()

do_nothing = np.zeros(ACTIONS)
do_nothing[0] = 1
img, reward, terminal = game_state.frame_step(do_nothing)
img = cv2.cvtColor(cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_BGR2GRAY)
_, img = cv2.threshold(img, 1, 255, cv2.THRESH_BINARY)
S0 = np.stack((img, img, img, img), axis=2)

繼續進行遊戲並訓練模型

sess = tf.Session()
sess.run(tf.global_variables_initializer())

t = 0
success = 0
saver = tf.train.Saver()
epsilon = INITIAL_EPSILON
while True:
    if epsilon > FINAL_EPSILON and t > OBSERVE:
        epsilon = INITIAL_EPSILON - (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE * (t - OBSERVE)

    Qv = sess.run(Q, feed_dict={S: [S0]})[0]
    Av = np.zeros(ACTIONS)
    if np.random.random() <= epsilon:
        action_index = np.random.randint(ACTIONS)
    else:
        action_index = np.argmax(Qv) 
    Av[action_index] = 1

    img, reward, terminal = game_state.frame_step(Av)
    if reward == 1:
        success += 1
    img = cv2.cvtColor(cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_BGR2GRAY)
    _, img = cv2.threshold(img, 1, 255, cv2.THRESH_BINARY)
    img = np.reshape(img, (IMAGE_SIZE, IMAGE_SIZE, 1))
    S1 = np.append(S0[:, :, 1:], img, axis=2)

    D.append((S0, Av, reward, S1, terminal))
    if len(D) > REPLAY_MEMORY:
        D.popleft()

    if t > OBSERVE:
        minibatch = random.sample(D, BATCH)
        S_batch = [d[0] for d in minibatch]
        A_batch = [d[1] for d in minibatch]
        R_batch = [d[2] for d in minibatch]
        S_batch_next = [d[3] for d in minibatch]
        T_batch = [d[4] for d in minibatch]

        Y_batch = []
        Q_batch_next = sess.run(Q, feed_dict={S: S_batch_next})
        for i in range(BATCH):
            if T_batch[i]:
                Y_batch.append(R_batch[i])
            else:
                Y_batch.append(R_batch[i] + GAMMA * np.max(Q_batch_next[i]))

        sess.run(optimizer, feed_dict={S: S_batch, A: A_batch, Y: Y_batch})

    S0 = S1
    t += 1

    if t > OBSERVE and t % 10000 == 0:
        saver.save(sess, './flappy_bird_dqn', global_step=t)

    if t <= OBSERVE:
        state = 'observe'
    elif t <= OBSERVE + EXPLORE:
        state = 'explore'
    else:
        state = 'train'
    print('Current Step %d Success %d State %s Epsilon %.6f Action %d Reward %f Q_MAX %f' % (t, success, state, epsilon, action_index, reward, np.max(Qv)))

執行dqn_flappy.py即可從零開始訓練模型,一開始角色各種亂跳,一根水管都跳不過去,但隨著訓練的進行,角色會通過學習獲得越來越穩定的表現

DQN模型執行結果

也可以直接使用以下程式碼執行訓練好的模型

# -*- coding: utf-8 -*-

import tensorflow as tf
import numpy as np
import cv2
import sys
sys.path.append('game/')
import wrapped_flappy_bird as fb

ACTIONS = 2
IMAGE_SIZE = 80

sess = tf.Session()
sess.run(tf.global_variables_initializer())

saver = tf.train.import_meta_graph('./flappy_bird_dqn-8500000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
graph = tf.get_default_graph()

S = graph.get_tensor_by_name('S:0')
Q = graph.get_tensor_by_name('Q/BiasAdd:0')

game_state = fb.GameState()

do_nothing = np.zeros(ACTIONS)
do_nothing[0] = 1
img, reward, terminal = game_state.frame_step(do_nothing)
img = cv2.cvtColor(cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_BGR2GRAY)
_, img = cv2.threshold(img, 1, 255, cv2.THRESH_BINARY)
S0 = np.stack((img, img, img, img), axis=2)

while True:
    Qv = sess.run(Q, feed_dict={S: [S0]})[0]
    Av = np.zeros(ACTIONS) 
    Av[np.argmax(Qv)] = 1

    img, reward, terminal = game_state.frame_step(Av)
    img = cv2.cvtColor(cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_BGR2GRAY)
    _, img = cv2.threshold(img, 1, 255, cv2.THRESH_BINARY)
    img = np.reshape(img, (IMAGE_SIZE, IMAGE_SIZE, 1))
    S0 = np.append(S0[:, :, 1:], img, axis=2)

參考

視訊講解課程