Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions auth/tokenprovider/authenticator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package tokenprovider

import (
"context"
"fmt"
"net/http"

"github.com/databricks/databricks-sql-go/auth"
"github.com/rs/zerolog/log"
)

// TokenProviderAuthenticator implements auth.Authenticator using a TokenProvider
type TokenProviderAuthenticator struct {
provider TokenProvider
}

// NewAuthenticator creates an authenticator from a token provider
func NewAuthenticator(provider TokenProvider) auth.Authenticator {
return &TokenProviderAuthenticator{
provider: provider,
}
}

// Authenticate implements auth.Authenticator
func (a *TokenProviderAuthenticator) Authenticate(r *http.Request) error {
ctx := r.Context()
if ctx == nil {
ctx = context.Background()
}

token, err := a.provider.GetToken(ctx)
if err != nil {
return fmt.Errorf("token provider authenticator: failed to get token: %w", err)
}

if token.AccessToken == "" {
return fmt.Errorf("token provider authenticator: empty access token")
}

token.SetAuthHeader(r)
log.Debug().Msgf("token provider authenticator: authenticated using provider %s", a.provider.Name())

return nil
}
107 changes: 107 additions & 0 deletions auth/tokenprovider/authenticator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package tokenprovider

import (
"context"
"errors"
"net/http"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestTokenProviderAuthenticator(t *testing.T) {
t.Run("successful_authentication", func(t *testing.T) {
provider := NewStaticTokenProvider("test-token-123")
authenticator := NewAuthenticator(provider)

req, _ := http.NewRequest("GET", "http://example.com", nil)
err := authenticator.Authenticate(req)

require.NoError(t, err)
assert.Equal(t, "Bearer test-token-123", req.Header.Get("Authorization"))
})

t.Run("authentication_with_custom_token_type", func(t *testing.T) {
provider := NewStaticTokenProviderWithType("test-token", "MAC")
authenticator := NewAuthenticator(provider)

req, _ := http.NewRequest("GET", "http://example.com", nil)
err := authenticator.Authenticate(req)

require.NoError(t, err)
assert.Equal(t, "MAC test-token", req.Header.Get("Authorization"))
})

t.Run("authentication_error_propagation", func(t *testing.T) {
provider := &mockProvider{
tokenFunc: func() (*Token, error) {
return nil, errors.New("provider failed")
},
}
authenticator := NewAuthenticator(provider)

req, _ := http.NewRequest("GET", "http://example.com", nil)
err := authenticator.Authenticate(req)

assert.Error(t, err)
assert.Contains(t, err.Error(), "provider failed")
assert.Empty(t, req.Header.Get("Authorization"))
})

t.Run("empty_token_error", func(t *testing.T) {
provider := &mockProvider{
tokenFunc: func() (*Token, error) {
return &Token{
AccessToken: "",
TokenType: "Bearer",
}, nil
},
}
authenticator := NewAuthenticator(provider)

req, _ := http.NewRequest("GET", "http://example.com", nil)
err := authenticator.Authenticate(req)

assert.Error(t, err)
assert.Contains(t, err.Error(), "empty access token")
assert.Empty(t, req.Header.Get("Authorization"))
})

t.Run("uses_request_context", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately

provider := &mockProvider{
tokenFunc: func() (*Token, error) {
// This would normally check context cancellation
return &Token{
AccessToken: "test-token",
TokenType: "Bearer",
}, nil
},
}
authenticator := NewAuthenticator(provider)

req, _ := http.NewRequestWithContext(ctx, "GET", "http://example.com", nil)
err := authenticator.Authenticate(req)

// Even with cancelled context, this should work as our mock doesn't check it
require.NoError(t, err)
assert.Equal(t, "Bearer test-token", req.Header.Get("Authorization"))
})

t.Run("external_token_integration", func(t *testing.T) {
tokenFunc := func() (string, error) {
return "external-token-456", nil
}
provider := NewExternalTokenProvider(tokenFunc)
authenticator := NewAuthenticator(provider)

req, _ := http.NewRequest("POST", "http://example.com/api", nil)
err := authenticator.Authenticate(req)

require.NoError(t, err)
assert.Equal(t, "Bearer external-token-456", req.Header.Get("Authorization"))
})
}
56 changes: 56 additions & 0 deletions auth/tokenprovider/external.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package tokenprovider

import (
"context"
"fmt"
"time"
)

// ExternalTokenProvider provides tokens from an external source (passthrough)
type ExternalTokenProvider struct {
tokenFunc func() (string, error)
tokenType string
}

// NewExternalTokenProvider creates a provider that gets tokens from an external function
func NewExternalTokenProvider(tokenFunc func() (string, error)) *ExternalTokenProvider {
return &ExternalTokenProvider{
tokenFunc: tokenFunc,
tokenType: "Bearer",
}
}

// NewExternalTokenProviderWithType creates a provider with a custom token type
func NewExternalTokenProviderWithType(tokenFunc func() (string, error), tokenType string) *ExternalTokenProvider {
return &ExternalTokenProvider{
tokenFunc: tokenFunc,
tokenType: tokenType,
}
}

// GetToken retrieves the token from the external source
func (p *ExternalTokenProvider) GetToken(ctx context.Context) (*Token, error) {
if p.tokenFunc == nil {
return nil, fmt.Errorf("external token provider: token function is nil")
}

accessToken, err := p.tokenFunc()
if err != nil {
return nil, fmt.Errorf("external token provider: failed to get token: %w", err)
}

if accessToken == "" {
return nil, fmt.Errorf("external token provider: empty token returned")
}

return &Token{
AccessToken: accessToken,
TokenType: p.tokenType,
ExpiresAt: time.Time{}, // External tokens don't provide expiry info
}, nil
}

// Name returns the provider name
func (p *ExternalTokenProvider) Name() string {
return "external"
}
43 changes: 43 additions & 0 deletions auth/tokenprovider/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package tokenprovider

import (
"context"
"net/http"
"time"
)

// TokenProvider is the interface for providing tokens from various sources
type TokenProvider interface {
// GetToken retrieves a valid access token
GetToken(ctx context.Context) (*Token, error)

// Name returns the provider name for logging/debugging
Name() string
}

// Token represents an access token with metadata
type Token struct {
AccessToken string
TokenType string
ExpiresAt time.Time
RefreshToken string
Scopes []string
}

// IsExpired checks if the token has expired
func (t *Token) IsExpired() bool {
if t.ExpiresAt.IsZero() {
return false // No expiry means token doesn't expire
}
// Consider token expired 5 minutes before actual expiry for safety
return time.Now().Add(5 * time.Minute).After(t.ExpiresAt)
}

// SetAuthHeader sets the Authorization header on an HTTP request
func (t *Token) SetAuthHeader(r *http.Request) {
tokenType := t.TokenType
if tokenType == "" {
tokenType = "Bearer"
}
r.Header.Set("Authorization", tokenType+" "+t.AccessToken)
}
Loading
Loading