1. 程式人生 > 實用技巧 >手寫一個RPC框架

手寫一個RPC框架

一、前言

前段時間看到一篇不錯的文章《看了這篇你就會手寫RPC框架了》,於是便來了興趣對著實現了一遍,後面覺得還有很多優化的地方便對其進行了改進。

主要改動點如下:

  1. 除了Java序列化協議,增加了protobuf和kryo序列化協議,配置即用。
  2. 增加多種負載均衡演算法(隨機、輪詢、加權輪詢、平滑加權輪詢),配置即用。
  3. 客戶端增加本地服務列表快取,提高效能。
  4. 修復高併發情況下,netty導致的記憶體洩漏問題
  5. 由原來的每個請求建立一次連線,改為建立TCP長連線,並多次複用。
  6. 服務端增加執行緒池提高訊息處理能力

二、介紹

RPC,即 Remote Procedure Call(遠端過程呼叫),呼叫遠端計算機上的服務,就像呼叫本地服務一樣。RPC可以很好的解耦系統,如WebService就是一種基於Http協議的RPC。


呼叫示意圖

總的來說,就如下幾個步驟:

  1. 客戶端(ServerA)執行遠端方法時就呼叫client stub傳遞類名、方法名和引數等資訊。
  2. client stub會將引數等資訊序列化為二進位制流的形式,然後通過Sockect傳送給服務端(ServerB)
  3. 服務端收到資料包後,server stub 需要進行解析反序列化為類名、方法名和引數等資訊。
  4. server stub呼叫對應的本地方法,並把執行結果返回給客戶端

所以一個RPC框架有如下角色:

服務消費者

遠端方法的呼叫方,即客戶端。一個服務既可以是消費者也可以是提供者。

服務提供者

遠端服務的提供方,即服務端。一個服務既可以是消費者也可以是提供者。

註冊中心

儲存服務提供者的服務地址等資訊,一般由zookeeper、redis等實現。

監控運維(可選)

監控介面的響應時間、統計請求數量等,及時發現系統問題併發出告警通知。

三、實現

本RPC框架rpc-spring-boot-starter涉及技術棧如下:

  • 使用zookeeper作為註冊中心
  • 使用netty作為通訊框架
  • 訊息編解碼:protostuff、kryo、java
  • spring
  • 使用SPI來根據配置動態選擇負載均衡演算法等

由於程式碼過多,這裡只講幾處改動點。

3.1動態負載均衡演算法

1.編寫LoadBalance的實現類


負載均衡演算法實現類

2.自定義註解 @LoadBalanceAno

  1. /**
  2. * 負載均衡註解
  3. */
  4. @Target(ElementType.TYPE)
  5. @Retention(RetentionPolicy.RUNTIME)
  6. @Documented
  7. public @interface LoadBalanceAno {

  8. String value() default "";
  9. }

  10. /**
  11. * 輪詢演算法
  12. */
  13. @LoadBalanceAno(RpcConstant.BALANCE_ROUND)
  14. public class FullRoundBalance implements LoadBalance {

  15. private static Logger logger = LoggerFactory.getLogger(FullRoundBalance.class);

  16. private volatile int index;

  17. @Override
  18. public synchronized Service chooseOne(List<Service> services) {
  19. // 加鎖防止多執行緒情況下,index超出services.size()
  20. if (index == services.size()) {
  21. index = 0;
  22. }
  23. return services.get(index++);
  24. }
  25. }

3.新建在resource目錄下META-INF/servers資料夾並建立檔案


enter description here

4.RpcConfig增加配置項loadBalance

  1. /**
  2. * @author 2YSP
  3. * @date 2020/7/26 15:13
  4. */
  5. @ConfigurationProperties(prefix = "sp.rpc")
  6. public class RpcConfig {

  7. /**
  8. * 服務註冊中心地址
  9. */
  10. private String registerAddress = "127.0.0.1:2181";

  11. /**
  12. * 服務暴露埠
  13. */
  14. private Integer serverPort = 9999;
  15. /**
  16. * 服務協議
  17. */
  18. private String protocol = "java";
  19. /**
  20. * 負載均衡演算法
  21. */
  22. private String loadBalance = "random";
  23. /**
  24. * 權重,預設為1
  25. */
  26. private Integer weight = 1;

  27. // 省略getter setter
  28. }

