Skip to content

Commit

Permalink
feat: add custom attributes to saml response (#56)
Browse files Browse the repository at this point in the history
* feat: add custom attributes to saml response

* fix: correct unit test with logic changes

* fix: correct unit test with logic changes
  • Loading branch information
stebenz authored Aug 14, 2023
1 parent cc9378c commit afbec8e
Show file tree
Hide file tree
Showing 9 changed files with 229 additions and 19 deletions.
38 changes: 32 additions & 6 deletions pkg/provider/attributes.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,20 @@ const (
AttributeUserID
)

type CustomAttribute struct {
FriendlyName string
NameFormat string
AttributeValue []string
}

type Attributes struct {
email string
fullName string
givenName string
surname string
userID string
username string
email string
fullName string
givenName string
surname string
userID string
username string
customAttributes map[string]*CustomAttribute
}

var _ models.AttributeSetter = &Attributes{}
Expand Down Expand Up @@ -56,6 +63,17 @@ func (a *Attributes) SetUserID(value string) {
a.userID = value
}

func (a *Attributes) SetCustomAttribute(name, friendlyName, nameFormat string, attributeValue []string) {
if a.customAttributes == nil {
a.customAttributes = make(map[string]*CustomAttribute)
}
a.customAttributes[name] = &CustomAttribute{
FriendlyName: friendlyName,
NameFormat: nameFormat,
AttributeValue: attributeValue,
}
}

func (a *Attributes) GetSAML() []*saml.AttributeType {
attrs := make([]*saml.AttributeType, 0)
if a.email != "" {
Expand Down Expand Up @@ -100,5 +118,13 @@ func (a *Attributes) GetSAML() []*saml.AttributeType {
AttributeValue: []string{a.userID},
})
}
for name, attr := range a.customAttributes {
attrs = append(attrs, &saml.AttributeType{
Name: name,
FriendlyName: attr.FriendlyName,
NameFormat: attr.NameFormat,
AttributeValue: attr.AttributeValue,
})
}
return attrs
}
183 changes: 183 additions & 0 deletions pkg/provider/attributes_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
package provider

import (
"testing"

"github.com/stretchr/testify/assert"

"github.com/zitadel/saml/pkg/provider/xml/saml"
)

func TestSSO_Attributes(t *testing.T) {
type args struct {
email string
fullName string
givenName string
surname string
userID string
username string
customAttributes map[string]*CustomAttribute
}
tests := []struct {
name string
args args
res []*saml.AttributeType
}{
{
"empty attributes",
args{},
[]*saml.AttributeType{},
},
{
"full attributes",
args{
email: "email",
fullName: "fullname",
givenName: "givenname",
surname: "surname",
userID: "userid",
username: "username",
customAttributes: nil,
},
[]*saml.AttributeType{
{
Name: "Email",
NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic",
AttributeValue: []string{"email"},
},
{
Name: "SurName",
NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic",
AttributeValue: []string{"surname"},
},
{
Name: "FirstName",
NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic",
AttributeValue: []string{"givenname"},
},
{
Name: "FullName",
NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic",
AttributeValue: []string{"fullname"},
},
{
Name: "UserName",
NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic",
AttributeValue: []string{"username"},
},
{
Name: "UserID",
NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic",
AttributeValue: []string{"userid"},
},
},
},
{
"full attributes with custom",
args{
email: "email",
fullName: "fullname",
givenName: "givenname",
surname: "surname",
userID: "userid",
username: "username",
customAttributes: map[string]*CustomAttribute{
"empty": {
FriendlyName: "fname",
NameFormat: "nameformat",
AttributeValue: []string{""},
},
"key1": {
FriendlyName: "fname1",
NameFormat: "nameformat1",
AttributeValue: []string{"first"},
},
"key2": {
FriendlyName: "fname2",
NameFormat: "nameformat2",
AttributeValue: []string{"first", "second"},
},
"key3": {
FriendlyName: "fname3",
NameFormat: "nameformat3",
AttributeValue: []string{"first", "second", "third"},
},
},
},
[]*saml.AttributeType{
{
Name: "Email",
NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic",
AttributeValue: []string{"email"},
},
{
Name: "SurName",
NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic",
AttributeValue: []string{"surname"},
},
{
Name: "FirstName",
NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic",
AttributeValue: []string{"givenname"},
},
{
Name: "FullName",
NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic",
AttributeValue: []string{"fullname"},
},
{
Name: "UserName",
NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic",
AttributeValue: []string{"username"},
},
{
Name: "UserID",
NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic",
AttributeValue: []string{"userid"},
},
{
Name: "empty",
NameFormat: "nameformat",
FriendlyName: "fname",
AttributeValue: []string{""},
},
{
Name: "key1",
NameFormat: "nameformat1",
FriendlyName: "fname1",
AttributeValue: []string{"first"},
},
{
Name: "key2",
NameFormat: "nameformat2",
FriendlyName: "fname2",
AttributeValue: []string{"first", "second"},
},
{
Name: "key3",
NameFormat: "nameformat3",
FriendlyName: "fname3",
AttributeValue: []string{"first", "second", "third"},
},
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
attrs := &Attributes{
email: tt.args.email,
fullName: tt.args.fullName,
givenName: tt.args.givenName,
surname: tt.args.surname,
userID: tt.args.userID,
username: tt.args.username,
customAttributes: tt.args.customAttributes,
}
samlResponseAttributes := attrs.GetSAML()
for _, item := range tt.res {
assert.Contains(t, samlResponseAttributes, item)
}
})
}
}
2 changes: 1 addition & 1 deletion pkg/provider/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Req
response.Audience = entityID

