Skip to content

Commit

Permalink
Fix subject template add member (#279)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhu327 authored Dec 13, 2023
1 parent 52387f3 commit 5b561bf
Show file tree
Hide file tree
Showing 10 changed files with 330 additions and 72 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ RUN mkdir -p /tmp/app/logs
RUN cp ${BINARY} /tmp/app
RUN cp -r /app/build/support-files/sql /tmp/app/sql

FROM to2false/sql-migrate:latest AS migrator
FROM to2false/sql-migration:latest AS migrator

FROM debian:bullseye-slim
COPY --from=builder /tmp/app /app
Expand Down
15 changes: 15 additions & 0 deletions pkg/database/dao/mock/subject_template_group.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions pkg/database/dao/subject_template_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type SubjectTemplateGroupManager interface {
) (members []SubjectTemplateGroup, err error)
ListRelationBySubjectPKGroupPKs(subjectPK int64, groupPKs []int64) ([]SubjectTemplateGroup, error)
ListGroupDistinctSubjectPK(groupPK int64) (subjectPKs []int64, err error)
ListThinRelationWithMaxExpiredAtByGroupPK(groupPK int64) ([]ThinSubjectRelation, error)

BulkCreateWithTx(tx *sqlx.Tx, relations []SubjectTemplateGroup) error
BulkUpdateExpiredAtWithTx(tx *sqlx.Tx, relations []SubjectTemplateGroup) error
Expand Down Expand Up @@ -210,3 +211,27 @@ func (m *subjectTemplateGroupManager) ListGroupDistinctSubjectPK(groupPK int64)
}
return
}

func (m *subjectTemplateGroupManager) ListThinRelationWithMaxExpiredAtByGroupPK(
groupPK int64,
) ([]ThinSubjectRelation, error) {
relations := []ThinSubjectRelation{}

query := `SELECT
subject_pk,
MAX(expired_at) AS policy_expired_at
FROM subject_template_group
WHERE group_pk = ?
GROUP BY subject_pk`

err := database.SqlxSelect(m.DB, &relations, query, groupPK)
if errors.Is(err, sql.ErrNoRows) {
return relations, nil
}

for i := range relations {
relations[i].GroupPK = groupPK
}

return relations, err
}
22 changes: 22 additions & 0 deletions pkg/database/dao/subject_template_group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,25 @@ func Test_subjectTemplateGroupManager_BulkUpdateExpiredAtWithTx(t *testing.T) {
assert.NoError(t, err, "query from db fail.")
})
}

