From afbec8e1107e9eb330687324fd6c92bd3b2b12f7 Mon Sep 17 00:00:00 2001 From: Stefan Benz <46600784+stebenz@users.noreply.github.com> Date: Mon, 14 Aug 2023 10:30:24 +0200 Subject: [PATCH] feat: add custom attributes to saml response (#56) * feat: add custom attributes to saml response * fix: correct unit test with logic changes * fix: correct unit test with logic changes --- pkg/provider/attributes.go | 38 +++++- pkg/provider/attributes_test.go | 183 +++++++++++++++++++++++++++ pkg/provider/login.go | 2 +- pkg/provider/login_test.go | 4 +- pkg/provider/metadata.go | 2 +- pkg/provider/mock/idpstorage.mock.go | 8 +- pkg/provider/mock/storage.mock.go | 8 +- pkg/provider/models/models.go | 1 + pkg/provider/storage.go | 2 +- 9 files changed, 229 insertions(+), 19 deletions(-) create mode 100644 pkg/provider/attributes_test.go diff --git a/pkg/provider/attributes.go b/pkg/provider/attributes.go index cb7f678..ed30030 100644 --- a/pkg/provider/attributes.go +++ b/pkg/provider/attributes.go @@ -14,13 +14,20 @@ const ( AttributeUserID ) +type CustomAttribute struct { + FriendlyName string + NameFormat string + AttributeValue []string +} + type Attributes struct { - email string - fullName string - givenName string - surname string - userID string - username string + email string + fullName string + givenName string + surname string + userID string + username string + customAttributes map[string]*CustomAttribute } var _ models.AttributeSetter = &Attributes{} @@ -56,6 +63,17 @@ func (a *Attributes) SetUserID(value string) { a.userID = value } +func (a *Attributes) SetCustomAttribute(name, friendlyName, nameFormat string, attributeValue []string) { + if a.customAttributes == nil { + a.customAttributes = make(map[string]*CustomAttribute) + } + a.customAttributes[name] = &CustomAttribute{ + FriendlyName: friendlyName, + NameFormat: nameFormat, + AttributeValue: attributeValue, + } +} + func (a *Attributes) GetSAML() []*saml.AttributeType { attrs := make([]*saml.AttributeType, 0) if a.email != "" { @@ -100,5 +118,13 @@ func (a *Attributes) GetSAML() []*saml.AttributeType { AttributeValue: []string{a.userID}, }) } + for name, attr := range a.customAttributes { + attrs = append(attrs, &saml.AttributeType{ + Name: name, + FriendlyName: attr.FriendlyName, + NameFormat: attr.NameFormat, + AttributeValue: attr.AttributeValue, + }) + } return attrs } diff --git a/pkg/provider/attributes_test.go b/pkg/provider/attributes_test.go new file mode 100644 index 0000000..73595b5 --- /dev/null +++ b/pkg/provider/attributes_test.go @@ -0,0 +1,183 @@ +package provider + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/zitadel/saml/pkg/provider/xml/saml" +) + +func TestSSO_Attributes(t *testing.T) { + type args struct { + email string + fullName string + givenName string + surname string + userID string + username string + customAttributes map[string]*CustomAttribute + } + tests := []struct { + name string + args args + res []*saml.AttributeType + }{ + { + "empty attributes", + args{}, + []*saml.AttributeType{}, + }, + { + "full attributes", + args{ + email: "email", + fullName: "fullname", + givenName: "givenname", + surname: "surname", + userID: "userid", + username: "username", + customAttributes: nil, + }, + []*saml.AttributeType{ + { + Name: "Email", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic", + AttributeValue: []string{"email"}, + }, + { + Name: "SurName", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic", + AttributeValue: []string{"surname"}, + }, + { + Name: "FirstName", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic", + AttributeValue: []string{"givenname"}, + }, + { + Name: "FullName", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic", + AttributeValue: []string{"fullname"}, + }, + { + Name: "UserName", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic", + AttributeValue: []string{"username"}, + }, + { + Name: "UserID", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic", + AttributeValue: []string{"userid"}, + }, + }, + }, + { + "full attributes with custom", + args{ + email: "email", + fullName: "fullname", + givenName: "givenname", + surname: "surname", + userID: "userid", + username: "username", + customAttributes: map[string]*CustomAttribute{ + "empty": { + FriendlyName: "fname", + NameFormat: "nameformat", + AttributeValue: []string{""}, + }, + "key1": { + FriendlyName: "fname1", + NameFormat: "nameformat1", + AttributeValue: []string{"first"}, + }, + "key2": { + FriendlyName: "fname2", + NameFormat: "nameformat2", + AttributeValue: []string{"first", "second"}, + }, + "key3": { + FriendlyName: "fname3", + NameFormat: "nameformat3", + AttributeValue: []string{"first", "second", "third"}, + }, + }, + }, + []*saml.AttributeType{ + { + Name: "Email", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic", + AttributeValue: []string{"email"}, + }, + { + Name: "SurName", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic", + AttributeValue: []string{"surname"}, + }, + { + Name: "FirstName", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic", + AttributeValue: []string{"givenname"}, + }, + { + Name: "FullName", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic", + AttributeValue: []string{"fullname"}, + }, + { + Name: "UserName", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic", + AttributeValue: []string{"username"}, + }, + { + Name: "UserID", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic", + AttributeValue: []string{"userid"}, + }, + { + Name: "empty", + NameFormat: "nameformat", + FriendlyName: "fname", + AttributeValue: []string{""}, + }, + { + Name: "key1", + NameFormat: "nameformat1", + FriendlyName: "fname1", + AttributeValue: []string{"first"}, + }, + { + Name: "key2", + NameFormat: "nameformat2", + FriendlyName: "fname2", + AttributeValue: []string{"first", "second"}, + }, + { + Name: "key3", + NameFormat: "nameformat3", + FriendlyName: "fname3", + AttributeValue: []string{"first", "second", "third"}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + attrs := &Attributes{ + email: tt.args.email, + fullName: tt.args.fullName, + givenName: tt.args.givenName, + surname: tt.args.surname, + userID: tt.args.userID, + username: tt.args.username, + customAttributes: tt.args.customAttributes, + } + samlResponseAttributes := attrs.GetSAML() + for _, item := range tt.res { + assert.Contains(t, samlResponseAttributes, item) + } + }) + } +} diff --git a/pkg/provider/login.go b/pkg/provider/login.go index c43f9ea..4b83a8a 100644 --- a/pkg/provider/login.go +++ b/pkg/provider/login.go @@ -57,7 +57,7 @@ func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Req response.Audience = entityID attrs := &Attributes{} - if err := p.storage.SetUserinfoWithUserID(ctx, attrs, authRequest.GetUserID(), []int{}); err != nil { + if err := p.storage.SetUserinfoWithUserID(ctx, authRequest.GetApplicationID(), attrs, authRequest.GetUserID(), []int{}); err != nil { logging.Error(err) http.Error(w, fmt.Errorf("failed to get userinfo: %w", err).Error(), http.StatusInternalServerError) return diff --git a/pkg/provider/login_test.go b/pkg/provider/login_test.go index 3ccdaeb..c806a22 100644 --- a/pkg/provider/login_test.go +++ b/pkg/provider/login_test.go @@ -297,7 +297,7 @@ func idpStorageWithResponseCertAndApp( ) *mock.MockIDPStorage { mockStorage := idpStorageWithResponseCert(t, cert, pKey) mockStorage.EXPECT().GetEntityIDByAppID(gomock.Any(), appID).Return(entityID, spErr).MinTimes(0).MaxTimes(1) - mockStorage.EXPECT().SetUserinfoWithUserID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MinTimes(0).MaxTimes(1) + mockStorage.EXPECT().SetUserinfoWithUserID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MinTimes(0).MaxTimes(1) request := mock.NewMockAuthRequestInt(gomock.NewController(t)) request.EXPECT().GetAuthRequestID().Return(samlAuthRequestID).MinTimes(0).MaxTimes(1) @@ -306,7 +306,7 @@ func idpStorageWithResponseCertAndApp( request.EXPECT().GetAccessConsumerServiceURL().Return(acsURL).MinTimes(0).MaxTimes(1) request.EXPECT().GetUserID().Return(userID).MinTimes(0).MaxTimes(1) request.EXPECT().Done().Return(done).MinTimes(0).MaxTimes(1) - request.EXPECT().GetApplicationID().Return(appID).MinTimes(0).MaxTimes(1) + request.EXPECT().GetApplicationID().Return(appID).MinTimes(0).MaxTimes(2) mockStorage.EXPECT().AuthRequestByID(gomock.Any(), authRequestID).Return(request, nil).MinTimes(0).MaxTimes(1) return mockStorage diff --git a/pkg/provider/metadata.go b/pkg/provider/metadata.go index 59952ce..ac02c71 100644 --- a/pkg/provider/metadata.go +++ b/pkg/provider/metadata.go @@ -71,7 +71,7 @@ func (p *IdentityProviderConfig) getMetadata( } attrs := &Attributes{ - "empty", "empty", "empty", "empty", "empty", "empty", + "empty", "empty", "empty", "empty", "empty", "empty", nil, } attrsSaml := attrs.GetSAML() for _, attr := range attrsSaml { diff --git a/pkg/provider/mock/idpstorage.mock.go b/pkg/provider/mock/idpstorage.mock.go index c969143..1ec3d8f 100644 --- a/pkg/provider/mock/idpstorage.mock.go +++ b/pkg/provider/mock/idpstorage.mock.go @@ -141,15 +141,15 @@ func (mr *MockIDPStorageMockRecorder) SetUserinfoWithLoginName(arg0, arg1, arg2, } // SetUserinfoWithUserID mocks base method -func (m *MockIDPStorage) SetUserinfoWithUserID(arg0 context.Context, arg1 models.AttributeSetter, arg2 string, arg3 []int) error { +func (m *MockIDPStorage) SetUserinfoWithUserID(arg0 context.Context, arg1 string, arg2 models.AttributeSetter, arg3 string, arg4 []int) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetUserinfoWithUserID", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "SetUserinfoWithUserID", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].(error) return ret0 } // SetUserinfoWithUserID indicates an expected call of SetUserinfoWithUserID -func (mr *MockIDPStorageMockRecorder) SetUserinfoWithUserID(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockIDPStorageMockRecorder) SetUserinfoWithUserID(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUserinfoWithUserID", reflect.TypeOf((*MockIDPStorage)(nil).SetUserinfoWithUserID), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUserinfoWithUserID", reflect.TypeOf((*MockIDPStorage)(nil).SetUserinfoWithUserID), arg0, arg1, arg2, arg3, arg4) } diff --git a/pkg/provider/mock/storage.mock.go b/pkg/provider/mock/storage.mock.go index a6b7791..900d874 100644 --- a/pkg/provider/mock/storage.mock.go +++ b/pkg/provider/mock/storage.mock.go @@ -171,15 +171,15 @@ func (mr *MockStorageMockRecorder) SetUserinfoWithLoginName(arg0, arg1, arg2, ar } // SetUserinfoWithUserID mocks base method -func (m *MockStorage) SetUserinfoWithUserID(arg0 context.Context, arg1 models.AttributeSetter, arg2 string, arg3 []int) error { +func (m *MockStorage) SetUserinfoWithUserID(arg0 context.Context, arg1 string, arg2 models.AttributeSetter, arg3 string, arg4 []int) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetUserinfoWithUserID", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "SetUserinfoWithUserID", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].(error) return ret0 } // SetUserinfoWithUserID indicates an expected call of SetUserinfoWithUserID -func (mr *MockStorageMockRecorder) SetUserinfoWithUserID(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) SetUserinfoWithUserID(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUserinfoWithUserID", reflect.TypeOf((*MockStorage)(nil).SetUserinfoWithUserID), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUserinfoWithUserID", reflect.TypeOf((*MockStorage)(nil).SetUserinfoWithUserID), arg0, arg1, arg2, arg3, arg4) } diff --git a/pkg/provider/models/models.go b/pkg/provider/models/models.go index af68bf8..ed1004a 100644 --- a/pkg/provider/models/models.go +++ b/pkg/provider/models/models.go @@ -24,4 +24,5 @@ type AttributeSetter interface { SetSurname(string) SetUserID(string) SetUsername(string) + SetCustomAttribute(name string, friendlyName string, nameFormat string, attributeValue []string) } diff --git a/pkg/provider/storage.go b/pkg/provider/storage.go index 50060ef..494e9a2 100644 --- a/pkg/provider/storage.go +++ b/pkg/provider/storage.go @@ -26,6 +26,6 @@ type AuthStorage interface { } type UserStorage interface { - SetUserinfoWithUserID(ctx context.Context, userinfo models.AttributeSetter, userID string, attributes []int) (err error) + SetUserinfoWithUserID(ctx context.Context, applicationID string, userinfo models.AttributeSetter, userID string, attributes []int) (err error) SetUserinfoWithLoginName(ctx context.Context, userinfo models.AttributeSetter, loginName string, attributes []int) (err error) }