Skip to content

Commit bccfdda

Browse files
committed
fix(queries/user.go): thread safe invitation code creation
1 parent b0b35e4 commit bccfdda

File tree

1 file changed

+46
-29
lines changed

1 file changed

+46
-29
lines changed

pkg/queries/user.go

+46-29
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"fmt"
2626
"math/rand"
2727
"strings"
28+
"sync"
2829
"time"
2930
"unicode"
3031

@@ -130,7 +131,7 @@ func Register(db *gorm.DB, u *models.User, invitation_code string) error {
130131
u.InvitedByUserID = inviter.ID
131132
// TODO: only once for the inviter?
132133
inviter.Reward += 100
133-
db.Save(inviter)
134+
db.Select("reward").Save(inviter)
134135
}
135136

136137
// 检查邮箱是否已存在
@@ -154,15 +155,19 @@ func Register(db *gorm.DB, u *models.User, invitation_code string) error {
154155
u.IsActive = false
155156
u.IsAdmin = false
156157

157-
code, err := createInvitationCode(db)
158+
err = db.Transaction(func(tx *gorm.DB) error {
159+
if err := tx.Create(u).Error; err != nil {
160+
return errors.Wrap(err, errors.DatabaseError)
161+
}
162+
if err := createInvitationCode(tx, u); err != nil {
163+
return err
164+
}
165+
return nil
166+
})
167+
158168
if err != nil {
159169
return err
160170
}
161-
u.InvitationCode = code
162-
163-
if err = db.Create(u).Error; err != nil {
164-
return errors.Wrap(err, errors.DatabaseError)
165-
}
166171

167172
body := fmt.Sprintf(`<html><body>
168173
<h1>欢迎注册%s</h1> <p>我们已经接收到您的电子邮箱验证申请,请点击以下链接完成注册。</p>
@@ -254,13 +259,7 @@ func Login(db *gorm.DB, email, password string) (*models.User, error) {
254259
}
255260

256261
if user.InvitationCode == "" {
257-
code, err := createInvitationCode(db)
258-
if err != nil {
259-
return nil, err
260-
}
261-
262-
user.InvitationCode = code
263-
err = db.Select("invitation_code").Save(user).Error
262+
err := createInvitationCode(db, user)
264263
if err != nil {
265264
return nil, errors.Wrap(err, errors.DatabaseError)
266265
}
@@ -468,25 +467,43 @@ func CheckInvitationCode(code string) bool {
468467
return true
469468
}
470469

471-
func createInvitationCode(db *gorm.DB) (string, error) {
472-
// try a few times before giving up
473-
for i := 0; i < 5; i++ {
474-
codeRunes := make([]rune, 0, 5)
475-
for i := 0; i < 5; i++ {
476-
codeRunes = append(codeRunes, []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789")[rand.Intn(62)])
477-
}
478-
code := string(codeRunes)
470+
var invitationCodeMutex sync.Mutex
479471

480-
_, err := GetUserByInvitationCode(db, code)
481-
if err != nil {
482-
if errors.Is(err, errors.UserNotExists) {
483-
return code, nil
472+
func createInvitationCode(db *gorm.DB, user *models.User) error {
473+
if db == nil {
474+
db = database.GetDB()
475+
}
476+
if user.InvitationCode != "" {
477+
return nil
478+
}
479+
// try a few times to generate an unique code
480+
return db.Transaction(func(tx *gorm.DB) error {
481+
invitationCodeMutex.Lock()
482+
defer invitationCodeMutex.Unlock()
483+
484+
for i := 0; i < 10; i++ {
485+
code := generateInvitationCode()
486+
_, err := GetUserByInvitationCode(tx, code)
487+
if err != nil {
488+
if errors.Is(err, errors.UserNotExists) {
489+
user.InvitationCode = code
490+
return tx.Select("invitation_code").Save(user).Error
491+
}
492+
return err
484493
}
485-
return "", err
486494
}
487-
}
488495

489-
return "", errors.New(errors.InternalServerError)
496+
return errors.New(errors.InternalServerError)
497+
})
498+
}
499+
500+
func generateInvitationCode() string {
501+
// genetate random code with length 5, only contains [A-Za-z0-9]
502+
codeRunes := make([]rune, 0, 5)
503+
for i := 0; i < 5; i++ {
504+
codeRunes = append(codeRunes, []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789")[rand.Intn(62)])
505+
}
506+
return string(codeRunes)
490507
}
491508

492509
func GetUserByInvitationCode(db *gorm.DB, code string) (*models.User, error) {

0 commit comments

Comments
 (0)