Deeplearning4j 實戰 (9):強化學習 -- Cartpole任務的訓練和效果測試
在之前的部落格中,我用Deeplearning4j構建深度神經網路來解決監督、無監督的機器學習問題。但除了這兩類問題外,強化學習也是機器學習中一個重要的分支,並且Deeplearning4j的子專案--Rl4j提供了對部分強化學習演算法的支援。這裡,就以強化學習中的經典任務--Cartpole問題作為學習Rl4j的入門例子,講解從環境搭建、模型訓練再到最後的效果評估的結果。
Cartpole描述的問題可以認為是:在一輛小車上豎立一根杆子,然後給小車一個推或者拉的力,使得杆子儘量保持平衡不滑倒。更詳細的描述可參見openai官網上關於Cartpole問題的解釋:https://gym.openai.com/envs/CartPole-v0
接著給出強化學習的一些概念:environment,action,reward
environment:描述強化學習問題中的外部環境,比如:Cartpole問題中杆子的角度,小車的位置、速度等。
action:在不同外部環境條件下采取的動作,比如:Cartpole問題中對於小車施加推或者拉的力。action可以是離散的集合,也可以是連續的。
reward:對於agent/network作出的action後獲取的回報/評價。比如:Cartpole問題中如果施加的力可以繼續讓杆子保持平衡,那reward就+1。
在描述reward這個概念時,提到了agent這個概念,在實際應用中,agent可以用神經網路來實現。
對於強化學習訓練後的agent來說,學習到的是如何在變化中的environment和reward選擇action的能力。通常有兩種學習策略可以選擇:Policy-Based和Value-Based。 Policy-Based直接學習action,通過Policy Gradient來更新模型引數,而相對的,Value-Based是最優化action所帶來的reward(action-value function,Q-function)來間接選取action。一般認為如果action是離散的,那麼Value-Based會優於Policy-Based,而連續的action則相反。在這裡主要討論Value-Based的學習策略,或者更具體的說Q-learning的問題。對於Policy-Based還有Model-Based不做討論。
Q-learning的概念早在20多年前就已經提出,再與近年來流行的深度神經網路結合產生了DQN的概念。Q-learning的目標是最大化Q值從而學習到選取action的策略。Q-leaning學習的策略公式:
Q(st,at)←Q(st,at)+α[rt+1+λmaxaQ(st+1,a)−Q(st,at)]
對於這裡主要討論的Catpole問題,我們也採用Q-learning來實現。
可以看到,與監督學習相比,強化學習多了action,environment等概念。雖然可以將reward類比成監督學習中的label(或者反過來,label也可以認為是強化學習中最終的reward),但通過action與environment不斷的互動甚至改變environment這一特點,是監督學習中所沒有的。在構建應用的時候,監督學習的學習的目標:label,灌入的資料都是一個定值。比如,影象的分類的問題,在用CNN訓練的時候,圖片本身不發生變化,label也不會發生變化,唯一變化的是神經網路中的權重值。但強化學習在訓練的時候,除了神經網路中的權重會發生變化(如果用NN建模的話),environment、reward等都會發生動態的變化。這樣構建合適正確的訓練資料會比較麻煩,容易出錯,所以對於CartPole問題,我們可以採用openAI提供的強化學習開發環境gym來訓練/測試agent。
gym的官方地址:https://gym.openai.com/
gym提供了棋類、視訊遊戲等強化學習問題的學習/測試/演算法效果比較的環境。這裡要處理的Cartpole問題,gym也提供了環境的支援。但是,除了python,gym對其他語言的支援不是很友好,所以為了可以獲取gym中的資料,RL4j提供了對gym-http-api(https://github.com/openai/gym-http-api)呼叫的包裝類。gym-http-api是為了方便除python外的其他語言也可以使用gym環境資料的一個REST介面。簡單來說,對於像RL4j這樣以Java實現的強化學習演算法庫可以通過gym-http-api獲取gym提供的資料。
gym的REST介面的安裝可以參見之前給出的github地址,裡面有詳細的描述。下面先給出gym-http-api的安裝和啟動過程的截圖:
下面就結合上面說的內容,給出RL4j的Catpole實現邏輯
1. 定義Q-learning的引數以及神經網路結構,兩者共同決定DQN的屬性
2. 定義讀取gym資料的包裝類物件
3. 訓練DQN並儲存模型
4. 載入儲存的模型並測試
這裡先貼下需要的Maven依賴以及程式碼版本
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<nd4j.version>0.8.0</nd4j.version>
<dl4j.version>0.8.0</dl4j.version>
<datavec.version>0.8.0</datavec.version>
<rl4j.version>0.8.0</rl4j.version>
<scala.binary.version>2.10</scala.binary.version>
</properties>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>rl4j-core</artifactId>
<version>${rl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>rl4j-gym</artifactId>
<version>${rl4j.version}</version>
</dependency>
</dependencies>
第一部分的程式碼如下:
public static QLearning.QLConfiguration CARTPOLE_QL =
new QLearning.QLConfiguration(
123, //Random seed
200, //Max step By epoch
150000, //Max step
150000, //Max size of experience replay
32, //size of batches
500, //target update (hard)
10, //num step noop warmup
0.01, //reward scaling
0.99, //gamma
1.0, //td-error clipping
0.1f, //min epsilon
1000, //num step for eps greedy anneal
true //double DQN
);
public static DQNFactoryStdDense.Configuration CARTPOLE_NET = DQNFactoryStdDense.Configuration.builder() .l2(0.001) .learningRate(0.0005)
.numHiddenNodes(16)
.numLayer(3)
.build();
第一部分中定義Q-learning的引數,包括每一輪的訓練的可採取的action的步數,最大步數以及儲存過往action的最大步數等。除此以外,DQNFactoryStdDense用來定義基於MLP的DQN網路結構,包括網路深度等常見引數。這裡的程式碼定義的是一個三層(只有一層隱藏層)的全連線神經網路。
接下來,定義兩個方法分別用於訓練和測試。catpole方法用於訓練DQN,而loadCartpole則用於測試。
訓練:
public static void cartPole() {
//record the training data in rl4j-data in a new folder (save)
DataManager manager = new DataManager(true);
//define the mdp from gym (name, render)
GymEnv<Box, Integer, DiscreteSpace> mdp = null;
try {
mdp = new GymEnv<Box, Integer, DiscreteSpace>("CartPole-v0", false, false);
} catch (RuntimeException e){
System.out.print("To run this example, download and start the gym-http-api repo found at https://github.com/openai/gym-http-api.");
}
//define the training
QLearningDiscreteDense<Box> dql = new QLearningDiscreteDense<Box>(mdp, CARTPOLE_NET, CARTPOLE_QL, manager);
//train
dql.train();
//get the final policy
DQNPolicy<Box> pol = dql.getPolicy();
//serialize and save (serialization showcase, but not required)
pol.save("/tmp/pol1");
//close the mdp (close http)
mdp.close();
}
測試:
public static void loadCartpole(){
//showcase serialization by using the trained agent on a new similar mdp (but render it this time)
//define the mdp from gym (name, render)
GymEnv<Box, Integer, DiscreteSpace> mdp2 = new GymEnv<Box, Integer, DiscreteSpace>("CartPole-v0", true, false);
//load the previous agent
DQNPolicy<Box> pol2 = DQNPolicy.load("/tmp/pol1");
//evaluate the agent
double rewards = 0;
for (int i = 0; i < 1000; i++) {
mdp2.reset();
double reward = pol2.play(mdp2);
rewards += reward;
Logger.getAnonymousLogger().info("Reward: " + reward);
}
Logger.getAnonymousLogger().info("average: " + rewards/1000);
mdp2.close();
}
在訓練模型的方法中,包含了第二、三步的內容。首先需要定義gym資料讀取物件,即程式碼中的GymEnv<Box, Integer, DiscreteSpace> mdp。它會通過gym-http-api的介面讀取訓練資料。接著,將第一步中定義的Q-learning的相關引數,神經網路結構作為引數傳入DQN訓練的包裝類中。其中DataManager的作用是用來管理訓練資料。
測試部分的程式碼實現了之前說的第四步,即載入策略模型並進行測試的過程。在測試的過程中,將每次action的reward打上log,並最後求取平均的reward。
訓練的過程截圖如下:
最後我們其實最關心的還是這個模型的效果。純粹通過平均reward的數值大小可能並不是非常的直觀,因此這裡給出一張gif的效果圖: