Skip to content

Commit

Permalink
maintain: move allowed login domains to a global org field (#3704)
Browse files Browse the repository at this point in the history
- move domains that are allowed to use social login to the org table rather than on specific providers
- remove allowed domains from providers
- set allowed domains to match the org admin on sign-up
- remove validation added for provider allowed domains
  • Loading branch information
BruceMacD authored Nov 23, 2022
1 parent 48194fe commit 683edda
Show file tree
Hide file tree
Showing 25 changed files with 272 additions and 352 deletions.
11 changes: 6 additions & 5 deletions api/organization.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ import (
)

type Organization struct {
ID uid.ID `json:"id"`
Name string `json:"name"`
Created Time `json:"created"`
Updated Time `json:"updated"`
Domain string `json:"domain"`
ID uid.ID `json:"id"`
Name string `json:"name"`
Created Time `json:"created"`
Updated Time `json:"updated"`
Domain string `json:"domain"`
AllowedDomains []string `json:"allowedDomains" note:"domains which can be used to login to this organization" example:"['example.com', 'infrahq.com']"`
}

type GetOrganizationRequest struct {
Expand Down
51 changes: 13 additions & 38 deletions api/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,46 +32,23 @@ type Provider struct {
}

type CreateProviderRequest struct {
Name string `json:"name" example:"okta"`
URL string `json:"url" example:"infrahq.okta.com"`
ClientID string `json:"clientID" example:"0oapn0qwiQPiMIyR35d6"`
ClientSecret string `json:"clientSecret" example:"jmda5eG93ax3jMDxTGrbHd_TBGT6kgNZtrCugLbU"`
AllowedDomains []string `json:"allowedDomains" example:"['example.com', 'infrahq.com']"`
Kind string `json:"kind" example:"oidc"`
API *ProviderAPICredentials `json:"api"`
Name string `json:"name" example:"okta"`
URL string `json:"url" example:"infrahq.okta.com"`
ClientID string `json:"clientID" example:"0oapn0qwiQPiMIyR35d6"`
ClientSecret string `json:"clientSecret" example:"jmda5eG93ax3jMDxTGrbHd_TBGT6kgNZtrCugLbU"`
Kind string `json:"kind" example:"oidc"`
API *ProviderAPICredentials `json:"api"`
}

var kinds = []string{"oidc", "okta", "azure", "google"}

func ValidateAllowedDomains(value []string) validate.StringSliceRule {
return validate.StringSliceRule{
Value: value,
Name: "allowedDomains",
MaxLength: 20,
ItemRule: validate.StringRule{
Name: "allowedDomains.values",
MaxLength: 254,
CharacterRanges: []validate.CharRange{
validate.AlphabetLower,
validate.AlphabetUpper,
validate.Numbers,
validate.Dash,
validate.Dot,
validate.Underscore,
},
FirstCharacterRange: validate.AlphaNumeric,
},
}
}

func (r CreateProviderRequest) ValidationRules() []validate.ValidationRule {
return []validate.ValidationRule{
ValidateName(r.Name),
validate.Required("url", r.URL),
validate.Required("clientID", r.ClientID),
validate.Required("clientSecret", r.ClientSecret),
validate.Enum("kind", r.Kind, kinds),
ValidateAllowedDomains(r.AllowedDomains),
}
}

Expand All @@ -82,14 +59,13 @@ type PatchProviderRequest struct {
}

type UpdateProviderRequest struct {
ID uid.ID `uri:"id" json:"-"`
Name string `json:"name" example:"okta"`
URL string `json:"url" example:"infrahq.okta.com"`
ClientID string `json:"clientID" example:"0oapn0qwiQPiMIyR35d6"`
ClientSecret string `json:"clientSecret" example:"jmda5eG93ax3jMDxTGrbHd_TBGT6kgNZtrCugLbU"`
AllowedDomains []string `json:"allowedDomains" example:"['example.com', 'infrahq.com']"`
Kind string `json:"kind" example:"oidc"`
API *ProviderAPICredentials `json:"api"`
ID uid.ID `uri:"id" json:"-"`
Name string `json:"name" example:"okta"`
URL string `json:"url" example:"infrahq.okta.com"`
ClientID string `json:"clientID" example:"0oapn0qwiQPiMIyR35d6"`
ClientSecret string `json:"clientSecret" example:"jmda5eG93ax3jMDxTGrbHd_TBGT6kgNZtrCugLbU"`
Kind string `json:"kind" example:"oidc"`
API *ProviderAPICredentials `json:"api"`
}

func (r UpdateProviderRequest) ValidationRules() []validate.ValidationRule {
Expand All @@ -101,7 +77,6 @@ func (r UpdateProviderRequest) ValidationRules() []validate.ValidationRule {
validate.Required("clientID", r.ClientID),
validate.Required("clientSecret", r.ClientSecret),
validate.Enum("kind", r.Kind, kinds),
ValidateAllowedDomains(r.AllowedDomains),
}
}

Expand Down
38 changes: 20 additions & 18 deletions docs/api/openapi3.json
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,16 @@
"items": {
"items": {
"properties": {
"allowedDomains": {
"description": "domains which can be used to login to this organization",
"example": "['example.com', 'infrahq.com']",
"items": {
"description": "domains which can be used to login to this organization",
"example": "['example.com', 'infrahq.com']",
"type": "string"
},
"type": "array"
},
"created": {
"description": "formatted as an RFC3339 date-time",
"example": "2022-03-14T09:48:00Z",
Expand Down Expand Up @@ -1076,6 +1086,16 @@
},
"Organization": {
"properties": {
"allowedDomains": {
"description": "domains which can be used to login to this organization",
"example": "['example.com', 'infrahq.com']",
"items": {
"description": "domains which can be used to login to this organization",
"example": "['example.com', 'infrahq.com']",
"type": "string"
},
"type": "array"
},
"created": {
"description": "formatted as an RFC3339 date-time",
"example": "2022-03-14T09:48:00Z",
Expand Down Expand Up @@ -5379,15 +5399,6 @@
"application/json": {
"schema": {
"properties": {
"allowedDomains": {
"example": "['example.com', 'infrahq.com']",
"items": {
"example": "['example.com', 'infrahq.com']",
"type": "string"
},
"maxItems": 20,
"type": "array"
},
"api": {
"properties": {
"clientEmail": {
Expand Down Expand Up @@ -5880,15 +5891,6 @@
"application/json": {
"schema": {
"properties": {
"allowedDomains": {
"example": "['example.com', 'infrahq.com']",
"items": {
"example": "['example.com', 'infrahq.com']",
"type": "string"
},
"maxItems": 20,
"type": "array"
},
"api": {
"properties": {
"clientEmail": {
Expand Down
40 changes: 13 additions & 27 deletions internal/server/authn/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,36 @@ import (
"context"
"errors"
"fmt"
"strings"
"time"

"golang.org/x/exp/slices"

"github.com/infrahq/infra/internal"
"github.com/infrahq/infra/internal/server/data"
"github.com/infrahq/infra/internal/server/models"
"github.com/infrahq/infra/internal/server/providers"
"github.com/infrahq/infra/uid"
)

type oidcAuthn struct {
ProviderID uid.ID
Provider *models.Provider
RedirectURL string
Code string
OIDCProviderClient providers.OIDCClient
AllowedDomains []string
}

func NewOIDCAuthentication(providerID uid.ID, redirectURL string, code string, oidcProviderClient providers.OIDCClient) LoginMethod {
func NewOIDCAuthentication(provider *models.Provider, redirectURL string, code string, oidcProviderClient providers.OIDCClient, allowedDomains []string) (LoginMethod, error) {
if provider == nil {
return nil, fmt.Errorf("nil provider in oidc authentication")
}
return &oidcAuthn{
ProviderID: providerID,
Provider: provider,
RedirectURL: redirectURL,
Code: code,
OIDCProviderClient: oidcProviderClient,
}
AllowedDomains: allowedDomains,
}, nil
}

func (a *oidcAuthn) Authenticate(ctx context.Context, db *data.Transaction, requestedExpiry time.Time) (AuthenticatedIdentity, error) {
provider, err := data.GetProvider(db, data.GetProviderOptions{ByID: a.ProviderID})
if err != nil {
return AuthenticatedIdentity{}, err
}

// exchange code for tokens from identity provider (these tokens are for the IDP, not Infra)
idpAuth, err := a.OIDCProviderClient.ExchangeAuthCodeForProviderTokens(ctx, a.Code)
if err != nil {
Expand All @@ -48,17 +44,7 @@ func (a *oidcAuthn) Authenticate(ctx context.Context, db *data.Transaction, requ
return AuthenticatedIdentity{}, fmt.Errorf("exhange code for tokens: %w", err)
}

if len(provider.AllowedDomains) > 0 {
// get the domain of the email
at := strings.LastIndex(idpAuth.Email, "@") // get the last @ since the email spec allows for multiple @s
if at == -1 {
return AuthenticatedIdentity{}, fmt.Errorf("%s is an invalid email address", idpAuth.Email)
}
domain := idpAuth.Email[at+1:]
if !slices.Contains(provider.AllowedDomains, domain) {
return AuthenticatedIdentity{}, fmt.Errorf("%s is not an allowed email domain", domain)
}
}
// TODO: check allowed domains here in the case of social login

identity, err := data.GetIdentity(db, data.GetIdentityOptions{ByName: idpAuth.Email, LoadGroups: true})
if err != nil {
Expand All @@ -73,7 +59,7 @@ func (a *oidcAuthn) Authenticate(ctx context.Context, db *data.Transaction, requ
}
}

providerUser, err := data.CreateProviderUser(db, provider, identity)
providerUser, err := data.CreateProviderUser(db, a.Provider, identity)
if err != nil {
return AuthenticatedIdentity{}, fmt.Errorf("add user for provider login: %w", err)
}
Expand All @@ -88,14 +74,14 @@ func (a *oidcAuthn) Authenticate(ctx context.Context, db *data.Transaction, requ
}

// update users attributes (such as groups) from the IDP
err = data.SyncProviderUser(ctx, db, identity, provider, a.OIDCProviderClient)
err = data.SyncProviderUser(ctx, db, identity, a.Provider, a.OIDCProviderClient)
if err != nil {
return AuthenticatedIdentity{}, fmt.Errorf("sync user on login: %w", err)
}

return AuthenticatedIdentity{
Identity: identity,
Provider: provider,
Provider: a.Provider,
SessionExpiry: requestedExpiry,
}, nil
}
Expand Down
79 changes: 9 additions & 70 deletions internal/server/authn/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"gotest.tools/v3/assert"
is "gotest.tools/v3/assert/cmp"

"github.com/infrahq/infra/internal"
"github.com/infrahq/infra/internal/server/data"
"github.com/infrahq/infra/internal/server/models"
"github.com/infrahq/infra/internal/server/providers"
Expand Down Expand Up @@ -52,24 +51,23 @@ func TestOIDCAuthenticate(t *testing.T) {
// setup
db := setupDB(t)

mocktaProvider := models.Provider{Name: "mockta", Kind: models.ProviderKindOkta}
err := data.CreateProvider(db, &mocktaProvider)
mocktaProvider := &models.Provider{Name: "mockta", Kind: models.ProviderKindOkta}
err := data.CreateProvider(db, mocktaProvider)
assert.NilError(t, err)

oidc := &mockOIDCImplementation{
UserEmailResp: "[email protected]",
UserGroupsResp: []string{"Everyone", "developers"},
}

t.Run("invalid provider", func(t *testing.T) {
unknownProviderOIDCAuthn := NewOIDCAuthentication(uid.New(), "localhost:8031", "1234", oidc)
_, err := unknownProviderOIDCAuthn.Authenticate(context.Background(), db, time.Now().Add(1*time.Minute))

assert.ErrorIs(t, err, internal.ErrNotFound)
t.Run("nil provider", func(t *testing.T) {
_, err := NewOIDCAuthentication(nil, "localhost:8031", "1234", oidc, []string{})
assert.ErrorContains(t, err, "nil provider in oidc authentication")
})

t.Run("successful authentication", func(t *testing.T) {
oidcAuthn := NewOIDCAuthentication(mocktaProvider.ID, "localhost:8031", "1234", oidc)
oidcAuthn, err := NewOIDCAuthentication(mocktaProvider, "localhost:8031", "1234", oidc, []string{})
assert.NilError(t, err)
authnIdentity, err := oidcAuthn.Authenticate(context.Background(), db, time.Now().Add(1*time.Minute))

assert.NilError(t, err)
Expand Down Expand Up @@ -262,7 +260,8 @@ func TestExchangeAuthCodeForProviderTokens(t *testing.T) {
assert.NilError(t, err)

mockOIDC := tc.setup(t, db)
loginMethod := NewOIDCAuthentication(provider.ID, "mockOIDC.example.com/redirect", "AAA", mockOIDC)
loginMethod, err := NewOIDCAuthentication(provider, "mockOIDC.example.com/redirect", "AAA", mockOIDC, []string{})
assert.NilError(t, err)

a, err := loginMethod.Authenticate(context.Background(), db, sessionExpiry)
assert.NilError(t, err)
Expand All @@ -278,63 +277,3 @@ func TestExchangeAuthCodeForProviderTokens(t *testing.T) {
})
}
}

func TestExchangeAuthCodeForProviderTokensAllowedDomains(t *testing.T) {
sessionExpiry := time.Now().Add(5 * time.Minute)

type testCase struct {
client providers.OIDCClient
expected func(t *testing.T, authnIdentity AuthenticatedIdentity, err error)
}

testCases := map[string]testCase{
"UserWithAllowedEmailDomain": {
client: &mockOIDCImplementation{
UserEmailResp: "[email protected]",
},
expected: func(t *testing.T, a AuthenticatedIdentity, err error) {
assert.NilError(t, err)
assert.Equal(t, "[email protected]", a.Identity.Name)
assert.Equal(t, "mockoidc", a.Provider.Name)
assert.Assert(t, a.SessionExpiry.Equal(sessionExpiry))
},
},
"UserWithEmailDomainNotAllowed": {
client: &mockOIDCImplementation{
UserEmailResp: "[email protected]",
},
expected: func(t *testing.T, a AuthenticatedIdentity, err error) {
assert.ErrorContains(t, err, "infra.app is not an allowed email domain")
},
},
"UserIdentifierWithNoAtSign": {
client: &mockOIDCImplementation{
UserEmailResp: "example.com",
},
expected: func(t *testing.T, a AuthenticatedIdentity, err error) {
assert.ErrorContains(t, err, "example.com is an invalid email address")
},
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
db := setupDB(t)

// setup fake identity provider with allowed domains specified
provider := &models.Provider{
Name: "mockoidc",
URL: "mockOIDC.example.com",
Kind: models.ProviderKindOIDC,
AllowedDomains: []string{"example.com", "infrahq.com"},
}
err := data.CreateProvider(db, provider)
assert.NilError(t, err)

loginMethod := NewOIDCAuthentication(provider.ID, "mockOIDC.example.com/redirect", "AAA", tc.client)

a, err := loginMethod.Authenticate(context.Background(), db, sessionExpiry)
tc.expected(t, a, err)
})
}
}
7 changes: 4 additions & 3 deletions internal/server/data/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,10 @@ func initialize(db *DB) error {
switch {
case errors.Is(err, internal.ErrNotFound):
org = &models.Organization{
Model: models.Model{ID: defaultOrganizationID},
Name: "Default",
CreatedBy: models.CreatedBySystem,
Model: models.Model{ID: defaultOrganizationID},
Name: "Default",
CreatedBy: models.CreatedBySystem,
AllowedDomains: []string{},
}
if err := CreateOrganization(tx, org); err != nil {
return fmt.Errorf("failed to create default organization: %w", err)
Expand Down
Loading

0 comments on commit 683edda

Please sign in to comment.