SpringBoot使用Redis+SpringEL表达式实现分布式并发锁注解

SpringBoot使用redis搭配lua脚本实现分布式并发锁一文,我使用redis+lua脚本实现了一个分布式并发锁RedisLockService,本文将对该功能进行扩展,通过注解+aop的形式使其更加方便使用。

简单说一下我的想法:

1、定义一个注解RedisLock,有以下3个属性值

(1)key:锁的key,即redis中的key(这里使用SpringEL表达式进行扩展,使其可以可以实现从入参动态获取key值)

(2)timeout:为了避免死锁,设计了一个超时时间,单位是秒,默认5分钟也就是300秒

(3)msg:加锁失败时的异常信息提示语,默认值为“请求已发起,请勿重复操作!”

2、以注解为切入点,使用aop对需要使用分布式并发锁的方法进行切面拦截

3、使用线程变量来保存锁的key值等信息

4、加锁和解锁流程:

(1)进入注解方法前:加锁,并记录锁的key值等信息

(2)方法执行结束后或方法抛出异常:解锁

5、因为能力有限,所以目前注解仅支持单层,即存在多个注解嵌套的情况下,仅最外层注解会生效

以上就是大概的实现思路,剩下的废话不多说,直接上代码:

注解代码(SpringEL表达式的使用参考下方代码注释)

/**
 * redis分布式锁注解
 * 1、存在多个注解嵌套的情况,仅最外层注解有效,内层注解将被忽略
 * 2、锁的key值就是在redis中的key值,key支持SpringEL表达式,可以实现从入参动态获取key值
 * 例:使用学生的名字和老师的名字作为rediskey
 *      注解:@RedisCoucurrentLock(params="#student.name+#teacher.name")
 *      方法:test(student,teacher)
 *      加前缀:@RedisCoucurrentLock(params="'前缀'+#student.name+#teacher.name")
 *      <b>注意<b/>:使用时需要保证指定规则生成的key值在业务上唯一,避免key值冲突
 *                  建议使用业务单据号或者主键作为key值,如需使用主键id作为key,一定要加业务前缀避免冲突。
 * @author liqingcan
 */
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface RedisLock {

    /**
     * 锁的key
     * @return
     */
    String key();

    /**
     * 超时时间(秒,默认300秒)
     * @return
     */
    long timeout() default 300L;

    /**
     * 加锁失败提示语
     * @return
     */
    String msg() default "请求已发起,请勿重复操作!";

}

aop切面代码(锁的key通过getRedisKey解析SpringEL表达式得出,锁的value值使用uuid)aop切点的部分知识可以参考:SpringBoot使用aop切面记录@Scheduled定时任务开始时间和结束时间

/**
 * redis分布式锁aop
 *
 * @author liqingcan
 */
@Aspect
@Component
@Slf4j
public class RedisLockAspect {

    /**
     * 线程变量——锁的key
     */
    private ThreadLocal<String> threadLocalKey = new ThreadLocal<>();

    /**
     * 线程变量——锁的value
     */
    private ThreadLocal<String> threadLocalValue = new ThreadLocal<>();

    /**
     * 线程变量——注解层数,用于保证只有最外层的注解才能生效
     */
    private ThreadLocal<Integer> threadLocalCount = new ThreadLocal<>();

    @Autowired
    private RedisLockService redisLockService;

    @Pointcut("@annotation(com.example.demo.redis.RedisLock)")
    public void pointCut() {

    }

    @Before("pointCut()&&@annotation(redisLock)")
    public void beforeHandle(JoinPoint joinPoint, RedisLock redisLock) {
        //避免threadLocalCount发生空指针问题,如果是null值设置为0
        if (threadLocalCount.get() == null) {
            threadLocalCount.set(0);
        }
        if (isFirst()) {
            //注解层数+1
            threadLocalCount.set(threadLocalCount.get()+1);
            String key = this.getRedisKey(joinPoint, redisLock);
            String value = Thread.currentThread().getName() +"-"+ UUID.randomUUID();
            long expire = redisLock.timeout();
            boolean lock = redisLockService.lock(key, value, expire);
            if (lock) {
                //加锁成功
                threadLocalKey.set(key);
                threadLocalValue.set(value);
            }else{
                //加锁失败
                throw new RedisLockException(redisLock.msg());
            }
        }else{
            //注解层数+1
            threadLocalCount.set(threadLocalCount.get()+1);
        }
    }

    @AfterReturning("pointCut()")
    public void afterHandle(JoinPoint joinPoint) {
        unlock();
    }

    @AfterThrowing(pointcut = "pointCut()", throwing = "e")
    public void afterThrowable(JoinPoint joinPoint, Throwable e) {
        unlock();
    }

    /**
     * 解锁操作
     */
    private void unlock() {
        //注解层数-1
        threadLocalCount.set(threadLocalCount.get()-1);
        if (isFirst()) {
            //进行解锁操作
            String key = threadLocalKey.get();
            String value = threadLocalValue.get();
            redisLockService.unlock(key, value);
            //进行线程remove操作
            removeThreadLocal();
        }
    }