5.在自動配置類RpcAutoConfiguration根據配置選擇對應的演算法實現類

  1. /**
  2. * 使用spi匹配符合配置的負載均衡演算法
  3. *
  4. * @param name
  5. * @return
  6. */
  7. private LoadBalance getLoadBalance(String name) {
  8. ServiceLoader<LoadBalance> loader = ServiceLoader.load(LoadBalance.class);
  9. Iterator<LoadBalance> iterator = loader.iterator();
  10. while (iterator.hasNext()) {
  11. LoadBalance loadBalance = iterator.next();
  12. LoadBalanceAno ano = loadBalance.getClass().getAnnotation(LoadBalanceAno.class);
  13. Assert.notNull(ano, "load balance name can not be empty!");
  14. if (name.equals(ano.value())) {
  15. return loadBalance;
  16. }
  17. }
  18. throw new RpcException("invalid load balance config");
  19. }

  20. @Bean
  21. public ClientProxyFactory proxyFactory(@Autowired RpcConfig rpcConfig) {
  22. ClientProxyFactory clientProxyFactory = new ClientProxyFactory();
  23. // 設定服務發現著
  24. clientProxyFactory.setServerDiscovery(new ZookeeperServerDiscovery(rpcConfig.getRegisterAddress()));

  25. // 設定支援的協議
  26. Map<String, MessageProtocol> supportMessageProtocols = buildSupportMessageProtocols();
  27. clientProxyFactory.setSupportMessageProtocols(supportMessageProtocols);
  28. // 設定負載均衡演算法
  29. LoadBalance loadBalance = getLoadBalance(rpcConfig.getLoadBalance());
  30. clientProxyFactory.setLoadBalance(loadBalance);
  31. // 設定網路層實現
  32. clientProxyFactory.setNetClient(new NettyNetClient());

  33. return clientProxyFactory;
  34. }

3.2本地服務列表快取

使用Map來快取資料

  1. /**
  2. * 服務發現本地快取
  3. */
  4. public class ServerDiscoveryCache {
  5. /**
  6. * key: serviceName
  7. */
  8. private static final Map<String, List<Service>> SERVER_MAP = new ConcurrentHashMap<>();
  9. /**
  10. * 客戶端注入的遠端服務service class
  11. */
  12. public static final List<String> SERVICE_CLASS_NAMES = new ArrayList<>();

  13. public static void put(String serviceName, List<Service> serviceList) {
  14. SERVER_MAP.put(serviceName, serviceList);
  15. }

  16. /**
  17. * 去除指定的值
  18. * @param serviceName
  19. * @param service
  20. */
  21. public static void remove(String serviceName, Service service) {
  22. SERVER_MAP.computeIfPresent(serviceName, (key, value) ->
  23. value.stream().filter(o -> !o.toString().equals(service.toString())).collect(Collectors.toList())
  24. );
  25. }

  26. public static void removeAll(String serviceName) {
  27. SERVER_MAP.remove(serviceName);
  28. }


  29. public static boolean isEmpty(String serviceName) {
  30. return SERVER_MAP.get(serviceName) == null || SERVER_MAP.get(serviceName).size() == 0;
  31. }

  32. public static List<Service> get(String serviceName) {
  33. return SERVER_MAP.get(serviceName);
  34. }
  35. }

ClientProxyFactory,先查本地快取,快取沒有再查詢zookeeper。

  1. /**
  2. * 根據服務名獲取可用的服務地址列表
  3. * @param serviceName
  4. * @return
  5. */
  6. private List<Service> getServiceList(String serviceName) {
  7. List<Service> services;
  8. synchronized (serviceName){
  9. if (ServerDiscoveryCache.isEmpty(serviceName)) {
  10. services = serverDiscovery.findServiceList(serviceName);
  11. if (services == null || services.size() == 0) {
  12. throw new RpcException("No provider available!");
  13. }
  14. ServerDiscoveryCache.put(serviceName, services);
  15. } else {
  16. services = ServerDiscoveryCache.get(serviceName);
  17. }
  18. }
  19. return services;
  20. }

