Java呼叫Keras、Tensorflow模型
實現python離線訓練模型,Java線上預測部署。檢視原文
目前深度學習主流使用python訓練自己的模型,有非常多的框架提供了能快速搭建神經網路的功能,其中Keras提供了high-level的語法,底層可以使用tensorflow或者theano。
但是有很多公司後臺應用是用Java開發的,如果用python提供HTTP介面,對業務延遲要求比較高的話,仍然會有一定得延遲,所以能不能使用Java呼叫模型,python可以離線的訓練模型?(tensorflow也提供了成熟的部署方案TensorFlow Serving)
手頭上有一個用Keras訓練的模型,網上關於Java呼叫Keras模型的資料不是很多,而且大部分是重複的,並且也沒有講的很詳細。大致有兩種方案,一種是基於Java的深度學習庫匯入Keras模型實現,另外一種是用tensorflow提供的Java介面呼叫。
Deeplearning4J
Eclipse Deeplearning4j is the first commercial-grade, open-source, distributed deep-learning library written for Java and Scala. Integrated with Hadoop and Spark, DL4J brings AIAI to business environments for use on distributed GPUs and CPUs.
Deeplearning4j目前支援匯入Keras訓練的模型,並且提供了類似python中numpy的一些功能,更方便地處理結構化的資料。遺憾的是,Deeplearning4j現在只覆蓋了Keras <2.0版本的大部分Layer,如果你是用Keras 2.0以上的版本,在匯入模型的時候可能會報錯。
Tensorflow
文件,Java的文件很少,不過呼叫模型的過程也很簡單。採用這種方式呼叫模型需要先將Keras匯出的模型轉成tensorflow的protobuf協議的模型。
1、Keras的h5模型轉為pb模型
在Keras中使用model.save(model.h5)
儲存當前模型為HDF5格式的檔案中。
Keras的後端框架使用的是tensorflow,所以先把模型匯出為pb模型。在Java中只需要呼叫模型進行預測,所以將當前的graph中的Variable全部變成Constant,並且使用訓練後的weight。以下是freeze graph的程式碼:
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True): """ :param session: 需要轉換的tensorflow的session :param keep_var_names:需要保留的variable,預設全部轉換constant :param output_names:output的名字 :param clear_devices:是否移除裝置指令以獲得更好的可移植性 :return: """ from tensorflow.python.framework.graph_util import convert_variables_to_constants graph = session.graph with graph.as_default(): freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or [])) output_names = output_names or [] # 如果指定了output名字,則複製一個新的Tensor,並且以指定的名字命名 if len(output_names) > 0: for i in range(output_names): # 當前graph中複製一個新的Tensor,指定名字 tf.identity(model.model.outputs[i], name=output_names[i]) output_names += [v.op.name for v in tf.global_variables()] input_graph_def = graph.as_graph_def() if clear_devices: for node in input_graph_def.node: node.device = "" frozen_graph = convert_variables_to_constants(session, input_graph_def, output_names, freeze_var_names) return frozen_graph
該方法可以將tensor為Variable的graph全部轉為constant並且使用訓練後的weight。注意output_name比較重要,後面Java呼叫模型的時候會用到。
在Keras中,模型是這麼定義的:
def create_model(self):
input_tensor = Input(shape=(self.maxlen,), name="input")
x = Embedding(len(self.text2id) + 1, 200)(input_tensor)
x = Bidirectional(LSTM(128))(x)
x = Dense(256, activation="relu")(x)
x = Dropout(self.dropout)(x)
x = Dense(len(self.id2class), activation='softmax', name="output_softmax")(x)
model = Model(inputs=input_tensor, outputs=x)
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
下面的程式碼可以檢視定義好的Keras模型的輸入、輸出的name,這對之後Java呼叫有幫助。
print(model.input.op.name)
print(model.output.op.name)
訓練好Keras模型後,轉換為pb模型:
from keras import backend as K
import tensorflow as tf
model.load_model("model.h5")
print(model.input.op.name)
print(model.output.op.name)
# 自定義output_names
frozen_graph = freeze_session(K.get_session(), output_names=["output"])
tf.train.write_graph(frozen_graph, "./", "model.pb", as_text=False)
### 輸出:
# input
# output_softmax/Softmax
# 如果不自定義output_name,則生成的pb模型的output_name為output_softmax/Softmax,如果自定義則以自定義名為output_name
執行之後會生成model.pb的模型,這將是之後呼叫的模型。
2、Java呼叫
新建一個maven專案,pom裡面匯入tensorflow包:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.6.0</version>
</dependency>
核心程式碼:
public void predict() throws Exception {
try (Graph graph = new Graph()) {
graph.importGraphDef(Files.readAllBytes(Paths.get(
"path/to/model.pb"
)));
try (Session sess = new Session(graph)) {
// 自己構造一個輸入
float[][] input = {{56, 632, 675, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}};
try (Tensor x = Tensor.create(input);
// input是輸入的name,output是輸出的name
Tensor y = sess.runner().feed("input", x).fetch("output").run().get(0)) {
float[][] result = new float[1][y.shape[1]];
y.copyTo(result);
System.out.println(Arrays.toString(y.shape()));
System.out.println(Arrays.toString(result[0]));
}
}
}
}
Graph和Tensor物件都是需要通過close()
方法顯式地釋放佔用的資源,程式碼中使用了try-with-resources
的方法實現的。
至此,已經可以實現Keras離線訓練,Java線上預測的功能。