Skip to content

自定义注解防重复提交(redis)

核心类

PreventDuplicate 注解

java
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD})
public @interface PreventDuplicate {

    /**
     * 生成 key 所需字段列表 (取字段的值)
     */
    String[] includeFieldKeys() default {};

    /**
     * 附加到键上的值列表 (附加原字符串)
     */
    String[] optionalValues() default {};

    /**
     * 过期时间,默认 10 秒
     */
    long expire() default 10_000L;

}

java
@Aspect
@Component
public class PreventDuplicateAspect {

    private static final Logger log = LoggerFactory.getLogger(PreventDuplicateAspect.class);

    private final StringRedisTemplate stringRedisTemplate;
    private final ObjectMapper objectMapper;

    public PreventDuplicateAspect(StringRedisTemplate stringRedisTemplate, ObjectMapper objectMapper) {
        this.stringRedisTemplate = stringRedisTemplate;
        this.objectMapper = objectMapper;
    }

    @Pointcut("@annotation(pdv)")
    private void preventPc(PreventDuplicate pdv) {
    }

    @Around("preventPc(pdv)")
    public Object aroundAdvice(ProceedingJoinPoint pjp, PreventDuplicate pdv) throws Throwable {
        String[] includeFieldKeys = pdv.includeFieldKeys();
        String[] optionalValues = pdv.optionalValues();
        long expireTime = pdv.expire();

        if (includeFieldKeys == null || includeFieldKeys.length == 0) {
            return pjp.proceed();
        }

        // 获取请求 body
        Object requestBody = HashUtils.getBody(pjp);
        if (requestBody == null) {
            return pjp.proceed();
        }

        // 转为 map 对象
        Map<String, Object> requestBodyMap = convertJsonToMap(requestBody);

        // 构建 key
        String keyRedis = buildKey(includeFieldKeys, optionalValues, requestBodyMap);

        // 计算 hash 值 (MD5)
        String keyRedisMD5 = HashUtils.hashMD5(keyRedis);

        // 检查是否重复请求
        checkRequestByKey(keyRedisMD5, expireTime);
        return pjp.proceed();
    }

    // 构建 key
    private String buildKey(String[] includeFieldKeys, String[] optionalValues, Map<String, Object> requestBodyMap) {
        String keyWithIncludeKey = Arrays.stream(includeFieldKeys)
                .map(requestBodyMap::get)
                .filter(Objects::nonNull)
                .map(Object::toString)
                .collect(Collectors.joining(":"));

        if (optionalValues.length > 0) {
            return keyWithIncludeKey + ":" + String.join(":", optionalValues);
        }

        return keyWithIncludeKey;
    }

    // 检查缓存 key
    public void checkRequestByKey(String key, long expireTime) throws DuplicateException {
        String script = """
                   if redis.call("EXISTS", KEYS[1]) == 0 then
                      redis.call("SET", KEYS[1], KEYS[1])
                      redis.call("EXPIRE", KEYS[1], ARGV[1])
                      return 1
                   else
                      return 0
                   end
                """;
        DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>(script, Long.class);
        Long ret = this.stringRedisTemplate.execute(redisScript, List.of(key), String.valueOf(TimeUnit.MILLISECONDS.toSeconds(expireTime)));
        if (ret == 0) {
            throw new DuplicateException("重复提交");
        }
    }

    // 将 Json 对象转换为 Map 对象
    public Map<String, Object> convertJsonToMap(Object jsonObject) {
        if (jsonObject == null) {
            return Collections.emptyMap();
        }
        try {
            return objectMapper.convertValue(jsonObject, new TypeReference<>() {
            });
        } catch (Exception ignored) {
            return Collections.emptyMap();
        }
    }

}

工具类

Hex

java
public class Hex {

    private static final char[] HEX = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};

    // 十六进制转为字节数组
    public static byte[] decode(CharSequence s) {
        int nChars = s.length();
        // 两位十六进制表示一个字节
        if (nChars % 2 != 0) {
            throw new IllegalArgumentException("16进制数据错误");
        }
        byte[] result = new byte[nChars / 2];
        for (int i = 0; i < nChars; i += 2) {
            int msb = Character.digit(s.charAt(i), 16);
            int lsb = Character.digit(s.charAt(i + 1), 16);
            if (msb == -1 || lsb == -1) {
                throw new IllegalArgumentException("检测到非十六进制字符,位置在 " + (i + 1) + " 或 " + (i + 2) + " 处");
            }
            // 两个十六进制合并为一个字节
            result[i / 2] = (byte) ((msb << 4) | lsb);
        }
        return result;
    }

    // 字节数组转为十六进制
    public static String encode(byte[] buf) {
        StringBuilder sb = new StringBuilder();
        for (int i = 0, leng = buf.length; i < leng; i++) {
            sb.append(HEX[(buf[i] & 0xF0) >>> 4]).append(HEX[buf[i] & 0x0F]);
        }
        return sb.toString();
    }

}

