ratelimit+redis+lua對介面限流
阿新 • • 發佈:2021-11-27
背景:為防止介面QPS太大而造成系統執行卡頓的現象,在這兒以ratelimit+redis+lua對系統介面做了個限流。當時也考慮過使用其他的限流方法,比如微服務生態中使用的sentinel中介軟體,但是這個如果要實現持久化要進行特殊的配置,比如使用nacos進行持久化,需要修改sntinel原始碼,相比較而言單純為了限流兒整合兩個中介軟體會顯得比較臃腫,所以到最後還是使用了retelimit+redis+lua這個方案,本身redis系統中就會使用,儲存token、部門資訊等一些讀取次數多的資料。
一、主要邏輯實現:
- 首先確定的是要採用切面的方式,後期如果相對某一個介面進行限流可以直接採用註解的方式。
- 其二在redis儲存的key的名稱要以方法名+ip的方式,這樣可以更好的實現思路1指出的問題。
- 使用lua指令碼直接傳到redis中操作,這樣可以減少網路開銷以及複用,並且可以保證是原子操作。
- 第四點就是lua指令碼的編寫啦,redis 以有序佇列進行儲存,每一個key值都帶有當前得分為當前時間戳的元素,每次新增的的時候都會將過時的元素進行清理,並進行判斷是否達到限流條件。
二、程式碼實現:
程式碼結構:
限流注解介面類
package com.heyu.ratelimit.annotation; import org.aspectj.lang.annotation.Aspect; import org.springframework.core.annotation.AliasFor; import org.springframework.core.annotation.AnnotationUtils; import java.lang.annotation.*; import java.util.concurrent.TimeUnit; /** * <p> * 限流注解,添加了 {@link AliasFor} 必須通過 {@link AnnotationUtils} 獲取,才會生效 * </p> * * @author: 程鵬 * @date: 2021-02-24 14:45 * @Description: 限流切面 */ @Target(ElementType.METHOD) @Retention(RetentionPolicy.RUNTIME) @Documented public @interface RateLimiter { long DEFAULT_REQUEST = 5; /** * max 最大請求數 */ @AliasFor("max") long value() default DEFAULT_REQUEST; /** * max 最大請求數 */ @AliasFor("value") long max() default DEFAULT_REQUEST; /** * 限流key */ String key() default ""; /** * 超時時長,預設1分鐘 */ long timeout() default 1; /** * 超時時間單位,預設 分鐘 */ TimeUnit timeUnit() default TimeUnit.MINUTES; }
切面操作類
package com.heyu.ratelimit.aspect; import cn.hutool.core.util.StrUtil; import com.heyu.ratelimit.annotation.RateLimiter; import com.heyu.ratelimit.util.IpUtil; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.aspectj.lang.ProceedingJoinPoint; import org.aspectj.lang.annotation.Around; import org.aspectj.lang.annotation.Aspect; import org.aspectj.lang.annotation.Pointcut; import org.aspectj.lang.reflect.MethodSignature; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.data.redis.core.StringRedisTemplate; import org.springframework.data.redis.core.script.RedisScript; import org.springframework.stereotype.Component; import java.lang.reflect.Method; import java.time.Instant; import java.util.Collections; import java.util.concurrent.TimeUnit; /** * @author: 程鵬 * @date: 2021-02-24 14:13 * @Description: 限流切面 */ @Slf4j @Aspect @Component @RequiredArgsConstructor(onConstructor_ = @Autowired) public class RateLimiterAspect { private final static String SEPARATOR = ":"; private final static String REDIS_LIMIT_KEY_PREFIX = "limit:"; private final StringRedisTemplate stringRedisTemplate; private final RedisScript<Long> limitRedisScript; @Pointcut("@annotation(com.heyu.ratelimit.annotation.RateLimiter)") public void rateLimit() { } @Around("rateLimit()") public Object pointcut(ProceedingJoinPoint point) throws Throwable { MethodSignature signature = (MethodSignature) point.getSignature(); Method method = signature.getMethod(); // 通過 AnnotationUtils.findAnnotation 獲取 RateLimiter 註解 RateLimiter rateLimiter = AnnotationUtils.findAnnotation(method, RateLimiter.class); if (rateLimiter != null) { String key = rateLimiter.key(); // 預設用類名+方法名做限流的 key 字首 if (StrUtil.isBlank(key)) { key = method.getDeclaringClass().getName() + StrUtil.DOT + method.getName(); } // 最終限流的 key 為 字首 + IP地址 // TODO: 此時需要考慮區域網多使用者訪問的情況,因此 key 後續需要加上方法引數更加合理 key = key + SEPARATOR + IpUtil.getIpAddr(); long max = rateLimiter.max(); long timeout = rateLimiter.timeout(); TimeUnit timeUnit = rateLimiter.timeUnit(); boolean limited = shouldLimited(key, max, timeout, timeUnit); if (limited) { throw new RuntimeException("手速太快了,慢點兒吧~"); } } return point.proceed(); } private boolean shouldLimited(String key, long max, long timeout, TimeUnit timeUnit) { // 最終的 key 格式為: // limit:自定義key:IP // limit:類名.方法名:IP key = REDIS_LIMIT_KEY_PREFIX + key; // 統一使用單位毫秒 long ttl = timeUnit.toMillis(timeout); // 當前時間毫秒數 long now = Instant.now().toEpochMilli(); long expired = now - ttl; // 注意這裡必須轉為 String,否則會報錯 java.lang.Long cannot be cast to java.lang.String Long executeTimes = stringRedisTemplate.execute(limitRedisScript, Collections.singletonList(key), now + "", ttl + "", expired + "", max + ""); if (executeTimes != null) { if (executeTimes == 0) { log.error("【{}】在單位時間 {} 毫秒內已達到訪問上限,當前介面上限 {}", key, ttl, max); return true; } else { log.info("【{}】在單位時間 {} 毫秒內訪問 {} 次", key, ttl, executeTimes); return false; } } return false; } }
redis配置類
package com.heyu.ratelimit.config;
import org.springframework.context.annotation.Bean;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.scripting.support.ResourceScriptSource;
import org.springframework.stereotype.Component;
/**
* @author: 程鵬
* @date: 2021-02-26 14:35
* @Description:
*/
@Component
public class RedisConfig {
@Bean
@SuppressWarnings("unchecked")
public RedisScript<Long> limitRedisScript() {
DefaultRedisScript redisScript = new DefaultRedisScript<>();
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("redis/limit.lua")));
redisScript.setResultType(Long.class);
return redisScript;
}
}
Ip解析類
package com.heyu.ratelimit.util;
import cn.hutool.core.util.StrUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
/**
* @author: 程鵬
* @date: 2021-02-26 14:28
* @Description:
*/
@Slf4j
public class IpUtil {
private final static String UNKNOWN = "unknown";
private final static int MAX_LENGTH = 15;
/**
* 獲取IP地址
* 使用Nginx等反向代理軟體, 則不能通過request.getRemoteAddr()獲取IP地址
* 如果使用了多級反向代理的話,X-Forwarded-For的值並不止一個,而是一串IP地址,X-Forwarded-For中第一個非unknown的有效IP字串,則為真實IP地址
*/
public static String getIpAddr() {
HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
String ip = null;
try {
ip = request.getHeader("x-forwarded-for");
if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (StrUtil.isEmpty(ip) || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_CLIENT_IP");
}
if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_X_FORWARDED_FOR");
}
if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
} catch (Exception e) {
log.error("IPUtils ERROR ", e);
}
// 使用代理,則獲取第一個IP地址
if (!StrUtil.isEmpty(ip) && ip.length() > MAX_LENGTH) {
if (ip.indexOf(StrUtil.COMMA) > 0) {
ip = ip.substring(0, ip.indexOf(StrUtil.COMMA));
}
}
log.error("訪客ip:"+ip);
return ip;
}
}
lua指令碼
-- 下標從 1 開始
local key = KEYS[1]
local now = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local expired = tonumber(ARGV[3])
-- 最大訪問量
local max = tonumber(ARGV[4])
-- 清除過期的資料
-- 移除指定分數區間內的所有元素,expired 即已經過期的 score
-- 根據當前時間毫秒數 - 超時毫秒數,得到過期時間 expired
redis.call('zremrangebyscore', key, 0, expired)
-- 獲取 zset 中的當前元素個數
local current = tonumber(redis.call('zcard', key))
local next = current + 1
if next > max then
-- 達到限流大小 返回 0
return 0;
else
-- 往 zset 中新增一個值、得分均為當前時間戳的元素,[value,score]
redis.call("zadd", key, now, now)
-- 每次訪問均重新設定 zset 的過期時間,單位毫秒
redis.call("pexpire", key, ttl)
return next
end
controller層測試