問題: 如果服務端因為宕機或網路問題下線了,快取卻還在就會導致客戶端請求已經不可用的服務端,增加請求失敗率。
解決方案:由於服務端註冊的是臨時節點,所以如果服務端下線節點會被移除。只要監聽zookeeper的子節點,如果新增或刪除子節點就直接清空本地快取即可。
DefaultRpcProcessor

  1. /**
  2. * Rpc處理者,支援服務啟動暴露,自動注入Service
  3. * @author 2YSP
  4. * @date 2020/7/26 14:46
  5. */
  6. public class DefaultRpcProcessor implements ApplicationListener<ContextRefreshedEvent> {



  7. @Override
  8. public void onApplicationEvent(ContextRefreshedEvent event) {
  9. // Spring啟動完畢過後會收到一個事件通知
  10. if (Objects.isNull(event.getApplicationContext().getParent())){
  11. ApplicationContext context = event.getApplicationContext();
  12. // 開啟服務
  13. startServer(context);
  14. // 注入Service
  15. injectService(context);
  16. }
  17. }

  18. private void injectService(ApplicationContext context) {
  19. String[] names = context.getBeanDefinitionNames();
  20. for(String name : names){
  21. Class<?> clazz = context.getType(name);
  22. if (Objects.isNull(clazz)){
  23. continue;
  24. }

  25. Field[] declaredFields = clazz.getDeclaredFields();
  26. for(Field field : declaredFields){
  27. // 找出標記了InjectService註解的屬性
  28. InjectService injectService = field.getAnnotation(InjectService.class);
  29. if (injectService == null){
  30. continue;
  31. }

  32. Class<?> fieldClass = field.getType();
  33. Object object = context.getBean(name);
  34. field.setAccessible(true);
  35. try {
  36. field.set(object,clientProxyFactory.getProxy(fieldClass));
  37. } catch (IllegalAccessException e) {
  38. e.printStackTrace();
  39. }
  40. // 新增本地服務快取
  41. ServerDiscoveryCache.SERVICE_CLASS_NAMES.add(fieldClass.getName());
  42. }
  43. }
  44. // 註冊子節點監聽
  45. if (clientProxyFactory.getServerDiscovery() instanceof ZookeeperServerDiscovery){
  46. ZookeeperServerDiscovery serverDiscovery = (ZookeeperServerDiscovery) clientProxyFactory.getServerDiscovery();
  47. ZkClient zkClient = serverDiscovery.getZkClient();
  48. ServerDiscoveryCache.SERVICE_CLASS_NAMES.forEach(name ->{
  49. String servicePath = RpcConstant.ZK_SERVICE_PATH + RpcConstant.PATH_DELIMITER + name + "/service";
  50. zkClient.subscribeChildChanges(servicePath, new ZkChildListenerImpl());
  51. });
  52. logger.info("subscribe service zk node successfully");
  53. }

  54. }

  55. private void startServer(ApplicationContext context) {
  56. ...

  57. }
  58. }

ZkChildListenerImpl

  1. /**
  2. * 子節點事件監聽處理類
  3. */
  4. public class ZkChildListenerImpl implements IZkChildListener {

  5. private static Logger logger = LoggerFactory.getLogger(ZkChildListenerImpl.class);

  6. /**
  7. * 監聽子節點的刪除和新增事件
  8. * @param parentPath /rpc/serviceName/service
  9. * @param childList
  10. * @throws Exception
  11. */
  12. @Override
  13. public void handleChildChange(String parentPath, List<String> childList) throws Exception {
  14. logger.debug("Child change parentPath:[{}] -- childList:[{}]", parentPath, childList);
  15. // 只要子節點有改動就清空快取
  16. String[] arr = parentPath.split("/");
  17. ServerDiscoveryCache.removeAll(arr[2]);
  18. }
  19. }

3.3nettyClient支援TCP長連線

這部分的改動最多,先增加新的sendRequest介面。


新增介面

