1. 程式人生 > 其它 >ratelimit+redis+lua對介面限流

ratelimit+redis+lua對介面限流

背景:為防止介面QPS太大而造成系統執行卡頓的現象,在這兒以ratelimit+redis+lua對系統介面做了個限流。當時也考慮過使用其他的限流方法,比如微服務生態中使用的sentinel中介軟體,但是這個如果要實現持久化要進行特殊的配置,比如使用nacos進行持久化,需要修改sntinel原始碼,相比較而言單純為了限流兒整合兩個中介軟體會顯得比較臃腫,所以到最後還是使用了retelimit+redis+lua這個方案,本身redis系統中就會使用,儲存token、部門資訊等一些讀取次數多的資料。

一、主要邏輯實現:

  1. 首先確定的是要採用切面的方式,後期如果相對某一個介面進行限流可以直接採用註解的方式。
  2. 其二在redis儲存的key的名稱要以方法名+ip的方式,這樣可以更好的實現思路1指出的問題。
  3. 使用lua指令碼直接傳到redis中操作,這樣可以減少網路開銷以及複用,並且可以保證是原子操作。
  4. 第四點就是lua指令碼的編寫啦,redis 以有序佇列進行儲存,每一個key值都帶有當前得分為當前時間戳的元素,每次新增的的時候都會將過時的元素進行清理,並進行判斷是否達到限流條件。

二、程式碼實現:

程式碼結構:

限流注解介面類

package com.heyu.ratelimit.annotation;

import org.aspectj.lang.annotation.Aspect;
import org.springframework.core.annotation.AliasFor;
import org.springframework.core.annotation.AnnotationUtils;

import java.lang.annotation.*;
import java.util.concurrent.TimeUnit;

/**
 * <p>
 * 限流注解,添加了 {@link AliasFor} 必須通過 {@link AnnotationUtils} 獲取,才會生效
 * </p>
 *
 * @author: 程鵬
 * @date: 2021-02-24 14:45
 * @Description: 限流切面
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {
    long DEFAULT_REQUEST = 5;


    /**
     * max 最大請求數
     */
    @AliasFor("max") long value() default DEFAULT_REQUEST;

    /**
     * max 最大請求數
     */
    @AliasFor("value") long max() default DEFAULT_REQUEST;

    /**
     * 限流key
     */
    String key() default "";

    /**
     * 超時時長,預設1分鐘
     */
    long timeout() default 1;

    /**
     * 超時時間單位,預設 分鐘
     */
    TimeUnit timeUnit() default TimeUnit.MINUTES;
}

切面操作類

package com.heyu.ratelimit.aspect;

import cn.hutool.core.util.StrUtil;
import com.heyu.ratelimit.annotation.RateLimiter;
import com.heyu.ratelimit.util.IpUtil;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;

import java.lang.reflect.Method;
import java.time.Instant;
import java.util.Collections;
import java.util.concurrent.TimeUnit;

/**
 * @author: 程鵬
 * @date: 2021-02-24 14:13
 * @Description: 限流切面
 */
@Slf4j
@Aspect
@Component
@RequiredArgsConstructor(onConstructor_ = @Autowired)
public class RateLimiterAspect {
    private final static String SEPARATOR = ":";
    private final static String REDIS_LIMIT_KEY_PREFIX = "limit:";
    private final StringRedisTemplate stringRedisTemplate;
    private final RedisScript<Long> limitRedisScript;

    @Pointcut("@annotation(com.heyu.ratelimit.annotation.RateLimiter)")
    public void rateLimit() {

    }

    @Around("rateLimit()")
    public Object pointcut(ProceedingJoinPoint point) throws Throwable {
        MethodSignature signature = (MethodSignature) point.getSignature();

        Method method = signature.getMethod();
        // 通過 AnnotationUtils.findAnnotation 獲取 RateLimiter 註解
        RateLimiter rateLimiter = AnnotationUtils.findAnnotation(method, RateLimiter.class);
        if (rateLimiter != null) {
            String key = rateLimiter.key();
            // 預設用類名+方法名做限流的 key 字首
            if (StrUtil.isBlank(key)) {
                key = method.getDeclaringClass().getName() + StrUtil.DOT + method.getName();
            }
            // 最終限流的 key 為 字首 + IP地址
            // TODO: 此時需要考慮區域網多使用者訪問的情況,因此 key 後續需要加上方法引數更加合理
            key = key + SEPARATOR + IpUtil.getIpAddr();

            long max = rateLimiter.max();
            long timeout = rateLimiter.timeout();
            TimeUnit timeUnit = rateLimiter.timeUnit();
            boolean limited = shouldLimited(key, max, timeout, timeUnit);
            if (limited) {
                throw new RuntimeException("手速太快了,慢點兒吧~");
            }
        }

        return point.proceed();
    }

