From 9bf356a75394fc3dc425f0fb311e83a1bf4d0f27 Mon Sep 17 00:00:00 2001 From: 0202zc Date: Fri, 24 Feb 2023 09:45:29 +0800 Subject: [PATCH] fix bugs --- cmd/api/handler/comment.go | 7 +++-- cmd/api/handler/message.go | 10 +++++++ cmd/api/handler/user.go | 22 +++++++++++++- cmd/api/handler/video.go | 2 +- cmd/favorite/main.go | 5 ++-- cmd/favorite/service/handler.go | 16 +++++++++- cmd/favorite/service/init.go | 3 +- cmd/favorite/service/timer.go | 8 +++-- cmd/message/main.go | 2 +- cmd/message/service/handler.go | 23 ++++++++++++++- cmd/relation/main.go | 9 +++--- cmd/relation/service/handler.go | 12 +++++++- cmd/relation/service/init.go | 3 +- cmd/relation/service/timer.go | 9 ++++-- cmd/user/main.go | 3 +- cmd/video/main.go | 2 +- dal/db/user.go | 4 +-- dal/redis/favorite.go | 48 +++++++++++------------------- dal/redis/message.go | 6 ++-- dal/redis/relation.go | 52 +++++++++------------------------ pkg/middleware/limit_init.go | 4 +++ 21 files changed, 149 insertions(+), 101 deletions(-) diff --git a/cmd/api/handler/comment.go b/cmd/api/handler/comment.go index 7a839dd..dd5762d 100644 --- a/cmd/api/handler/comment.go +++ b/cmd/api/handler/comment.go @@ -2,12 +2,13 @@ package handler import ( "context" - "github.com/bytedance-youthcamp-jbzx/tiktok/cmd/api/rpc" - "github.com/bytedance-youthcamp-jbzx/tiktok/internal/response" - kitex "github.com/bytedance-youthcamp-jbzx/tiktok/kitex/kitex_gen/comment" "github.com/cloudwego/hertz/pkg/app" "net/http" "strconv" + + "github.com/bytedance-youthcamp-jbzx/tiktok/cmd/api/rpc" + "github.com/bytedance-youthcamp-jbzx/tiktok/internal/response" + kitex "github.com/bytedance-youthcamp-jbzx/tiktok/kitex/kitex_gen/comment" ) func CommentAction(ctx context.Context, c *app.RequestContext) { diff --git a/cmd/api/handler/message.go b/cmd/api/handler/message.go index 75ae98e..e9a37be 100644 --- a/cmd/api/handler/message.go +++ b/cmd/api/handler/message.go @@ -83,6 +83,16 @@ func MessageAction(ctx context.Context, c *app.RequestContext) { return } + if len(c.Query("content")) == 0 { + c.JSON(http.StatusOK, response.MessageAction{ + Base: response.Base{ + StatusCode: -1, + StatusMsg: "参数 content 不能为空", + }, + }) + return + } + // 调用kitex/kitex_gen req := &kitex.MessageActionRequest{ Token: token, diff --git a/cmd/api/handler/user.go b/cmd/api/handler/user.go index 116af76..ce97fe1 100644 --- a/cmd/api/handler/user.go +++ b/cmd/api/handler/user.go @@ -104,7 +104,27 @@ func Login(ctx context.Context, c *app.RequestContext) { func UserInfo(ctx context.Context, c *app.RequestContext) { userId := c.Query("user_id") token := c.Query("token") - id, _ := strconv.ParseInt(userId, 10, 64) + if len(token) == 0 { + c.JSON(http.StatusOK, response.UserInfo{ + Base: response.Base{ + StatusCode: -1, + StatusMsg: "token 已过期", + }, + User: nil, + }) + return + } + id, err := strconv.ParseInt(userId, 10, 64) + if err != nil { + c.JSON(http.StatusOK, response.UserInfo{ + Base: response.Base{ + StatusCode: -1, + StatusMsg: "user_id 不合法", + }, + User: nil, + }) + return + } //调用kitex/kitex_genit req := &kitex.UserInfoRequest{ diff --git a/cmd/api/handler/video.go b/cmd/api/handler/video.go index 334acdf..150f11e 100644 --- a/cmd/api/handler/video.go +++ b/cmd/api/handler/video.go @@ -31,7 +31,7 @@ func Feed(ctx context.Context, c *app.RequestContext) { } res, _ := rpc.Feed(ctx, req) if res.StatusCode == -1 { - c.JSON(http.StatusOK, response.FavoriteList{ + c.JSON(http.StatusOK, response.Feed{ Base: response.Base{ StatusCode: -1, StatusMsg: res.StatusMsg, diff --git a/cmd/favorite/main.go b/cmd/favorite/main.go index 23f1004..51304f2 100644 --- a/cmd/favorite/main.go +++ b/cmd/favorite/main.go @@ -30,8 +30,7 @@ func init() { } func main() { - // logger = z.InitLogger() - // defer logger.Sync() + defer service.FavoriteMq.Destroy() // 服务注册 r, err := etcd.NewEtcdRegistry([]string{etcdAddr}) @@ -52,7 +51,7 @@ func main() { server.WithRegistry(r), server.WithLimit(&limit.Option{MaxConnections: 1000, MaxQPS: 100}), server.WithMuxTransport(), - // server.WithSuite(tracing.NewServerSuite()), + //server.WithSuite(tracing.NewServerSuite()), server.WithServerBasicInfo(&rpcinfo.EndpointBasicInfo{ServiceName: serviceName}), ) diff --git a/cmd/favorite/service/handler.go b/cmd/favorite/service/handler.go index b2fcc05..5682f6f 100644 --- a/cmd/favorite/service/handler.go +++ b/cmd/favorite/service/handler.go @@ -3,12 +3,17 @@ package service import ( "context" "encoding/json" + "fmt" "github.com/bytedance-youthcamp-jbzx/tiktok/dal/db" "github.com/bytedance-youthcamp-jbzx/tiktok/dal/redis" favorite "github.com/bytedance-youthcamp-jbzx/tiktok/kitex/kitex_gen/favorite" user "github.com/bytedance-youthcamp-jbzx/tiktok/kitex/kitex_gen/user" video "github.com/bytedance-youthcamp-jbzx/tiktok/kitex/kitex_gen/video" "github.com/bytedance-youthcamp-jbzx/tiktok/pkg/minio" + "github.com/bytedance-youthcamp-jbzx/tiktok/pkg/rabbitmq" + "github.com/bytedance-youthcamp-jbzx/tiktok/pkg/zap" + amqp "github.com/rabbitmq/amqp091-go" + "strings" ) // FavoriteServiceImpl implements the last service interface defined in the IDL. @@ -16,6 +21,7 @@ type FavoriteServiceImpl struct{} // FavoriteAction implements the FavoriteServiceImpl interface. func (s *FavoriteServiceImpl) FavoriteAction(ctx context.Context, req *favorite.FavoriteActionRequest) (resp *favorite.FavoriteActionResponse, err error) { + logger := zap.InitLogger() // 解析token,获取用户id claims, err := Jwt.ParseToken(req.Token) if err != nil { @@ -36,8 +42,16 @@ func (s *FavoriteServiceImpl) FavoriteAction(ctx context.Context, req *favorite. //CreatedAt: time.Now(), } jsonFC, _ := json.Marshal(fc) + fmt.Println("Publish new message: ", fc) if err = FavoriteMq.PublishSimple(ctx, jsonFC); err != nil { - logger.Errorf("消息队列发布错误:%v", err.Error()) + logger.Errorf("消息队列发布错误:%v", err.Error()) + if strings.Contains(err.Error(), amqp.ErrClosed.Reason) { + // 检测到通道关闭,则重连 + FavoriteMq.Destroy() + FavoriteMq = rabbitmq.NewRabbitMQSimple("favorite") + logger.Errorln("消息队列通道尝试重连:favorite") + go consume() + } res := &favorite.FavoriteActionResponse{ StatusCode: -1, StatusMsg: "操作失败:服务器内部错误", diff --git a/cmd/favorite/service/init.go b/cmd/favorite/service/init.go index b6f63a9..4e97068 100644 --- a/cmd/favorite/service/init.go +++ b/cmd/favorite/service/init.go @@ -15,5 +15,6 @@ var ( func Init(signingKey string) { Jwt = jwt.NewJWT([]byte(signingKey)) - GoCron() + //GoCron() + go consume() } diff --git a/cmd/favorite/service/timer.go b/cmd/favorite/service/timer.go index 36a163a..47fffd7 100644 --- a/cmd/favorite/service/timer.go +++ b/cmd/favorite/service/timer.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "github.com/bytedance-youthcamp-jbzx/tiktok/dal/redis" "github.com/bytedance-youthcamp-jbzx/tiktok/pkg/gocron" ) @@ -16,19 +15,22 @@ func consume() { msgs, err := FavoriteMq.ConsumeSimple() if err != nil { fmt.Println(err.Error()) + logger.Errorf("FavoriteMQ Err: %s", err.Error()) } // 将消息队列的消息全部取出 for msg := range msgs { - fmt.Printf("==> Get new message: %v", msg.MessageId) fc := new(redis.FavoriteCache) // 解析json if err = json.Unmarshal(msg.Body, &fc); err != nil { + logger.Errorf("json unmarshal error: %s", err.Error()) fmt.Println("json unmarshal error:" + err.Error()) continue } + fmt.Printf("==> Get new message: %v\n", fc) // 将结构体存入redis if err = redis.UpdateFavorite(context.Background(), fc); err != nil { - fmt.Println("add to redis error:" + err.Error()) + logger.Errorf("json unmarshal error: %s", err.Error()) + fmt.Println("json unmarshal error:" + err.Error()) continue } } diff --git a/cmd/message/main.go b/cmd/message/main.go index 1601330..53c89d4 100644 --- a/cmd/message/main.go +++ b/cmd/message/main.go @@ -50,7 +50,7 @@ func main() { server.WithRegistry(r), server.WithLimit(&limit.Option{MaxConnections: 1000, MaxQPS: 100}), server.WithMuxTransport(), - // server.WithSuite(tracing.NewServerSuite()), + //server.WithSuite(tracing.NewServerSuite()), server.WithServerBasicInfo(&rpcinfo.EndpointBasicInfo{ServiceName: serviceName}), ) diff --git a/cmd/message/service/handler.go b/cmd/message/service/handler.go index ab3c0cf..a09316a 100644 --- a/cmd/message/service/handler.go +++ b/cmd/message/service/handler.go @@ -2,6 +2,7 @@ package service import ( "context" + "github.com/bytedance-youthcamp-jbzx/tiktok/dal/db" "github.com/bytedance-youthcamp-jbzx/tiktok/dal/redis" "github.com/bytedance-youthcamp-jbzx/tiktok/internal/tool" @@ -39,8 +40,9 @@ func (s *MessageServiceImpl) MessageChat(ctx context.Context, req *message.Messa } var results []*db.Message - if lastTimestamp == 0 { + if lastTimestamp == -1 { results, err = db.GetMessagesByUserIDs(ctx, userID, req.ToUserId, int64(lastTimestamp)) + lastTimestamp = 0 } else { results, err = db.GetMessagesByUserToUser(ctx, req.ToUserId, userID, int64(lastTimestamp)) } @@ -123,6 +125,25 @@ func (s *MessageServiceImpl) MessageAction(ctx context.Context, req *message.Mes toUserID, actionType := req.ToUserId, req.ActionType + if userID == toUserID { + logger.Errorln("不能给自己发送消息") + res := &message.MessageActionResponse{ + StatusCode: -1, + StatusMsg: "消息发送失败:不能给自己发送消息", + } + return res, nil + } + + relation, err := db.GetRelationByUserIDs(ctx, userID, toUserID) + if relation == nil { + logger.Errorf("消息发送失败:非朋友关系,无法发送") + res := &message.MessageActionResponse{ + StatusCode: -1, + StatusMsg: "消息发送失败:非朋友关系,无法发送", + } + return res, nil + } + rsaContent, err := tool.RsaEncrypt([]byte(req.Content), publicKey) if err != nil { logger.Errorf("rsa encrypt error: %v\n", err.Error()) diff --git a/cmd/relation/main.go b/cmd/relation/main.go index e48f72d..5de28a7 100644 --- a/cmd/relation/main.go +++ b/cmd/relation/main.go @@ -4,15 +4,14 @@ import ( "fmt" "net" - "github.com/cloudwego/kitex/pkg/limit" - "github.com/cloudwego/kitex/pkg/rpcinfo" - "github.com/bytedance-youthcamp-jbzx/tiktok/cmd/relation/service" "github.com/bytedance-youthcamp-jbzx/tiktok/kitex/kitex_gen/relation/relationservice" "github.com/bytedance-youthcamp-jbzx/tiktok/pkg/etcd" "github.com/bytedance-youthcamp-jbzx/tiktok/pkg/middleware" "github.com/bytedance-youthcamp-jbzx/tiktok/pkg/viper" "github.com/bytedance-youthcamp-jbzx/tiktok/pkg/zap" + "github.com/cloudwego/kitex/pkg/limit" + "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/server" ) @@ -30,7 +29,7 @@ func init() { } func main() { - // defer logger.Sync() + defer service.RelationMq.Destroy() // 服务注册 r, err := etcd.NewEtcdRegistry([]string{etcdAddr}) @@ -51,7 +50,7 @@ func main() { server.WithRegistry(r), server.WithLimit(&limit.Option{MaxConnections: 1000, MaxQPS: 100}), server.WithMuxTransport(), - // server.WithSuite(tracing.NewServerSuite()), + //server.WithSuite(tracing.NewServerSuite()), server.WithServerBasicInfo(&rpcinfo.EndpointBasicInfo{ServiceName: serviceName}), ) diff --git a/cmd/relation/service/handler.go b/cmd/relation/service/handler.go index a117dc7..f1b8ea9 100644 --- a/cmd/relation/service/handler.go +++ b/cmd/relation/service/handler.go @@ -9,7 +9,10 @@ import ( relation "github.com/bytedance-youthcamp-jbzx/tiktok/kitex/kitex_gen/relation" user "github.com/bytedance-youthcamp-jbzx/tiktok/kitex/kitex_gen/user" "github.com/bytedance-youthcamp-jbzx/tiktok/pkg/minio" + "github.com/bytedance-youthcamp-jbzx/tiktok/pkg/rabbitmq" "github.com/bytedance-youthcamp-jbzx/tiktok/pkg/zap" + amqp "github.com/rabbitmq/amqp091-go" + "strings" ) // RelationServiceImpl implements the last service interface defined in the IDL. @@ -68,7 +71,14 @@ func (s *RelationServiceImpl) RelationAction(ctx context.Context, req *relation. } jsonRc, _ := json.Marshal(relationCache) if err = RelationMq.PublishSimple(ctx, jsonRc); err != nil { - logger.Errorln(err.Error()) + logger.Errorf("消息队列发布错误:%v", err.Error()) + if strings.Contains(err.Error(), amqp.ErrClosed.Reason) { + // 检测到通道关闭,则重连 + RelationMq.Destroy() + RelationMq = rabbitmq.NewRabbitMQSimple("relation") + logger.Errorln("消息队列通道尝试重连:relation") + go consume() + } res := &relation.RelationActionResponse{ StatusCode: -1, StatusMsg: "服务器内部错误:操作失败", diff --git a/cmd/relation/service/init.go b/cmd/relation/service/init.go index d852747..5800da2 100644 --- a/cmd/relation/service/init.go +++ b/cmd/relation/service/init.go @@ -19,5 +19,6 @@ var ( func Init(signingKey string) { Jwt = jwt.NewJWT([]byte(signingKey)) privateKey, _ = tool.ReadKeyFromFile(tool.PrivateKeyFilePath) - GoCron() + //GoCron() + go consume() } diff --git a/cmd/relation/service/timer.go b/cmd/relation/service/timer.go index 3121039..5f4c39d 100644 --- a/cmd/relation/service/timer.go +++ b/cmd/relation/service/timer.go @@ -12,26 +12,31 @@ import ( const frequency = 10 // 点赞服务消息队列消费者 -func consume() { +func consume() error { msgs, err := RelationMq.ConsumeSimple() if err != nil { fmt.Println(err.Error()) + logger.Errorf("RelationMQ Err: %s", err.Error()) + return err } // 将消息队列的消息全部取出 for msg := range msgs { - fmt.Printf("==> Get new message: %v", msg.MessageId) rc := new(redis.RelationCache) // 解析json if err = json.Unmarshal(msg.Body, &rc); err != nil { fmt.Println("json unmarshal error:" + err.Error()) + logger.Errorf("RelationMQ Err: %s", err.Error()) continue } + fmt.Printf("==> Get new message: %v\n", rc) // 将结构体存入redis if err = redis.UpdateRelation(context.Background(), rc); err != nil { fmt.Println("add to redis error:" + err.Error()) + logger.Errorf("RelationMQ Err: %s", err.Error()) continue } } + return nil } // gocron定时任务,每隔一段时间就让Consumer消费消息队列的所有消息 diff --git a/cmd/user/main.go b/cmd/user/main.go index 65c0509..acb77bb 100644 --- a/cmd/user/main.go +++ b/cmd/user/main.go @@ -5,7 +5,6 @@ import ( "net" "github.com/bytedance-youthcamp-jbzx/tiktok/cmd/user/service" - "github.com/kitex-contrib/obs-opentelemetry/tracing" "github.com/bytedance-youthcamp-jbzx/tiktok/kitex/kitex_gen/user/userservice" "github.com/bytedance-youthcamp-jbzx/tiktok/pkg/etcd" @@ -52,7 +51,7 @@ func main() { server.WithRegistry(r), server.WithLimit(&limit.Option{MaxConnections: 1000, MaxQPS: 100}), server.WithMuxTransport(), - server.WithSuite(tracing.NewServerSuite()), + //server.WithSuite(tracing.NewServerSuite()), server.WithServerBasicInfo(&rpcinfo.EndpointBasicInfo{ServiceName: serviceName}), ) diff --git a/cmd/video/main.go b/cmd/video/main.go index eaa9b6f..b123ba8 100644 --- a/cmd/video/main.go +++ b/cmd/video/main.go @@ -50,7 +50,7 @@ func main() { server.WithRegistry(r), server.WithLimit(&limit.Option{MaxConnections: 1000, MaxQPS: 100}), server.WithMuxTransport(), - // server.WithSuite(tracing.NewServerSuite()), + //server.WithSuite(tracing.NewServerSuite()), server.WithServerBasicInfo(&rpcinfo.EndpointBasicInfo{ServiceName: serviceName}), ) diff --git a/dal/db/user.go b/dal/db/user.go index 38a57bd..6a46a60 100644 --- a/dal/db/user.go +++ b/dal/db/user.go @@ -96,9 +96,9 @@ func CreateUsers(ctx context.Context, users []*User) error { // CreateUser // // @Description: 新增一条用户数据 -// @Date 2023-02-22 11:47:43 +// @Date 2023-02-22 11:46:43 // @param ctx 数据库操作上下文 -// @param users 用户数据 +// @param user 用户数据 // @return error func CreateUser(ctx context.Context, user *User) error { err := GetDB().Clauses(dbresolver.Write).WithContext(ctx).Transaction(func(tx *gorm.DB) error { diff --git a/dal/redis/favorite.go b/dal/redis/favorite.go index afe348b..e792f0c 100644 --- a/dal/redis/favorite.go +++ b/dal/redis/favorite.go @@ -20,49 +20,35 @@ type FavoriteCache struct { * 2. set存储某个类型点赞的记录,Key为video::vid::user::uid,hashKey为点赞视频+点赞人,Value为action_type */ func UpdateFavorite(ctx context.Context, favorite *FavoriteCache) error { - //keyVideo := fmt.Sprintf("video::%d", favorite.VideoID) + errLock := LockByMutex(ctx, favoriteMutex) // Read 用于与前端同步,且创建定时器检查是否过期;Write 用于与前端同步,不设置过期,但是需要定时与MySQL同步后进行删除 keyUserIDRead := fmt.Sprintf("video::%d::user::%d::r", favorite.VideoID, favorite.UserID) keyUserIDWrite := fmt.Sprintf("video::%d::user::%d::w", favorite.VideoID, favorite.UserID) - //if favorite.ActionType == 1 { - // _, err := GetRedisHelper().SAdd(ctx, keyVideo, favorite.UserID).Result() - // if err != nil { - // zapLogger.Errorln(err.Error()) - // return err - // } - // //userMap := make(map[string]interface{}) - // //userMap["created_at"] = favorite.CreatedAt.Format("2006-01-02 15:04:05") - // //_, err = GetRedisHelper().Set(ctx, keyUserID, favorite.CreatedAt.Format("2006-01-02 15:04:05"), 0).Result() - //} else if favorite.ActionType == 2 { - // err := GetRedisHelper().SRem(ctx, keyVideo, 1, keyUserIDRead).Err() - // if err != nil { - // zapLogger.Errorln(err.Error()) - // return err - // } - //} else { - // zapLogger.Errorln("\"action_type\" is not equal to 1 or 2") - // return errors.New("\"action_type\" is not equal to 1 or 2") - //} + if errLock != nil { + zapLogger.Errorf("lock failed: %s", errLock.Error()) + return errLock + } _, err := GetRedisHelper().Set(ctx, keyUserIDRead, favorite.ActionType, ExpireTime).Result() if err != nil { + errUnlock := UnlockByMutex(ctx, favoriteMutex) + if errUnlock != nil { + zapLogger.Errorf("unlock failed: %s", errUnlock.Error()) + return errUnlock + } zapLogger.Errorln(err.Error()) return err } - err = LockByMutex(ctx, favoriteMutex) - if err != nil { - zapLogger.Errorf("lock failed: %s", err.Error()) - return err + fmt.Println(keyUserIDWrite, " => ", favorite.ActionType) + _, err = GetRedisHelper().Set(ctx, keyUserIDWrite, favorite.ActionType, 0).Result() + errUnlock := UnlockByMutex(ctx, favoriteMutex) + if errUnlock != nil { + zapLogger.Errorf("unlock failed: %s", errUnlock.Error()) + return errUnlock } - _, err1 := GetRedisHelper().Set(ctx, keyUserIDWrite, favorite.ActionType, 0).Result() - err = UnlockByMutex(ctx, favoriteMutex) if err != nil { - zapLogger.Errorf("unlock failed: %s", err.Error()) + zapLogger.Errorln(err.Error()) return err } - if err1 != nil { - zapLogger.Errorln(err1.Error()) - return err1 - } return nil } diff --git a/dal/redis/message.go b/dal/redis/message.go index 9993272..206d964 100644 --- a/dal/redis/message.go +++ b/dal/redis/message.go @@ -16,14 +16,14 @@ redis里的键值,即便没有新的消息传来。 func GetMessageTimestamp(ctx context.Context, token string, toUserID int64) (int, error) { key := fmt.Sprintf("%s_%d", token, toUserID) if ec, err := GetRedisHelper().Exists(ctx, key).Result(); err != nil { - return 0, err + return -1, err } else if ec == 0 { - return 0, nil //errors.New("key not found") + return -1, nil //errors.New("key not found") } val, err := GetRedisHelper().Get(ctx, key).Result() if err != nil { - return 0, err + return -1, err } return strconv.Atoi(val) diff --git a/dal/redis/relation.go b/dal/redis/relation.go index c3320f7..08fa01a 100644 --- a/dal/redis/relation.go +++ b/dal/redis/relation.go @@ -16,56 +16,32 @@ type RelationCache struct { // UpdateRelation 更新关系 func UpdateRelation(ctx context.Context, relation *RelationCache) error { // 在userID的关注列表中加入toUserID,同时在toUserID的粉丝列表中加入userID - //keyFollower, keyFollowing := fmt.Sprintf("follower::%d", relation.ToUserID), fmt.Sprintf("following::%d", relation.UserID) + errLock := LockByMutex(ctx, relationMutex) + if errLock != nil { + zapLogger.Errorf("lock failed: %s", errLock.Error()) + return errLock + } + keyRelationRead := fmt.Sprintf("user::%d::to_user::%d::r", relation.UserID, relation.ToUserID) keyRelationWrite := fmt.Sprintf("user::%d::to_user::%d::w", relation.UserID, relation.ToUserID) - //if relation.ActionType == 1 { - // // 添加user的关注者id - // if err := GetRedisHelper().SAdd(ctx, keyFollower, relation.UserID).Err(); err != nil { - // zapLogger.Errorln(err.Error()) - // return err - // } - // // 添加to_user的粉丝id - // if err := GetRedisHelper().SAdd(ctx, keyFollowing, relation.ToUserID).Err(); err != nil { - // zapLogger.Errorln(err.Error()) - // return err - // } - //} else if relation.ActionType == 2 { - // // 删除user的关注者id - // if err := GetRedisHelper().SRem(ctx, keyFollowing, 1, keyFollower).Err(); err != nil { - // zapLogger.Errorln(err.Error()) - // return err - // } - // // 删除to_user的粉丝id - // if err := GetRedisHelper().SRem(ctx, keyFollower, 1, keyFollowing).Err(); err != nil { - // zapLogger.Errorln(err.Error()) - // return err - // } - //} else { - // zapLogger.Errorln("\"action_type\" is not equal to 1 or 2") - // return errors.New("\"action_type\" is not equal to 1 or 2") - //} err := GetRedisHelper().Set(ctx, keyRelationRead, relation.ActionType, ExpireTime).Err() if err != nil { zapLogger.Errorln(err.Error()) return err } - err = LockByMutex(ctx, relationMutex) - if err != nil { - zapLogger.Errorf("lock failed: %s", err.Error()) - return err + err = GetRedisHelper().Set(ctx, keyRelationWrite, relation.ActionType, 0).Err() + + errUnlock := UnlockByMutex(ctx, relationMutex) + if errUnlock != nil { + zapLogger.Errorf("lock failed: %s", errUnlock.Error()) + return errUnlock } - err1 := GetRedisHelper().Set(ctx, keyRelationWrite, relation.ActionType, 0).Err() - err = UnlockByMutex(ctx, relationMutex) + if err != nil { - zapLogger.Errorf("lock failed: %s", err.Error()) + zapLogger.Errorln(err.Error()) return err } - if err1 != nil { - zapLogger.Errorln(err1.Error()) - return err1 - } return nil } diff --git a/pkg/middleware/limit_init.go b/pkg/middleware/limit_init.go index 852c03d..1622d48 100644 --- a/pkg/middleware/limit_init.go +++ b/pkg/middleware/limit_init.go @@ -45,6 +45,8 @@ type TokenBuckets struct { buckets map[string]*TokenBucket capacity int64 rate int64 + + lock sync.Mutex } func MakeTokenBuckets(capacity, rate int64) *TokenBuckets { @@ -56,6 +58,8 @@ func MakeTokenBuckets(capacity, rate int64) *TokenBuckets { } func (tbs *TokenBuckets) Allow(token string) bool { + tbs.lock.Lock() + defer tbs.lock.Unlock() if bucket, ok := tbs.buckets[token]; ok { return bucket.Allow() } else {