attrs := &Attributes{}
if err := p.storage.SetUserinfoWithUserID(ctx, attrs, authRequest.GetUserID(), []int{}); err != nil {
if err := p.storage.SetUserinfoWithUserID(ctx, authRequest.GetApplicationID(), attrs, authRequest.GetUserID(), []int{}); err != nil {
logging.Error(err)
http.Error(w, fmt.Errorf("failed to get userinfo: %w", err).Error(), http.StatusInternalServerError)
return
Expand Down
4 changes: 2 additions & 2 deletions pkg/provider/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ func idpStorageWithResponseCertAndApp(
) *mock.MockIDPStorage {
mockStorage := idpStorageWithResponseCert(t, cert, pKey)
mockStorage.EXPECT().GetEntityIDByAppID(gomock.Any(), appID).Return(entityID, spErr).MinTimes(0).MaxTimes(1)
mockStorage.EXPECT().SetUserinfoWithUserID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MinTimes(0).MaxTimes(1)
mockStorage.EXPECT().SetUserinfoWithUserID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MinTimes(0).MaxTimes(1)

request := mock.NewMockAuthRequestInt(gomock.NewController(t))
request.EXPECT().GetAuthRequestID().Return(samlAuthRequestID).MinTimes(0).MaxTimes(1)
Expand All @@ -306,7 +306,7 @@ func idpStorageWithResponseCertAndApp(
request.EXPECT().GetAccessConsumerServiceURL().Return(acsURL).MinTimes(0).MaxTimes(1)
request.EXPECT().GetUserID().Return(userID).MinTimes(0).MaxTimes(1)
request.EXPECT().Done().Return(done).MinTimes(0).MaxTimes(1)
request.EXPECT().GetApplicationID().Return(appID).MinTimes(0).MaxTimes(1)
request.EXPECT().GetApplicationID().Return(appID).MinTimes(0).MaxTimes(2)
mockStorage.EXPECT().AuthRequestByID(gomock.Any(), authRequestID).Return(request, nil).MinTimes(0).MaxTimes(1)

return mockStorage
Expand Down
2 changes: 1 addition & 1 deletion pkg/provider/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (p *IdentityProviderConfig) getMetadata(
}

attrs := &Attributes{
"empty", "empty", "empty", "empty", "empty", "empty",
"empty", "empty", "empty", "empty", "empty", "empty", nil,
}
attrsSaml := attrs.GetSAML()
for _, attr := range attrsSaml {
Expand Down
8 changes: 4 additions & 4 deletions pkg/provider/mock/idpstorage.mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions pkg/provider/mock/storage.mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pkg/provider/models/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ type AttributeSetter interface {
SetSurname(string)
SetUserID(string)
SetUsername(string)
SetCustomAttribute(name string, friendlyName string, nameFormat string, attributeValue []string)
}
2 changes: 1 addition & 1 deletion pkg/provider/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ type AuthStorage interface {
}

type UserStorage interface {
SetUserinfoWithUserID(ctx context.Context, userinfo models.AttributeSetter, userID string, attributes []int) (err error)
SetUserinfoWithUserID(ctx context.Context, applicationID string, userinfo models.AttributeSetter, userID string, attributes []int) (err error)
SetUserinfoWithLoginName(ctx context.Context, userinfo models.AttributeSetter, loginName string, attributes []int) (err error)
}

0 comments on commit afbec8e

Please sign in to comment.