    private boolean shouldLimited(String key, long max, long timeout, TimeUnit timeUnit) {
        // 最終的 key 格式為:
        // limit:自定義key:IP
        // limit:類名.方法名:IP
        key = REDIS_LIMIT_KEY_PREFIX + key;
        // 統一使用單位毫秒
        long ttl = timeUnit.toMillis(timeout);
        // 當前時間毫秒數
        long now = Instant.now().toEpochMilli();
        long expired = now - ttl;
        // 注意這裡必須轉為 String,否則會報錯 java.lang.Long cannot be cast to java.lang.String
        Long executeTimes = stringRedisTemplate.execute(limitRedisScript, Collections.singletonList(key), now + "", ttl + "", expired + "", max + "");
        if (executeTimes != null) {
            if (executeTimes == 0) {
                log.error("【{}】在單位時間 {} 毫秒內已達到訪問上限,當前介面上限 {}", key, ttl, max);
                return true;
            } else {
                log.info("【{}】在單位時間 {} 毫秒內訪問 {} 次", key, ttl, executeTimes);
                return false;
            }
        }
        return false;
    }
}

redis配置類

package com.heyu.ratelimit.config;

import org.springframework.context.annotation.Bean;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.scripting.support.ResourceScriptSource;
import org.springframework.stereotype.Component;

/**
 * @author: 程鵬
 * @date: 2021-02-26 14:35
 * @Description:
 */
@Component
public class RedisConfig {
    @Bean
    @SuppressWarnings("unchecked")
    public RedisScript<Long> limitRedisScript() {
        DefaultRedisScript redisScript = new DefaultRedisScript<>();
        redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("redis/limit.lua")));
        redisScript.setResultType(Long.class);
        return redisScript;
    }
}

Ip解析類

package com.heyu.ratelimit.util;

import cn.hutool.core.util.StrUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;

/**
 * @author: 程鵬
 * @date: 2021-02-26 14:28
 * @Description:
 */
@Slf4j
public class IpUtil {
    private final static String UNKNOWN = "unknown";
    private final static int MAX_LENGTH = 15;

    /**
     * 獲取IP地址
     * 使用Nginx等反向代理軟體, 則不能通過request.getRemoteAddr()獲取IP地址
     * 如果使用了多級反向代理的話,X-Forwarded-For的值並不止一個,而是一串IP地址,X-Forwarded-For中第一個非unknown的有效IP字串,則為真實IP地址
     */
    public static String getIpAddr() {
        HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
        String ip = null;
        try {
            ip = request.getHeader("x-forwarded-for");
            if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeader("Proxy-Client-IP");
            }
            if (StrUtil.isEmpty(ip) || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeader("WL-Proxy-Client-IP");
            }
            if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeader("HTTP_CLIENT_IP");
            }
            if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeader("HTTP_X_FORWARDED_FOR");
            }
            if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getRemoteAddr();
            }
        } catch (Exception e) {
            log.error("IPUtils ERROR ", e);
        }
        // 使用代理,則獲取第一個IP地址
        if (!StrUtil.isEmpty(ip) && ip.length() > MAX_LENGTH) {
            if (ip.indexOf(StrUtil.COMMA) > 0) {
                ip = ip.substring(0, ip.indexOf(StrUtil.COMMA));
            }
        }
        log.error("訪客ip:"+ip);
        return ip;
    }
}

lua指令碼

-- 下標從 1 開始
local key = KEYS[1]
local now = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local expired = tonumber(ARGV[3])
-- 最大訪問量
local max = tonumber(ARGV[4])

-- 清除過期的資料
-- 移除指定分數區間內的所有元素,expired 即已經過期的 score
-- 根據當前時間毫秒數 - 超時毫秒數,得到過期時間 expired
redis.call('zremrangebyscore', key, 0, expired)

-- 獲取 zset 中的當前元素個數
local current = tonumber(redis.call('zcard', key))
local next = current + 1

if next > max then
  -- 達到限流大小 返回 0
  return 0;
else
  -- 往 zset 中新增一個值、得分均為當前時間戳的元素,[value,score]
  redis.call("zadd", key, now, now)
  -- 每次訪問均重新設定 zset 的過期時間,單位毫秒
  redis.call("pexpire", key, ttl)
  return next
end

controller層測試