From 9b66657128e7a74e043848025152763688415976 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E6=96=BD=E5=9B=BD=E9=B9=8F?= <1033404553@qq.com>
Date: Wed, 21 Aug 2024 18:11:08 +0800
Subject: [PATCH] Fix/redis reacquire (#2)

* fix: redis locker reAcquire

* fix: locker

* fix: redis del lua script
---
 core/stores/redis/lockscript.lua    |  6 --
 core/stores/redis/redislock.go      | 36 +++++-------
 core/stores/redis/redislock_test.go | 87 +++++++++++++++++++++--------
 3 files changed, 79 insertions(+), 50 deletions(-)
 delete mode 100644 core/stores/redis/lockscript.lua

diff --git a/core/stores/redis/lockscript.lua b/core/stores/redis/lockscript.lua
deleted file mode 100644
index 11a1fe350afb..000000000000
--- a/core/stores/redis/lockscript.lua
+++ /dev/null
@@ -1,6 +0,0 @@
-if redis.call("GET", KEYS[1]) == ARGV[1] then
-    redis.call("SET", KEYS[1], ARGV[1], "PX", ARGV[2])
-    return "OK"
-else
-    return redis.call("SET", KEYS[1], ARGV[1], "NX", "PX", ARGV[2])
-end
\ No newline at end of file
diff --git a/core/stores/redis/redislock.go b/core/stores/redis/redislock.go
index 4677dd1b5959..93b136d78b7c 100644
--- a/core/stores/redis/redislock.go
+++ b/core/stores/redis/redislock.go
@@ -5,7 +5,6 @@ import (
 	_ "embed"
 	"errors"
 	"math/rand"
-	"strconv"
 	"sync/atomic"
 	"time"
 
@@ -15,16 +14,10 @@ import (
 )
 
 const (
-	randomLen       = 16
-	tolerance       = 500 // milliseconds
-	millisPerSecond = 1000
+	randomLen = 16
 )
 
 var (
-	//go:embed lockscript.lua
-	lockLuaScript string
-	lockScript    = NewScript(lockLuaScript)
-
 	//go:embed delscript.lua
 	delLuaScript string
 	delScript    = NewScript(delLuaScript)
@@ -58,26 +51,27 @@ func (rl *RedisLock) Acquire() (bool, error) {
 
 // AcquireCtx acquires the lock with the given ctx.
 func (rl *RedisLock) AcquireCtx(ctx context.Context) (bool, error) {
-	seconds := atomic.LoadUint32(&rl.seconds)
-	resp, err := rl.store.ScriptRunCtx(ctx, lockScript, []string{rl.key}, []string{
-		rl.id, strconv.Itoa(int(seconds)*millisPerSecond + tolerance),
-	})
+
+	var (
+		seconds = atomic.LoadUint32(&rl.seconds)
+		res     bool
+		err     error
+	)
+
+	if seconds == 0 {
+		res, err = rl.store.SetnxCtx(ctx, rl.key, rl.id)
+	} else {
+		res, err = rl.store.SetnxExCtx(ctx, rl.key, rl.id, int(seconds))
+	}
+
 	if errors.Is(err, red.Nil) {
 		return false, nil
 	} else if err != nil {
 		logx.Errorf("Error on acquiring lock for %s, %s", rl.key, err.Error())
 		return false, err
-	} else if resp == nil {
-		return false, nil
-	}
-
-	reply, ok := resp.(string)
-	if ok && reply == "OK" {
-		return true, nil
 	}
 
-	logx.Errorf("Unknown reply when acquiring lock for %s: %v", rl.key, resp)
-	return false, nil
+	return res, nil
 }
 
 // Release releases the lock.
diff --git a/core/stores/redis/redislock_test.go b/core/stores/redis/redislock_test.go
index 4f1d535d74eb..e8a2a4d6e45c 100644
--- a/core/stores/redis/redislock_test.go
+++ b/core/stores/redis/redislock_test.go
@@ -3,12 +3,45 @@ package redis
 import (
 	"context"
 	"testing"
+	"time"
 
+	"github.com/alicebob/miniredis/v2"
 	"github.com/stretchr/testify/assert"
 
 	"github.com/zeromicro/go-zero/core/stringx"
 )
 
+func TestRedisLock_SameAcquire(t *testing.T) {
+
+	var (
+		s       = miniredis.RunT(t)
+		seconds = 5
+	)
+	client := MustNewRedis(
+		RedisConf{
+			Host: s.Addr(),
+			Type: NodeType,
+		},
+	)
+
+	key := stringx.Rand()
+	firstLock := NewRedisLock(client, key)
+	firstLock.SetExpire(seconds)
+	firstAcquire, err := firstLock.Acquire()
+	assert.Nil(t, err)
+	assert.True(t, firstAcquire)
+
+	secondAcquire, err := firstLock.Acquire()
+	assert.Nil(t, err)
+	assert.False(t, secondAcquire)
+
+	s.FastForward(time.Second * time.Duration(seconds+1))
+
+	thirdAcquire, err := firstLock.Acquire()
+	assert.Nil(t, err)
+	assert.True(t, thirdAcquire)
+}
+
 func TestRedisLock(t *testing.T) {
 	testFn := func(ctx context.Context) func(client *Redis) {
 		return func(client *Redis) {
@@ -35,31 +68,39 @@ func TestRedisLock(t *testing.T) {
 		}
 	}
 
-	t.Run("normal", func(t *testing.T) {
-		runOnRedis(t, testFn(nil))
-	})
+	t.Run(
+		"normal", func(t *testing.T) {
+			runOnRedis(t, testFn(nil))
+		},
+	)
 
-	t.Run("withContext", func(t *testing.T) {
-		runOnRedis(t, testFn(context.Background()))
-	})
+	t.Run(
+		"withContext", func(t *testing.T) {
+			runOnRedis(t, testFn(context.Background()))
+		},
+	)
 }
 
 func TestRedisLock_Expired(t *testing.T) {
-	runOnRedis(t, func(client *Redis) {
-		key := stringx.Rand()
-		redisLock := NewRedisLock(client, key)
-		ctx, cancel := context.WithCancel(context.Background())
-		cancel()
-		_, err := redisLock.AcquireCtx(ctx)
-		assert.NotNil(t, err)
-	})
-
-	runOnRedis(t, func(client *Redis) {
-		key := stringx.Rand()
-		redisLock := NewRedisLock(client, key)
-		ctx, cancel := context.WithCancel(context.Background())
-		cancel()
-		_, err := redisLock.ReleaseCtx(ctx)
-		assert.NotNil(t, err)
-	})
+	runOnRedis(
+		t, func(client *Redis) {
+			key := stringx.Rand()
+			redisLock := NewRedisLock(client, key)
+			ctx, cancel := context.WithCancel(context.Background())
+			cancel()
+			_, err := redisLock.AcquireCtx(ctx)
+			assert.NotNil(t, err)
+		},
+	)
+
+	runOnRedis(
+		t, func(client *Redis) {
+			key := stringx.Rand()
+			redisLock := NewRedisLock(client, key)
+			ctx, cancel := context.WithCancel(context.Background())
+			cancel()
+			_, err := redisLock.ReleaseCtx(ctx)
+			assert.NotNil(t, err)
+		},
+	)
 }