1. 程式人生 > >使用netty+zookeeper+protobuf實現一個RPC過程

使用netty+zookeeper+protobuf實現一個RPC過程

上次實現了一個基於java序列化和阻塞IO模型的RPC過程,效率很低,這次換用NIO來實現。程式碼有點多,儘量寫清楚一點。
這是maven的版本依賴,先放在前面,接下來就可以複製了。。。

    <dependency>
            <groupId>junit</groupId>
            <artifactId>junit</artifactId>
            <version>4.11</version>
            <scope>
test</scope> </dependency> <!-- SLF4J --> <dependency> <groupId>org.slf4j</groupId> <artifactId>slf4j-log4j12</artifactId> <version>1.7.7</version> </dependency> <!-- Netty -->
<dependency> <groupId>io.netty</groupId> <artifactId>netty-all</artifactId> <version>5.0.0.Alpha1</version> </dependency> <!-- protostuff --> <dependency> <groupId
>
io.protostuff</groupId> <artifactId>protostuff-core</artifactId> <version>1.6.0</version> </dependency> <dependency> <groupId>io.protostuff</groupId> <artifactId>protostuff-runtime</artifactId> <version>1.6.0</version> </dependency> <!-- ZooKeeper --> <dependency> <groupId>org.apache.zookeeper</groupId> <artifactId>zookeeper</artifactId> <version>3.4.6</version> </dependency> <!-- Apache Commons Collections --> <dependency> <groupId>org.apache.commons</groupId> <artifactId>commons-collections4</artifactId> <version>4.0</version> </dependency> <!-- Objenesis --> <dependency> <groupId>org.objenesis</groupId> <artifactId>objenesis</artifactId> <version>2.1</version> </dependency> <!-- CGLib --> <dependency> <groupId>cglib</groupId> <artifactId>cglib</artifactId> <version>3.1</version> </dependency>

寫完的目錄結構:
這裡寫圖片描述
好,先從兩個服務說起:加減法的計算
這是介面:

public interface IDiff {
    double diff(double a,double b);
}

public interface ISum {
    public int sum(int a, int b);
}


public class DiffImpl implements IDiff {
    @Override
    public double diff(double a, double b) {
        return a - b;
    }

}


public class SumImpl implements ISum {

    public int sum(int a, int b) {
        return a + b;
    }

}

1.伺服器端使用zookeeper動態註冊服務節點:

服務註冊程式碼,這裡有個注意的地方就是:CreateMode.EPHEMERAL_SEQUENTIAL,使用臨時節點的方式註冊,這樣在服務關閉時就會自動消失。不會留下死服務。

/**
 * 連線ZK註冊中心,建立服務註冊目錄
 */
public class ServiceRegistry {

    private static final Logger LOGGER = LoggerFactory.getLogger(ServiceRegistry.class);

    private final CountDownLatch latch = new CountDownLatch(1);

    private ZooKeeper zk;


    public ServiceRegistry() {
    }

    public void register(String data) {
        if (data != null) {
            zk = connectServer();
            if (zk != null) {
                createNode(Constant.ZK_DATA_PATH, data);
            }
        }
    }

    private ZooKeeper connectServer() {
        ZooKeeper zk = null;
        try {
            zk = new ZooKeeper(Constant.ZK_CONNECT, Constant.ZK_SESSION_TIMEOUT, new Watcher() {
                @Override
                public void process(WatchedEvent event) {
                    // 判斷是否已連線ZK,連線後計數器遞減.
                    if (event.getState() == Event.KeeperState.SyncConnected) {
                        latch.countDown();
                    }
                }
            });

            // 若計數器不為0,則等待.
            latch.await();
        } catch (IOException | InterruptedException e) {
            LOGGER.error("", e);
        }
        return zk;
    }