實現類NettyNetClient

  1. /**
  2. * @author 2YSP
  3. * @date 2020/7/25 20:12
  4. */
  5. public class NettyNetClient implements NetClient {

  6. private static Logger logger = LoggerFactory.getLogger(NettyNetClient.class);

  7. private static ExecutorService threadPool = new ThreadPoolExecutor(4, 10, 200,
  8. TimeUnit.SECONDS, new LinkedBlockingQueue<>(1000), new ThreadFactoryBuilder()
  9. .setNameFormat("rpcClient-%d")
  10. .build());

  11. private EventLoopGroup loopGroup = new NioEventLoopGroup(4);

  12. /**
  13. * 已連線的服務快取
  14. * key: 服務地址,格式:ip:port
  15. */
  16. public static Map<String, SendHandlerV2> connectedServerNodes = new ConcurrentHashMap<>();

  17. @Override
  18. public byte[] sendRequest(byte[] data, Service service) throws InterruptedException {
  19. ....
  20. return respData;
  21. }

  22. @Override
  23. public RpcResponse sendRequest(RpcRequest rpcRequest, Service service, MessageProtocol messageProtocol) {

  24. String address = service.getAddress();
  25. synchronized (address) {
  26. if (connectedServerNodes.containsKey(address)) {
  27. SendHandlerV2 handler = connectedServerNodes.get(address);
  28. logger.info("使用現有的連線");
  29. return handler.sendRequest(rpcRequest);
  30. }

  31. String[] addrInfo = address.split(":");
  32. final String serverAddress = addrInfo[0];
  33. final String serverPort = addrInfo[1];
  34. final SendHandlerV2 handler = new SendHandlerV2(messageProtocol, address);
  35. threadPool.submit(() -> {
  36. // 配置客戶端
  37. Bootstrap b = new Bootstrap();
  38. b.group(loopGroup).channel(NioSocketChannel.class)
  39. .option(ChannelOption.TCP_NODELAY, true)
  40. .handler(new ChannelInitializer<SocketChannel>() {
  41. @Override
  42. protected void initChannel(SocketChannel socketChannel) throws Exception {
  43. ChannelPipeline pipeline = socketChannel.pipeline();
  44. pipeline
  45. .addLast(handler);
  46. }
  47. });
  48. // 啟用客戶端連線
  49. ChannelFuture channelFuture = b.connect(serverAddress, Integer.parseInt(serverPort));
  50. channelFuture.addListener(new ChannelFutureListener() {
  51. @Override
  52. public void operationComplete(ChannelFuture channelFuture) throws Exception {
  53. connectedServerNodes.put(address, handler);
  54. }
  55. });
  56. }
  57. );
  58. logger.info("使用新的連線。。。");
  59. return handler.sendRequest(rpcRequest);
  60. }
  61. }
  62. }

每次請求都會呼叫sendRequest()方法,用執行緒池非同步和服務端建立TCP長連線,連線成功後將SendHandlerV2快取到ConcurrentHashMap中方便複用,後續請求的請求地址(ip+port)如果在connectedServerNodes中存在則使用connectedServerNodes中的handler處理不再重新建立連線。

SendHandlerV2

  1. /**
  2. * @author 2YSP
  3. * @date 2020/8/19 20:06
  4. */
  5. public class SendHandlerV2 extends ChannelInboundHandlerAdapter {

  6. private static Logger logger = LoggerFactory.getLogger(SendHandlerV2.class);

  7. /**
  8. * 等待通道建立最大時間
  9. */
  10. static final int CHANNEL_WAIT_TIME = 4;
  11. /**
  12. * 等待響應最大時間
  13. */
  14. static final int RESPONSE_WAIT_TIME = 8;

  15. private volatile Channel channel;

  16. private String remoteAddress;

  17. private static Map<String, RpcFuture<RpcResponse>> requestMap = new ConcurrentHashMap<>();

  18. private MessageProtocol messageProtocol;

  19. private CountDownLatch latch = new CountDownLatch(1);

  20. public SendHandlerV2(MessageProtocol messageProtocol,String remoteAddress) {
  21. this.messageProtocol = messageProtocol;
  22. this.remoteAddress = remoteAddress;
  23. }

  24. @Override
  25. public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
  26. this.channel = ctx.channel();
  27. latch.countDown();
  28. }

  29. @Override
  30. public void channelActive(ChannelHandlerContext ctx) throws Exception {
  31. logger.debug("Connect to server successfully:{}", ctx);
  32. }

  33. @Override
  34. public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
  35. logger.debug("Client reads message:{}", msg);
  36. ByteBuf byteBuf = (ByteBuf) msg;
  37. byte[] resp = new byte[byteBuf.readableBytes()];
  38. byteBuf.readBytes(resp);
  39. // 手動回收
  40. ReferenceCountUtil.release(byteBuf);
  41. RpcResponse response = messageProtocol.unmarshallingResponse(resp);
  42. RpcFuture<RpcResponse> future = requestMap.get(response.getRequestId());
  43. future.setResponse(response);
  44. }

  45. @Override
  46. public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
  47. cause.printStackTrace();
  48. logger.error("Exception occurred:{}", cause.getMessage());
  49. ctx.close();
  50. }

  51. @Override
  52. public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
  53. ctx.flush();
  54. }

  55. @Override
  56. public void channelInactive(ChannelHandlerContext ctx) throws Exception {
  57. super.channelInactive(ctx);
  58. logger.error("channel inactive with remoteAddress:[{}]",remoteAddress);
  59. NettyNetClient.connectedServerNodes.remove(remoteAddress);

  60. }

  61. @Override
  62. public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
  63. super.userEventTriggered(ctx, evt);
  64. }

  65. public RpcResponse sendRequest(RpcRequest request) {
  66. RpcResponse response;
  67. RpcFuture<RpcResponse> future = new RpcFuture<>();
  68. requestMap.put(request.getRequestId(), future);
  69. try {
  70. byte[] data = messageProtocol.marshallingRequest(request);
  71. ByteBuf reqBuf = Unpooled.buffer(data.length);
  72. reqBuf.writeBytes(data);
  73. if (latch.await(CHANNEL_WAIT_TIME,TimeUnit.SECONDS)){
  74. channel.writeAndFlush(reqBuf);
  75. // 等待響應
  76. response = future.get(RESPONSE_WAIT_TIME, TimeUnit.SECONDS);
  77. }else {
  78. throw new RpcException("establish channel time out");
  79. }
  80. } catch (Exception e) {
  81. throw new RpcException(e.getMessage());
  82. } finally {
  83. requestMap.remove(request.getRequestId());
  84. }
  85. return response;
  86. }
  87. }

