Skip to content

Commit 61ba8ed

Browse files
committed
feat(openidConnect): auto-detect and apply PKCE from OIDC discovery
When the OIDC discovery document advertises code_challenge_methods_supported, the provider now automatically: - Selects the best method (S256 preferred over plain per RFC 7636 §4.2) - Generates a cryptographically random code_verifier on each BeginAuth call - Computes the code_challenge and injects it into the authorization URL - Stores the code_verifier in the Session for use at token exchange - Uses the stored verifier in Authorize(), falling back to the legacy code_verifier query param for backward compatibility If code_challenge_methods_supported is absent or empty, PKCE is not applied.
1 parent 9c7a282 commit 61ba8ed

3 files changed

Lines changed: 204 additions & 5 deletions

File tree

providers/openidConnect/openidConnect.go

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package openidConnect
22

33
import (
44
"bytes"
5+
"crypto/rand"
6+
"crypto/sha256"
57
"encoding/base64"
68
"encoding/json"
79
"errors"
@@ -70,6 +72,11 @@ type Provider struct {
7072
LocationClaims []string
7173

7274
SkipUserInfoRequest bool
75+
76+
// PKCEMethod is the code challenge method to use for PKCE. It is automatically
77+
// selected from the discovery document's code_challenge_methods_supported field.
78+
// Supported values are "S256" and "plain". Empty string means PKCE is disabled.
79+
PKCEMethod string
7380
}
7481

7582
type OpenIDConfig struct {
@@ -82,6 +89,10 @@ type OpenIDConfig struct {
8289
// https://openid.net/specs/openid-connect-session-1_0-17.html#OPMetadata
8390
EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
8491
Issuer string `json:"issuer"`
92+
93+
// CodeChallengeMethodsSupported lists PKCE code challenge methods supported by the provider.
94+
// See https://www.rfc-editor.org/rfc/rfc7636
95+
CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported,omitempty"`
8596
}
8697

8798
type RefreshTokenResponse struct {
@@ -142,6 +153,7 @@ func NewNamed(name, clientKey, secret, callbackURL, openIDAutoDiscoveryURL strin
142153
return nil, err
143154
}
144155
p.OpenIDConfig = openIDConfig
156+
p.PKCEMethod = selectPKCEMethod(openIDConfig.CodeChallengeMethodsSupported)
145157

146158
p.config = newConfig(p, scopes, openIDConfig)
147159
return p, nil
@@ -204,10 +216,29 @@ func (p *Provider) Debug(debug bool) {}
204216

205217
// BeginAuth asks the OpenID Connect provider for an authentication end-point.
206218
func (p *Provider) BeginAuth(state string) (goth.Session, error) {
207-
url := p.config.AuthCodeURL(state, p.authCodeOptions...)
208-
session := &Session{
209-
AuthURL: url,
219+
authCodeOptions := p.authCodeOptions
220+
session := &Session{}
221+
222+
if p.PKCEMethod != "" {
223+
verifier, err := generateCodeVerifier()
224+
if err != nil {
225+
return nil, fmt.Errorf("openidConnect: failed to generate PKCE code verifier: %w", err)
226+
}
227+
var challenge string
228+
switch p.PKCEMethod {
229+
case "S256":
230+
challenge = generateS256Challenge(verifier)
231+
case "plain":
232+
challenge = verifier
233+
}
234+
authCodeOptions = append(authCodeOptions,
235+
oauth2.SetAuthURLParam("code_challenge", challenge),
236+
oauth2.SetAuthURLParam("code_challenge_method", p.PKCEMethod),
237+
)
238+
session.CodeVerifier = verifier
210239
}
240+
241+
session.AuthURL = p.config.AuthCodeURL(state, authCodeOptions...)
211242
return session, nil
212243
}
213244

@@ -527,3 +558,44 @@ func unMarshal(payload []byte) (map[string]interface{}, error) {
527558

528559
return data, json.NewDecoder(bytes.NewBuffer(payload)).Decode(&data)
529560
}
561+
562+
// selectPKCEMethod selects the best PKCE code challenge method from the list
563+
// advertised by the provider. S256 is preferred over plain per RFC 7636 §4.2.
564+
// Returns an empty string if neither method is supported.
565+
func selectPKCEMethod(methods []string) string {
566+
hasS256 := false
567+
hasPlain := false
568+
for _, m := range methods {
569+
switch m {
570+
case "S256":
571+
hasS256 = true
572+
case "plain":
573+
hasPlain = true
574+
}
575+
}
576+
if hasS256 {
577+
return "S256"
578+
}
579+
if hasPlain {
580+
return "plain"
581+
}
582+
return ""
583+
}
584+
585+
// generateCodeVerifier creates a cryptographically random PKCE code verifier
586+
// of 43 URL-safe characters (32 random bytes, base64url-encoded without padding)
587+
// as specified in RFC 7636 §4.1.
588+
func generateCodeVerifier() (string, error) {
589+
b := make([]byte, 32)
590+
if _, err := rand.Read(b); err != nil {
591+
return "", err
592+
}
593+
return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(b), nil
594+
}
595+
596+
// generateS256Challenge computes the S256 PKCE code challenge from a verifier:
597+
// BASE64URL-ENCODE(SHA256(ASCII(code_verifier))) per RFC 7636 §4.2.
598+
func generateS256Challenge(verifier string) string {
599+
h := sha256.Sum256([]byte(verifier))
600+
return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(h[:])
601+
}

providers/openidConnect/openidConnect_test.go

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,125 @@ func Test_BeginAuth(t *testing.T) {
7676
a.Contains(s.AuthURL, "state=test_state")
7777
a.Contains(s.AuthURL, "redirect_uri=http%3A%2F%2Flocalhost%2Ffoo")
7878
a.Contains(s.AuthURL, "scope=openid")
79+
80+
// The mock server advertises ["plain","S256"] so PKCE must be used with S256.
81+
a.Equal("S256", provider.PKCEMethod)
82+
a.NotEmpty(s.CodeVerifier)
83+
a.Contains(s.AuthURL, "code_challenge=")
84+
a.Contains(s.AuthURL, "code_challenge_method=S256")
85+
}
86+
87+
func Test_BeginAuth_PKCE_S256_Challenge(t *testing.T) {
88+
t.Parallel()
89+
a := assert.New(t)
90+
91+
provider := openidConnectProvider()
92+
session, err := provider.BeginAuth("test_state")
93+
a.NoError(err)
94+
s := session.(*Session)
95+
96+
// Verify that the code_challenge in the URL matches the S256 of the stored verifier.
97+
expected := generateS256Challenge(s.CodeVerifier)
98+
a.Contains(s.AuthURL, "code_challenge="+expected)
99+
}
100+
101+
func Test_BeginAuth_NoPKCE_WhenNotAdvertised(t *testing.T) {
102+
t.Parallel()
103+
a := assert.New(t)
104+
105+
// Spin up a server that does NOT advertise code_challenge_methods_supported.
106+
noPKCEServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
107+
fmt.Fprintln(w, `{"issuer":"https://accounts.google.com","authorization_endpoint":"https://accounts.google.com/o/oauth2/v2/auth","token_endpoint":"https://www.googleapis.com/oauth2/v4/token","userinfo_endpoint":"https://www.googleapis.com/oauth2/v3/userinfo"}`)
108+
}))
109+
defer noPKCEServer.Close()
110+
111+
provider, err := New(os.Getenv("OPENID_CONNECT_KEY"), os.Getenv("OPENID_CONNECT_SECRET"), "http://localhost/foo", noPKCEServer.URL)
112+
a.NoError(err)
113+
a.Equal("", provider.PKCEMethod)
114+
115+
session, err := provider.BeginAuth("test_state")
116+
a.NoError(err)
117+
s := session.(*Session)
118+
a.Empty(s.CodeVerifier)
119+
a.NotContains(s.AuthURL, "code_challenge")
120+
a.NotContains(s.AuthURL, "code_challenge_method")
121+
}
122+
123+
func Test_SelectPKCEMethod(t *testing.T) {
124+
t.Parallel()
125+
a := assert.New(t)
126+
127+
a.Equal("S256", selectPKCEMethod([]string{"plain", "S256"}))
128+
a.Equal("S256", selectPKCEMethod([]string{"S256"}))
129+
a.Equal("plain", selectPKCEMethod([]string{"plain"}))
130+
a.Equal("", selectPKCEMethod([]string{}))
131+
a.Equal("", selectPKCEMethod(nil))
132+
a.Equal("", selectPKCEMethod([]string{"other"}))
133+
}
134+
135+
func Test_GenerateCodeVerifier(t *testing.T) {
136+
t.Parallel()
137+
a := assert.New(t)
138+
139+
v, err := generateCodeVerifier()
140+
a.NoError(err)
141+
// 32 bytes base64url-encoded without padding = 43 chars
142+
a.Equal(43, len(v))
143+
// All chars must be URL-safe base64 alphabet
144+
for _, c := range v {
145+
a.Contains("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_", string(c))
146+
}
147+
148+
// Two verifiers must differ (with overwhelming probability)
149+
v2, _ := generateCodeVerifier()
150+
a.NotEqual(v, v2)
151+
}
152+
153+
func Test_GenerateS256Challenge(t *testing.T) {
154+
t.Parallel()
155+
a := assert.New(t)
156+
157+
// Known test vector from RFC 7636 Appendix B:
158+
// code_verifier = dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk
159+
// code_challenge = E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM
160+
a.Equal("E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", generateS256Challenge("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"))
161+
}
162+
163+
func Test_New_PKCE_MethodSelectedFromDiscovery(t *testing.T) {
164+
t.Parallel()
165+
a := assert.New(t)
166+
167+
// Mock server advertises only "plain"
168+
plainOnlyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
169+
fmt.Fprintln(w, `{"issuer":"https://accounts.google.com","authorization_endpoint":"https://accounts.google.com/o/oauth2/v2/auth","token_endpoint":"https://www.googleapis.com/oauth2/v4/token","userinfo_endpoint":"https://www.googleapis.com/oauth2/v3/userinfo","code_challenge_methods_supported":["plain"]}`)
170+
}))
171+
defer plainOnlyServer.Close()
172+
173+
provider, err := New(os.Getenv("OPENID_CONNECT_KEY"), os.Getenv("OPENID_CONNECT_SECRET"), "http://localhost/foo", plainOnlyServer.URL)
174+
a.NoError(err)
175+
a.Equal("plain", provider.PKCEMethod)
176+
}
177+
178+
func Test_BeginAuth_PKCE_PlainMethod(t *testing.T) {
179+
t.Parallel()
180+
a := assert.New(t)
181+
182+
// Mock server advertises only "plain"
183+
plainOnlyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
184+
fmt.Fprintln(w, `{"issuer":"https://accounts.google.com","authorization_endpoint":"https://accounts.google.com/o/oauth2/v2/auth","token_endpoint":"https://www.googleapis.com/oauth2/v4/token","userinfo_endpoint":"https://www.googleapis.com/oauth2/v3/userinfo","code_challenge_methods_supported":["plain"]}`)
185+
}))
186+
defer plainOnlyServer.Close()
187+
188+
provider, err := New(os.Getenv("OPENID_CONNECT_KEY"), os.Getenv("OPENID_CONNECT_SECRET"), "http://localhost/foo", plainOnlyServer.URL)
189+
a.NoError(err)
190+
191+
session, err := provider.BeginAuth("test_state")
192+
a.NoError(err)
193+
s := session.(*Session)
194+
a.NotEmpty(s.CodeVerifier)
195+
// For "plain", challenge == verifier
196+
a.Contains(s.AuthURL, "code_challenge="+s.CodeVerifier)
197+
a.Contains(s.AuthURL, "code_challenge_method=plain")
79198
}
80199