    private void createNode(String dir, String data) {
        try {
            byte[] bytes = data.getBytes();
            String path = zk.create(dir, bytes, ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.EPHEMERAL_SEQUENTIAL);
            LOGGER.info("create zookeeper node ({} => {})", path, data);
        } catch (KeeperException | InterruptedException e) {
            LOGGER.error("", e);
        }
    }
}

這是用到一個常量介面:

public interface Constant {

    int ZK_SESSION_TIMEOUT = 10000;
    String ZK_CONNECT = "s1:2181,s2:2181,s3:2181";
    String ZK_REGISTRY_PATH = "/registry";
    String ZK_DATA_PATH = ZK_REGISTRY_PATH + "/data";
    String ZK_IP_SPLIT = ":";
}

2.伺服器開發:

nioServer的半包處理,還有服務註冊,以及相應服務的確定。

public class RPCServer {

    private Map<String, Object> getServices() {
        Map<String, Object> services = new HashMap<String, Object>();
        // 先將服務確定好,才能區呼叫,不允許客戶端自動新增服務
        services.put(ISum.class.getName(), new SumImpl());
        services.put(IDiff.class.getName(), new DiffImpl());
        return services;
    }

    private void bind(int port) {
        EventLoopGroup bossGroup = new NioEventLoopGroup();
        EventLoopGroup workerGroup = new NioEventLoopGroup();
        Map<String, Object> handlerMap = getServices();
        try {
            ServerBootstrap b = new ServerBootstrap();
            b.group(bossGroup, workerGroup).channel(NioServerSocketChannel.class).option(ChannelOption.SO_BACKLOG, 100)
                    .childHandler(new ChannelInitializer<SocketChannel>() {

                        @Override
                        protected void initChannel(SocketChannel ch) throws Exception {
                            ch.pipeline()
                                .addLast(new LengthFieldBasedFrameDecoder(65535, 0, 2, 0, 2))
                                .addLast(new RPCDecoder(RPCRequest.class))
                                .addLast(new LengthFieldPrepender(2))
                                .addLast(new RPCEncoder(RPCResponse.class))
                                .addLast(new RPCServerHandler(handlerMap));
                        }
                    });
            ChannelFuture f = b.bind(port).sync();
            f.channel().closeFuture().sync();
        } catch (InterruptedException e) {
            e.printStackTrace();
        } finally {
            bossGroup.shutdownGracefully();
            workerGroup.shutdownGracefully();
        }
    }


    public static String getAddress(){
        InetAddress host = null;
        try {
//          獲取本機ip
            host = InetAddress.getLocalHost();
        } catch (UnknownHostException e) {
            e.printStackTrace();
        }
        String address = host.getHostAddress();
        return address;
    }

    public void initService(int port)  {
        ServiceRegistry serviceRegistry = new ServiceRegistry();
        String ip = getAddress();
//      向zookeeper註冊服務地址
        serviceRegistry.register(ip+Constant.ZK_IP_SPLIT+port);
        bind(port);
    }

    public static void main(String[] args) {
        int port = 9090;
        new RPCServer().initService(port);
    }
}

伺服器的處理邏輯:

public class RPCServerHandler extends ChannelHandlerAdapter {

    private static final Logger LOGGER = LoggerFactory.getLogger(RPCServerHandler.class);

    private final Map<String, Object> handlerMap;


    public RPCServerHandler(Map<String, Object> handlerMap) {
        this.handlerMap = handlerMap;
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        RPCResponse response = new RPCResponse();
        RPCRequest request = (RPCRequest) msg;
        response.setRequestId(request.getRequestId());
        try {
            Object result = handle(request);
            response.setResult(result);
        } catch (Throwable t) {
            response.setError(t);
        }
        ctx.writeAndFlush(response);
    }

    private Object handle(RPCRequest request) throws Throwable {
        String className = request.getClassName();
        Object serviceBean = handlerMap.get(className);

        Class<?> serviceClass = serviceBean.getClass();
        String methodName = request.getMethodName();
        Class<?>[] parameterTypes = request.getParameterTypes();
        Object[] parameters = request.getParameters();

        FastClass serviceFastClass = FastClass.create(serviceClass);
        FastMethod serviceFastMethod = serviceFastClass.getMethod(methodName, parameterTypes);
        return serviceFastMethod.invoke(serviceBean, parameters);
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
        LOGGER.error("server caught exception", cause);
        ctx.close();
    }
}