    /**
     * 线程变量清理(线程池重复使用线程,需要remove一下,防止脏数据带到下一个线程)
     */
    private void removeThreadLocal() {
        threadLocalCount.remove();
        threadLocalKey.remove();
        threadLocalValue.remove();
    }

    /**
     * 通过threadLocalCount值是否为0判断是否是最外层的注解
     * @return
     */
    private boolean isFirst() {
        return threadLocalCount.get() == 0;
    }

    /**
     * 获取redis缓存的key
     * @param joinPoint
     * @param redisLock
     * @return
     */
    private String getRedisKey(JoinPoint joinPoint, RedisLock redisLock) {
        //获取注解上的key
        String key = redisLock.key();

        //使用SpringEL表达式解析注解上的key
        SpelExpressionParser parser = new SpelExpressionParser();
        Expression expression = parser.parseExpression(key);
        //获取方法入参
        Object[] parameterValues = joinPoint.getArgs();
        //获取方法形参
        MethodSignature signature = (MethodSignature)joinPoint.getSignature();
        Method method = signature.getMethod();
        DefaultParameterNameDiscoverer nameDiscoverer = new DefaultParameterNameDiscoverer();
        String[] parameterNames = nameDiscoverer.getParameterNames(method);
        if (parameterNames == null || parameterNames.length == 0) {
            //方法没有入参,直接返回注解上的key
            return key;
        }
        //解析表达式
        EvaluationContext evaluationContext = new StandardEvaluationContext();
        // 给上下文赋值
        for(int i = 0 ; i < parameterNames.length ; i++) {
            evaluationContext.setVariable(parameterNames[i], parameterValues[i]);
        }
        try {
            Object expressionValue = expression.getValue(evaluationContext);
            if (expressionValue != null && !"".equals(expressionValue.toString())) {
                //返回el解析后的key
                return expressionValue.toString();
            }else{
                //使用注解上的key
                return key;
            }
        } catch (Exception e) {
            //解析失败,默认使用注解上的key
            return key;
        }
    }

}

分布式锁工具代码(实现原理参考SpringBoot使用redis搭配lua脚本实现分布式并发锁一文,此处直接引用复制)

/**
 * redis分布式锁工具
 * 使用lua脚本实现加锁和解锁操作,保证原子性
 * @author liqingcan
 */
@Component
public class RedisLockService {

    @Autowired
    private RedisTemplate<String, String> redisTemplate;

    /**
     * 加锁的lua脚本
     */
    private final static RedisScript<Long> LOCK_LUA_SCRIPT = new DefaultRedisScript<>(
            "if redis.call(\"setnx\", KEYS[1], KEYS[2]) == 1 then return redis.call(\"expire\", KEYS[1], KEYS[3]) else return 0 end"
            , Long.class
    );

    /**
     * 加锁失败结果
     */
    private final static Long LOCK_FAIL = 0L;

    /**
     * 解锁的lua脚本
     */
    private final static RedisScript<Long> UNLOCK_LUA_SCRIPT = new DefaultRedisScript<>(
            "if redis.call(\"get\",KEYS[1]) == KEYS[2] then return redis.call(\"del\",KEYS[1]) else return -1 end"
            , Long.class
    );

    /**
     * 解锁失败结果
     */
    private final static Long UNLOCK_FAIL = -1L;

    /**
     * 加锁方法
     * 对key加锁,value为key对应的值,expire是锁自动过期时间防止死锁
     * @param key key
     * @param value value
     * @param expire 锁自动过期时间(秒)
     * @return
     */
    public boolean lock(String key, String value, Long expire){
        if (key == null || value == null || expire == null) {
            return false;
        }
        List<String> keys = Arrays.asList(key, value, expire.toString());
        Long res = redisTemplate.execute(LOCK_LUA_SCRIPT, keys);
        return !LOCK_FAIL.equals(res);
    }

    /**
     * 解锁方法
     * 对key解锁,只有value值等于redis中key对应的值才能解锁,避免误解锁
     * @param key
     * @param value
     * @return
     */
    public boolean unlock(String key, String value){
        if (key == null || value == null) {
            return false;
        }
        List<String> keys = Arrays.asList(key, value);
        Long res = redisTemplate.execute(UNLOCK_LUA_SCRIPT, keys);
        return !UNLOCK_FAIL.equals(res);
    }

}

异常类代码

/**
 * redis分布式锁异常
 * @author liqingcan
 */
public class RedisLockException extends RuntimeException {

    public RedisLockException() {
    }

    public RedisLockException(String message) {
        super(message);
    }

    public RedisLockException(String message, Throwable cause) {
        super(message, cause);
    }

    public RedisLockException(Throwable cause) {
        super(cause);
    }

    public RedisLockException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) {
        super(message, cause, enableSuppression, writableStackTrace);
    }

}

到这里注解已经可以使用了,具体的使用以及一个并发扣库存的样例可以参考:https://gitee.com/lqccan/blog-demo demo14中的测试代码及注释


觉得内容还不错?打赏个钢镚鼓励鼓励!!👍