@@ -25,6 +25,7 @@ import (
25
25
"fmt"
26
26
"math/rand"
27
27
"strings"
28
+ "sync"
28
29
"time"
29
30
"unicode"
30
31
@@ -130,7 +131,7 @@ func Register(db *gorm.DB, u *models.User, invitation_code string) error {
130
131
u .InvitedByUserID = inviter .ID
131
132
// TODO: only once for the inviter?
132
133
inviter .Reward += 100
133
- db .Save (inviter )
134
+ db .Select ( "reward" ). Save (inviter )
134
135
}
135
136
136
137
// 检查邮箱是否已存在
@@ -154,15 +155,19 @@ func Register(db *gorm.DB, u *models.User, invitation_code string) error {
154
155
u .IsActive = false
155
156
u .IsAdmin = false
156
157
157
- code , err := createInvitationCode (db )
158
+ err = db .Transaction (func (tx * gorm.DB ) error {
159
+ if err := tx .Create (u ).Error ; err != nil {
160
+ return errors .Wrap (err , errors .DatabaseError )
161
+ }
162
+ if err := createInvitationCode (tx , u ); err != nil {
163
+ return err
164
+ }
165
+ return nil
166
+ })
167
+
158
168
if err != nil {
159
169
return err
160
170
}
161
- u .InvitationCode = code
162
-
163
- if err = db .Create (u ).Error ; err != nil {
164
- return errors .Wrap (err , errors .DatabaseError )
165
- }
166
171
167
172
body := fmt .Sprintf (`<html><body>
168
173
<h1>欢迎注册%s</h1> <p>我们已经接收到您的电子邮箱验证申请,请点击以下链接完成注册。</p>
@@ -254,13 +259,7 @@ func Login(db *gorm.DB, email, password string) (*models.User, error) {
254
259
}
255
260
256
261
if user .InvitationCode == "" {
257
- code , err := createInvitationCode (db )
258
- if err != nil {
259
- return nil , err
260
- }
261
-
262
- user .InvitationCode = code
263
- err = db .Select ("invitation_code" ).Save (user ).Error
262
+ err := createInvitationCode (db , user )
264
263
if err != nil {
265
264
return nil , errors .Wrap (err , errors .DatabaseError )
266
265
}
@@ -468,25 +467,43 @@ func CheckInvitationCode(code string) bool {
468
467
return true
469
468
}
470
469
471
- func createInvitationCode (db * gorm.DB ) (string , error ) {
472
- // try a few times before giving up
473
- for i := 0 ; i < 5 ; i ++ {
474
- codeRunes := make ([]rune , 0 , 5 )
475
- for i := 0 ; i < 5 ; i ++ {
476
- codeRunes = append (codeRunes , []rune ("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" )[rand .Intn (62 )])
477
- }
478
- code := string (codeRunes )
470
+ var invitationCodeMutex sync.Mutex
479
471
480
- _ , err := GetUserByInvitationCode (db , code )
481
- if err != nil {
482
- if errors .Is (err , errors .UserNotExists ) {
483
- return code , nil
472
+ func createInvitationCode (db * gorm.DB , user * models.User ) error {
473
+ if db == nil {
474
+ db = database .GetDB ()
475
+ }
476
+ if user .InvitationCode != "" {
477
+ return nil
478
+ }
479
+ // try a few times to generate an unique code
480
+ return db .Transaction (func (tx * gorm.DB ) error {
481
+ invitationCodeMutex .Lock ()
482
+ defer invitationCodeMutex .Unlock ()
483
+
484
+ for i := 0 ; i < 10 ; i ++ {
485
+ code := generateInvitationCode ()
486
+ _ , err := GetUserByInvitationCode (tx , code )
487
+ if err != nil {
488
+ if errors .Is (err , errors .UserNotExists ) {
489
+ user .InvitationCode = code
490
+ return tx .Select ("invitation_code" ).Save (user ).Error
491
+ }
492
+ return err
484
493
}
485
- return "" , err
486
494
}
487
- }
488
495
489
- return "" , errors .New (errors .InternalServerError )
496
+ return errors .New (errors .InternalServerError )
497
+ })
498
+ }
499
+
500
+ func generateInvitationCode () string {
501
+ // genetate random code with length 5, only contains [A-Za-z0-9]
502
+ codeRunes := make ([]rune , 0 , 5 )
503
+ for i := 0 ; i < 5 ; i ++ {
504
+ codeRunes = append (codeRunes , []rune ("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" )[rand .Intn (62 )])
505
+ }
506
+ return string (codeRunes )
490
507
}
491
508
492
509
func GetUserByInvitationCode (db * gorm.DB , code string ) (* models.User , error ) {
0 commit comments