3.RPC請求和響應的封裝

比較簡單沒什麼好說的,唯一值得注意的就是請求和響應都帶有Id,這是因為NIO通訊是非同步的,如果出現一個客戶端傳送了多個請求,那麼也會有多個響應,由於是非同步的,那麼就免不了,出現不一致的對應情況,這時候可以用ID將每個請求和響應對應起來。

public class RPCRequest {
    private String requestId;
    private String className;
    private String methodName;
    private Class<?>[] parameterTypes;
    private Object[] parameters;

    public String getRequestId() {
        return requestId;
    }

    public void setRequestId(String requestId) {
        this.requestId = requestId;
    }

    public String getClassName() {
        return className;
    }

    public void setClassName(String className) {
        this.className = className;
    }

    public String getMethodName() {
        return methodName;
    }

    public void setMethodName(String methodName) {
        this.methodName = methodName;
    }

    public Class<?>[] getParameterTypes() {
        return parameterTypes;
    }

    public void setParameterTypes(Class<?>[] parameterTypes) {
        this.parameterTypes = parameterTypes;
    }

    public Object[] getParameters() {
        return parameters;
    }

    public void setParameters(Object[] parameters) {
        this.parameters = parameters;
    }
}

public class RPCResponse {
    private String requestId;
    private Throwable error;
    private Object result;

    public String getRequestId() {
        return requestId;
    }

    public void setRequestId(String requestId) {
        this.requestId = requestId;
    }

    public Throwable getError() {
        return error;
    }

    public void setError(Throwable error) {
        this.error = error;
    }

    public Object getResult() {
        return result;
    }

    public void setResult(Object result) {
        this.result = result;
    }
}

4.使用protobuf序列化:

這裡做成工具,好處是,如果換其他序列化工具,你就可以只改這個類,不用去改其他類。

public class SerializationUtil {

    private static Map<Class<?>, Schema<?>> cachedSchema = new ConcurrentHashMap<>();

    private static Objenesis objenesis = new ObjenesisStd(true);

    private SerializationUtil() {
    }

    @SuppressWarnings("unchecked")
    private static <T> Schema<T> getSchema(Class<T> cls) {
        Schema<T> schema = (Schema<T>) cachedSchema.get(cls);
        if (schema == null) {
            schema = RuntimeSchema.createFrom(cls);
            if (schema != null) {
                cachedSchema.put(cls, schema);
            }
        }
        return schema;
    }

    @SuppressWarnings("unchecked")
    public static <T> byte[] serialize(T obj) {
        Class<T> cls = (Class<T>) obj.getClass();
        LinkedBuffer buffer = LinkedBuffer.allocate(LinkedBuffer.DEFAULT_BUFFER_SIZE);
        try {
            Schema<T> schema = getSchema(cls);
            return ProtostuffIOUtil.toByteArray(obj, schema, buffer);
        } catch (Exception e) {
            throw new IllegalStateException(e.getMessage(), e);
        } finally {
            buffer.clear();
        }
    }

    public static <T> T deserialize(byte[] data, Class<T> cls) {
        try {
            T message = (T) objenesis.newInstance(cls);
            Schema<T> schema = getSchema(cls);
            ProtostuffIOUtil.mergeFrom(data, message, schema);
            return message;
        } catch (Exception e) {
            throw new IllegalStateException(e.getMessage(), e);
        }
    }
}

5.編解碼類

就是把byte[] –> Object,Object –> byte[];

public class RPCDecoder extends ByteToMessageDecoder {
    private Class<?> genericClass;

    public RPCDecoder(Class<?> genericClass) {
        this.genericClass = genericClass;
    }

