diff --git a/service/aiproxy/common/balance/balance.go b/service/aiproxy/common/balance/balance.go index a1cce1709536..faf7172638e7 100644 --- a/service/aiproxy/common/balance/balance.go +++ b/service/aiproxy/common/balance/balance.go @@ -12,7 +12,6 @@ type GroupBalance interface { type PostGroupConsumer interface { PostGroupConsume(ctx context.Context, tokenName string, usage float64) (float64, error) - GetBalance(ctx context.Context) (float64, error) } var Default GroupBalance = NewMockGroupBalance() diff --git a/service/aiproxy/common/balance/mock.go b/service/aiproxy/common/balance/mock.go index f38c9d257a49..fe94e2247ca8 100644 --- a/service/aiproxy/common/balance/mock.go +++ b/service/aiproxy/common/balance/mock.go @@ -25,7 +25,3 @@ func (q *MockGroupBalance) GetGroupRemainBalance(_ context.Context, _ model.Grou func (q *MockGroupBalance) PostGroupConsume(_ context.Context, _ string, usage float64) (float64, error) { return usage, nil } - -func (q *MockGroupBalance) GetBalance(_ context.Context) (float64, error) { - return mockBalance, nil -} diff --git a/service/aiproxy/common/balance/sealos.go b/service/aiproxy/common/balance/sealos.go index c49894cf3003..9f7f2f316c7c 100644 --- a/service/aiproxy/common/balance/sealos.go +++ b/service/aiproxy/common/balance/sealos.go @@ -26,6 +26,7 @@ const ( appType = "LLM-TOKEN" sealosRequester = "sealos-admin" sealosGroupBalanceKey = "sealos:balance:%s" + sealosUserRealNameKey = "sealos:realName:%s" getBalanceRetry = 3 ) @@ -39,6 +40,11 @@ var ( sealosCacheExpire = 3 * time.Minute ) +var ( + sealosCheckRealNameEnable = env.Bool("BALANCE_SEALOS_CHECK_REAL_NAME_ENABLE", false) + sealosNoRealNameUsedAmountLimit = env.Float64("BALANCE_SEALOS_NO_REAL_NAME_USED_AMOUNT_LIMIT", 1) +) + type Sealos struct { accountURL string } @@ -146,12 +152,20 @@ func cacheDecreaseGroupBalance(ctx context.Context, group string, amount int64) return decreaseGroupBalanceScript.Run(ctx, common.RDB, []string{fmt.Sprintf(sealosGroupBalanceKey, group)}, amount).Err() } +var ErrRealNameUsedAmountLimit = errors.New("real name used amount limit reached") + func (s *Sealos) GetGroupRemainBalance(ctx context.Context, group model.GroupCache) (float64, PostGroupConsumer, error) { var errs []error for i := 0; ; i++ { - balance, consumer, err := s.getGroupRemainBalance(ctx, group.ID) + balance, userUID, err := s.getGroupRemainBalance(ctx, group.ID) if err == nil { - return balance, consumer, nil + if sealosCheckRealNameEnable && + !s.checkRealName(ctx, userUID) && + group.UsedAmount > sealosNoRealNameUsedAmountLimit { + return 0, nil, ErrRealNameUsedAmountLimit + } + return decimal.NewFromInt(balance).Div(decimalBalancePrecision).InexactFloat64(), + newSealosPostGroupConsumer(s.accountURL, group.ID, userUID), nil } errs = append(errs, err) if i == getBalanceRetry-1 { @@ -161,26 +175,105 @@ func (s *Sealos) GetGroupRemainBalance(ctx context.Context, group model.GroupCac } } +func cacheGetUserRealName(ctx context.Context, userUID string) (bool, error) { + if !common.RedisEnabled || !sealosRedisCacheEnable { + return true, redis.Nil + } + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + realName, err := common.RDB.Get(ctx, fmt.Sprintf(sealosUserRealNameKey, userUID)).Bool() + if err != nil { + return false, err + } + return realName, nil +} + +func cacheSetUserRealName(ctx context.Context, userUID string, realName bool) error { + if !common.RedisEnabled || !sealosRedisCacheEnable { + return nil + } + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + var expireTime time.Duration + if realName { + expireTime = time.Hour * 12 + } else { + expireTime = time.Minute * 1 + } + return common.RDB.Set(ctx, fmt.Sprintf(sealosUserRealNameKey, userUID), realName, expireTime).Err() +} + +func (s *Sealos) checkRealName(ctx context.Context, userUID string) bool { + if cache, err := cacheGetUserRealName(ctx, userUID); err == nil { + return cache + } else if err != nil && !errors.Is(err, redis.Nil) { + log.Errorf("get user (%s) real name cache failed: %s", userUID, err) + } + + realName, err := s.fetchRealNameFromAPI(ctx, userUID) + if err != nil { + log.Errorf("fetch user (%s) real name failed: %s", userUID, err) + return true + } + + if err := cacheSetUserRealName(ctx, userUID, realName); err != nil { + log.Errorf("set user (%s) real name cache failed: %s", userUID, err) + } + + return realName +} + +type sealosGetRealNameInfoResp struct { + IsRealName bool `json:"isRealName"` + Error string `json:"error"` +} + +func (s *Sealos) fetchRealNameFromAPI(ctx context.Context, userUID string) (bool, error) { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, + fmt.Sprintf("%s/admin/v1alpha1/real-name-info?userUID=%s", s.accountURL, userUID), nil) + if err != nil { + return false, err + } + + req.Header.Set("Authorization", "Bearer "+jwtToken) + resp, err := sealosHTTPClient.Do(req) + if err != nil { + return false, err + } + defer resp.Body.Close() + + var sealosResp sealosGetRealNameInfoResp + if err := json.NewDecoder(resp.Body).Decode(&sealosResp); err != nil { + return false, err + } + + if resp.StatusCode != http.StatusOK || sealosResp.Error != "" { + return false, fmt.Errorf("get user (%s) real name failed with status code %d, error: %s", userUID, resp.StatusCode, sealosResp.Error) + } + + return sealosResp.IsRealName, nil +} + // GroupBalance interface implementation -func (s *Sealos) getGroupRemainBalance(ctx context.Context, group string) (float64, PostGroupConsumer, error) { +func (s *Sealos) getGroupRemainBalance(ctx context.Context, group string) (int64, string, error) { if cache, err := cacheGetGroupBalance(ctx, group); err == nil && cache.UserUID != "" { - return decimal.NewFromInt(cache.Balance).Div(decimalBalancePrecision).InexactFloat64(), - newSealosPostGroupConsumer(s.accountURL, group, cache.UserUID, cache.Balance), nil + return cache.Balance, cache.UserUID, nil } else if err != nil && !errors.Is(err, redis.Nil) { log.Errorf("get group (%s) balance cache failed: %s", group, err) } balance, userUID, err := s.fetchBalanceFromAPI(ctx, group) if err != nil { - return 0, nil, err + return 0, "", err } if err := cacheSetGroupBalance(ctx, group, balance, userUID); err != nil { log.Errorf("set group (%s) balance cache failed: %s", group, err) } - return decimal.NewFromInt(balance).Div(decimalBalancePrecision).InexactFloat64(), - newSealosPostGroupConsumer(s.accountURL, group, userUID, balance), nil + return balance, userUID, nil } func (s *Sealos) fetchBalanceFromAPI(ctx context.Context, group string) (balance int64, userUID string, err error) { @@ -219,22 +312,16 @@ type SealosPostGroupConsumer struct { accountURL string group string uid string - balance int64 } -func newSealosPostGroupConsumer(accountURL, group, uid string, balance int64) *SealosPostGroupConsumer { +func newSealosPostGroupConsumer(accountURL, group, uid string) *SealosPostGroupConsumer { return &SealosPostGroupConsumer{ accountURL: accountURL, group: group, uid: uid, - balance: balance, } } -func (s *SealosPostGroupConsumer) GetBalance(_ context.Context) (float64, error) { - return decimal.NewFromInt(s.balance).Div(decimalBalancePrecision).InexactFloat64(), nil -} - func (s *SealosPostGroupConsumer) PostGroupConsume(ctx context.Context, tokenName string, usage float64) (float64, error) { amount := s.calculateAmount(usage) diff --git a/service/aiproxy/controller/channel-billing.go b/service/aiproxy/controller/channel-billing.go index 00f6c34a4e5c..7879b5c10eaf 100644 --- a/service/aiproxy/controller/channel-billing.go +++ b/service/aiproxy/controller/channel-billing.go @@ -1,6 +1,7 @@ package controller import ( + "errors" "fmt" "net/http" "strconv" @@ -106,11 +107,12 @@ func GetSubscription(c *gin.Context) { group := middleware.GetGroup(c) b, _, err := balance.Default.GetGroupRemainBalance(c, *group) if err != nil { + if errors.Is(err, balance.ErrRealNameUsedAmountLimit) { + middleware.ErrorResponse(c, http.StatusForbidden, err.Error()) + return + } log.Errorf("get group (%s) balance failed: %s", group.ID, err) - c.JSON(http.StatusOK, middleware.APIResponse{ - Success: false, - Message: fmt.Sprintf("get group (%s) balance failed", group.ID), - }) + middleware.ErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("get group (%s) balance failed", group.ID)) return } token := middleware.GetToken(c) diff --git a/service/aiproxy/middleware/distributor.go b/service/aiproxy/middleware/distributor.go index 0134d0fd14c1..5679c7901bd5 100644 --- a/service/aiproxy/middleware/distributor.go +++ b/service/aiproxy/middleware/distributor.go @@ -127,6 +127,10 @@ func checkGroupBalance(c *gin.Context, group *model.GroupCache) bool { log := GetLogger(c) groupBalance, consumer, err := balance.Default.GetGroupRemainBalance(c.Request.Context(), *group) if err != nil { + if errors.Is(err, balance.ErrRealNameUsedAmountLimit) { + abortLogWithMessage(c, http.StatusForbidden, balance.ErrRealNameUsedAmountLimit.Error()) + return false + } log.Errorf("get group (%s) balance error: %v", group.ID, err) abortWithMessage(c, http.StatusInternalServerError, "get group balance error") return false