利用java編寫thrift來對inception模型進行訪問與部署
阿新 • • 發佈:2018-10-31
thrift就是一個跨語言呼叫的軟體框架。
首先運用thrift生成一個客戶端與服務端檔案,
客戶端:就是設定ip地址等資訊,並呼叫服務端資訊。
import cn.thrift.Tensorflow_Service; import org.apache.thrift.TException; import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.protocol.TProtocol; import org.apache.thrift.transport.TSocket; import org.apache.thrift.transport.TTransport; import sun.misc.BASE64Decoder; import sun.misc.BASE64Encoder; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; public class Client { private static final String SERVER_IP = "172.32.0.60"; private static final int SERVER_PORT = 8090; private static final int TIMEOUT = 30000; public static String GetImageStr(String imgFile) {//將圖片檔案轉化為位元組陣列字串,並對其進行Base64編碼處理 InputStream in = null; byte[] data = null; //讀取圖片位元組陣列 try { in = new FileInputStream(imgFile); data = new byte[in.available()]; in.read(data); in.close(); } catch (IOException e) { e.printStackTrace(); } //對位元組陣列Base64編碼 BASE64Encoder encoder = new BASE64Encoder(); return encoder.encode(data);//返回Base64編碼過的位元組陣列字串 } public void startClient() { TTransport transport = null; try { transport = new TSocket(SERVER_IP, SERVER_PORT, TIMEOUT); // 協議要和服務端一致 TProtocol protocol = new TBinaryProtocol(transport); Tensorflow_Service.Client client = new Tensorflow_Service.Client(protocol); transport.open(); String img_path ="E:\\thrift_tensorflow\\cropped_panda.jpg"; String img_base = GetImageStr(img_path); String result = client.tensorflow(img_base); System.out.println("Thrift client result is: " + result); } catch (TException e) { e.printStackTrace(); } finally { if (null != transport) { transport.close(); } } } /** * @param args */ public static void main(String[] args) { Client client = new Client(); client.startClient(); } }
服務端:呼叫後臺演算法處理程式
import cn.thrift.Tensorflow_Service; import cn.thrift.TensorflowImpl; import org.apache.thrift.TProcessor; import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.server.TServer; import org.apache.thrift.server.TSimpleServer; import org.apache.thrift.transport.TServerSocket; public class Server { private static final int SERVER_PORT = 8090; public void startServer() { try { System.out.println("Tensorflow TSimpleServer start ...."); TProcessor tprocessor = new Tensorflow_Service.Processor<>(new TensorflowImpl()); // 簡單的單執行緒服務模型,一般用於測試 TServerSocket serverTransport = new TServerSocket(SERVER_PORT); TServer.Args tArgs = new TServer.Args(serverTransport); tArgs.processor(tprocessor); tArgs.protocolFactory(new TBinaryProtocol.Factory()); TServer server = new TSimpleServer(tArgs); server.serve(); } catch (Exception e) { System.out.println("Server start error!!!"); e.printStackTrace(); } } /** * @param args */ public static void main(String[] args) { Server server = new Server(); server.startServer(); } }
演算法處理程式:
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ //package cn.thrift; package cn.thrift; import org.tensorflow.*; import org.tensorflow.types.UInt8; import java.io.IOException; import java.io.PrintStream; import java.nio.charset.Charset; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.Arrays; import java.util.List; //import org.apache.commons.codec.binary.Base64; import sun.misc.BASE64Decoder;//將base64轉換為byte[] /** Sample use of the TensorFlow Java API to label images using a pre-trained model. */ public class LabelImage_personal_model { private static void printUsage(PrintStream s) { final String url = "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip"; s.println( "Java program that uses a pre-trained Inception model (http://arxiv.org/abs/1512.00567)"); s.println("to label JPEG images."); s.println("TensorFlow version: " + TensorFlow.version()); s.println(); s.println("Usage: label_image <model dir> <image file>"); s.println(); s.println("Where:"); s.println("<model dir> is a directory containing the unzipped contents of the inception model"); s.println(" (from " + url + ")"); s.println("<image file> is the path to a JPEG image file"); } public static String Result(String img_string) throws IOException { String modelDir = "E:\\wunanjing\\inception5h"; // String imageFile ="E:\\wunanjing\\cropped_panda.jpg"; // private static final String INPUT_NAME = "input"; byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "tensorflow_inception_graph.pb"));//給定的檔案路徑,此方法開啟該檔案,該檔案的內容讀入一個位元組陣列,,然後關閉該檔案。 List<String> labels = readAllLinesOrExit(Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt")); // byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile));//讀取圖片 BASE64Decoder decode = new BASE64Decoder(); //將base64轉換為byte[] byte[] imageBytes = decode.decodeBuffer(img_string); try (Tensor<Float> image = constructAndExecuteGraphToNormalizeImage(imageBytes)) { float[] labelProbabilities = executeInceptionGraph(graphDef, image);//呼叫模型 並輸入圖片 進行預測 ,得到結果 int bestLabelIdx = maxIndex(labelProbabilities); System.out.println( String.format("BEST MATCH: %s (%.2f%% likely)", labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f)); //String lab = labels.get(bestLabelIdx); //String index = labelProbabilities[bestLabelIdx] * 100f; return String.format(labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f); } } private static Tensor<Float> constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {//把圖片轉換成inception需要模式 try (Graph g = new Graph()) {//建立一個空的構造方法 GraphBuilder b = new GraphBuilder(g); // Some constants specific to the pre-trained model at: // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip // // - The model was trained with images scaled to 224x224 pixels. // - The colors, represented as R, G, B in 1-byte each were converted to // float using (value - Mean)/Scale. final int H = 224; final int W = 224; final float mean = 117f; final float scale = 1f; // Since the graph is being constructed once per execution here, we can use a constant for the // input image. If the graph were to be re-used for multiple input images, a placeholder would // have been more appropriate. final Output<String> input = b.constant("input", imageBytes);//DecodeJpeg/contents:0 final Output<Float> output = b.div( b.sub( b.resizeBilinear( b.expandDims( b.cast(b.decodeJpeg(input, 3), Float.class), b.constant("make_batch", 0)), b.constant("size", new int[] {H, W})), b.constant("mean", mean)), b.constant("scale", scale)); try (Session s = new Session(g)) { // Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks. return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class); } } } private static float[] executeInceptionGraph(byte[] graphDef, Tensor<Float> image) {//呼叫模型graphDef 並輸入圖片image 進行預測 ,得到結果 try (Graph g = new Graph()) { g.importGraphDef(graphDef); try (Session s = new Session(g); // Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks. Tensor<Float> result = s.runner().feed("input", image).fetch("output").run().get(0).expect(Float.class)) { final long[] rshape = result.shape(); if (result.numDimensions() != 2 || rshape[0] != 1) { throw new RuntimeException( String.format( "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s", Arrays.toString(rshape))); } int nlabels = (int) rshape[1]; return result.copyTo(new float[1][nlabels])[0]; } } } private static int maxIndex(float[] probabilities) { int best = 0; for (int i = 1; i < probabilities.length; ++i) { if (probabilities[i] > probabilities[best]) { best = i; } } return best; } private static byte[] readAllBytesOrExit(Path path) { try { return Files.readAllBytes(path); } catch (IOException e) { System.err.println("Failed to read [" + path + "]: " + e.getMessage()); System.exit(1); } return null; } private static List<String> readAllLinesOrExit(Path path) { try { return Files.readAllLines(path, Charset.forName("UTF-8")); } catch (IOException e) { System.err.println("Failed to read [" + path + "]: " + e.getMessage()); System.exit(0); } return null; } // In the fullness of time, equivalents of the methods of this class should be auto-generated from // the OpDefs linked into libtensorflow_jni.so. That would match what is done in other languages // like Python, C++ and Go. static class GraphBuilder { GraphBuilder(Graph g) { this.g = g; } Output<Float> div(Output<Float> x, Output<Float> y) { return binaryOp("Div", x, y); } <T> Output<T> sub(Output<T> x, Output<T> y) { return binaryOp("Sub", x, y); } <T> Output<Float> resizeBilinear(Output<T> images, Output<Integer> size) { return binaryOp3("ResizeBilinear", images, size); } <T> Output<T> expandDims(Output<T> input, Output<Integer> dim) { return binaryOp3("ExpandDims", input, dim); } <T, U> Output<U> cast(Output<T> value, Class<U> type) { DataType dtype = DataType.fromClass(type); return g.opBuilder("Cast", "Cast") .addInput(value) .setAttr("DstT", dtype) .build() .<U>output(0); } Output<UInt8> decodeJpeg(Output<String> contents, long channels) { return g.opBuilder("DecodeJpeg", "DecodeJpeg") .addInput(contents) .setAttr("channels", channels) .build() .<UInt8>output(0); } <T> Output<T> constant(String name, Object value, Class<T> type) { try (Tensor<T> t = Tensor.<T>create(value, type)) { return g.opBuilder("Const", name) .setAttr("dtype", DataType.fromClass(type)) .setAttr("value", t) .build() .<T>output(0); } } Output<String> constant(String name, byte[] value) { return this.constant(name, value, String.class); } Output<Integer> constant(String name, int value) { return this.constant(name, value, Integer.class); } Output<Integer> constant(String name, int[] value) { return this.constant(name, value, Integer.class); } Output<Float> constant(String name, float value) { return this.constant(name, value, Float.class); } private <T> Output<T> binaryOp(String type, Output<T> in1, Output<T> in2) { return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0); } private <T, U, V> Output<T> binaryOp3(String type, Output<U> in1, Output<V> in2) { return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0); } private Graph g; } }
服務端程式可以直接部署在伺服器上,
本地客戶端程式執行的時候 ,輸入正確ip就可以呼叫服務端程式。
就會在本地輸出正確結果。