    @Override
    protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
        final int length = in.readableBytes();
        final byte[] bytes = new byte[length];
        in.readBytes(bytes, 0, length);
        Object obj = SerializationUtil.deserialize(bytes, genericClass);
        out.add(obj);
    }

}


public class RPCEncoder extends MessageToByteEncoder<Object> {

    private Class<?> genericClass;

    public RPCEncoder(Class<?> genericClass) {
        this.genericClass = genericClass;
    }

    @Override
    public void encode(ChannelHandlerContext ctx, Object in, ByteBuf out) throws Exception {
        if (genericClass.isInstance(in)) {
            byte[] data = SerializationUtil.serialize(in);
            out.writeBytes(data);
        }
    }
}

6客戶端代理

public class RPCProxy {
    private String serverAddress;
    private ServiceDiscovery serviceDiscovery;

    public RPCProxy(String serverAddress) {
        this.serverAddress = serverAddress;
    }

    public RPCProxy(ServiceDiscovery serviceDiscovery) {
        this.serviceDiscovery = serviceDiscovery;
    }

    @SuppressWarnings("unchecked")
    public <T> T create(Class<?> interfaceClass) {
        return (T) Proxy.newProxyInstance(interfaceClass.getClassLoader(), new Class<?>[] { interfaceClass },
                new InvocationHandler() {
                    @Override
                    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
                        RPCRequest request = new RPCRequest(); // 建立並初始化 RPC 請求
                        request.setRequestId(UUID.randomUUID().toString());
                        request.setClassName(method.getDeclaringClass().getName());
                        request.setMethodName(method.getName());
                        request.setParameterTypes(method.getParameterTypes());
                        request.setParameters(args);

                        if (serviceDiscovery != null) {
                            serverAddress = serviceDiscovery.discover(); // 發現服務
                        }
//                      "123.23.213.23:9090"
                        String[] array = serverAddress.split(Constant.ZK_IP_SPLIT);
                        String host = array[0];
                        int port = Integer.parseInt(array[1]);

                        RPCClient client = new RPCClient(host, port); // 初始化 RPC
                                                                        // 客戶端
                        RPCResponse response = client.send(request); // 通過 RPC
                        if (response.getError() != null) {
                            throw response.getError();
                        } else {
                            return response.getResult();
                        }
                    }
                });
    }
}

7尋找和發現zookeeper /registry目錄下的服務。

/**
 * 服務發現:連線ZK,新增watch事件
 */
public class ServiceDiscovery {

    private static final Logger LOGGER = LoggerFactory.getLogger(ServiceDiscovery.class);

    private final CountDownLatch latch = new CountDownLatch(1);

    private volatile List<String> dataList = new ArrayList<>();

    private final String registryAddress;

    public ServiceDiscovery(String registryAddress) {
        this.registryAddress = registryAddress;

        ZooKeeper zk = connectServer();
        if (zk != null) {
            watchNode(zk);
        }
    }

    public String discover() {
        String data = null;
        int size = dataList.size();
        if (size > 0) {
            if (size == 1) {
                data = dataList.get(0);
                LOGGER.debug("using only data: {}", data);
            } else {
                data = dataList.get(ThreadLocalRandom.current().nextInt(size));
                LOGGER.debug("using random data: {}", data);
            }
        }
        return data;
    }

    private ZooKeeper connectServer() {
        ZooKeeper zk = null;
        try {
            zk = new ZooKeeper(registryAddress, Constant.ZK_SESSION_TIMEOUT, new Watcher() {
                @Override
                public void process(WatchedEvent event) {
                    if (event.getState() == Event.KeeperState.SyncConnected) {
                        latch.countDown();
                    }
                }
            });
            latch.await();
        } catch (IOException | InterruptedException e) {
            LOGGER.error("", e);
        }
        return zk;
    }

