用 Java 訓練深度學習模型,原來可以這麼簡單!
阿新 • • 發佈:2020-11-02
> 本文適合有 Java 基礎的人群
![](https://img2020.cnblogs.com/blog/759200/202011/759200-20201101171629904-336726111.jpg)
作者:**DJL-Keerthan&Lanking**
HelloGitHub 推出的[《講解開源專案》](https://github.com/HelloGitHub-Team/Article) 系列。這一期是由亞馬遜工程師:[Keerthan Vasist](https://github.com/keerthanvasist),為我們講解 DJL(完全由 Java 構建的深度學習平臺)系列的第 4 篇。
## 一、前言
很長時間以來,Java 都是一個很受企業歡迎的程式語言。得益於豐富的生態以及完善維護的包和框架,Java 擁有著龐大的開發者社群。儘管深度學習應用的不斷演進和落地,提供給 Java 開發者的框架和庫卻十分短缺。現今主要流行的深度學習模型都是用 Python 編譯和訓練的。對於 Java 開發者而言,如果要進軍深度學習界,就需要重新學習並接受一門新的程式語言同時還要學習深度學習的複雜知識。這使得大部分 Java 開發者學習和轉型深度學習開發變得困難重重。
為了減少 Java 開發者學習深度學習的成本,AWS 構建了 Deep Java Library (DJL),一個為 Java 開發者定製的開源深度學習框架。它為 Java 開發者對接主流深度學習框架提供了一個橋樑。
![](https://img2020.cnblogs.com/blog/759200/202011/759200-20201101171507537-1333339590.png)
在這篇文章中,我們會嘗試用 DJL 構建一個深度學習模型並用它訓練 MNIST 手寫數字識別任務。
## 二、什麼是深度學習?
在我們正式開始之前,我們先來了解一下機器學習和深度學習的基本概念。
機器學習是一個通過利用統計學知識,將資料輸入到計算機中進行訓練並完成特定目標任務的過程。這種歸納學習的方法可以讓計算機學習一些特徵並進行一系列複雜的任務,比如識別照片中的物體。由於需要寫複雜的邏輯以及測量標準,這些任務在傳統計算科學領域中很難實現。
深度學習是機器學習的一個分支,主要側重於對於人工神經網路的開發。人工神經網路是通過研究人腦如何學習和實現目標的過程中歸納而得出一套計算邏輯。它通過模擬部分人腦神經間資訊傳遞的過程,從而實現各類複雜的任務。深度學習中的“深度”來源於我們會在人工神經網路中編織構建出許多層(layer)從而進一步對資料資訊進行更深層的傳導。深度學習技術應用範圍十分廣泛,現在被用來做目標檢測、動作識別、機器翻譯、語意分析等各類現實應用中。
## 三、訓練 MNIST 手寫數字識別
### 3.1 專案配置
你可以用如下的 `gradle` 配置來引入依賴項。在這個案例中,我們用 DJL 的 api 包 (核心 DJL 元件) 和 basicdataset 包 (DJL 資料集) 來構建神經網路和資料集。這個案例中我們使用了 MXNet 作為深度學習引擎,所以我們會引入 `mxnet-engine` 和 `mxnet-native-auto` 兩個包。這個案例也可以執行在 PyTorch 引擎下,只需要替換成對應的軟體包即可。
```
plugins {
id 'java'
}
repositories {
jcenter()
}
dependencies {
implementation platform("ai.djl:bom:0.8.0")
implementation "ai.djl:api"
implementation "ai.djl:basicdataset"
// MXNet
runtimeOnly "ai.djl.mxnet:mxnet-engine"
runtimeOnly "ai.djl.mxnet:mxnet-native-auto"
}
```
### 3.2 NDArray 和 NDManager
NDArray 是 DJL 儲存資料結構和數學運算的基本結構。一個 NDArray 表達了一個定長的多維陣列。NDArray 的使用方法類似於 Python 中的 `numpy.ndarray`。
NDManager 是 NDArray 的老闆。它負責管理 NDArray 的產生和回收過程,這樣可以幫助我們更好的對 Java 記憶體進行優化。每一個 NDArray 都會是由一個 NDManager 創造出來,同時它們會在 NDManager 關閉時一同關閉。NDManager 和 NDArray 都是由 Java 的 AutoClosable 構建,這樣可以確保在執行結束時及時進行回收。想了解更多關於它們的用法和實踐,請參閱我們前一期文章:
[DJL 之 Java 玩轉多維陣列,就像 NumPy 一樣](https://www.cnblogs.com/xueweihan/p/13603551.html)
### Model
在 DJL 中,訓練和推理都是從 Model class 開始構建的。我們在這裡主要講訓練過程中的構建方法。下面我們為 Model 建立一個新的目標。因為 Model 也是繼承了 AutoClosable 結構體,我們會用一個 try block 實現:
```java
try (Model model = Model.newInstance()) {
...
// 主體訓練程式碼
...
}
```
### 準備資料
MNIST(Modified National Institute of Standards and Technology)資料庫包含大量手寫數字的圖,通常被用來訓練影象處理系統。DJL 已經將 MNIST 的資料集收錄到了 basicdataset 資料集裡,每個 MNIST 的圖的大小是 `28 x 28`。如果你有自己的資料集,你也可以通過 DJL 資料集匯入教程來匯入資料集到你的訓練任務中。
> 資料集匯入教程: http://docs.djl.ai/docs/development/how_to_use_dataset.html#how-to-create-your-own-dataset
```java
int batchSize = 32; // 批大小
Mnist trainingDataset = Mnist.builder()
.optUsage(Usage.TRAIN) // 訓練集
.setSampling(batchSize, true)
.build();
Mnist validationDataset = Mnist.builder()
.optUsage(Usage.TEST) // 驗證集
.setSampling(batchSize, true)
.build();
```
這段程式碼分別製作出了訓練和驗證集。同時我們也隨機排列了資料集從而更好的訓練。除了這些配置以外,你也可以新增對於圖片的進一步處理,比如設定圖片大小,對圖片進行歸一化等處理。
### 製作 model(建立 Block)
當你的資料集準備就緒後,我們就可以構建神經網路了。在 DJL 中,神經網路是由 Block(程式碼塊)構成的。一個 Block 是一個具備多種神經網路特性的結構。它們可以代表 一個操作, 神經網路的一部分,甚至是一個完整的神經網路。然後 Block 可以順序執行或者並行。同時 Block 本身也可以帶引數和子 Block。這種巢狀結構可以幫助我們構造一個複雜但又不失維護性的神經網路。在訓練過程中,每個 Block 中附帶的引數會被實時更新,同時也包括它們的各個子 Block。這種遞迴更新的過程可以確保整個神經網路得到充分訓練。
當我們構建這些 Block 的過程中,最簡單的方式就是將它們一個一個的巢狀起來。直接使用準備好 DJL 的 Block 種類,我們就可以快速製作出各類神經網路。
根據幾種基本的神經網路工作模式,我們提供了幾種 Block 的變體。SequentialBlock 是為了應對順序執行每一個子 Block 構造而成的。它會將前一個子 Block 的輸出作為下一個 Block 的輸入 繼續執行到底。與之對應的,是 ParallelBlock 它用於將一個輸入並行輸入到每一個子 Block 中,同時將輸出結果根據特定的合併方程合併起來。最後我們說一下 LambdaBlock,它是幫助使用者進行快速操作的一個 Block,其中並不具備任何引數,所以也沒有任何部分在訓練過程中更新。
![](https://img2020.cnblogs.com/blog/759200/202011/759200-20201101171531098-336448006.png)
我們來嘗試建立一個基本的 多層感知機(MLP)神經網路吧。多層感知機是一個簡單的前向型神經網路,它只包含了幾個全連線層 (LinearBlock)。那麼構建這個網路,我們可以直接使用 SequentialBlock。
```java
int input = 28 * 28; // 輸入層大小
int output = 10; // 輸出層大小
int[] hidden = new int[] {128, 64}; // 隱藏層大小
SequentialBlock sequentialBlock = new SequentialBlock();
sequentialBlock.add(Blocks.batchFlattenBlock(input));
for (int hiddenSize : hidden) {
// 全連線層
sequentialBlock.add(Linear.builder().setUnits(hiddenSize).build());
// 啟用函式
sequentialBlock.add(activation);
}
sequentialBlock.add(Linear.builder().setUnits(output).build());
```
當然 DJL 也提供了直接就可以拿來用的 MLP Block :
```java
Block block = new Mlp(
Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH,
Mnist.NUM_CLASSES,
new int[] {128, 64});
```
### 訓練
當我們準備好資料集和神經網路之後,就可以開始訓練模型了。在深度學習中,一般會由下面幾步來完成一個訓練過程:
![](https://img2020.cnblogs.com/blog/759200/202011/759200-20201101171547247-1850623355.png)
* 初始化:我們會對每一個 Block 的引數進行初始化,初始化每個引數的函式都是由 設定的 Initializer 決定的。
* 前向傳播:這一步將輸入資料在神經網路中逐層傳遞,然後產生輸出資料。
* 計算損失:我們會根據特定的損失函式 Loss 來計算輸出和標記結果的偏差。
* 反向傳播:在這一步中,你可以利用損失反向求導算出每一個引數的梯度。
* 更新權重:我們會根據選擇的優化器(Optimizer)更新每一個在 Block 上引數的值。
DJL 利用了 Trainer 結構體精簡了整個過程。開發者只需要建立 Trainer 並指定對應的 Initializer、Loss 和 Optimizer 即可。這些引數都是由 TrainingConfig 設定的。下面我們來看一下具體的引數設定:
* `TrainingListener`:這個是對訓練過程設定的監聽器。它可以實時反饋每個階段的訓練結果。這些結果可以用於記錄訓練過程或者幫助 debug 神經網路訓練過程中的問題。使用者也可以定製自己的 TrainingListener 來對訓練過程進行監聽。
```java
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.addEvaluator(new Accuracy())
.addTrainingListeners(TrainingListener.Defaults.logging());
try (Trainer trainer = model.newTrainer(config)){
// 訓練程式碼
}
```
當訓練器產生後,我們可以定義輸入的 Shape。之後就可以呼叫 fit 函式來進行訓練。fit 函式會對輸入資料,訓練多個 epoch 是並最終將結果儲存在本地目錄下。
```java
/*
* MNIST 包含 28x28 灰度圖片並匯入成 28 * 28 NDArray。
* 第一個維度是批大小, 在這裡我們設定批大小為 1 用於初始化。
*/
Shape inputShape = new Shape(1, Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH);
int numEpoch = 5;
String outputDir = "/build/model";
// 用輸入初始化 trainer
trainer.initialize(inputShape);
TrainingUtils.fit(trainer, numEpoch, trainingSet, validateSet, outputDir, "mlp");
```
這就是訓練過程的全部流程了!用 DJL 訓練是不是還是很輕鬆的?之後看一下輸出每一步的訓練結果。如果你用了我們預設的監聽器,那麼輸出是類似於下圖:
```
[INFO ] - Downloading libmxnet.dylib ...
[INFO ] - Training on: cpu().
[INFO ] - Load MXNet Engine Version 1.7.0 in 0.131 ms.
Training: 100% |████████████████████████████████████████| Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.24, speed: 1235.20 items/sec
Validating: 100% |████████████████████████████████████████|
[INFO ] - Epoch 1 finished.
[INFO ] - Train: Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.24
[INFO ] - Validate: Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14
Training: 100% |████████████████████████████████████████| Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.10, speed: 2851.06 items/sec
Validating: 100% |████████████████████████████████████████|
[INFO ] - Epoch 2 finished.NG [1m 41s]
[INFO ] - Train: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.10
[INFO ] - Validate: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.09
[INFO ] - train P50: 12.756 ms, P90: 21.044 ms
[INFO ] - forward P50: 0.375 ms, P90: 0.607 ms
[INFO ] - training-metrics P50: 0.021 ms, P90: 0.034 ms
[INFO ] - backward P50: 0.608 ms, P90: 0.973 ms
[INFO ] - step P50: 0.543 ms, P90: 0.869 ms
[INFO ] - epoch P50: 35.989 s, P90: 35.989 s
```
當訓練結果完成後,我們可以用剛才的模型進行推理來識別手寫數字。如果剛才的內容哪裡有不是很清楚的,可以參照下面兩個連結直接嘗試訓練。
> 手寫資料集訓練:https://docs.djl.ai/examples/docs/train_mnist_mlp.html
>
> 手寫資料集推理:https://docs.djl.ai/jupyter/tutorial/03_image_classification_with_your_model.html
## 四、最後
在這個文章中,我們介紹了深度學習的基本概念,同時還有如何優雅的利用 DJL 構建深度學習模型並進行訓練。DJL 也提供了更加多樣的資料集和神經網路。如果有興趣學習深度學習,可以參閱我們的 Java 深度學習書。
> Java 深度學習書:https://zh.d2l.ai/
![](https://img2020.cnblogs.com/blog/759200/202011/759200-20201101171605480-21733676.png)
Deep Java Library(DJL)是一個基於 Java 的深度學習框架,同時支援訓練以及推理。DJL 博取眾長,構建在多個深度學習框架之上 (TenserFlow、PyTorch、MXNet 等) 也同時具備多個框架的優良特性。你可以輕鬆使用 DJL 來進行訓練然後部署你的模型。
它同時擁有著強大的模型庫支援:只需一行便可以輕鬆讀取各種預訓練的模型。現在 DJL 的模型庫同時支援高達 70 多個來自 GluonCV、 HuggingFace、TorchHub 以及 Keras 的模型。
> 專案地址:https://github.com/awsla