RpcFuture

  1. package cn.sp.rpc.client.net;

  2. import java.util.concurrent.*;

  3. /**
  4. * @author 2YSP
  5. * @date 2020/8/19 22:31
  6. */
  7. public class RpcFuture<T> implements Future<T> {

  8. private T response;
  9. /**
  10. * 因為請求和響應是一一對應的,所以這裡是1
  11. */
  12. private CountDownLatch countDownLatch = new CountDownLatch(1);
  13. /**
  14. * Future的請求時間,用於計算Future是否超時
  15. */
  16. private long beginTime = System.currentTimeMillis();

  17. @Override
  18. public boolean cancel(boolean mayInterruptIfRunning) {
  19. return false;
  20. }

  21. @Override
  22. public boolean isCancelled() {
  23. return false;
  24. }

  25. @Override
  26. public boolean isDone() {
  27. if (response != null) {
  28. return true;
  29. }
  30. return false;
  31. }

  32. /**
  33. * 獲取響應,直到有結果才返回
  34. * @return
  35. * @throws InterruptedException
  36. * @throws ExecutionException
  37. */
  38. @Override
  39. public T get() throws InterruptedException, ExecutionException {
  40. countDownLatch.await();
  41. return response;
  42. }

  43. @Override
  44. public T get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
  45. if (countDownLatch.await(timeout,unit)){
  46. return response;
  47. }
  48. return null;
  49. }

  50. public void setResponse(T response) {
  51. this.response = response;
  52. countDownLatch.countDown();
  53. }

  54. public long getBeginTime() {
  55. return beginTime;
  56. }
  57. }

此處邏輯,第一次執行 SendHandlerV2#sendRequest() 時channel需要等待通道建立好之後才能傳送請求,所以用CountDownLatch來控制,等待通道建立。
自定義Future+requestMap快取來實現netty的請求和阻塞等待響應,RpcRequest物件在建立時會生成一個請求的唯一標識requestId,傳送請求前先將RpcFuture快取到requestMap中,key為requestId,讀取到服務端的響應資訊後(channelRead方法),將響應結果放入對應的RpcFuture中。
SendHandlerV2#channelInactive() 方法中,如果連線的服務端異常斷開連線了,則及時清理快取中對應的serverNode。

四、壓力測試

測試環境:
(英特爾)Intel(R) Core(TM) i5-6300HQ CPU @ 2.30GHz
4核
windows10家庭版(64位)
16G記憶體

1.本地啟動zookeeper
2.本地啟動一個消費者,兩個服務端,輪詢演算法
3.使用ab進行壓力測試,4個執行緒傳送10000個請求

ab -c 4 -n 10000 http://localhost:8080/test/user?id=1

測試結果


測試結果

從圖片可以看出,10000個請求只用了11s,比之前的130+秒耗時減少了10倍以上。

程式碼地址:
https://github.com/2YSP/rpc-spring-boot-starter
https://github.com/2YSP/rpc-example

參考:
看了這篇你就會手寫RPC框架了