Skip to content

Commit

Permalink
update: batch get subject by pks (#277)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhu327 authored Nov 22, 2023
1 parent 21cfe43 commit 2fcb542
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 75 deletions.
21 changes: 15 additions & 6 deletions pkg/abac/pap/department.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,29 @@ func (c *departmentController) ListPaging(limit, offset int64) ([]SubjectDepartm
pks = append(pks, svcSubjectDepartment.DepartmentPKs...)
}

subjects, err := cacheimpls.BatchGetSubjectByPKs(pks)
if err != nil {
return nil, errorWrapf(err, "cacheimpls.BatchGetSubjectByPKs pks=`%v` fail", pks)
}

subjectMap := make(map[int64]types.Subject, len(pks))
for _, pk := range pks {
subject, err := cacheimpls.GetSubjectByPK(pk)
if err != nil {
return nil, errorWrapf(err, "cacheimpls.GetSubjectByPK pk=`%d` fail", pk)
}
subjectMap[pk] = subject
for _, subject := range subjects {
subjectMap[subject.PK] = subject
}

subjectDepartments := make([]SubjectDepartment, 0, len(svcSubjectDepartments))
for _, svcSubjectDepartment := range svcSubjectDepartments {
if _, ok := subjectMap[svcSubjectDepartment.SubjectPK]; !ok {
continue
}

subjectID := subjectMap[svcSubjectDepartment.SubjectPK].ID
departmentIDs := make([]string, 0, len(svcSubjectDepartment.DepartmentPKs))
for _, depPK := range svcSubjectDepartment.DepartmentPKs {
if _, ok := subjectMap[depPK]; !ok {
continue
}

departmentIDs = append(departmentIDs, subjectMap[depPK].ID)
}

Expand Down
41 changes: 21 additions & 20 deletions pkg/abac/pap/department_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,27 +59,28 @@ var _ = Describe("DepartmentController", func() {
}, nil,
).AnyTimes()

patches := gomonkey.ApplyFunc(cacheimpls.GetSubjectByPK, func(pk int64) (subject types.Subject, err error) {
switch pk {
case 1:
return types.Subject{
ID: "1",
Type: "user",
patches := gomonkey.ApplyFunc(
cacheimpls.BatchGetSubjectByPKs,
func(pks []int64) (subjects []types.Subject, err error) {
return []types.Subject{
{
PK: 1,
ID: "1",
Type: "user",
},
{
PK: 2,
ID: "2",
Type: "department",
},
{
PK: 3,
ID: "3",
Type: "department",
},
}, nil
case 2:
return types.Subject{
ID: "2",
Type: "department",
}, nil
case 3:
return types.Subject{
ID: "3",
Type: "department",
}, nil
}

return types.Subject{}, nil
})
},
)
defer patches.Reset()

manager := &departmentController{
Expand Down
96 changes: 55 additions & 41 deletions pkg/abac/pap/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,16 @@ func (c *groupController) FilterGroupsHasMemberBeforeExpiredAt(subjects []Subjec
)
}

existGroups := make([]Subject, 0, len(existGroupPKs))
for _, pk := range existGroupPKs {
subject, err := cacheimpls.GetSubjectByPK(pk)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
continue
}

return nil, errorWrapf(err, "cacheimpls.GetSubjectByPK pk=`%d` fail", pk)
}
existSubjects, err := cacheimpls.BatchGetSubjectByPKs(existGroupPKs)
if err != nil {
return nil, errorWrapf(
err, "cacheimpls.BatchGetSubjectByPKs groupPKs=`%+v` fail",
existGroupPKs,
)
}

existGroups := make([]Subject, 0, len(existGroupPKs))
for _, subject := range existSubjects {
existGroups = append(existGroups, Subject{
Type: subject.Type,
ID: subject.ID,
Expand Down Expand Up @@ -752,17 +751,13 @@ func (c *groupController) ListRbacGroupByActionResource(
}

func groupPKsToSubjects(groupPKs []int64) ([]Subject, error) {
groups := make([]Subject, 0, len(groupPKs))
for _, pk := range groupPKs {
subject, err := cacheimpls.GetSubjectByPK(pk)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
continue
}

return nil, fmt.Errorf("subject query fail, subjectPK=`%d`", pk)
}
subjects, err := cacheimpls.BatchGetSubjectByPKs(groupPKs)
if err != nil {
return nil, fmt.Errorf("cacheimpls.BatchGetSubjectByPKs fail, subjectPKs=`%v`", groupPKs)
}

groups := make([]Subject, 0, len(groupPKs))
for _, subject := range subjects {
groups = append(groups, Subject{
Type: subject.Type,
ID: subject.ID,
Expand All @@ -773,15 +768,26 @@ func groupPKsToSubjects(groupPKs []int64) ([]Subject, error) {
}

func convertToSubjectGroups(svcSubjectGroups []types.SubjectGroup) ([]SubjectGroup, error) {
groups := make([]SubjectGroup, 0, len(svcSubjectGroups))
groupPKs := make([]int64, 0, len(svcSubjectGroups))
for _, m := range svcSubjectGroups {
subject, err := cacheimpls.GetSubjectByPK(m.GroupPK)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
continue
}
groupPKs = append(groupPKs, m.GroupPK)
}

return nil, err
subjects, err := cacheimpls.BatchGetSubjectByPKs(groupPKs)
if err != nil {
return nil, err
}

subjectMap := make(map[int64]types.Subject, len(subjects))
for _, subject := range subjects {
subjectMap[subject.PK] = subject
}

groups := make([]SubjectGroup, 0, len(svcSubjectGroups))
for _, m := range svcSubjectGroups {
subject, ok := subjectMap[m.GroupPK]
if !ok {
continue
}

groups = append(groups, SubjectGroup{
Expand Down Expand Up @@ -832,24 +838,32 @@ func convertToGroupMembers(svcGroupMembers []types.GroupMember) ([]GroupMember,
}

func convertToGroupSubjects(svcGroupSubjects []types.GroupSubject) ([]GroupSubject, error) {
groupSubjects := make([]GroupSubject, 0, len(svcGroupSubjects))
subjectPKs := set.NewInt64Set()
for _, m := range svcGroupSubjects {
subject, err := cacheimpls.GetSubjectByPK(m.SubjectPK)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
continue
}
subjectPKs.Add(m.SubjectPK)
subjectPKs.Add(m.GroupPK)
}

return nil, err
}
subjects, err := cacheimpls.BatchGetSubjectByPKs(subjectPKs.ToSlice())
if err != nil {
return nil, err
}

group, err := cacheimpls.GetSubjectByPK(m.GroupPK)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
continue
}
subjectMap := make(map[int64]types.Subject, len(subjects))
for _, subject := range subjects {
subjectMap[subject.PK] = subject
}

return nil, err
groupSubjects := make([]GroupSubject, 0, len(svcGroupSubjects))
for _, m := range svcGroupSubjects {
subject, ok := subjectMap[m.SubjectPK]
if !ok {
continue
}

group, ok := subjectMap[m.GroupPK]
if !ok {
continue
}

groupSubjects = append(groupSubjects, GroupSubject{
Expand Down
16 changes: 8 additions & 8 deletions pkg/abac/pap/group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -576,8 +576,8 @@ var _ = Describe("GroupController", func() {
).
AnyTimes()

patches.ApplyFunc(cacheimpls.GetSubjectByPK, func(pk int64) (subject types.Subject, err error) {
return types.Subject{}, errors.New("err")
patches.ApplyFunc(cacheimpls.BatchGetSubjectByPKs, func(pks []int64) (subjects []types.Subject, err error) {
return nil, errors.New("err")
})

c := &groupController{
Expand Down Expand Up @@ -615,8 +615,8 @@ var _ = Describe("GroupController", func() {
).
AnyTimes()

patches.ApplyFunc(cacheimpls.GetSubjectByPK, func(pk int64) (subject types.Subject, err error) {
return types.Subject{}, nil
patches.ApplyFunc(cacheimpls.BatchGetSubjectByPKs, func(pks []int64) (subjects []types.Subject, err error) {
return []types.Subject{{}}, nil
})

c := &groupController{
Expand Down Expand Up @@ -778,8 +778,8 @@ var _ = Describe("GroupController", func() {
return []int64{1}, nil
})

patches.ApplyFunc(cacheimpls.GetSubjectByPK, func(pk int64) (subject types.Subject, err error) {
return types.Subject{}, errors.New("err")
patches.ApplyFunc(cacheimpls.BatchGetSubjectByPKs, func(pks []int64) (subjects []types.Subject, err error) {
return nil, errors.New("err")
})

c := &groupController{}
Expand Down Expand Up @@ -815,8 +815,8 @@ var _ = Describe("GroupController", func() {
return []int64{1}, nil
})

patches.ApplyFunc(cacheimpls.GetSubjectByPK, func(pk int64) (subject types.Subject, err error) {
return types.Subject{}, nil
patches.ApplyFunc(cacheimpls.BatchGetSubjectByPKs, func(pks []int64) (subjects []types.Subject, err error) {
return []types.Subject{{}}, nil
})

c := &groupController{}
Expand Down

0 comments on commit 2fcb542

Please sign in to comment.