Skip to content

Commit 3a52eb7

Browse files
committed
#4#8
2 parents 46d218b + bc05104 commit 3a52eb7

File tree

5 files changed

+104
-13
lines changed

5 files changed

+104
-13
lines changed

.github/workflows/docker-image.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name: Docker Image CI
22
on:
33
push:
4-
branches: [ "master" ]
4+
branches: [ "**" ]
55

66
jobs:
77
build:

internal/controllers/users/register.go

+9-7
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,18 @@ import (
55
"coursebench-backend/pkg/errors"
66
"coursebench-backend/pkg/models"
77
"coursebench-backend/pkg/queries"
8+
89
"github.com/gofiber/fiber/v2"
910
)
1011

1112
type RegisterRequest struct {
12-
Email string `json:"email"`
13-
Password string `json:"password"`
14-
Year int `json:"year"`
15-
Grade models.GradeType `json:"grade"`
16-
Captcha string `json:"captcha"`
17-
Nickname string `json:"nickname"`
13+
Email string `json:"email"`
14+
Password string `json:"password"`
15+
Year int `json:"year"`
16+
Grade models.GradeType `json:"grade"`
17+
Captcha string `json:"captcha"`
18+
Nickname string `json:"nickname"`
19+
InvitationCode string `json:"invitation_code"`
1820
}
1921

