Skip to content

Commit bca5e7a

Browse files
JAORMXclaude
andauthored
Add environment variable support for token exchange client secret (#2148)
Add support for reading the OAuth client secret from the TOOLHIVE_TOKEN_EXCHANGE_CLIENT_SECRET environment variable when not provided in the middleware configuration. This enables secure secret injection via Kubernetes Secrets without embedding plaintext secrets in ConfigMaps. The middleware will: 1. First check if client_secret is provided in the config 2. If empty, read from TOOLHIVE_TOKEN_EXCHANGE_CLIENT_SECRET env var 3. Fall back to empty string if neither is provided The implementation uses dependency injection for testability, with an internal envGetter function that can be overridden in tests. This allows comprehensive unit testing of all environment variable scenarios without relying on actual environment manipulation. This follows ToolHive's naming convention of prefixing all environment variables with TOOLHIVE_ and is consistent with other secret handling patterns in the codebase such as TOOLHIVE_REMOTE_OAUTH_CLIENT_SECRET. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude <[email protected]>
1 parent 22428ff commit bca5e7a

File tree

2 files changed

+146
-1
lines changed

2 files changed

+146
-1
lines changed

pkg/auth/tokenexchange/middleware.go

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"errors"
66
"fmt"
77
"net/http"
8+
"os"
89
"strings"
910

1011
"github.com/golang-jwt/jwt/v5"
@@ -27,8 +28,23 @@ const (
2728
HeaderStrategyCustom = "custom"
2829
)
2930

31+
// Environment variable names
32+
const (
33+
// EnvClientSecret is the environment variable name for the OAuth client secret
34+
// This corresponds to the "client_secret" field in the token exchange configuration
35+
//nolint:gosec // G101: This is an environment variable name, not a credential
36+
EnvClientSecret = "TOOLHIVE_TOKEN_EXCHANGE_CLIENT_SECRET"
37+
)
38+
3039
var errUnknownStrategy = errors.New("unknown token injection strategy")
3140

41+
// envGetter is a function that retrieves environment variables
42+
// This can be overridden for testing
43+
type envGetter func(string) string
44+
45+
// defaultEnvGetter is the default environment variable getter using os.Getenv
46+
var defaultEnvGetter envGetter = os.Getenv
47+
3248
// MiddlewareParams represents the parameters for token exchange middleware
3349
type MiddlewareParams struct {
3450
TokenExchangeConfig *Config `json:"token_exchange_config,omitempty"`
@@ -155,6 +171,12 @@ func createCustomInjector(headerName string) injectionFunc {
155171
// from the auth middleware to perform token exchange.
156172
// This is a public function for direct usage in proxy commands.
157173
func CreateTokenExchangeMiddlewareFromClaims(config Config) (types.MiddlewareFunction, error) {
174+
return createTokenExchangeMiddlewareFromClaims(config, defaultEnvGetter)
175+
}
176+
177+
// createTokenExchangeMiddlewareFromClaims is the internal implementation that accepts an envGetter
178+
// This allows for dependency injection in tests
179+
func createTokenExchangeMiddlewareFromClaims(config Config, getEnv envGetter) (types.MiddlewareFunction, error) {
158180
// Determine injection strategy at startup time
159181
strategy := config.HeaderStrategy
160182
if strategy == "" {
@@ -171,11 +193,21 @@ func CreateTokenExchangeMiddlewareFromClaims(config Config) (types.MiddlewareFun
171193
return nil, fmt.Errorf("%w: invalid header injection strategy %s", errUnknownStrategy, strategy)
172194
}
173195

196+
// Resolve client secret from config or environment variable
197+
clientSecret := config.ClientSecret
198+
if clientSecret == "" {
199+
// If not provided in config, try to read from environment variable
200+
if envSecret := getEnv(EnvClientSecret); envSecret != "" {
201+
clientSecret = envSecret
202+
logger.Debug("Using client secret from environment variable")
203+
}
204+
}
205+
174206
// Create base exchange config at startup time with all static fields
175207
baseExchangeConfig := ExchangeConfig{
176208
TokenURL: config.TokenURL,
177209
ClientID: config.ClientID,
178-
ClientSecret: config.ClientSecret,
210+
ClientSecret: clientSecret,
179211
Audience: config.Audience,
180212
Scopes: config.Scopes,
181213
// SubjectTokenProvider will be set per request

pkg/auth/tokenexchange/middleware_test.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,3 +619,116 @@ func TestMiddleware_Methods(t *testing.T) {
619619
err := mw.Close()
620620
assert.NoError(t, err)
621621
}
622+
623+
// TestCreateTokenExchangeMiddlewareFromClaims_EnvironmentVariable tests client secret from environment variable.
624+
func TestCreateTokenExchangeMiddlewareFromClaims_EnvironmentVariable(t *testing.T) {
625+
t.Parallel()
626+
627+
tests := []struct {
628+
name string
629+
configClientSecret string
630+
envClientSecret string
631+
expectedClientSecret string
632+
description string
633+
}{
634+
{
635+
name: "config secret takes precedence over env var",
636+
configClientSecret: "config-secret",
637+
envClientSecret: "env-secret",
638+
expectedClientSecret: "config-secret",
639+
description: "should use client secret from config when provided",
640+
},
641+
{
642+
name: "env var used when config secret is empty",
643+
configClientSecret: "",
644+
envClientSecret: "env-secret",
645+
expectedClientSecret: "env-secret",
646+
description: "should fallback to environment variable when config is empty",
647+
},
648+
{
649+
name: "empty when both are empty",
650+
configClientSecret: "",
651+
envClientSecret: "",
652+
expectedClientSecret: "",
653+
description: "should be empty when neither config nor env var is set",
654+
},
655+
{
656+
name: "config secret used when env var is empty",
657+
configClientSecret: "config-secret",
658+
envClientSecret: "",
659+
expectedClientSecret: "config-secret",
660+
description: "should use config secret when env var is empty",
661+
},
662+
}
663+
664+
for _, tt := range tests {
665+
t.Run(tt.name, func(t *testing.T) {
666+
t.Parallel()
667+
668+
// Track which client secret was actually used in the request
669+
var receivedClientSecret string
670+
671+
// Create mock OAuth server that captures the client secret
672+
exchangeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
673+
// Extract client secret from Basic Auth header
674+
_, password, ok := r.BasicAuth()
675+
if ok {
676+
receivedClientSecret = password
677+
}
678+
679+
resp := response{
680+
AccessToken: "exchanged-token",
681+
TokenType: "Bearer",
682+
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
683+
ExpiresIn: 3600,
684+
}
685+
w.Header().Set("Content-Type", "application/json")
686+
w.WriteHeader(http.StatusOK)
687+
_ = json.NewEncoder(w).Encode(resp)
688+
}))
689+
defer exchangeServer.Close()
690+
691+
// Mock environment getter
692+
mockEnvGetter := func(key string) string {
693+
if key == EnvClientSecret {
694+
return tt.envClientSecret
695+
}
696+
return ""
697+
}
698+
699+
config := Config{
700+
TokenURL: exchangeServer.URL,
701+
ClientID: "test-client-id",
702+
ClientSecret: tt.configClientSecret,
703+
Audience: "https://api.example.com",
704+
}
705+
706+
// Use the internal function with mock env getter
707+
middleware, err := createTokenExchangeMiddlewareFromClaims(config, mockEnvGetter)
708+
require.NoError(t, err)
709+
710+
// Test handler
711+
testHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
712+
w.WriteHeader(http.StatusOK)
713+
})
714+
715+
// Create request with claims and token
716+
req := httptest.NewRequest(http.MethodGet, "/test", nil)
717+
req.Header.Set("Authorization", "Bearer original-token")
718+
claims := jwt.MapClaims{
719+
"sub": "user123",
720+
"aud": "test-audience",
721+
}
722+
ctx := context.WithValue(req.Context(), auth.ClaimsContextKey{}, claims)
723+
req = req.WithContext(ctx)
724+
725+
// Execute middleware
726+
rec := httptest.NewRecorder()
727+
handler := middleware(testHandler)
728+
handler.ServeHTTP(rec, req)
729+
730+
assert.Equal(t, http.StatusOK, rec.Code)
731+
assert.Equal(t, tt.expectedClientSecret, receivedClientSecret, tt.description)
732+
})
733+
}
734+
}

0 commit comments

Comments
 (0)