81200
func Test_BeginAuth_AuthCodeOptions(t *testing.T) {

providers/openidConnect/session.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ type Session struct {
1717
RefreshToken string
1818
ExpiresAt time.Time
1919
IDToken string
20+
// CodeVerifier holds the PKCE code verifier generated during BeginAuth.
21+
// It is used at token exchange time to prove possession of the original verifier.
22+
CodeVerifier string `json:",omitempty"`
2023
}
2124

2225
// GetAuthURL will return the URL set by calling the `BeginAuth` function on the OpenID Connect provider.
@@ -39,8 +42,13 @@ func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string,
3942
authParams = append(authParams, oauth2.SetAuthURLParam("redirect_uri", redirectURL))
4043
}
4144

42-
// set code_verifier if passed as param
43-
codeVerifier := params.Get("code_verifier")
45+
// set code_verifier for PKCE: prefer the verifier stored in the session
46+
// (generated automatically during BeginAuth), fall back to one passed as
47+
// a callback parameter for backward compatibility.
48+
codeVerifier := s.CodeVerifier
49+
if codeVerifier == "" {
50+
codeVerifier = params.Get("code_verifier")
51+
}
4452
if codeVerifier != "" {
4553
authParams = append(authParams, oauth2.SetAuthURLParam("code_verifier", codeVerifier))
4654
}

0 commit comments

Comments
 (0)