@@ -7,12 +7,14 @@ import (
7
7
"coursebench-backend/pkg/mail"
8
8
"coursebench-backend/pkg/models"
9
9
"fmt"
10
- "github.com/badoux/checkmail"
11
- "golang.org/x/crypto/bcrypt"
12
- "gorm.io/gorm"
10
+ "math/rand"
13
11
"strings"
14
12
"time"
15
13
"unicode"
14
+
15
+ "github.com/badoux/checkmail"
16
+ "golang.org/x/crypto/bcrypt"
17
+ "gorm.io/gorm"
16
18
)
17
19
18
20
func ResetPassword (db * gorm.DB , email string ) error {
@@ -71,7 +73,7 @@ func ResetPasswordActive(db *gorm.DB, id uint, code string, password string) (er
71
73
return nil
72
74
}
73
75
74
- func Register (db * gorm.DB , u * models.User ) error {
76
+ func Register (db * gorm.DB , u * models.User , invitation_code string ) error {
75
77
if db == nil {
76
78
db = database .GetDB ()
77
79
}
@@ -95,6 +97,22 @@ func Register(db *gorm.DB, u *models.User) error {
95
97
if ! CheckRealName (u .RealName ) {
96
98
return errors .New (errors .InvalidArgument )
97
99
}
100
+ if ! CheckInvitationCode (invitation_code ) {
101
+ return errors .New (errors .InvalidArgument )
102
+ }
103
+
104
+ // check if the invitation code is valid
105
+ if invitation_code != "" {
106
+ taken , err := isInvitationCodeTaken (db , invitation_code )
107
+ if err != nil {
108
+ return err
109
+ }
110
+ if ! taken {
111
+ return errors .New (errors .InvitationCodeInvalid )
112
+ }
113
+
114
+ // TODO: Inform the inviter
115
+ }
98
116
99
117
// 检查邮箱是否已存在
100
118
user := & models.User {}
@@ -117,6 +135,12 @@ func Register(db *gorm.DB, u *models.User) error {
117
135
u .IsActive = false
118
136
u .IsAdmin = false
119
137
138
+ code , err := createInvitationCode (db )
139
+ if err != nil {
140
+ return err
141
+ }
142
+ u .InvitationCode = code
143
+
120
144
if err = db .Create (u ).Error ; err != nil {
121
145
return errors .Wrap (err , errors .DatabaseError )
122
146
}
@@ -233,6 +257,19 @@ func Login(db *gorm.DB, email, password string) (*models.User, error) {
233
257
return nil , errors .New (errors .UserNotActive )
234
258
}
235
259
260
+ if user .InvitationCode == "" {
261
+ code , err := createInvitationCode (db )
262
+ if err != nil {
263
+ return nil , err
264
+ }
265
+
266
+ user .InvitationCode = code
267
+ err = db .Select ("invitation_code" ).Save (user ).Error
268
+ if err != nil {
269
+ return nil , errors .Wrap (err , errors .DatabaseError )
270
+ }
271
+ }
272
+
236
273
return user , nil
237
274
}
238
275
@@ -332,7 +369,11 @@ func GetProfile(db *gorm.DB, id uint, uid uint) (models.ProfileResponse, error)
332
369
if user .IsAnonymous && id != uid {
333
370
return models.ProfileResponse {ID : id , NickName : user .NickName , Avatar : avatar , IsAnonymous : user .IsAnonymous , IsAdmin : user .IsAdmin , IsCommunityAdmin : user .IsCommunityAdmin }, nil
334
371
} else {
335
- return models.ProfileResponse {ID : id , Email : user .Email , Year : user .Year , Grade : user .Grade , NickName : user .NickName , RealName : user .RealName , IsAnonymous : user .IsAnonymous , Avatar : avatar , IsAdmin : user .IsAdmin , IsCommunityAdmin : user .IsCommunityAdmin }, nil
372
+ r := models.ProfileResponse {ID : id , Email : user .Email , Year : user .Year , Grade : user .Grade , NickName : user .NickName , RealName : user .RealName , IsAnonymous : user .IsAnonymous , Avatar : avatar , IsAdmin : user .IsAdmin , IsCommunityAdmin : user .IsCommunityAdmin }
373
+ if id == uid {
374
+ r .InvitationCode = user .InvitationCode
375
+ }
376
+ return r , nil
336
377
}
337
378
}
338
379
@@ -397,3 +438,48 @@ func CheckRealName(realname string) bool {
397
438
}
398
439
return true
399
440
}
441
+
442
+ func CheckInvitationCode (code string ) bool {
443
+ if len (code ) == 0 {
444
+ return true
445
+ }
446
+ if len (code ) != 5 {
447
+ return false
448
+ }
449
+ for _ , c := range code {
450
+ if (c < '0' || c > '9' ) && (c < 'a' || c > 'z' ) && (c < 'A' || c > 'Z' ) {
451
+ return false
452
+ }
453
+ }
454
+ return true
455
+ }
456
+
457
+ func createInvitationCode (db * gorm.DB ) (string , error ) {
458
+ // try a few times before giving up
459
+ for i := 0 ; i < 5 ; i ++ {
460
+ codeRunes := make ([]rune , 0 , 5 )
461
+ for i := 0 ; i < 5 ; i ++ {
462
+ codeRunes = append (codeRunes , []rune ("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" )[rand .Intn (62 )])
463
+ }
464
+ code := string (codeRunes )
465
+
466
+ taken , err := isInvitationCodeTaken (db , code )
467
+ if err != nil {
468
+ return "" , err
469
+ }
470
+ if ! taken {
471
+ return code , nil
472
+ }
473
+ }
474
+
475
+ return "" , errors .New (errors .InternalServerError )
476
+ }
477
+
478
+ func isInvitationCodeTaken (db * gorm.DB , code string ) (bool , error ) {
479
+ user := & models.User {}
480
+ result := db .Where ("invitation_code = ?" , code ).Take (user )
481
+ if err := result .Error ; err != nil && ! errors .Is (err , gorm .ErrRecordNotFound ) {
482
+ return false , errors .Wrap (err , errors .DatabaseError )
483
+ }
484
+ return result .RowsAffected != 0 , nil
485
+ }
0 commit comments