    private void watchNode(final ZooKeeper zk) {
        try {
            List<String> nodeList = zk.getChildren(Constant.ZK_REGISTRY_PATH, new Watcher() {
                @Override
                public void process(WatchedEvent event) {
                    if (event.getType() == Event.EventType.NodeChildrenChanged) {
                        watchNode(zk);
                    }
                }
            });
            List<String> dataList = new ArrayList<>();
            for (String node : nodeList) {
                byte[] bytes = zk.getData(Constant.ZK_REGISTRY_PATH + "/" + node, false, null);
                dataList.add(new String(bytes));
            }
            LOGGER.debug("node data: {}", dataList);
            this.dataList = dataList;
        } catch (KeeperException | InterruptedException e) {
            LOGGER.error("", e);
        }
    }
}

8,開發客戶端。

public class RPCClient {
    private final String host;
    private final int port;
    private final CountDownLatch latch;
    public RPCClient(String host,int port) {
        this.host = host;
        this.port = port;
        this.latch = new CountDownLatch(1);
    }

    public RPCResponse send(RPCRequest request){
        EventLoopGroup group = new NioEventLoopGroup();
        final RPCClientHandler handler = new RPCClientHandler(request,latch);
        RPCResponse response = null;
        try {
            Bootstrap b = new Bootstrap().group(group)
                .channel(NioSocketChannel.class).option(ChannelOption.TCP_NODELAY, true)
                .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 3000)
                .handler(new ChannelInitializer<SocketChannel>() {

                    @Override
                    protected void initChannel(SocketChannel ch) throws Exception {
                        ch.pipeline()
                            .addLast(new LengthFieldBasedFrameDecoder(65535, 0, 2, 0, 2))
                            .addLast(new RPCDecoder(RPCResponse.class))
                            .addLast(new LengthFieldPrepender(2))
                            .addLast(new RPCEncoder(RPCRequest.class))
                            .addLast(handler);
                    }
            });

            ChannelFuture f = b.connect(host, port).sync();
            latch.await();  
            response = handler.getResponse();
            if(response != null)
                f.channel().close();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }finally{
            group.shutdownGracefully();
        }
        return response;
    }

}

客戶端handler

public class RPCClientHandler extends ChannelHandlerAdapter {

    private final RPCRequest request;
    private RPCResponse response;
    private final CountDownLatch latch;

    public RPCClientHandler(RPCRequest request, CountDownLatch latch) {
        this.request = request;
        this.latch = latch;
    }

    public RPCResponse getResponse() {
        return response;
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        cause.getCause().printStackTrace();
        ctx.close();
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        ctx.writeAndFlush(request);
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        response = (RPCResponse) msg;
        latch.countDown();
    }

    @Override
    public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
        ctx.flush();
    }

}

基本上就這麼多,接下來測試。
啟動伺服器:

public class RPCServiceTest {
    public static void main(String[] args) {
        int port = 9090;
        new RPCServer().initService(port);
    }
}

這裡寫圖片描述

檢視zookeeper
這裡寫圖片描述
可以清晰的看到註冊的服務地址和埠。

啟動伺服器:

public class RPCTest {
    public static void serviceTest() {
        ServiceDiscovery discovery = new ServiceDiscovery(Constant.ZK_CONNECT);

        RPCProxy rpcProxy = new RPCProxy(discovery);
        IDiff diff = rpcProxy.create(IDiff.class);
        double result = diff.diff(1321, 32.2);
        ISum sum = rpcProxy.create(ISum.class);
        int result2 = sum.sum(1000, 1000);
        System.out.println(result+":"+result2);
//      -20.1:1100
    }

    public static void main(String[] args) {
        serviceTest();

    }
}

客戶端會在zookeeper上獲取伺服器地址,然後對伺服器RPC
這裡寫圖片描述
總結:花了一天時間,使用netty+zookeeper+protobuf實現了一個RPC呼叫,主要是想對大資料底層(hadoop,spark,hbase)的RPC有一個更加深刻的瞭解,不能只是停留在使用的層面。