java
/**
 * 哈希工具类,提供获取请求体和生成MD5哈希的功能
 */
public class HashUtils {

    /**
     * 从切面连接点中获取带有@RequestBody注解的参数
     *
     * @param pjp 切面连接点
     * @return 带有@RequestBody注解的参数对象,如果没有则返回null
     */
    public static Object getBody(ProceedingJoinPoint pjp) {
        // 遍历所有参数
        for (int i = 0; i < pjp.getArgs().length; i++) {
            Object arg = pjp.getArgs()[i];
            // 检查参数是否非空且带有@RequestBody注解
            if (arg != null && isAnnotatedWithRequestBody(pjp, i)) {
                return arg;
            }
        }
        return null;
    }

    /**
     * 检查指定位置的参数是否带有@RequestBody注解
     *
     * @param pjp        切面连接点
     * @param paramIndex 参数索引
     * @return 如果参数带有@RequestBody注解返回true,否则返回false
     */
    private static boolean isAnnotatedWithRequestBody(ProceedingJoinPoint pjp, int paramIndex) {
        Method method = getMethod(pjp);
        // 获取方法的所有参数注解
        Annotation[][] parameterAnnotations = method.getParameterAnnotations();
        // 遍历指定参数的所有注解
        for (Annotation annotation : parameterAnnotations[paramIndex]) {
            // 检查是否是@RequestBody注解
            if (RequestBody.class.isAssignableFrom(annotation.annotationType())) {
                return true;
            }
        }
        return false;
    }

    /**
     * 从切面连接点中获取方法对象
     *
     * @param pjp 切面连接点
     * @return 方法对象
     */
    private static Method getMethod(ProceedingJoinPoint pjp) {
        // 获取方法签名
        MethodSignature methodSignature = (MethodSignature) pjp.getSignature();
        // 返回方法对象
        return methodSignature.getMethod();
    }

    /**
     * 生成字符串的MD5哈希值
     *
     * @param source 要哈希的源字符串
     * @return MD5哈希字符串,如果发生异常则返回null
     */
    public static String hashMD5(String source) {
        String res = null;
        try {
            // 获取MD5消息摘要实例
            MessageDigest messageDigest = MessageDigest.getInstance("MD5");
            // 计算MD5哈希值
            byte[] bytes = messageDigest.digest(source.getBytes());
            // 将字节数组转换为十六进制字符串
            res = Hex.encode(bytes);
        } catch (Exception e) {

        }
        return res;
    }

}

响应类

REnum

java
public enum REnum {

    SUCCESS(200, "成功"),
    FAIL(500, "失败");

    private final Integer code;
    private final String desc;

    REnum(Integer code, String desc) {
        this.code = code;
        this.desc = desc;
    }

    public REnum getByCode(Integer code) {
        return Arrays.stream(REnum.values())
                .filter(r -> r.getCode().equals(code))
                .findFirst().orElse(null);
    }

    public Integer getCode() {
        return code;
    }

    public String getDesc() {
        return desc;
    }

}

R

java
public class R<T> {

    private Integer code;
    private String msg;
    private T data;

    public R(Integer code, String msg, T data) {
        this.code = code;
        this.msg = msg;
        this.data = data;
    }

    public static <T> R success() {
        return R.of(REnum.SUCCESS.getCode(), REnum.SUCCESS.getDesc(), null);
    }

    public static <T> R success(T data) {
        return R.of(REnum.SUCCESS.getCode(), REnum.SUCCESS.getDesc(), data);
    }

    public static <T> R success(String msg, T data) {
        return R.of(REnum.SUCCESS.getCode(), msg, data);
    }

    public static R fail() {
        return R.of(REnum.FAIL.getCode(), REnum.FAIL.getDesc(), null);
    }

    public static R fail(String msg) {
        return R.of(REnum.FAIL.getCode(), msg, null);
    }

    public static <T> R of(Integer code, String msg, T data) {
        return new R(code, msg, data);
    }

    public Integer getCode() {
        return code;
    }

    public void setCode(Integer code) {
        this.code = code;
    }

    public String getMsg() {
        return msg;
    }

    public void setMsg(String msg) {
        this.msg = msg;
    }

    public T getData() {
        return data;
    }

    public void setData(T data) {
        this.data = data;
    }
}

异常类

DuplicateException

java
public class DuplicateException extends Exception {
    public DuplicateException(String message) {
        super(message);
    }
}

GlobalExceptionAdvice

plain
@RestControllerAdvice
public class GlobalExceptionAdvice {

    @ExceptionHandler(DuplicateException.class)
    public R handleDuplicateException(DuplicateException ex) {
        return R.fail(ex.getMessage());
    }

}

更新: 2025-07-17 20:52:44
原文: https://www.yuque.com/lsxxyg/sz/gyzh0zb3z4m2o80w