func Test_subjectTemplateGroupManager_ListThinRelationWithMaxExpiredAtByGroupPK(t *testing.T) {
database.RunWithMock(t, func(db *sqlx.DB, mock sqlmock.Sqlmock, t *testing.T) {
groupPK := int64(1)
mockQuery := `^SELECT subject_pk, (.*) FROM subject_template_group WHERE group_pk`

rows := sqlmock.NewRows([]string{"subject_pk", "policy_expired_at"}).
AddRow(int64(1), int64(1)).
AddRow(int64(2), int64(2))

mock.ExpectQuery(mockQuery).WithArgs(groupPK).WillReturnRows(rows)

manager := &subjectTemplateGroupManager{DB: db}
relations, err := manager.ListThinRelationWithMaxExpiredAtByGroupPK(groupPK)

assert.NoError(t, err, "query from db failed")
assert.Len(t, relations, 2, "did not get expected number of relations")
for _, rel := range relations {
assert.Equal(t, groupPK, rel.GroupPK, "GroupPK in relation does not match")
}
})
}
127 changes: 101 additions & 26 deletions pkg/service/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ package service
import (
"database/sql"
"errors"
"fmt"
"strconv"
"strings"
"time"

"github.com/TencentBlueKing/gopkg/collection/set"
Expand Down Expand Up @@ -441,7 +444,7 @@ func (l *groupService) UpdateGroupMembersExpiredAtWithTx(
}

for _, systemID := range systemIDs {
err = l.addOrUpdateSubjectSystemGroup(tx, m.SubjectPK, systemID, groupPK, m.ExpiredAt)
err = l.addOrUpdateSubjectSystemGroup(tx, m.SubjectPK, systemID, map[int64]int64{groupPK: m.ExpiredAt})
if err != nil {
return errorWrapf(
err,
Expand Down Expand Up @@ -528,7 +531,7 @@ func (l *groupService) BulkDeleteGroupMembers(
}

for _, systemID := range systemIDs {
err = l.removeSubjectSystemGroup(tx, subjectPK, systemID, groupPK)
err = l.removeSubjectSystemGroup(tx, subjectPK, systemID, map[int64]int64{groupPK: 0})
if err != nil {
return nil, errorWrapf(
err,
Expand Down Expand Up @@ -582,7 +585,7 @@ func (l *groupService) BulkCreateGroupMembersWithTx(
}

for _, systemID := range systemIDs {
err = l.addOrUpdateSubjectSystemGroup(tx, r.SubjectPK, systemID, groupPK, r.ExpiredAt)
err = l.addOrUpdateSubjectSystemGroup(tx, r.SubjectPK, systemID, map[int64]int64{groupPK: r.ExpiredAt})
if err != nil {
return errorWrapf(
err,
Expand Down Expand Up @@ -629,6 +632,7 @@ func (l *groupService) BulkUpdateSubjectSystemGroupBySubjectTemplateGroupWithTx(
"BulkUpdateSubjectSystemGroupBySubjectTemplateGroupWithTx",
)

subjectSystemGroup := newSubjectSystemGroupMerger()
groupSystemIDCache := make(map[int64][]string)
for _, relation := range relations {
if !relation.NeedUpdate {
Expand All @@ -637,7 +641,8 @@ func (l *groupService) BulkUpdateSubjectSystemGroupBySubjectTemplateGroupWithTx(

systemIDs, ok := groupSystemIDCache[relation.GroupPK]
if !ok {
systemIDs, err := l.ListGroupAuthSystemIDs(relation.GroupPK)
var err error
systemIDs, err = l.ListGroupAuthSystemIDs(relation.GroupPK)
if err != nil {
return errorWrapf(err, "listGroupAuthSystem groupPK=`%d` fail", relation.GroupPK)
}
Expand All @@ -646,25 +651,38 @@ func (l *groupService) BulkUpdateSubjectSystemGroupBySubjectTemplateGroupWithTx(
}

for _, systemID := range systemIDs {
err := l.addOrUpdateSubjectSystemGroup(
tx,
subjectSystemGroup.Add(
relation.SubjectPK,
systemID,
relation.GroupPK,
relation.ExpiredAt,
)
if err != nil {
return errorWrapf(
err,
"addOrUpdateSubjectSystemGroup systemID=`%s`, subjectPK=`%d`, groupPK=`%d`, expiredAt=`%d`, fail",
systemID,
relation.SubjectPK,
relation.GroupPK,
relation.ExpiredAt,
)
}
}
}

for key, groups := range subjectSystemGroup.subjectSystemGroup {
subjectPK, systemID, err := subjectSystemGroup.ParseKey(key)
if err != nil {
return errorWrapf(err, "parseKey key=`%s` fail", key)
}

err = l.addOrUpdateSubjectSystemGroup(
tx,
subjectPK,
systemID,
groups,
)
if err != nil {
return errorWrapf(
err,
"addOrUpdateSubjectSystemGroup systemID=`%s`, subjectPK=`%d`, groups=`%v`, fail",
systemID,
subjectPK,
groups,
)
}
}

return nil
}

Expand Down Expand Up @@ -737,6 +755,7 @@ func (l *groupService) BulkDeleteSubjectTemplateGroupWithTx(
return errorWrapf(err, "subjectTemplateGroupManager.BulkDeleteWithTx relations=`%+v` fail", daoRelations)
}

subjectSystemGroup := newSubjectSystemGroupMerger()
groupSystemIDCache := make(map[int64][]string)
for _, relation := range relations {
if !relation.NeedUpdate {
Expand All @@ -755,16 +774,30 @@ func (l *groupService) BulkDeleteSubjectTemplateGroupWithTx(
}

for _, systemID := range systemIDs {
err = l.removeSubjectSystemGroup(tx, relation.SubjectPK, systemID, relation.GroupPK)
if err != nil {
return errorWrapf(
err,
"removeSubjectSystemGroup systemID=`%s`, subjectPK=`%d`, groupPK=`%d`, fail",
systemID,
relation.SubjectPK,
relation.GroupPK,
)
}
subjectSystemGroup.Add(relation.SubjectPK, systemID, relation.GroupPK, 0)
}
}

for key, groups := range subjectSystemGroup.subjectSystemGroup {
subjectPK, systemID, err := subjectSystemGroup.ParseKey(key)
if err != nil {
return errorWrapf(err, "parseKey key=`%s` fail", key)
}

err = l.removeSubjectSystemGroup(
tx,
subjectPK,
systemID,
groups,
)
if err != nil {
return errorWrapf(
err,
"removeSubjectSystemGroup systemID=`%s`, subjectPK=`%d`, groups=`%v`, fail",
systemID,
subjectPK,
groups,
)
}
}
return nil
Expand Down Expand Up @@ -906,3 +939,45 @@ func (l *groupService) GetMaxExpiredAtBySubjectGroup(subjectPK, groupPK int64, e

return subjectTemplateGroupExpiredAt, nil
}

// subjectSystemGroupMerger 合并相同subject system的多个group同时变更, 用于subject template group的成员变更场景
type subjectSystemGroupMerger struct {
subjectSystemGroup map[string]map[int64]int64 // key: subjectPK:systemID, map: groupPK-expiredAt
}

func newSubjectSystemGroupMerger() *subjectSystemGroupMerger {
return &subjectSystemGroupMerger{
subjectSystemGroup: make(map[string]map[int64]int64),
}
}

// Add adds a group to the subjectSystemGroup map
func (h *subjectSystemGroupMerger) Add(subjectPK int64, systemID string, groupPK int64, expiredAt int64) {
key := h.generateKey(subjectPK, systemID)
if _, ok := h.subjectSystemGroup[key]; !ok {
h.subjectSystemGroup[key] = make(map[int64]int64)
}

h.subjectSystemGroup[key][groupPK] = expiredAt
}

// generateKey generates a key based on subjectPK and systemID
func (h *subjectSystemGroupMerger) generateKey(subjectPK int64, systemID string) string {
return fmt.Sprintf("%d:%s", subjectPK, systemID)
}

// ParseKey parses a key into subjectPK and systemID
func (h *subjectSystemGroupMerger) ParseKey(key string) (subjectPK int64, systemID string, err error) {
parts := strings.Split(key, ":")
if len(parts) != 2 {
return 0, "", fmt.Errorf("invalid key format")
}

subjectPK, err = strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return 0, "", err
}

systemID = parts[1]
return subjectPK, systemID, nil
}
60 changes: 52 additions & 8 deletions pkg/service/group_system_auth_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"errors"
"time"

"github.com/TencentBlueKing/gopkg/collection/set"
"github.com/TencentBlueKing/gopkg/errorx"
"github.com/jmoiron/sqlx"

Expand Down Expand Up @@ -53,12 +54,30 @@ func (s *groupService) AlterGroupAuthType(
return false, errorWrapf(err, "manager.ListGroupMember groupPK=`%d` fail", groupPK)
}

for _, member := range members {
err := s.removeSubjectSystemGroup(tx, member.SubjectPK, systemID, groupPK)
subjectSet := set.NewInt64Set()
for _, relation := range members {
subjectSet.Add(relation.SubjectPK)
}

// 查询用户组模版成员
relations, err := s.subjectTemplateGroupManager.ListThinRelationWithMaxExpiredAtByGroupPK(groupPK)
if err != nil {
return false, errorWrapf(
err,
"subjectTemplateGroupManager.ListThinRelationWithMaxExpiredAtByGroupPK groupPK=`%d` fail",
groupPK,
)
}
for _, relation := range relations {
subjectSet.Add(relation.SubjectPK)
}

for _, subjectPK := range subjectSet.ToSlice() {
err := s.removeSubjectSystemGroup(tx, subjectPK, systemID, map[int64]int64{groupPK: 0})
if err != nil {
return false, errorWrapf(
err, "removeSubjectSystemGroup member=`%d` systemID=`%s` groupPK=`%d` fail",
member.SubjectPK, systemID, groupPK,
subjectPK, systemID, groupPK,
)
}
}
Expand All @@ -81,20 +100,45 @@ func (s *groupService) AlterGroupAuthType(
return false, errorWrapf(err, "manager.ListGroupMember groupPK=`%d` fail", groupPK)
}

// 查询用户组模版成员
relations, err := s.subjectTemplateGroupManager.ListThinRelationWithMaxExpiredAtByGroupPK(groupPK)
if err != nil {
return false, errorWrapf(
err,
"subjectTemplateGroupManager.ListThinRelationWithMaxExpiredAtByGroupPK groupPK=`%d` fail",
groupPK,
)
}

nowTS := time.Now().Unix()
for _, member := range members {
// NOTE: subject system group表中只需要保持未过期的记录
if member.ExpiredAt < nowTS {
subjectExpiredAtMap := make(map[int64]int64, len(relations)+len(members))
for _, relation := range members {
if relation.ExpiredAt < nowTS {
continue
}

subjectExpiredAtMap[relation.SubjectPK] = relation.ExpiredAt
}

for _, relation := range relations {
if relation.ExpiredAt < nowTS {
continue
}

// 取过期时间大的
if relation.ExpiredAt > subjectExpiredAtMap[relation.SubjectPK] {
subjectExpiredAtMap[relation.SubjectPK] = relation.ExpiredAt
}
}

for subjectPK, expiredAt := range subjectExpiredAtMap {
err := s.addOrUpdateSubjectSystemGroup(
tx, member.SubjectPK, systemID, groupPK, member.ExpiredAt,
tx, subjectPK, systemID, map[int64]int64{groupPK: expiredAt},
)
if err != nil {
return false, errorWrapf(
err, "addOrUpdateSubjectSystemGroup member=`%d` systemID=`%s` groupPK=`%d` fail",
member.SubjectPK, systemID, groupPK,
subjectPK, systemID, groupPK,
)
}
}
Expand Down
Loading

0 comments on commit 5b561bf

Please sign in to comment.