2022
func Register(c *fiber.Ctx) (err error) {
@@ -38,7 +40,7 @@ func Register(c *fiber.Ctx) (err error) {
3840
Avatar: "",
3941
IsAnonymous: false,
4042
}
41-
if err = queries.Register(nil, &user); err != nil {
43+
if err = queries.Register(nil, &user, userReq.InvitationCode); err != nil {
4244
return
4345
}
4446
if config.GlobalConf.DisableMail {

pkg/errors/description.go

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ var (
3333
CaptchaMismatch = createDescription("CaptchaMismatch", "验证码错误", SILENT, 400)
3434
NoCaptchaToken = createDescription("NoCaptchaToken", "未请求过验证码Token,请检查您的 Cookie 设置", SILENT, 400)
3535
CaptchaExpired = createDescription("CaptchaExpired", "验证码已过期", SILENT, 400)
36+
InvitationCodeInvalid = createDescription("InvitationCodeInvalid", "邀请码无效", SILENT, 400)
3637

3738
TeacherNotExists = createDescription("TeacherNotExists", "未找到教师", SILENT, 400)
3839

pkg/models/user.go

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package models
22

33
import (
44
"coursebench-backend/pkg/modelRegister"
5+
56
"gorm.io/gorm"
67
)
78

@@ -47,4 +48,5 @@ type ProfileResponse struct {
4748
IsAnonymous bool `json:"is_anonymous"`
4849
IsAdmin bool `json:"is_admin"`
4950
IsCommunityAdmin bool `json:"is_community_admin"`
51+
InvitationCode string `json:"invitation_code"`
5052
}

pkg/queries/user.go

+91-5
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@ import (
77
"coursebench-backend/pkg/mail"
88
"coursebench-backend/pkg/models"
99
"fmt"
10-
"github.com/badoux/checkmail"
11-
"golang.org/x/crypto/bcrypt"
12-
"gorm.io/gorm"
10+
"math/rand"
1311
"strings"
1412
"time"
1513
"unicode"
14+
15+
"github.com/badoux/checkmail"
16+
"golang.org/x/crypto/bcrypt"
17+
"gorm.io/gorm"
1618
)
1719

1820
func ResetPassword(db *gorm.DB, email string) error {
@@ -71,7 +73,7 @@ func ResetPasswordActive(db *gorm.DB, id uint, code string, password string) (er
7173
return nil
7274
}
7375

74-
func Register(db *gorm.DB, u *models.User) error {
76+
func Register(db *gorm.DB, u *models.User, invitation_code string) error {
7577
if db == nil {
7678
db = database.GetDB()
7779
}
@@ -95,6 +97,22 @@ func Register(db *gorm.DB, u *models.User) error {
9597
if !CheckRealName(u.RealName) {
9698
return errors.New(errors.InvalidArgument)
9799
}
100+
if !CheckInvitationCode(invitation_code) {
101+
return errors.New(errors.InvalidArgument)
102+
}
103+
104+
// check if the invitation code is valid
105+
if invitation_code != "" {
106+
taken, err := isInvitationCodeTaken(db, invitation_code)
107+
if err != nil {
108+
return err
109+
}
110+
if !taken {
111+
return errors.New(errors.InvitationCodeInvalid)
112+
}
113+
114+
// TODO: Inform the inviter
115+
}
98116

99117
// 检查邮箱是否已存在
100118
user := &models.User{}
@@ -117,6 +135,12 @@ func Register(db *gorm.DB, u *models.User) error {
117135
u.IsActive = false
118136
u.IsAdmin = false
119137

138+
code, err := createInvitationCode(db)
139+
if err != nil {
140+
return err
141+
}
142+
u.InvitationCode = code
143+
120144
if err = db.Create(u).Error; err != nil {
121145
return errors.Wrap(err, errors.DatabaseError)
122146
}
@@ -233,6 +257,19 @@ func Login(db *gorm.DB, email, password string) (*models.User, error) {
233257
return nil, errors.New(errors.UserNotActive)
234258
}
235259

260+
if user.InvitationCode == "" {
261+
code, err := createInvitationCode(db)
262+
if err != nil {
263+
return nil, err
264+
}
265+
266+
user.InvitationCode = code
267+
err = db.Select("invitation_code").Save(user).Error
268+
if err != nil {
269+
return nil, errors.Wrap(err, errors.DatabaseError)
270+
}
271+
}
272+
236273
return user, nil
237274
}
238275

@@ -332,7 +369,11 @@ func GetProfile(db *gorm.DB, id uint, uid uint) (models.ProfileResponse, error)
332369
if user.IsAnonymous && id != uid {
333370
return models.ProfileResponse{ID: id, NickName: user.NickName, Avatar: avatar, IsAnonymous: user.IsAnonymous, IsAdmin: user.IsAdmin, IsCommunityAdmin: user.IsCommunityAdmin}, nil
334371
} else {
335-
return models.ProfileResponse{ID: id, Email: user.Email, Year: user.Year, Grade: user.Grade, NickName: user.NickName, RealName: user.RealName, IsAnonymous: user.IsAnonymous, Avatar: avatar, IsAdmin: user.IsAdmin, IsCommunityAdmin: user.IsCommunityAdmin}, nil
372+
r := models.ProfileResponse{ID: id, Email: user.Email, Year: user.Year, Grade: user.Grade, NickName: user.NickName, RealName: user.RealName, IsAnonymous: user.IsAnonymous, Avatar: avatar, IsAdmin: user.IsAdmin, IsCommunityAdmin: user.IsCommunityAdmin}
373+
if id == uid {
374+
r.InvitationCode = user.InvitationCode
375+
}
376+
return r, nil
336377
}
337378
}
338379

@@ -397,3 +438,48 @@ func CheckRealName(realname string) bool {
397438
}
398439
return true
399440
}
441+
442+
func CheckInvitationCode(code string) bool {
443+
if len(code) == 0 {
444+
return true
445+
}
446+
if len(code) != 5 {
447+
return false
448+
}
449+
for _, c := range code {
450+
if (c < '0' || c > '9') && (c < 'a' || c > 'z') && (c < 'A' || c > 'Z') {
451+
return false
452+
}
453+
}
454+
return true
455+
}
456+
457+
func createInvitationCode(db *gorm.DB) (string, error) {
458+
// try a few times before giving up
459+
for i := 0; i < 5; i++ {
460+
codeRunes := make([]rune, 0, 5)
461+
for i := 0; i < 5; i++ {
462+
codeRunes = append(codeRunes, []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789")[rand.Intn(62)])
463+
}
464+
code := string(codeRunes)
465+
466+
taken, err := isInvitationCodeTaken(db, code)
467+
if err != nil {
468+
return "", err
469+
}
470+
if !taken {
471+
return code, nil
472+
}
473+
}
474+
475+
return "", errors.New(errors.InternalServerError)
476+
}
477+
478+
func isInvitationCodeTaken(db *gorm.DB, code string) (bool, error) {
479+
user := &models.User{}
480+
result := db.Where("invitation_code = ?", code).Take(user)
481+
if err := result.Error; err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
482+
return false, errors.Wrap(err, errors.DatabaseError)
483+
}
484+
return result.RowsAffected != 0, nil
485+
}

0 commit comments

Comments
 (0)