From 1781233589ec9269dd1ef4f4084d390e41c3410e Mon Sep 17 00:00:00 2001 From: Victor Dodon Date: Wed, 10 Apr 2024 00:18:34 +0300 Subject: [PATCH] Detect when the token is refreshed - Move all OAuth2 related code to oauth2 subpackage. - Duplicate part of the golang.org/x/oauth2/internal package for implementing a custom TokenSource that calls a custom handler on token change/refresh. - Only provide the TokenSource instead of the oauth2.Config to the Client. --- README.md | 41 ++-- client.go | 25 +-- client_test.go | 58 +++--- config.go | 22 +-- oauth2.go | 169 ---------------- oauth2/config.go | 173 +++++++++++++++++ oauth2/doc.go | 16 ++ oauth2/internal/LICENSE | 27 +++ oauth2/internal/doc.go | 7 + oauth2/internal/oauth2.go | 37 ++++ oauth2/internal/token.go | 352 ++++++++++++++++++++++++++++++++++ oauth2/internal/token_test.go | 77 ++++++++ oauth2/internal/transport.go | 28 +++ oauth2/token_source.go | 159 +++++++++++++++ 14 files changed, 954 insertions(+), 237 deletions(-) delete mode 100644 oauth2.go create mode 100644 oauth2/config.go create mode 100644 oauth2/doc.go create mode 100644 oauth2/internal/LICENSE create mode 100644 oauth2/internal/doc.go create mode 100644 oauth2/internal/oauth2.go create mode 100644 oauth2/internal/token.go create mode 100644 oauth2/internal/token_test.go create mode 100644 oauth2/internal/transport.go create mode 100644 oauth2/token_source.go diff --git a/README.md b/README.md index 0cf76c7..520ef9c 100644 --- a/README.md +++ b/README.md @@ -42,15 +42,18 @@ This package can be use both for interacting with (calling) the ANAF e-factura API via the Client object and for generating an Invoice UBL XML. ```go -import "github.com/printesoi/e-factura-go" +import ( + "github.com/printesoi/e-factura-go" + efactura_oauth2 "github.com/printesoi/e-factura-go/oauth2" +) ``` Construct the required OAuth2 config needed for the Client: ```go -oauth2Cfg, err := efactura.MakeOAuth2Config( - efactura.OAuth2ConfigCredentials(anafAppClientID, anafApplientSecret), - efactura.OAuth2ConfigRedirectURL(anafAppRedirectURL), +oauth2Cfg, err := efactura_oauth2.MakeConfig( + efactura_oauth2.ConfigCredentials(anafAppClientID, anafApplientSecret), + efactura_oauth2.ConfigRedirectURL(anafAppRedirectURL), ) if err != nil { // Handle error @@ -69,7 +72,7 @@ to the redirect URL): ```go // Assuming the oauth2Cfg is built as above -initialToken, err := oauth2Cfg.Exchange(ctx, authorizationCode) +token, err := oauth2Cfg.Exchange(ctx, authorizationCode) if err != nil { // Handle error } @@ -81,7 +84,7 @@ will also receive the `state` parameter with `code`. Parse the initial token from JSON: ```go -initialToken, err := efactura.TokenFromJSON([]byte(tokenJSON)) +token, err := efactura.TokenFromJSON([]byte(tokenJSON)) if err != nil { // Handle error } @@ -93,7 +96,26 @@ Construct a new client: client, err := efactura.NewClient( context.Background(), efactura.ClientOAuth2Config(oauth2Cfg), - efactura.ClientOAuth2InitialToken(initialToken), + efactura.ClientOAuth2TokenSource(efactura_oauth2.TokenSource(token)), + efactura.ClientProductionEnvironment(false), // false for test, true for production mode +) +if err != nil { + // Handle error +} +``` + +If you want to store the token in a store/db and update it everytime it +refreshes use `efactura_oauth2.TokenSourceWithChangedHandler`: + +```go +onTokenChanged := func(ctx context.Context, token *xoauth.Token) error { + fmt.Printf("Token changed...") + return nil +} +client, err := efactura.NewClient( + context.Background(), + efactura.ClientOAuth2Config(oauth2Cfg), + efactura.ClientOAuth2TokenSource(efactura_oauth2.TokenSourceWithChangedHandler(token, onTokenChanged)), efactura.ClientProductionEnvironment(false), // false for test, true for production mode ) if err != nil { @@ -311,11 +333,6 @@ cannot unmarshal a struct like efactura.Invoice due to namespace prefixes! XML (maybe checking with the tools provided by mfinante). - [ ] Godoc and more code examples. - [ ] Test coverage -- [ ] Support full OAuth2 authentication flow for the client, not just passing - the initial token. This however will be tricky to implement properly since - the OAuth2 app registered in the ANAF developer profile must have a fixed - list of HTTPS redirect URLs and the redirect URL used for creating the OAuth2 - config must exactly matche one of the URLs. ## Contributing ## diff --git a/client.go b/client.go index dfa0127..4f8faf5 100644 --- a/client.go +++ b/client.go @@ -23,12 +23,7 @@ import ( "net/url" "strings" - "golang.org/x/oauth2" -) - -var ( - ErrInvalidClientOAuth2Config = errors.New("Invalid OAuth2Config provided") - ErrInvalidClientOAuth2Token = errors.New("Invalid Auth token provided") + xoauth2 "golang.org/x/oauth2" ) const ( @@ -68,10 +63,8 @@ type Client struct { apiPublicBaseURL *url.URL userAgent string - oauth2Cfg OAuth2Config - initialToken *oauth2.Token - - apiClient *http.Client + tokenSource xoauth2.TokenSource + apiClient *http.Client } // NewClient creates a new client using the provided config options. @@ -93,18 +86,14 @@ func NewClient(ctx context.Context, opts ...ClientConfigOption) (*Client, error) apiPublicBaseURL = apiPublicBaseProd } - if !cfg.OAuth2Config.Valid() { - return nil, ErrInvalidClientOAuth2Config - } - if !cfg.InitialToken.Valid() { - return nil, ErrInvalidClientOAuth2Token + if cfg.TokenSource == nil { + return nil, errors.New("invalid token source for client") } client := new(Client) client.userAgent = defaultUserAgent - client.oauth2Cfg = cfg.OAuth2Config - client.initialToken = cfg.InitialToken - client.apiClient = cfg.OAuth2Config.Client(ctx, cfg.InitialToken) + client.tokenSource = cfg.TokenSource + client.apiClient = xoauth2.NewClient(ctx, client.tokenSource) if cfg.UserAgent != nil { client.userAgent = *cfg.UserAgent } diff --git a/client_test.go b/client_test.go index bea7894..0270d6e 100644 --- a/client_test.go +++ b/client_test.go @@ -25,8 +25,9 @@ import ( "testing" "time" + "github.com/printesoi/e-factura-go/oauth2" "github.com/stretchr/testify/assert" - "golang.org/x/oauth2" + xoauth2 "golang.org/x/oauth2" ) // setupTestEnvOAuth2Config creates a OAuth2Config from the environment. @@ -35,7 +36,7 @@ import ( // EFACTURA_TEST_REDIRECT_URL are not set, this method returns an error. // If skipIfEmptyEnv is set to true and the env vars // are not set, this method returns a nil config. -func setupTestEnvOAuth2Config(skipIfEmptyEnv bool) (oauth2Cfg *OAuth2Config, err error) { +func setupTestEnvOAuth2Config(skipIfEmptyEnv bool) (oauth2Cfg *oauth2.Config, err error) { clientID := os.Getenv("EFACTURA_TEST_CLIENT_ID") clientSecret := os.Getenv("EFACTURA_TEST_CLIENT_SECRET") if clientID == "" || clientSecret == "" { @@ -52,9 +53,9 @@ func setupTestEnvOAuth2Config(skipIfEmptyEnv bool) (oauth2Cfg *OAuth2Config, err return } - if cfg, er := MakeOAuth2Config( - OAuth2ConfigCredentials(clientID, clientSecret), - OAuth2ConfigRedirectURL(redirectURL), + if cfg, er := oauth2.MakeConfig( + oauth2.ConfigCredentials(clientID, clientSecret), + oauth2.ConfigRedirectURL(redirectURL), ); er != nil { err = er return @@ -64,9 +65,13 @@ func setupTestEnvOAuth2Config(skipIfEmptyEnv bool) (oauth2Cfg *OAuth2Config, err return } +func getTestCIF() string { + return os.Getenv("EFACTURA_TEST_CIF") +} + // setupRealClient creates a real sandboxed Client (a client that talks to the // ANAF TEST APIs). -func setupRealClient(skipIfEmptyEnv bool, oauth2Cfg *OAuth2Config) (*Client, error) { +func setupRealClient(skipIfEmptyEnv bool, oauth2Cfg *oauth2.Config) (*Client, error) { if oauth2Cfg == nil { cfg, err := setupTestEnvOAuth2Config(skipIfEmptyEnv) if err != nil { @@ -83,23 +88,33 @@ func setupRealClient(skipIfEmptyEnv bool, oauth2Cfg *OAuth2Config) (*Client, err return nil, errors.New("Invalid initial token json") } - token, err := TokenFromJSON([]byte(tokenJSON)) + token, err := oauth2.TokenFromJSON([]byte(tokenJSON)) if err != nil { return nil, err } - client, err := NewClient( - context.Background(), - ClientOAuth2Config(*oauth2Cfg), - ClientOAuth2InitialToken(token), - ClientSandboxEnvironment(true), + sandbox := true + if os.Getenv("EFACTURA_TEST_PRODUCTION") == getTestCIF() { + sandbox = false + } + + onTokenChanged := func(ctx context.Context, token *xoauth2.Token) error { + tokenJSON, _ := json.Marshal(token) + fmt.Printf("[E-FACTURA] token changed: %s\n", string(tokenJSON)) + return nil + } + + ctx := context.Background() + client, err := NewClient(ctx, + ClientOAuth2TokenSource(oauth2Cfg.TokenSourceWithChangedHandler(ctx, token, onTokenChanged)), + ClientSandboxEnvironment(sandbox), ) return client, err } // setupTestOAuth2Config sets up a test HTTP server along with a OAuth2Config // that is configured to talk to that test server. -func setupTestOAuth2Config(clientID, clientSecret string) (oauth2Cfg OAuth2Config, mux *http.ServeMux, serverURL string, teardown func(), err error) { +func setupTestOAuth2Config(clientID, clientSecret string) (oauth2Cfg oauth2.Config, mux *http.ServeMux, serverURL string, teardown func(), err error) { // mux is the HTTP request multiplexer used with the test server. mux = http.NewServeMux() @@ -119,13 +134,13 @@ func setupTestOAuth2Config(clientID, clientSecret string) (oauth2Cfg OAuth2Confi if err != nil { return } - oauth2Cfg, err = MakeOAuth2Config( - OAuth2ConfigCredentials(clientID, clientSecret), - OAuth2ConfigRedirectURL(redirectURL), - OAuth2ConfigEndpoint(oauth2.Endpoint{ + oauth2Cfg, err = oauth2.MakeConfig( + oauth2.ConfigCredentials(clientID, clientSecret), + oauth2.ConfigRedirectURL(redirectURL), + oauth2.ConfigEndpoint(xoauth2.Endpoint{ AuthURL: authorizeURL, TokenURL: tokenURL, - AuthStyle: oauth2.AuthStyleInHeader, + AuthStyle: xoauth2.AuthStyleInHeader, }), ) if err != nil { @@ -141,7 +156,7 @@ func setupTestOAuth2Config(clientID, clientSecret string) (oauth2Cfg OAuth2Confi // setupTestClient sets up a test HTTP server along with a Client that is // configured to talk to that test server. Tests should register handlers on // mux which provide mock responses for the API method being tested. -func setupTestClient(oauth2Cfg OAuth2Config, initialToken *oauth2.Token) (client *Client, mux *http.ServeMux, serverURL string, teardown func(), err error) { +func setupTestClient(token *xoauth2.Token) (client *Client, mux *http.ServeMux, serverURL string, teardown func(), err error) { // mux is the HTTP request multiplexer used with the test server. mux = http.NewServeMux() @@ -157,8 +172,7 @@ func setupTestClient(oauth2Cfg OAuth2Config, initialToken *oauth2.Token) (client serverURL = server.URL client, err = NewClient( context.Background(), - ClientOAuth2Config(oauth2Cfg), - ClientOAuth2InitialToken(initialToken), + ClientOAuth2TokenSource(xoauth2.StaticTokenSource(token)), ClientSandboxEnvironment(true), ClientBaseURL(serverURL+apiBasePathSandbox), ClientBasePublicURL(serverURL+apiPublicBasePathProd), @@ -261,7 +275,7 @@ func TestClientAuth(t *testing.T) { return } - client, clientMux, serverURL, clientTeardown, err := setupTestClient(oauth2Cfg, token) + client, clientMux, serverURL, clientTeardown, err := setupTestClient(token) if clientTeardown != nil { defer clientTeardown() } diff --git a/config.go b/config.go index 133bcdd..653b754 100644 --- a/config.go +++ b/config.go @@ -15,18 +15,15 @@ package efactura import ( - "golang.org/x/oauth2" + xoauth2 "golang.org/x/oauth2" ) // ClientConfig is the config used to create a Client type ClientConfig struct { - // OAuth2Config is the OAuth2 config used for creating the http.Client that - // autorefreshes the Token. - OAuth2Config OAuth2Config - // Token is the starting oauth2 Token (including the refresh token). + // TokenSource is the token source used for generating OAuth2 tokens. // Until this library will support authentication with the SPV certificate, // this must always be provided. - InitialToken *oauth2.Token + TokenSource xoauth2.TokenSource // Unless BaseURL is set, Sandbox controlls whether to use production // endpoints (if set to false) or test endpoints (if set to true). Sandbox bool @@ -47,17 +44,10 @@ type ClientConfig struct { // ClientConfigOption allows gradually modifying a ClientConfig type ClientConfigOption func(*ClientConfig) -// ClientOAuth2Config sets the OAuth2 config -func ClientOAuth2Config(oauth2Cfg OAuth2Config) ClientConfigOption { +// ClientOAuth2TokenSource sets the token source to use. +func ClientOAuth2TokenSource(tokenSource xoauth2.TokenSource) ClientConfigOption { return func(c *ClientConfig) { - c.OAuth2Config = oauth2Cfg - } -} - -// ClientOAuth2InitialToken sets the initial OAuth2 Token -func ClientOAuth2InitialToken(token *oauth2.Token) ClientConfigOption { - return func(c *ClientConfig) { - c.InitialToken = token + c.TokenSource = tokenSource } } diff --git a/oauth2.go b/oauth2.go deleted file mode 100644 index 36c9e1a..0000000 --- a/oauth2.go +++ /dev/null @@ -1,169 +0,0 @@ -// Copyright 2024 Victor Dodon -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License - -package efactura - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "math" - "time" - - "golang.org/x/oauth2" -) - -var ( - // Endpoint is the default endpoint for ANAF OAuth2 protocol - Endpoint = oauth2.Endpoint{ - AuthURL: "https://logincert.anaf.ro/anaf-oauth2/v1/authorize", - TokenURL: "https://logincert.anaf.ro/anaf-oauth2/v1/token", - AuthStyle: oauth2.AuthStyleInHeader, - } - - ErrInvalidOAuth2Credentials = errors.New("Invalid OAuth2 credentials") - ErrInvalidOAuth2Endpoint = errors.New("Invalid OAuth2 endpoint") -) - -type OAuth2Config struct { - oauth2.Config -} - -// OAuth2ConfigOption allows gradually modifying an OAuth2Config -type OAuth2ConfigOption func(*OAuth2Config) - -// OAuth2ConfigCredentials set client ID and client secret -func OAuth2ConfigCredentials(clientID, clientSecret string) OAuth2ConfigOption { - return func(c *OAuth2Config) { - c.ClientID, c.ClientSecret = clientID, clientSecret - } -} - -// OAuth2ConfigRedirectURL set the redirect URL -func OAuth2ConfigRedirectURL(redirectURL string) OAuth2ConfigOption { - return func(c *OAuth2Config) { - c.RedirectURL = redirectURL - } -} - -// OAuth2ConfigEndpoint sets the auth endpoint for the config. This should only -// be used with debugging/testing auth requests. -func OAuth2ConfigEndpoint(endpoint oauth2.Endpoint) OAuth2ConfigOption { - return func(c *OAuth2Config) { - c.Config.Endpoint = endpoint - } -} - -// MakeOAuth2Config creates a OAuth2Config using provided options. At least -// OAuth2ConfigCredentials must be provided, otherwise -// ErrInvoiceOAuth2Credentials will be returned. If an invalid endpoint if -// provided usingaOAuth2ConfigEndpoint, then ErrInvalidOAuth2Endpoint is -// returned. -func MakeOAuth2Config(opts ...OAuth2ConfigOption) (cfg OAuth2Config, err error) { - cfg.Endpoint = Endpoint - for _, opt := range opts { - opt(&cfg) - } - if !cfg.validCredentials() { - err = ErrInvalidOAuth2Credentials - } - if !cfg.validEndpoint() { - err = ErrInvalidOAuth2Endpoint - } - return -} - -func (c OAuth2Config) validCredentials() bool { - return c.ClientID != "" && c.ClientSecret != "" -} - -func (c OAuth2Config) validEndpoint() bool { - return c.Endpoint.AuthURL != "" && c.Endpoint.TokenURL != "" -} - -// Valid returns true if the config is valid (ie. is non-nil, has non-empty -// credentials, non-empty endpoint, and non-empty redirect URL). -func (c *OAuth2Config) Valid() bool { - return c != nil && c.validCredentials() && c.validEndpoint() && c.RedirectURL != "" -} - -// AuthCodeURL generates the code authorization URL. -func (c OAuth2Config) AuthCodeURL(state string) string { - return c.Config.AuthCodeURL(state, - oauth2.SetAuthURLParam("token_content_type", "jwt")) -} - -// Exchange converts an authorization code into a token. -func (c OAuth2Config) Exchange(ctx context.Context, code string) (*oauth2.Token, error) { - return c.Config.Exchange(ctx, code, - oauth2.SetAuthURLParam("token_content_type", "jwt")) -} - -// tokenJSON is the struct representing the HTTP response from OAuth2 -// providers returning a token or error in JSON form. -// https://datatracker.ietf.org/doc/html/rfc6749#section-5.1 -type tokenJSON struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - RefreshToken string `json:"refresh_token"` - ExpiresIn expirationTime `json:"expires_in"` -} - -func (e *tokenJSON) expiry() (t time.Time) { - if v := e.ExpiresIn; v != 0 { - return timeNow().Add(time.Duration(v) * time.Second) - } - return -} - -type expirationTime int32 - -func (e *expirationTime) UnmarshalJSON(b []byte) error { - if len(b) == 0 || string(b) == "null" { - return nil - } - var n json.Number - err := json.Unmarshal(b, &n) - if err != nil { - return err - } - i, err := n.Int64() - if err != nil { - return err - } - if i > math.MaxInt32 { - i = math.MaxInt32 - } - *e = expirationTime(i) - return nil -} - -// TokenFromJSON is a convenience method that parses an oauth2.Token from a -// JSON encoded value. -func TokenFromJSON(body []byte) (token *oauth2.Token, err error) { - var tj tokenJSON - if err = json.Unmarshal(body, &tj); err != nil { - err = fmt.Errorf("oauth2: cannot parse json: %v", err) - return - } - - token = &oauth2.Token{ - AccessToken: tj.AccessToken, - TokenType: tj.TokenType, - RefreshToken: tj.RefreshToken, - Expiry: tj.expiry(), - } - return -} diff --git a/oauth2/config.go b/oauth2/config.go new file mode 100644 index 0000000..c5f53e6 --- /dev/null +++ b/oauth2/config.go @@ -0,0 +1,173 @@ +// Copyright 2024 Victor Dodon +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +package oauth2 + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + xoauth2 "golang.org/x/oauth2" +) + +const ( + // AuthURL is the ANAF authorize URL for the OAuth2 protocol + AuthURL = "https://logincert.anaf.ro/anaf-oauth2/v1/authorize" + // TokenURL is the ANAF token URL for the OAuth2 protocol + TokenURL = "https://logincert.anaf.ro/anaf-oauth2/v1/token" + // RevokeTokenURL is the ANAF token revocation URL for the OAuth2 protocol + RevokeTokenURL = "https://logincert.anaf.ro/anaf-oauth2/v1/revoke" +) + +var ( + // Endpoint is the default endpoint for ANAF OAuth2 protocol + Endpoint = xoauth2.Endpoint{ + AuthURL: AuthURL, + TokenURL: TokenURL, + AuthStyle: xoauth2.AuthStyleInHeader, + } + + ErrInvalidOAuth2Credentials = errors.New("Invalid OAuth2 credentials") + ErrInvalidOAuth2Endpoint = errors.New("Invalid OAuth2 endpoint") + ErrInvalidOAuth2RedirectURL = errors.New("Invalid OAuth2 redirect URL") +) + +// Config is a wrapper over the golang.org/x/oauth2.Config. +type Config struct { + xoauth2.Config +} + +// ConfigOption allows gradually modifying a Config +type ConfigOption func(*Config) + +// ConfigCredentials set client ID and client secret +func ConfigCredentials(clientID, clientSecret string) ConfigOption { + return func(c *Config) { + c.ClientID, c.ClientSecret = clientID, clientSecret + } +} + +// ConfigRedirectURL set the redirect URL +func ConfigRedirectURL(redirectURL string) ConfigOption { + return func(c *Config) { + c.RedirectURL = redirectURL + } +} + +// ConfigEndpoint sets the auth endpoint for the config. This should only +// be used with debugging/testing auth requests. +func ConfigEndpoint(endpoint xoauth2.Endpoint) ConfigOption { + return func(c *Config) { + c.Config.Endpoint = endpoint + } +} + +// MakeConfig creates a Config using provided options. At least +// ConfigCredentials must be provided, otherwise +// ErrInvoiceOAuth2Credentials will be returned. If an invalid endpoint if +// provided using ConfigEndpoint, then ErrInvalidOAuth2Endpoint is +// returned. +func MakeConfig(opts ...ConfigOption) (cfg Config, err error) { + cfg.Endpoint = Endpoint + for _, opt := range opts { + opt(&cfg) + } + if !cfg.validCredentials() { + err = ErrInvalidOAuth2Credentials + return + } + if !cfg.validEndpoint() { + err = ErrInvalidOAuth2Endpoint + return + } + if cfg.RedirectURL == "" { + err = ErrInvalidOAuth2RedirectURL + return + } + return +} + +func (c Config) validCredentials() bool { + return c.ClientID != "" && c.ClientSecret != "" +} + +func (c Config) validEndpoint() bool { + return c.Endpoint.AuthURL != "" && c.Endpoint.TokenURL != "" +} + +// Valid returns true if the config is valid (ie. is non-nil, has non-empty +// credentials, non-empty endpoint, and non-empty redirect URL). +func (c *Config) Valid() bool { + return c != nil && c.validCredentials() && c.validEndpoint() && c.RedirectURL != "" +} + +// AuthCodeURL generates the code authorization URL. +func (c Config) AuthCodeURL(state string) string { + return c.Config.AuthCodeURL(state, + xoauth2.SetAuthURLParam("token_content_type", "jwt")) +} + +// Exchange converts an authorization code into a token. +func (c Config) Exchange(ctx context.Context, code string) (*xoauth2.Token, error) { + return c.Config.Exchange(ctx, code, + xoauth2.SetAuthURLParam("token_content_type", "jwt")) +} + +type tokenJSON struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in,omitempty"` + Expiry *time.Time `json:"expiry,omitempty"` +} + +// timeNow is time.Now but pulled out as a variable for tests. +var timeNow = time.Now + +func (tj *tokenJSON) expiry() (t time.Time) { + if e := tj.Expiry; e != nil { + return *e + } + if ei := tj.ExpiresIn; ei != 0 { + return timeNow().Add(time.Duration(ei) * time.Second) + } + return +} + +// TokenFromJSON is a convenience function that parses a xoauth2.Token from a +// JSON encoded value. This can parse both a JSON from the ANAF OAuth2 provider +// (with an expires_in field) or a JSON encoded xoauth2.Token (with an expiry +// field). +func TokenFromJSON(jsonData []byte) (token *xoauth2.Token, err error) { + var tj tokenJSON + if err = json.Unmarshal(jsonData, &tj); err != nil { + err = fmt.Errorf("efactura.oauth2: cannot parse json: %v", err) + return + } + + t := xoauth2.Token{ + AccessToken: tj.AccessToken, + TokenType: tj.TokenType, + RefreshToken: tj.RefreshToken, + Expiry: tj.expiry(), + } + if t.Type() != "Bearer" || t.AccessToken == "" || t.RefreshToken == "" { + err = fmt.Errorf("efactura.oauth2: malformed or incomplete token") + return + } + return &t, nil +} diff --git a/oauth2/doc.go b/oauth2/doc.go new file mode 100644 index 0000000..77cc470 --- /dev/null +++ b/oauth2/doc.go @@ -0,0 +1,16 @@ +// Copyright 2024 Victor Dodon +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +// Package oauth2 contains utilities for ANAF OAuth2 protocol. +package oauth2 diff --git a/oauth2/internal/LICENSE b/oauth2/internal/LICENSE new file mode 100644 index 0000000..6a66aea --- /dev/null +++ b/oauth2/internal/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/oauth2/internal/doc.go b/oauth2/internal/doc.go new file mode 100644 index 0000000..9f7636d --- /dev/null +++ b/oauth2/internal/doc.go @@ -0,0 +1,7 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package internal contains support packages for oauth2 package. +// This a clone of golang.org/x/oauth2/internal +package internal diff --git a/oauth2/internal/oauth2.go b/oauth2/internal/oauth2.go new file mode 100644 index 0000000..14989be --- /dev/null +++ b/oauth2/internal/oauth2.go @@ -0,0 +1,37 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package internal + +import ( + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" +) + +// ParseKey converts the binary contents of a private key file +// to an *rsa.PrivateKey. It detects whether the private key is in a +// PEM container or not. If so, it extracts the private key +// from PEM container before conversion. It only supports PEM +// containers with no passphrase. +func ParseKey(key []byte) (*rsa.PrivateKey, error) { + block, _ := pem.Decode(key) + if block != nil { + key = block.Bytes + } + parsedKey, err := x509.ParsePKCS8PrivateKey(key) + if err != nil { + parsedKey, err = x509.ParsePKCS1PrivateKey(key) + if err != nil { + return nil, fmt.Errorf("private key should be a PEM or plain PKCS1 or PKCS8; parse error: %v", err) + } + } + parsed, ok := parsedKey.(*rsa.PrivateKey) + if !ok { + return nil, errors.New("private key is invalid") + } + return parsed, nil +} diff --git a/oauth2/internal/token.go b/oauth2/internal/token.go new file mode 100644 index 0000000..e83ddee --- /dev/null +++ b/oauth2/internal/token.go @@ -0,0 +1,352 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package internal + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "math" + "mime" + "net/http" + "net/url" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +// Token represents the credentials used to authorize +// the requests to access protected resources on the OAuth 2.0 +// provider's backend. +// +// This type is a mirror of oauth2.Token and exists to break +// an otherwise-circular dependency. Other internal packages +// should convert this Token into an oauth2.Token before use. +type Token struct { + // AccessToken is the token that authorizes and authenticates + // the requests. + AccessToken string + + // TokenType is the type of token. + // The Type method returns either this or "Bearer", the default. + TokenType string + + // RefreshToken is a token that's used by the application + // (as opposed to the user) to refresh the access token + // if it expires. + RefreshToken string + + // Expiry is the optional expiration time of the access token. + // + // If zero, TokenSource implementations will reuse the same + // token forever and RefreshToken or equivalent + // mechanisms for that TokenSource will not be used. + Expiry time.Time + + // Raw optionally contains extra metadata from the server + // when updating a token. + Raw interface{} +} + +// tokenJSON is the struct representing the HTTP response from OAuth2 +// providers returning a token or error in JSON form. +// https://datatracker.ietf.org/doc/html/rfc6749#section-5.1 +type tokenJSON struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + RefreshToken string `json:"refresh_token"` + ExpiresIn expirationTime `json:"expires_in"` // at least PayPal returns string, while most return number + // error fields + // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 + ErrorCode string `json:"error"` + ErrorDescription string `json:"error_description"` + ErrorURI string `json:"error_uri"` +} + +func (e *tokenJSON) expiry() (t time.Time) { + if v := e.ExpiresIn; v != 0 { + return time.Now().Add(time.Duration(v) * time.Second) + } + return +} + +type expirationTime int32 + +func (e *expirationTime) UnmarshalJSON(b []byte) error { + if len(b) == 0 || string(b) == "null" { + return nil + } + var n json.Number + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + i, err := n.Int64() + if err != nil { + return err + } + if i > math.MaxInt32 { + i = math.MaxInt32 + } + *e = expirationTime(i) + return nil +} + +// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op. +// +// Deprecated: this function no longer does anything. Caller code that +// wants to avoid potential extra HTTP requests made during +// auto-probing of the provider's auth style should set +// Endpoint.AuthStyle. +func RegisterBrokenAuthHeaderProvider(tokenURL string) {} + +// AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type. +type AuthStyle int + +const ( + AuthStyleUnknown AuthStyle = 0 + AuthStyleInParams AuthStyle = 1 + AuthStyleInHeader AuthStyle = 2 +) + +// LazyAuthStyleCache is a backwards compatibility compromise to let Configs +// have a lazily-initialized AuthStyleCache. +// +// The two users of this, oauth2.Config and oauth2/clientcredentials.Config, +// both would ideally just embed an unexported AuthStyleCache but because both +// were historically allowed to be copied by value we can't retroactively add an +// uncopyable Mutex to them. +// +// We could use an atomic.Pointer, but that was added recently enough (in Go +// 1.18) that we'd break Go 1.17 users where the tests as of 2023-08-03 +// still pass. By using an atomic.Value, it supports both Go 1.17 and +// copying by value, even if that's not ideal. +type LazyAuthStyleCache struct { + v atomic.Value // of *AuthStyleCache +} + +func (lc *LazyAuthStyleCache) Get() *AuthStyleCache { + if c, ok := lc.v.Load().(*AuthStyleCache); ok { + return c + } + c := new(AuthStyleCache) + if !lc.v.CompareAndSwap(nil, c) { + c = lc.v.Load().(*AuthStyleCache) + } + return c +} + +// AuthStyleCache is the set of tokenURLs we've successfully used via +// RetrieveToken and which style auth we ended up using. +// It's called a cache, but it doesn't (yet?) shrink. It's expected that +// the set of OAuth2 servers a program contacts over time is fixed and +// small. +type AuthStyleCache struct { + mu sync.Mutex + m map[string]AuthStyle // keyed by tokenURL +} + +// lookupAuthStyle reports which auth style we last used with tokenURL +// when calling RetrieveToken and whether we have ever done so. +func (c *AuthStyleCache) lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + style, ok = c.m[tokenURL] + return +} + +// setAuthStyle adds an entry to authStyleCache, documented above. +func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) { + c.mu.Lock() + defer c.mu.Unlock() + if c.m == nil { + c.m = make(map[string]AuthStyle) + } + c.m[tokenURL] = v +} + +// newTokenRequest returns a new *http.Request to retrieve a new token +// from tokenURL using the provided clientID, clientSecret, and POST +// body parameters. +// +// inParams is whether the clientID & clientSecret should be encoded +// as the POST body. An 'inParams' value of true means to send it in +// the POST body (along with any values in v); false means to send it +// in the Authorization header. +func newTokenRequest(tokenURL, clientID, clientSecret string, v url.Values, authStyle AuthStyle) (*http.Request, error) { + if authStyle == AuthStyleInParams { + v = cloneURLValues(v) + if clientID != "" { + v.Set("client_id", clientID) + } + if clientSecret != "" { + v.Set("client_secret", clientSecret) + } + } + req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + if authStyle == AuthStyleInHeader { + req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret)) + } + return req, nil +} + +func cloneURLValues(v url.Values) url.Values { + v2 := make(url.Values, len(v)) + for k, vv := range v { + v2[k] = append([]string(nil), vv...) + } + return v2 +} + +func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) { + needsAuthStyleProbe := authStyle == 0 + if needsAuthStyleProbe { + if style, ok := styleCache.lookupAuthStyle(tokenURL); ok { + authStyle = style + needsAuthStyleProbe = false + } else { + authStyle = AuthStyleInHeader // the first way we'll try + } + } + req, err := newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle) + if err != nil { + return nil, err + } + token, err := doTokenRoundTrip(ctx, req) + if err != nil && needsAuthStyleProbe { + // If we get an error, assume the server wants the + // clientID & clientSecret in a different form. + // See https://code.google.com/p/goauth2/issues/detail?id=31 for background. + // In summary: + // - Reddit only accepts client secret in the Authorization header + // - Dropbox accepts either it in URL param or Auth header, but not both. + // - Google only accepts URL param (not spec compliant?), not Auth header + // - Stripe only accepts client secret in Auth header with Bearer method, not Basic + // + // We used to maintain a big table in this code of all the sites and which way + // they went, but maintaining it didn't scale & got annoying. + // So just try both ways. + authStyle = AuthStyleInParams // the second way we'll try + req, _ = newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle) + token, err = doTokenRoundTrip(ctx, req) + } + if needsAuthStyleProbe && err == nil { + styleCache.setAuthStyle(tokenURL, authStyle) + } + // Don't overwrite `RefreshToken` with an empty value + // if this was a token refreshing request. + if token != nil && token.RefreshToken == "" { + token.RefreshToken = v.Get("refresh_token") + } + return token, err +} + +func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) { + r, err := ContextClient(ctx).Do(req.WithContext(ctx)) + if err != nil { + return nil, err + } + body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20)) + r.Body.Close() + if err != nil { + return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) + } + + failureStatus := r.StatusCode < 200 || r.StatusCode > 299 + retrieveError := &RetrieveError{ + Response: r, + Body: body, + // attempt to populate error detail below + } + + var token *Token + content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) + switch content { + case "application/x-www-form-urlencoded", "text/plain": + // some endpoints return a query string + vals, err := url.ParseQuery(string(body)) + if err != nil { + if failureStatus { + return nil, retrieveError + } + return nil, fmt.Errorf("oauth2: cannot parse response: %v", err) + } + retrieveError.ErrorCode = vals.Get("error") + retrieveError.ErrorDescription = vals.Get("error_description") + retrieveError.ErrorURI = vals.Get("error_uri") + token = &Token{ + AccessToken: vals.Get("access_token"), + TokenType: vals.Get("token_type"), + RefreshToken: vals.Get("refresh_token"), + Raw: vals, + } + e := vals.Get("expires_in") + expires, _ := strconv.Atoi(e) + if expires != 0 { + token.Expiry = time.Now().Add(time.Duration(expires) * time.Second) + } + default: + var tj tokenJSON + if err = json.Unmarshal(body, &tj); err != nil { + if failureStatus { + return nil, retrieveError + } + return nil, fmt.Errorf("oauth2: cannot parse json: %v", err) + } + retrieveError.ErrorCode = tj.ErrorCode + retrieveError.ErrorDescription = tj.ErrorDescription + retrieveError.ErrorURI = tj.ErrorURI + token = &Token{ + AccessToken: tj.AccessToken, + TokenType: tj.TokenType, + RefreshToken: tj.RefreshToken, + Expiry: tj.expiry(), + Raw: make(map[string]interface{}), + } + json.Unmarshal(body, &token.Raw) // no error checks for optional fields + } + // according to spec, servers should respond status 400 in error case + // https://www.rfc-editor.org/rfc/rfc6749#section-5.2 + // but some unorthodox servers respond 200 in error case + if failureStatus || retrieveError.ErrorCode != "" { + return nil, retrieveError + } + if token.AccessToken == "" { + return nil, errors.New("oauth2: server response missing access_token") + } + return token, nil +} + +// mirrors oauth2.RetrieveError +type RetrieveError struct { + Response *http.Response + Body []byte + ErrorCode string + ErrorDescription string + ErrorURI string +} + +func (r *RetrieveError) Error() string { + if r.ErrorCode != "" { + s := fmt.Sprintf("oauth2: %q", r.ErrorCode) + if r.ErrorDescription != "" { + s += fmt.Sprintf(" %q", r.ErrorDescription) + } + if r.ErrorURI != "" { + s += fmt.Sprintf(" %q", r.ErrorURI) + } + return s + } + return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body) +} diff --git a/oauth2/internal/token_test.go b/oauth2/internal/token_test.go new file mode 100644 index 0000000..c08862a --- /dev/null +++ b/oauth2/internal/token_test.go @@ -0,0 +1,77 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package internal + +import ( + "context" + "fmt" + "io" + "math" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestRetrieveToken_InParams(t *testing.T) { + styleCache := new(AuthStyleCache) + const clientID = "client-id" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got, want := r.FormValue("client_id"), clientID; got != want { + t.Errorf("client_id = %q; want %q", got, want) + } + if got, want := r.FormValue("client_secret"), ""; got != want { + t.Errorf("client_secret = %q; want empty", got) + } + w.Header().Set("Content-Type", "application/json") + io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`) + })) + defer ts.Close() + _, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleInParams, styleCache) + if err != nil { + t.Errorf("RetrieveToken = %v; want no error", err) + } +} + +func TestRetrieveTokenWithContexts(t *testing.T) { + styleCache := new(AuthStyleCache) + const clientID = "client-id" + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`) + })) + defer ts.Close() + + _, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleUnknown, styleCache) + if err != nil { + t.Errorf("RetrieveToken (with background context) = %v; want no error", err) + } + + retrieved := make(chan struct{}) + cancellingts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-retrieved + })) + defer cancellingts.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}, AuthStyleUnknown, styleCache) + close(retrieved) + if err == nil { + t.Errorf("RetrieveToken (with cancelled context) = nil; want error") + } +} + +func TestExpiresInUpperBound(t *testing.T) { + var e expirationTime + if err := e.UnmarshalJSON([]byte(fmt.Sprint(int64(math.MaxInt32) + 1))); err != nil { + t.Fatal(err) + } + const want = math.MaxInt32 + if e != want { + t.Errorf("expiration time = %v; want %v", e, want) + } +} diff --git a/oauth2/internal/transport.go b/oauth2/internal/transport.go new file mode 100644 index 0000000..b9db01d --- /dev/null +++ b/oauth2/internal/transport.go @@ -0,0 +1,28 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package internal + +import ( + "context" + "net/http" +) + +// HTTPClient is the context key to use with golang.org/x/net/context's +// WithValue function to associate an *http.Client value with a context. +var HTTPClient ContextKey + +// ContextKey is just an empty struct. It exists so HTTPClient can be +// an immutable public variable with a unique type. It's immutable +// because nobody else can create a ContextKey, being unexported. +type ContextKey struct{} + +func ContextClient(ctx context.Context) *http.Client { + if ctx != nil { + if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok { + return hc + } + } + return http.DefaultClient +} diff --git a/oauth2/token_source.go b/oauth2/token_source.go new file mode 100644 index 0000000..fb0abf6 --- /dev/null +++ b/oauth2/token_source.go @@ -0,0 +1,159 @@ +// Copyright 2024 Victor Dodon +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "context" + "errors" + "net/url" + "sync" + + "github.com/printesoi/e-factura-go/oauth2/internal" + xoauth2 "golang.org/x/oauth2" +) + +type TokenChangedHandler func(ctx context.Context, t *xoauth2.Token) error + +// tokenFromInternal maps an *internal.Token struct into +// a *Token struct. +func tokenFromInternal(t *internal.Token) *xoauth2.Token { + if t == nil { + return nil + } + return &xoauth2.Token{ + AccessToken: t.AccessToken, + TokenType: t.TokenType, + RefreshToken: t.RefreshToken, + Expiry: t.Expiry, + } +} + +// retrieveToken takes a *Config and uses that to retrieve an *internal.Token. +// This token is then mapped from *internal.Token into an *oauth2.Token which +// is returned along with an error.. +func retrieveToken(ctx context.Context, c *xoauth2.Config, v url.Values) (*xoauth2.Token, error) { + tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), nil) + if err != nil { + if rErr, ok := err.(*internal.RetrieveError); ok { + return nil, (*xoauth2.RetrieveError)(rErr) + } + return nil, err + } + return tokenFromInternal(tk), nil +} + +// tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token" +// HTTP requests to renew a token using a RefreshToken. When the token is +// refreshed, onTokenChanged is called. +type tokenRefresher struct { + ctx context.Context // used to get HTTP requests + conf *xoauth2.Config + refreshToken string + onTokenChanged TokenChangedHandler +} + +// WARNING: Token is not safe for concurrent access, as it +// updates the tokenRefresher's refreshToken field. +// Within this package, it is used by reuseTokenSource which +// synchronizes calls to this method with its own mutex. +func (tf *tokenRefresher) Token() (*xoauth2.Token, error) { + if tf.refreshToken == "" { + return nil, errors.New("oauth2: token expired and refresh token is not set") + } + + tk, err := retrieveToken(tf.ctx, tf.conf, url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {tf.refreshToken}, + }) + + if err != nil { + return nil, err + } + if tf.refreshToken != tk.RefreshToken { + tf.refreshToken = tk.RefreshToken + if err := tf.onTokenChanged(tf.ctx, tk); err != nil { + return tk, err + } + } + return tk, err +} + +// reuseTokenSource is a TokenSource that holds a single token in memory +// and validates its expiry before each call to retrieve it with +// Token. If it's expired, it will be auto-refreshed using the +// new TokenSource. +type reuseTokenSource struct { + new xoauth2.TokenSource // called when t is expired. + + mu sync.Mutex // guards t + t *xoauth2.Token +} + +// Token returns the current token if it's still valid, else will +// refresh the current token (using r.Context for HTTP client +// information) and return the new one. +func (s *reuseTokenSource) Token() (*xoauth2.Token, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.t.Valid() { + return s.t, nil + } + t, err := s.new.Token() + if err != nil { + return nil, err + } + s.t = t + return t, nil +} + +// TokenSource returns a TokenSource that returns t until t expires, +// automatically refreshing it as necessary using the provided context. +func (c *Config) TokenSource(ctx context.Context, t *xoauth2.Token) xoauth2.TokenSource { + tkr := &tokenRefresher{ + ctx: ctx, + conf: &c.Config, + } + if t != nil { + tkr.refreshToken = t.RefreshToken + } + return &reuseTokenSource{ + t: t, + new: tkr, + } +} + +// TokenSourceWithChangedHandler returns a TokenSource that returns t until t +// expires, automatically refreshing it as necessary using the provided +// context. Every time the access token is refreshed, the onTokenChanged +// handler is called. This is useful if you need to update the token in a +// store/db. +func (c *Config) TokenSourceWithChangedHandler(ctx context.Context, t *xoauth2.Token, onTokenChanged TokenChangedHandler) xoauth2.TokenSource { + tkr := &tokenRefresher{ + ctx: ctx, + conf: &c.Config, + onTokenChanged: onTokenChanged, + } + if t != nil { + tkr.refreshToken = t.RefreshToken + } + return &reuseTokenSource{ + t: t, + new: tkr, + } +}