diff --git a/cli/README.md b/cli/README.md index 4f9d71b80f..b57240fac2 100644 --- a/cli/README.md +++ b/cli/README.md @@ -34,6 +34,70 @@ supermq-cli users create supermq-cli users token ``` +#### OAuth Authentication + +Authenticate using OAuth providers (e.g., Google) to obtain access tokens using the device authorization flow. + +```bash +supermq-cli users oauth +``` + +Example with Google: + +```bash +supermq-cli users oauth google +``` + +**How it works (Device Authorization Flow):** + +The CLI uses the OAuth 2.0 Device Authorization Flow, which is specifically designed for command-line applications and devices with limited input capabilities: + +1. The CLI requests a device code from the OAuth provider +2. You receive a short user code (e.g., `ABCD-EFGH`) and a verification URL +3. Visit the verification URL in any browser (on any device) +4. Enter the user code when prompted +5. Authorize the application in your browser +6. The CLI automatically detects the authorization and retrieves your tokens +7. Tokens are displayed in JSON format + +**Example Output:** + +``` +=== OAuth Device Authorization === + +Please complete authentication in your browser: + + 1. Visit: https://auth.example.com/device + 2. Enter code: ABCD-EFGH + +Waiting for authorization... + +✓ Authentication successful! + +{ + "access_token": "eyJhbGc...", + "refresh_token": "eyJhbGc..." +} +``` + +**Benefits of Device Flow:** + +- **No local server needed** - Works in any environment (containers, SSH, etc.) +- **Cross-device authentication** - Authenticate on your phone/tablet while using CLI +- **Better security** - No need to open ports or handle callbacks +- **Works everywhere** - Headless servers, restricted networks, Docker containers + +**Prerequisites:** + +The OAuth provider must be configured on the SuperMQ server. For Google OAuth, the following environment variables must be set: + +- `SMQ_GOOGLE_CLIENT_ID` - Google OAuth client ID +- `SMQ_GOOGLE_CLIENT_SECRET` - Google OAuth client secret +- `SMQ_GOOGLE_REDIRECT_URL` - OAuth redirect URL (e.g., `http://localhost:9002/oauth/callback/google`) +- `SMQ_GOOGLE_STATE` - OAuth state parameter for security + +See the [Users Service README](../users/README.md#oauth-configuration) for more details. + #### Get User ```bash diff --git a/cli/oauth.go b/cli/oauth.go new file mode 100644 index 0000000000..0d54783fb6 --- /dev/null +++ b/cli/oauth.go @@ -0,0 +1,191 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cli + +import ( + "context" + "fmt" + "net" + "net/http" + "os/exec" + "runtime" + "strings" + "sync" + "time" + + "github.com/spf13/cobra" +) + +const ( + callbackPath = "/callback" + localServerPort = "9090" + callbackTimeout = 5 * time.Minute + shutdownTimeout = 5 * time.Second +) + +type oauthCallbackResult struct { + code string + state string + err error +} + +type browserOpener interface { + Open(url string) error +} + +type defaultBrowserOpener struct{} + +func (defaultBrowserOpener) Open(url string) error { + return openBrowser(url) +} + +type callbackServer struct { + listener net.Listener + server *http.Server +} + +func newCallbackServer(resultChan chan<- oauthCallbackResult) (*callbackServer, error) { + listener, err := net.Listen("tcp", "127.0.0.1:"+localServerPort) + if err != nil { + return nil, fmt.Errorf("failed to start local server: %w", err) + } + + mux := http.NewServeMux() + server := &http.Server{Handler: mux} + + var once sync.Once + mux.HandleFunc(callbackPath, func(w http.ResponseWriter, r *http.Request) { + handleOAuthCallback(w, r, resultChan, &once) + }) + + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + once.Do(func() { + resultChan <- oauthCallbackResult{err: fmt.Errorf("server error: %w", err)} + }) + } + }() + + return &callbackServer{ + listener: listener, + server: server, + }, nil +} + +func (cs *callbackServer) Shutdown(cmd *cobra.Command) { + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + defer cancel() + if err := cs.server.Shutdown(ctx); err != nil { + logErrorCmd(*cmd, err) + } +} + +func performOAuthLogin(cmd *cobra.Command, provider string) error { + return performOAuthLoginWithBrowser(cmd, provider, defaultBrowserOpener{}) +} + +func performOAuthLoginWithBrowser(cmd *cobra.Command, provider string, browser browserOpener) error { + ctx := cmd.Context() + callbackChan := make(chan oauthCallbackResult, 1) + + server, err := newCallbackServer(callbackChan) + if err != nil { + return err + } + defer server.Shutdown(cmd) + + callbackURL := fmt.Sprintf("http://127.0.0.1:%s%s", localServerPort, callbackPath) + authURL, state, err := sdk.OAuthAuthorizationURL(ctx, provider, callbackURL) + if err != nil { + return fmt.Errorf("failed to get authorization URL: %w", err) + } + + printAuthInstructions(authURL) + if err := browser.Open(authURL); err != nil { + fmt.Printf("Failed to open browser automatically: %v\n", err) + } + + fmt.Println("Waiting for authentication callback...") + + result, err := waitForCallback(callbackChan) + if err != nil { + return err + } + + fmt.Println("Exchanging authorization code for tokens...") + token, err := sdk.OAuthCallback(ctx, provider, result.code, state, callbackURL) + if err != nil { + return fmt.Errorf("failed to exchange code for token: %w", err) + } + + logJSONCmd(*cmd, token) + fmt.Println("\nAuthentication successful! You can now use the access_token for API requests.") + + return nil +} + +func handleOAuthCallback(w http.ResponseWriter, r *http.Request, resultChan chan<- oauthCallbackResult, once *sync.Once) { + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + errParam := r.URL.Query().Get("error") + + w.Header().Set("Content-Type", "text/html") + + if errParam != "" { + html := strings.Replace(errorHTML, "{{ERROR_MESSAGE}}", errParam, 1) + fmt.Fprint(w, html) + once.Do(func() { + resultChan <- oauthCallbackResult{err: fmt.Errorf("oauth error: %s", errParam)} + }) + return + } + + if code == "" { + html := strings.Replace(errorHTML, "{{ERROR_MESSAGE}}", "missing authorization code", 1) + fmt.Fprint(w, html) + once.Do(func() { + resultChan <- oauthCallbackResult{err: fmt.Errorf("missing authorization code")} + }) + return + } + + fmt.Fprint(w, successHTML) + once.Do(func() { + resultChan <- oauthCallbackResult{code: code, state: state} + }) +} + +func printAuthInstructions(authURL string) { + fmt.Printf("Opening browser for authentication...\n") + fmt.Printf("If the browser doesn't open automatically, please visit:\n%s\n\n", authURL) +} + +func waitForCallback(callbackChan <-chan oauthCallbackResult) (oauthCallbackResult, error) { + select { + case result := <-callbackChan: + if result.err != nil { + return oauthCallbackResult{}, fmt.Errorf("callback error: %w", result.err) + } + return result, nil + case <-time.After(callbackTimeout): + return oauthCallbackResult{}, fmt.Errorf("authentication timeout after %v", callbackTimeout) + } +} + +func openBrowser(url string) error { + var cmd *exec.Cmd + + switch runtime.GOOS { + case "linux": + cmd = exec.Command("xdg-open", url) + case "darwin": + cmd = exec.Command("open", url) + case "windows": + cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url) + default: + return fmt.Errorf("unsupported platform") + } + + return cmd.Start() +} diff --git a/cli/oauth_device.go b/cli/oauth_device.go new file mode 100644 index 0000000000..2d78eba4f0 --- /dev/null +++ b/cli/oauth_device.go @@ -0,0 +1,100 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cli + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/fatih/color" + "github.com/spf13/cobra" +) + +const ( + pollInterval = 3 * time.Second + pollTimeout = 10 * time.Minute +) + +func performOAuthDeviceLogin(cmd *cobra.Command, provider string) error { + ctx := cmd.Context() + + // Step 1: Get device code + deviceCode, err := sdk.OAuthDeviceCode(ctx, provider) + if err != nil { + return fmt.Errorf("failed to get device code: %w", err) + } + + // Step 2: Display instructions to user + printDeviceInstructions(deviceCode.VerificationURI, deviceCode.UserCode) + + // Step 3: Poll for authorization + token, pollErr := pollForAuthorization(ctx, provider, deviceCode.DeviceCode, deviceCode.Interval) + if pollErr != nil { + return fmt.Errorf("authorization failed: %w", pollErr) + } + + // Step 4: Display success message + logJSONCmd(*cmd, token) + successMsg := color.New(color.FgGreen, color.Bold).SprintFunc() + fmt.Printf("\n%s\n", successMsg("✓ Authentication successful!")) + fmt.Println("You can now use the access_token for API requests.") + + return nil +} + +func printDeviceInstructions(verificationURI, userCode string) { + fmt.Println() + fmt.Println(color.New(color.FgCyan, color.Bold).Sprint("=== OAuth Device Authorization ===")) + fmt.Println() + fmt.Println(color.New(color.FgYellow).Sprint("Please complete authentication in your browser:")) + fmt.Println() + fmt.Printf(" 1. Visit: %s\n", color.New(color.FgBlue, color.Underline).Sprint(verificationURI)) + fmt.Printf(" 2. Enter code: %s\n", color.New(color.FgGreen, color.Bold).Sprint(userCode)) + fmt.Println() + fmt.Println(color.New(color.FgWhite).Sprint("Waiting for authorization...")) + fmt.Println() +} + +func pollForAuthorization(ctx context.Context, provider, deviceCode string, interval int) (interface{}, error) { + pollDuration := time.Duration(interval) * time.Second + if pollDuration < pollInterval { + pollDuration = pollInterval + } + + ticker := time.NewTicker(pollDuration) + defer ticker.Stop() + + timeout := time.After(pollTimeout) + spinner := []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"} + spinnerIdx := 0 + + for { + select { + case <-timeout: + return nil, fmt.Errorf("authentication timeout after %v", pollTimeout) + + case <-ticker.C: + // Show spinner + fmt.Printf("\r%s Polling for authorization...", color.CyanString(spinner[spinnerIdx])) + spinnerIdx = (spinnerIdx + 1) % len(spinner) + + token, err := sdk.OAuthDeviceToken(ctx, provider, deviceCode) + if err != nil { + errMsg := err.Error() + // Check if it's a pending error (expected during polling) + if strings.Contains(errMsg, "authorization pending") || strings.Contains(errMsg, "slow down") { + continue + } + // Any other error is a real failure + return nil, fmt.Errorf("failed to get token: %w", err) + } + + // Clear the spinner line + fmt.Print("\r" + string(make([]byte, 50)) + "\r") + return token, nil + } + } +} diff --git a/cli/oauth_device_test.go b/cli/oauth_device_test.go new file mode 100644 index 0000000000..5f52b5d19c --- /dev/null +++ b/cli/oauth_device_test.go @@ -0,0 +1,315 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cli + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + smqsdk "github.com/absmach/supermq/pkg/sdk" + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPerformOAuthDeviceLogin(t *testing.T) { + tests := []struct { + name string + provider string + mockResponses []mockResponse + expectedErr bool + errContains string + }{ + { + name: "successful device flow", + provider: "google", + mockResponses: []mockResponse{ + { + path: "/oauth/device/code/google", + response: `{ + "device_code": "device_code_123", + "user_code": "ABCD-EFGH", + "verification_uri": "https://example.com/device", + "expires_in": 600, + "interval": 1 + }`, + status: http.StatusOK, + }, + { + path: "/oauth/device/token/google", + response: `{"error": "authorization pending"}`, + status: http.StatusAccepted, + }, + { + path: "/oauth/device/token/google", + response: `{ + "access_token": "access_token_123", + "refresh_token": "refresh_token_456" + }`, + status: http.StatusOK, + }, + }, + expectedErr: false, + }, + { + name: "device code request fails", + provider: "google", + mockResponses: []mockResponse{ + { + path: "/oauth/device/code/google", + response: `{"error": "oauth provider is disabled"}`, + status: http.StatusNotFound, + }, + }, + expectedErr: true, + errContains: "failed to get device code", + }, + { + name: "authorization denied", + provider: "google", + mockResponses: []mockResponse{ + { + path: "/oauth/device/code/google", + response: `{ + "device_code": "device_code_123", + "user_code": "ABCD-EFGH", + "verification_uri": "https://example.com/device", + "expires_in": 600, + "interval": 1 + }`, + status: http.StatusOK, + }, + { + path: "/oauth/device/token/google", + response: `{"error": "access denied"}`, + status: http.StatusUnauthorized, + }, + }, + expectedErr: true, + errContains: "failed to get token", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if callCount < len(tc.mockResponses) { + mock := tc.mockResponses[callCount] + assert.Equal(t, mock.path, r.URL.Path) + w.WriteHeader(mock.status) + w.Write([]byte(mock.response)) + callCount++ + } + })) + defer server.Close() + + // Set up SDK for testing + sdkConf := smqsdk.Config{ + UsersURL: server.URL, + } + sdk = smqsdk.NewSDK(sdkConf) + + cmd := &cobra.Command{} + cmd.SetContext(context.Background()) + + err := performOAuthDeviceLogin(cmd, tc.provider) + + if tc.expectedErr { + require.Error(t, err) + if tc.errContains != "" { + assert.Contains(t, err.Error(), tc.errContains) + } + return + } + + require.NoError(t, err) + }) + } +} + +func TestPollForAuthorization(t *testing.T) { + tests := []struct { + name string + deviceCode string + interval int + mockResponses []pollResponse + expectedErr bool + errContains string + }{ + { + name: "successful authorization after pending", + deviceCode: "device123", + interval: 1, + mockResponses: []pollResponse{ + { + response: `{"error": "authorization pending"}`, + status: http.StatusAccepted, + delay: 0, + }, + { + response: `{ + "access_token": "access_token_123", + "refresh_token": "refresh_token_456" + }`, + status: http.StatusOK, + delay: 0, + }, + }, + expectedErr: false, + }, + { + name: "slow down response", + deviceCode: "device123", + interval: 1, + mockResponses: []pollResponse{ + { + response: `{"error": "slow down"}`, + status: http.StatusAccepted, + delay: 0, + }, + { + response: `{ + "access_token": "access_token_123", + "refresh_token": "refresh_token_456" + }`, + status: http.StatusOK, + delay: 0, + }, + }, + expectedErr: false, + }, + { + name: "access denied", + deviceCode: "device123", + interval: 1, + mockResponses: []pollResponse{ + { + response: `{"error": "access denied"}`, + status: http.StatusUnauthorized, + delay: 0, + }, + }, + expectedErr: true, + errContains: "failed to get token", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if callCount < len(tc.mockResponses) { + mock := tc.mockResponses[callCount] + time.Sleep(mock.delay) + w.WriteHeader(mock.status) + w.Write([]byte(mock.response)) + callCount++ + } + })) + defer server.Close() + + sdkConf := smqsdk.Config{ + UsersURL: server.URL, + } + sdk = smqsdk.NewSDK(sdkConf) + + ctx := context.Background() + _, err := pollForAuthorization(ctx, "google", tc.deviceCode, tc.interval) + + if tc.expectedErr { + require.Error(t, err) + if tc.errContains != "" { + assert.Contains(t, err.Error(), tc.errContains) + } + return + } + + require.NoError(t, err) + }) + } +} + +func TestPrintDeviceInstructions(t *testing.T) { + // This test just ensures the function doesn't panic + t.Run("prints instructions without panic", func(t *testing.T) { + assert.NotPanics(t, func() { + printDeviceInstructions("https://example.com/device", "ABCD-EFGH") + }) + }) +} + +// Helper types for testing +type mockResponse struct { + path string + response string + status int +} + +type pollResponse struct { + response string + status int + delay time.Duration +} + +func TestDeviceCodeGeneration(t *testing.T) { + t.Run("device code is generated correctly", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := `{ + "device_code": "ABCDEFGHIJK123456", + "user_code": "WXYZ-1234", + "verification_uri": "https://example.com/verify", + "expires_in": 600, + "interval": 3 + }` + w.WriteHeader(http.StatusOK) + w.Write([]byte(response)) + })) + defer server.Close() + + sdkConf := smqsdk.Config{ + UsersURL: server.URL, + } + testSDK := smqsdk.NewSDK(sdkConf) + + deviceCode, err := testSDK.OAuthDeviceCode(context.Background(), "google") + require.NoError(t, err) + + assert.NotEmpty(t, deviceCode.DeviceCode) + assert.NotEmpty(t, deviceCode.UserCode) + assert.Contains(t, deviceCode.UserCode, "-") + assert.NotEmpty(t, deviceCode.VerificationURI) + assert.Greater(t, deviceCode.ExpiresIn, 0) + assert.Greater(t, deviceCode.Interval, 0) + }) +} + +func TestDeviceFlowTimeout(t *testing.T) { + t.Run("timeout after max duration", func(t *testing.T) { + // This test would take too long to run, so we skip it in normal test runs + // It's here to document the timeout behavior + t.Skip("Skipping timeout test - takes too long") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Always return pending + w.WriteHeader(http.StatusAccepted) + w.Write([]byte(`{"error": "authorization pending"}`)) + })) + defer server.Close() + + sdkConf := smqsdk.Config{ + UsersURL: server.URL, + } + sdk = smqsdk.NewSDK(sdkConf) + + ctx := context.Background() + _, err := pollForAuthorization(ctx, "google", "device123", 1) + + require.Error(t, err) + assert.Contains(t, err.Error(), "timeout") + }) +} diff --git a/cli/oauth_html.go b/cli/oauth_html.go new file mode 100644 index 0000000000..3efe66883e --- /dev/null +++ b/cli/oauth_html.go @@ -0,0 +1,493 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cli + +const successHTML = ` + + + + + Authentication Successful - Magistrala + + + +
+
+ +
+ +
+ + + +
+ +

Authentication Successful!

+

You have been successfully authenticated.

+

You can now close this window and return to the CLI.

+ + +
+ +` + +const errorHTML = ` + + + + + Authentication Failed - Magistrala + + + +
+
+ +
+ +
+
+
+ +

Authentication Failed

+

We encountered an error during authentication.

+ +
+ {{ERROR_MESSAGE}} +
+ +

Please close this window and try again.

+ + +
+ +` diff --git a/cli/oauth_test.go b/cli/oauth_test.go new file mode 100644 index 0000000000..bce8e96751 --- /dev/null +++ b/cli/oauth_test.go @@ -0,0 +1,324 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cli + +import ( + "errors" + "fmt" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockBrowserOpener struct { + opened string + err error +} + +func (m *mockBrowserOpener) Open(url string) error { + m.opened = url + return m.err +} + +func TestHandleOAuthCallback(t *testing.T) { + cases := []struct { + name string + queryParams string + expectedCode string + expectedState string + expectedErr bool + expectedHTML string + callTwice bool + secondCallSent bool + }{ + { + name: "successful callback", + queryParams: "?code=test-code&state=test-state", + expectedCode: "test-code", + expectedState: "test-state", + expectedErr: false, + expectedHTML: "Authentication Successful", + }, + { + name: "callback with error parameter", + queryParams: "?error=access_denied", + expectedCode: "", + expectedState: "", + expectedErr: true, + expectedHTML: "Authentication Failed", + }, + { + name: "callback with missing code", + queryParams: "?state=test-state", + expectedCode: "", + expectedState: "", + expectedErr: true, + expectedHTML: "missing authorization code", + }, + { + name: "multiple calls only process first", + queryParams: "?code=test-code&state=test-state", + expectedCode: "test-code", + expectedState: "test-state", + expectedErr: false, + expectedHTML: "Authentication Successful", + callTwice: true, + secondCallSent: false, + }, + { + name: "callback with empty state", + queryParams: "?code=test-code&state=", + expectedCode: "test-code", + expectedState: "", + expectedErr: false, + expectedHTML: "Authentication Successful", + }, + { + name: "callback with special characters in state", + queryParams: "?code=test-code&state=abc%2Fdef%3D123", + expectedCode: "test-code", + expectedState: "abc/def=123", + expectedErr: false, + expectedHTML: "Authentication Successful", + }, + { + name: "callback with both error and code", + queryParams: "?code=test-code&error=some_error", + expectedCode: "", + expectedState: "", + expectedErr: true, + expectedHTML: "Authentication Failed", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + resultChan := make(chan oauthCallbackResult, 1) + var once sync.Once + + req := httptest.NewRequest(http.MethodGet, "/callback"+tc.queryParams, nil) + w := httptest.NewRecorder() + + handleOAuthCallback(w, req, resultChan, &once) + + resp := w.Result() + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Contains(t, w.Body.String(), tc.expectedHTML) + + select { + case result := <-resultChan: + if tc.expectedErr { + assert.Error(t, result.err) + } else { + assert.NoError(t, result.err) + assert.Equal(t, tc.expectedCode, result.code) + assert.Equal(t, tc.expectedState, result.state) + } + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for result") + } + + if tc.callTwice { + w2 := httptest.NewRecorder() + handleOAuthCallback(w2, req, resultChan, &once) + + select { + case <-resultChan: + if !tc.secondCallSent { + t.Fatal("second call should not send to channel") + } + case <-time.After(100 * time.Millisecond): + // Expected - second call should not send to channel + } + } + }) + } +} + +func TestPrintAuthInstructions(t *testing.T) { + authURL := "https://example.com/oauth/authorize" + printAuthInstructions(authURL) +} + +func TestWaitForCallback(t *testing.T) { + cases := []struct { + name string + setupChan func() <-chan oauthCallbackResult + expectErr bool + expectedMsg string + timeout time.Duration + }{ + { + name: "successful callback", + setupChan: func() <-chan oauthCallbackResult { + ch := make(chan oauthCallbackResult, 1) + ch <- oauthCallbackResult{code: "test-code", state: "test-state"} + return ch + }, + expectErr: false, + }, + { + name: "callback with error", + setupChan: func() <-chan oauthCallbackResult { + ch := make(chan oauthCallbackResult, 1) + ch <- oauthCallbackResult{err: errors.New("oauth error")} + return ch + }, + expectErr: true, + expectedMsg: "callback error", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + callbackChan := tc.setupChan() + result, err := waitForCallback(callbackChan) + + if tc.expectErr { + assert.Error(t, err) + if tc.expectedMsg != "" { + assert.Contains(t, err.Error(), tc.expectedMsg) + } + } else { + assert.NoError(t, err) + assert.NotEmpty(t, result.code) + } + }) + } +} + +func TestOpenBrowser(t *testing.T) { + err := openBrowser("https://example.com") + // This might fail in CI environments, so we just check it doesn't panic + _ = err +} + +func TestCallbackServer(t *testing.T) { + t.Run("successful callback", func(t *testing.T) { + resultChan := make(chan oauthCallbackResult, 1) + + server, err := newCallbackServer(resultChan) + require.NoError(t, err) + require.NotNil(t, server) + defer server.Shutdown(nil) + + callbackURL := fmt.Sprintf("http://127.0.0.1:%s%s?code=test-code&state=test-state", localServerPort, callbackPath) + + resp, err := http.Get(callbackURL) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + select { + case result := <-resultChan: + assert.NoError(t, result.err) + assert.Equal(t, "test-code", result.code) + assert.Equal(t, "test-state", result.state) + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for callback result") + } + }) + + t.Run("callback with error parameter", func(t *testing.T) { + resultChan := make(chan oauthCallbackResult, 1) + + server, err := newCallbackServer(resultChan) + require.NoError(t, err) + require.NotNil(t, server) + defer server.Shutdown(nil) + + callbackURL := fmt.Sprintf("http://127.0.0.1:%s%s?error=access_denied", localServerPort, callbackPath) + + resp, err := http.Get(callbackURL) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + select { + case result := <-resultChan: + assert.Error(t, result.err) + assert.Contains(t, result.err.Error(), "access_denied") + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for callback result") + } + }) + + t.Run("callback with missing code", func(t *testing.T) { + resultChan := make(chan oauthCallbackResult, 1) + + server, err := newCallbackServer(resultChan) + require.NoError(t, err) + require.NotNil(t, server) + defer server.Shutdown(nil) + + callbackURL := fmt.Sprintf("http://127.0.0.1:%s%s?state=test-state", localServerPort, callbackPath) + + resp, err := http.Get(callbackURL) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + select { + case result := <-resultChan: + assert.Error(t, result.err) + assert.Contains(t, result.err.Error(), "missing authorization code") + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for callback result") + } + }) +} + +func TestDefaultBrowserOpener(t *testing.T) { + opener := defaultBrowserOpener{} + err := opener.Open("https://example.com") + // This might fail in CI environments, so we just check it returns an error type + _ = err +} + +func TestMockBrowserOpener(t *testing.T) { + cases := []struct { + name string + url string + setupErr error + expectErr bool + }{ + { + name: "successful browser open", + url: "https://example.com/auth", + setupErr: nil, + expectErr: false, + }, + { + name: "browser fails to open", + url: "https://example.com/auth", + setupErr: errors.New("browser failed"), + expectErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + mockBrowser := &mockBrowserOpener{err: tc.setupErr} + err := mockBrowser.Open(tc.url) + + if tc.expectErr { + assert.Error(t, err) + assert.Equal(t, tc.url, mockBrowser.opened) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.url, mockBrowser.opened) + } + }) + } +} diff --git a/cli/users.go b/cli/users.go index 1a8f3fb450..196c0305e5 100644 --- a/cli/users.go +++ b/cli/users.go @@ -25,6 +25,7 @@ const ( username = "username" email = "email" role = "role" + oauth = "oauth" // Usage strings for user operations. usageUserCreate = "cli users create [user_auth_token]" @@ -53,14 +54,15 @@ Available update options: usageUserSearch = "cli users search \nQuery format: username=|firstname=|lastname=|id=[&offset=][&limit=]\nExample: cli users search \"username=john_doe\" " usageUserSendVerification = "cli users sendverification " usageUserVerifyEmail = "cli users verifyemail " + usageUserOAuth = "cli users oauth " ) func NewUsersCmd() *cobra.Command { cmd := &cobra.Command{ - Use: "users [operation] [args...]", + Use: "users [operation] [args...]", Short: "Users management", - Long: `Format: - users [args...] + Long: `Format: + users [args...] users [args...] Operations (require user_id/all): get, update, enable, disable, delete @@ -77,6 +79,7 @@ Examples: users search "firstname=john&limit=10" users sendverification users verifyemail + users oauth users all get users get users update @@ -176,6 +179,9 @@ Examples: } handleUserSearch(cmd, args[1], args[2:]) return + case oauth: + handleUserOAuth(cmd, args[1], args[2:]) + return } if len(args) < 2 { @@ -590,3 +596,16 @@ func handleUserSearch(cmd *cobra.Command, query string, args []string) { logJSONCmd(*cmd, users) } + +func handleUserOAuth(cmd *cobra.Command, provider string, args []string) { + if len(args) != 0 { + logUsageCmd(*cmd, usageUserOAuth) + return + } + + // Use device flow by default for better CLI experience + if err := performOAuthDeviceLogin(cmd, provider); err != nil { + logErrorCmd(*cmd, err) + return + } +} diff --git a/cmd/users/main.go b/cmd/users/main.go index 78eceda6ba..59a5eaedb3 100644 --- a/cmd/users/main.go +++ b/cmd/users/main.go @@ -19,6 +19,7 @@ import ( grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1" grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1" grpcUsersV1 "github.com/absmach/supermq/api/grpc/users/v1" + redisclient "github.com/absmach/supermq/internal/clients/redis" "github.com/absmach/supermq/internal/email" smqlog "github.com/absmach/supermq/logger" smqauthn "github.com/absmach/supermq/pkg/authn" @@ -60,16 +61,18 @@ import ( ) const ( - svcName = "users" - envPrefixDB = "SMQ_USERS_DB_" - envPrefixHTTP = "SMQ_USERS_HTTP_" - envPrefixGRPC = "SMQ_USERS_GRPC_" - envPrefixAuth = "SMQ_AUTH_GRPC_" - envPrefixDomains = "SMQ_DOMAINS_GRPC_" - envPrefixGoogle = "SMQ_GOOGLE_" - defDB = "users" - defSvcHTTPPort = "9002" - defSvcGRPCPort = "7002" + svcName = "users" + envPrefixDB = "SMQ_USERS_DB_" + envPrefixHTTP = "SMQ_USERS_HTTP_" + envPrefixGRPC = "SMQ_USERS_GRPC_" + envPrefixAuth = "SMQ_AUTH_GRPC_" + envPrefixDomains = "SMQ_DOMAINS_GRPC_" + envPrefixGoogle = "SMQ_GOOGLE_" + envPrefixGoogleDevice = "SMQ_GOOGLE_DEVICE_" + envPrefixGoogleUser = "SMQ_GOOGLE_USER_" + defDB = "users" + defSvcHTTPPort = "9002" + defSvcGRPCPort = "7002" ) type config struct { @@ -97,6 +100,7 @@ type config struct { PasswordResetEmailTemplate string `env:"SMQ_PASSWORD_RESET_EMAIL_TEMPLATE" envDefault:"reset-password-email.tmpl"` VerificationURLPrefix string `env:"SMQ_VERIFICATION_URL_PREFIX" envDefault:"http://localhost/verify-email"` VerificationEmailTemplate string `env:"SMQ_VERIFICATION_EMAIL_TEMPLATE" envDefault:"verification-email.tmpl"` + CacheURL string `env:"SMQ_USERS_CACHE_URL" envDefault:"redis://localhost:6379/0"` PassRegex *regexp.Regexp } @@ -152,6 +156,13 @@ func main() { exitCode = 1 return } + cacheClient, err := redisclient.Connect(cfg.CacheURL) + if err != nil { + logger.Error(fmt.Sprintf("failed to connect to redis: %s", err)) + exitCode = 1 + return + } + defer cacheClient.Close() migration := postgres.Migration() db, err := pgclient.Setup(dbConfig, *migration) @@ -265,17 +276,59 @@ func main() { return } - oauthConfig := oauth2.Config{} - if err := env.ParseWithOptions(&oauthConfig, env.Options{Prefix: envPrefixGoogle}); err != nil { - logger.Error(fmt.Sprintf("failed to load %s Google configuration : %s", svcName, err.Error())) + // Try to load separate device and user OAuth configs first + deviceOauthConfig := oauth2.DeviceConfig{} + if err := env.ParseWithOptions(&deviceOauthConfig, env.Options{Prefix: envPrefixGoogleDevice}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s Google device configuration : %s", svcName, err.Error())) exitCode = 1 return } - oauthProvider := googleoauth.NewProvider(oauthConfig, cfg.OAuthUIRedirectURL, cfg.OAuthUIErrorURL) + + userOauthConfig := oauth2.UserConfig{} + if err := env.ParseWithOptions(&userOauthConfig, env.Options{Prefix: envPrefixGoogleUser}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s Google user configuration : %s", svcName, err.Error())) + exitCode = 1 + return + } + + // Fallback to legacy config if new configs are not set + if deviceOauthConfig.ClientID == "" { + legacyConfig := oauth2.Config{} + if err := env.ParseWithOptions(&legacyConfig, env.Options{Prefix: envPrefixGoogle}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s Google configuration : %s", svcName, err.Error())) + exitCode = 1 + return + } + deviceOauthConfig = oauth2.DeviceConfig{ + ClientID: legacyConfig.ClientID, + ClientSecret: legacyConfig.ClientSecret, + State: legacyConfig.State, + RedirectURL: legacyConfig.RedirectURL, + } + } + + if userOauthConfig.ClientID == "" { + legacyConfig := oauth2.Config{} + if err := env.ParseWithOptions(&legacyConfig, env.Options{Prefix: envPrefixGoogle}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s Google configuration : %s", svcName, err.Error())) + exitCode = 1 + return + } + userOauthConfig = oauth2.UserConfig{ + ClientID: legacyConfig.ClientID, + ClientSecret: legacyConfig.ClientSecret, + State: legacyConfig.State, + RedirectURL: legacyConfig.RedirectURL, + } + } + + deviceProvider := []oauth2.Provider{googleoauth.NewProvider(deviceOauthConfig.ToConfig(), cfg.OAuthUIRedirectURL, cfg.OAuthUIErrorURL)} + userProvider := []oauth2.Provider{googleoauth.NewProvider(userOauthConfig.ToConfig(), cfg.OAuthUIRedirectURL, cfg.OAuthUIErrorURL)} mux := chi.NewRouter() idp := uuid.New() - httpSrv := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(csvc, authnMiddleware, tokenClient, cfg.SelfRegister, mux, logger, cfg.InstanceID, cfg.PassRegex, idp, oauthProvider), logger) + handler := httpapi.MakeHandler(csvc, authnMiddleware, tokenClient, cfg.SelfRegister, mux, logger, cfg.InstanceID, cfg.PassRegex, idp, cacheClient, deviceProvider, userProvider) + httpSrv := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, handler, logger) if cfg.SendTelemetry { chc := chclient.New(svcName, supermq.Version, logger, cancel) diff --git a/docker/.env b/docker/.env index 17974c3928..efd6ab2087 100644 --- a/docker/.env +++ b/docker/.env @@ -269,6 +269,7 @@ SMQ_PASSWORD_RESET_URL_PREFIX=http://localhost/password-reset SMQ_PASSWORD_RESET_EMAIL_TEMPLATE=reset-password-email.tmpl SMQ_VERIFICATION_URL_PREFIX=http://localhost/verify-email SMQ_VERIFICATION_EMAIL_TEMPLATE=verification-email.tmpl +SMQ_USERS_CACHE_URL=redis://users-redis:${SMQ_REDIS_TCP_PORT}/0 #### Users Client Config SMQ_USERS_URL=http://users:9002 @@ -294,11 +295,27 @@ SMQ_NOTIFICATIONS_LOG_LEVEL=debug SMQ_NOTIFICATIONS_INSTANCE_ID= ### Google OAuth2 +# Legacy OAuth2 configuration (for backward compatibility) +# If device/user specific configs are not set, these will be used for both flows SMQ_GOOGLE_CLIENT_ID= SMQ_GOOGLE_CLIENT_SECRET= SMQ_GOOGLE_REDIRECT_URL= SMQ_GOOGLE_STATE= +### Google OAuth2 - Device Flow (CLI) +# Separate OAuth2 client for device/CLI authentication flow +SMQ_GOOGLE_DEVICE_CLIENT_ID= +SMQ_GOOGLE_DEVICE_CLIENT_SECRET= +SMQ_GOOGLE_DEVICE_REDIRECT_URL= +SMQ_GOOGLE_DEVICE_STATE= + +### Google OAuth2 - User Flow (Web) +# Separate OAuth2 client for web user authentication flow +SMQ_GOOGLE_USER_CLIENT_ID= +SMQ_GOOGLE_USER_CLIENT_SECRET= +SMQ_GOOGLE_USER_REDIRECT_URL= +SMQ_GOOGLE_USER_STATE= + ### Groups SMQ_GROUPS_LOG_LEVEL=debug SMQ_GROUPS_HTTP_HOST=groups diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 518866ef3d..c5c3e30478 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -10,6 +10,7 @@ networks: volumes: supermq-users-db-volume: + supermq-users-redis-volume: supermq-groups-db-volume: supermq-clients-db-volume: supermq-channels-db-volume: @@ -823,11 +824,21 @@ services: volumes: - supermq-users-db-volume:/var/lib/postgresql/data + users-redis: + image: docker.io/redis:8.2.2-alpine3.22 + container_name: supermq-users-redis + restart: on-failure + networks: + - supermq-base-net + volumes: + - supermq-users-redis-volume:/data + users: image: docker.io/supermq/users:${SMQ_RELEASE_TAG} container_name: supermq-users depends_on: - users-db + - users-redis - auth - nats restart: on-failure @@ -864,6 +875,7 @@ services: SMQ_USERS_DB_SSL_CERT: ${SMQ_USERS_DB_SSL_CERT} SMQ_USERS_DB_SSL_KEY: ${SMQ_USERS_DB_SSL_KEY} SMQ_USERS_DB_SSL_ROOT_CERT: ${SMQ_USERS_DB_SSL_ROOT_CERT} + SMQ_USERS_CACHE_URL: ${SMQ_USERS_CACHE_URL} SMQ_USERS_ALLOW_SELF_REGISTER: ${SMQ_USERS_ALLOW_SELF_REGISTER} SMQ_EMAIL_HOST: ${SMQ_EMAIL_HOST} SMQ_EMAIL_PORT: ${SMQ_EMAIL_PORT} @@ -889,6 +901,14 @@ services: SMQ_GOOGLE_CLIENT_SECRET: ${SMQ_GOOGLE_CLIENT_SECRET} SMQ_GOOGLE_REDIRECT_URL: ${SMQ_GOOGLE_REDIRECT_URL} SMQ_GOOGLE_STATE: ${SMQ_GOOGLE_STATE} + SMQ_GOOGLE_DEVICE_CLIENT_ID: ${SMQ_GOOGLE_DEVICE_CLIENT_ID} + SMQ_GOOGLE_DEVICE_CLIENT_SECRET: ${SMQ_GOOGLE_DEVICE_CLIENT_SECRET} + SMQ_GOOGLE_DEVICE_REDIRECT_URL: ${SMQ_GOOGLE_DEVICE_REDIRECT_URL} + SMQ_GOOGLE_DEVICE_STATE: ${SMQ_GOOGLE_DEVICE_STATE} + SMQ_GOOGLE_USER_CLIENT_ID: ${SMQ_GOOGLE_USER_CLIENT_ID} + SMQ_GOOGLE_USER_CLIENT_SECRET: ${SMQ_GOOGLE_USER_CLIENT_SECRET} + SMQ_GOOGLE_USER_REDIRECT_URL: ${SMQ_GOOGLE_USER_REDIRECT_URL} + SMQ_GOOGLE_USER_STATE: ${SMQ_GOOGLE_USER_STATE} SMQ_OAUTH_UI_REDIRECT_URL: ${SMQ_OAUTH_UI_REDIRECT_URL} SMQ_OAUTH_UI_ERROR_URL: ${SMQ_OAUTH_UI_ERROR_URL} SMQ_USERS_DELETE_INTERVAL: ${SMQ_USERS_DELETE_INTERVAL} diff --git a/pkg/oauth2/device.go b/pkg/oauth2/device.go new file mode 100644 index 0000000000..2ddcf484ac --- /dev/null +++ b/pkg/oauth2/device.go @@ -0,0 +1,91 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package oauth2 + +import ( + "errors" + "time" +) + +const ( + // DeviceCodeLength is the length of the user code (e.g., "ABCD-EFGH"). + DeviceCodeLength = 8 + + // DeviceCodePollTimeout is the timeout for polling device code status. + DeviceCodePollTimeout = 5 * time.Second + + // CodeCheckInterval is the minimum interval between polling requests. + CodeCheckInterval = 3 * time.Second + + // DeviceStatePrefix is the prefix used in state parameter for device flow. + DeviceStatePrefix = "device:" + + // DeviceCodeExpiry is the time after which device codes expire. + DeviceCodeExpiry = 10 * time.Minute +) + +var ( + // ErrDeviceCodeExpired indicates that the device code has expired. + ErrDeviceCodeExpired = errors.New("device code expired") + + // ErrDeviceCodePending indicates that the user hasn't authorized the device yet. + ErrDeviceCodePending = errors.New("authorization pending") + + // ErrSlowDown indicates that the client is polling too frequently. + ErrSlowDown = errors.New("slow down") + + // ErrAccessDenied indicates that the user denied the authorization request. + ErrAccessDenied = errors.New("access denied") + + // ErrInvalidState indicates that the OAuth state parameter is invalid. + ErrInvalidState = errors.New("invalid state") + + // ErrEmptyCode indicates that the authorization code is empty. + ErrEmptyCode = errors.New("empty code") + + // ErrInvalidProvider indicates that the OAuth provider is not found or disabled. + ErrInvalidProvider = errors.New("invalid provider") + + // ErrDeviceCodeNotFound indicates that the device code was not found. + ErrDeviceCodeNotFound = errors.New("device code not found") + + // ErrUserCodeNotFound indicates that the user code was not found. + ErrUserCodeNotFound = errors.New("user code not found") +) + +// DeviceCode represents an OAuth2 device authorization code. +type DeviceCode struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` + Provider string `json:"provider,omitempty"` + CreatedAt time.Time `json:"created_at,omitempty"` + State string `json:"state,omitempty"` + AccessToken string `json:"access_token,omitempty"` + Approved bool `json:"approved,omitempty"` + Denied bool `json:"denied,omitempty"` + LastPoll time.Time `json:"last_poll,omitempty"` +} + +// DeviceCodeStore manages device authorization codes. +// It provides operations to save, retrieve, update, and delete device codes +// used in the OAuth2 device authorization flow. +type DeviceCodeStore interface { + // Save stores a new device code. + Save(code DeviceCode) error + + // Get retrieves a device code by its device code value. + Get(deviceCode string) (DeviceCode, error) + + // GetByUserCode retrieves a device code by its user code. + GetByUserCode(userCode string) (DeviceCode, error) + + // Update updates an existing device code. + Update(code DeviceCode) error + + // Delete removes a device code. + Delete(deviceCode string) error +} diff --git a/pkg/oauth2/google/provider.go b/pkg/oauth2/google/provider.go index 951c1ac46a..048c344fdf 100644 --- a/pkg/oauth2/google/provider.go +++ b/pkg/oauth2/google/provider.go @@ -34,9 +34,9 @@ var httpClient = &http.Client{ Timeout: defTimeout, } -var _ mgoauth2.Provider = (*config)(nil) +var _ mgoauth2.Provider = (*provider)(nil) -type config struct { +type provider struct { config *oauth2.Config state string uiRedirectURL string @@ -45,7 +45,7 @@ type config struct { // NewProvider returns a new Google OAuth provider. func NewProvider(cfg mgoauth2.Config, uiRedirectURL, errorURL string) mgoauth2.Provider { - return &config{ + return &provider{ config: &oauth2.Config{ ClientID: cfg.ClientID, ClientSecret: cfg.ClientSecret, @@ -59,27 +59,27 @@ func NewProvider(cfg mgoauth2.Config, uiRedirectURL, errorURL string) mgoauth2.P } } -func (cfg *config) Name() string { +func (cfg *provider) Name() string { return providerName } -func (cfg *config) State() string { +func (cfg *provider) State() string { return cfg.state } -func (cfg *config) RedirectURL() string { +func (cfg *provider) RedirectURL() string { return cfg.uiRedirectURL } -func (cfg *config) ErrorURL() string { +func (cfg *provider) ErrorURL() string { return cfg.errorURL } -func (cfg *config) IsEnabled() bool { +func (cfg *provider) IsEnabled() bool { return cfg.config.ClientID != "" && cfg.config.ClientSecret != "" } -func (cfg *config) Exchange(ctx context.Context, code string) (oauth2.Token, error) { +func (cfg *provider) Exchange(ctx context.Context, code string) (oauth2.Token, error) { token, err := cfg.config.Exchange(ctx, code) if err != nil { return oauth2.Token{}, err @@ -88,7 +88,19 @@ func (cfg *config) Exchange(ctx context.Context, code string) (oauth2.Token, err return *token, nil } -func (cfg *config) UserInfo(accessToken string) (uclient.User, error) { +func (cfg *provider) ExchangeWithRedirect(ctx context.Context, code, redirectURL string) (oauth2.Token, error) { + // Create a temporary config with the custom redirect URL + tempConfig := *cfg.config + tempConfig.RedirectURL = redirectURL + token, err := tempConfig.Exchange(ctx, code) + if err != nil { + return oauth2.Token{}, err + } + + return *token, nil +} + +func (cfg *provider) UserInfo(accessToken string) (uclient.User, error) { resp, err := httpClient.Get(userInfoURL + url.QueryEscape(accessToken)) if err != nil { return uclient.User{}, err @@ -111,3 +123,14 @@ func (cfg *config) UserInfo(accessToken string) (uclient.User, error) { return user, nil } + +func (cfg *provider) GetAuthURL() string { + return cfg.config.AuthCodeURL(cfg.state) +} + +func (cfg *provider) GetAuthURLWithRedirect(redirectURL string) string { + // Create a temporary config with the custom redirect URL + tempConfig := *cfg.config + tempConfig.RedirectURL = redirectURL + return tempConfig.AuthCodeURL(cfg.state) +} diff --git a/pkg/oauth2/google/provider_test.go b/pkg/oauth2/google/provider_test.go new file mode 100644 index 0000000000..3cd4a85189 --- /dev/null +++ b/pkg/oauth2/google/provider_test.go @@ -0,0 +1,202 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package google_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/absmach/supermq/pkg/oauth2" + "github.com/absmach/supermq/pkg/oauth2/google" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + testClientID = "test-client-id" + testClientSecret = "test-client-secret" + testState = "test-state" + testRedirectURL = "http://localhost/callback" + testCode = "test-code" +) + +func TestGetAuthURL(t *testing.T) { + cfg := oauth2.Config{ + ClientID: testClientID, + ClientSecret: testClientSecret, + State: testState, + RedirectURL: testRedirectURL, + } + + provider := google.NewProvider(cfg, "http://localhost/ui", "http://localhost/error") + + authURL := provider.GetAuthURL() + + assert.NotEmpty(t, authURL) + assert.Contains(t, authURL, "accounts.google.com/o/oauth2/auth") + assert.Contains(t, authURL, "client_id="+testClientID) + assert.Contains(t, authURL, "state="+testState) + // redirect_uri is URL-encoded in the query string + assert.Contains(t, authURL, "redirect_uri=") +} + +func TestGetAuthURLWithRedirect(t *testing.T) { + cfg := oauth2.Config{ + ClientID: testClientID, + ClientSecret: testClientSecret, + State: testState, + RedirectURL: testRedirectURL, + } + + provider := google.NewProvider(cfg, "http://localhost/ui", "http://localhost/error") + + customRedirect := "http://localhost:9090/callback" + authURL := provider.GetAuthURLWithRedirect(customRedirect) + + assert.NotEmpty(t, authURL) + assert.Contains(t, authURL, "accounts.google.com/o/oauth2/auth") + assert.Contains(t, authURL, "client_id="+testClientID) + assert.Contains(t, authURL, "state="+testState) + // redirect_uri is URL-encoded in the query string, just verify it exists + assert.Contains(t, authURL, "redirect_uri=") + // Verify the custom redirect is in the URL (URL-encoded) + assert.Contains(t, authURL, "9090") +} + +func TestExchangeWithRedirect(t *testing.T) { + // Create a mock OAuth2 server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/token" { + http.NotFound(w, r) + return + } + + err := r.ParseForm() + require.NoError(t, err) + + // Verify the code + code := r.FormValue("code") + + if code != testCode { + w.WriteHeader(http.StatusBadRequest) + _, err := w.Write([]byte(`{"error": "invalid_grant"}`)) + assert.NoError(t, err) + return + } + + // Return a mock token + w.Header().Set("Content-Type", "application/json") + _, err = w.Write([]byte(`{ + "access_token": "test-access-token", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "test-refresh-token" + }`)) + assert.NoError(t, err) + })) + defer server.Close() + + cfg := oauth2.Config{ + ClientID: testClientID, + ClientSecret: testClientSecret, + State: testState, + RedirectURL: testRedirectURL, + } + + // We can't easily test the actual Google provider without modifying the endpoint + // This test verifies the method exists and has the correct signature + provider := google.NewProvider(cfg, "http://localhost/ui", "http://localhost/error") + + // Test with invalid code (will fail but ensures method works) + _, err := provider.ExchangeWithRedirect(context.Background(), "invalid-code", "http://localhost:9090/callback") + assert.Error(t, err) // Expected to fail with actual Google OAuth +} + +func TestIsEnabled(t *testing.T) { + cases := []struct { + name string + config oauth2.Config + expected bool + }{ + { + name: "enabled with all credentials", + config: oauth2.Config{ + ClientID: testClientID, + ClientSecret: testClientSecret, + State: testState, + RedirectURL: testRedirectURL, + }, + expected: true, + }, + { + name: "disabled without client ID", + config: oauth2.Config{ + ClientID: "", + ClientSecret: testClientSecret, + State: testState, + RedirectURL: testRedirectURL, + }, + expected: false, + }, + { + name: "disabled without client secret", + config: oauth2.Config{ + ClientID: testClientID, + ClientSecret: "", + State: testState, + RedirectURL: testRedirectURL, + }, + expected: false, + }, + { + name: "disabled without credentials", + config: oauth2.Config{ + ClientID: "", + ClientSecret: "", + State: testState, + RedirectURL: testRedirectURL, + }, + expected: false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + provider := google.NewProvider(tc.config, "http://localhost/ui", "http://localhost/error") + assert.Equal(t, tc.expected, provider.IsEnabled()) + }) + } +} + +func TestProviderMethods(t *testing.T) { + cfg := oauth2.Config{ + ClientID: testClientID, + ClientSecret: testClientSecret, + State: testState, + RedirectURL: testRedirectURL, + } + + uiRedirectURL := "http://localhost:9095/ui/tokens/secure" + errorURL := "http://localhost:9095/ui/error" + + provider := google.NewProvider(cfg, uiRedirectURL, errorURL) + + t.Run("Name", func(t *testing.T) { + assert.Equal(t, "google", provider.Name()) + }) + + t.Run("State", func(t *testing.T) { + assert.Equal(t, testState, provider.State()) + }) + + t.Run("RedirectURL", func(t *testing.T) { + assert.Equal(t, uiRedirectURL, provider.RedirectURL()) + }) + + t.Run("ErrorURL", func(t *testing.T) { + assert.Equal(t, errorURL, provider.ErrorURL()) + }) +} diff --git a/pkg/oauth2/http/oauth.go b/pkg/oauth2/http/oauth.go new file mode 100644 index 0000000000..6d4a91595c --- /dev/null +++ b/pkg/oauth2/http/oauth.go @@ -0,0 +1,228 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package http + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + + grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1" + "github.com/absmach/supermq/pkg/oauth2" + "github.com/go-chi/chi/v5" +) + +var ( + errInvalidBody = newErrorResponse("invalid request body") + errInvalidState = newErrorResponse("invalid state") + errEmptyCode = newErrorResponse("empty code") + errProviderDisabled = newErrorResponse("oauth provider is disabled") +) + +type errorResponse struct { + Error string `json:"error"` +} + +// newErrorResponse creates a JSON error response. +func newErrorResponse(msg string) errorResponse { + return errorResponse{Error: msg} +} + +type authURLResponse struct { + AuthorizationURL string `json:"authorization_url"` + State string `json:"state"` +} + +// Handler registers OAuth2 routes for the given providers. +// It sets up three endpoints for each provider: +// - GET /oauth/authorize/{provider} - Returns the authorization URL +// - GET /oauth/callback/{provider} - Handles OAuth2 callback and sets cookies +// - POST /oauth/cli/callback/{provider} - Handles CLI OAuth2 callback and returns JSON. +func Handler(r *chi.Mux, tokenClient grpcTokenV1.TokenServiceClient, oauthSvc oauth2.Service, providers ...oauth2.Provider) *chi.Mux { + for _, provider := range providers { + r.HandleFunc("/oauth/callback/"+provider.Name(), oauth2CallbackHandler(provider, oauthSvc)) + r.Get("/oauth/authorize/"+provider.Name(), oauth2AuthorizeHandler(provider)) + r.Post("/oauth/cli/callback/"+provider.Name(), oauth2CLICallbackHandler(provider, oauthSvc)) + } + + return r +} + +// oauth2CallbackHandler is a http.HandlerFunc that handles OAuth2 callbacks. +func oauth2CallbackHandler(oauth oauth2.Provider, oauthSvc oauth2.Service) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if !oauth.IsEnabled() { + redirectWithError(w, r, oauth.ErrorURL(), errProviderDisabled.Error) + return + } + + state := r.FormValue("state") + + // Check if this is a device flow callback (state contains device: prefix) + if oauth2.IsDeviceFlowState(state) { + handleDeviceFlowCallback(w, r, oauth, oauthSvc) + return + } + + if state != oauth.State() { + redirectWithError(w, r, oauth.ErrorURL(), errInvalidState.Error) + return + } + + code := r.FormValue("code") + if code == "" { + redirectWithError(w, r, oauth.ErrorURL(), errEmptyCode.Error) + return + } + + jwt, err := oauthSvc.ProcessWebCallback(r.Context(), oauth, code, "") + if err != nil { + redirectWithError(w, r, oauth.ErrorURL(), err.Error()) + return + } + + setTokenCookies(w, jwt) + http.Redirect(w, r, oauth.RedirectURL(), http.StatusFound) + } +} + +// oauth2AuthorizeHandler returns the authorization URL for the OAuth2 provider. +func oauth2AuthorizeHandler(oauth oauth2.Provider) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + if !oauth.IsEnabled() { + respondWithJSON(w, http.StatusNotFound, errProviderDisabled) + return + } + + redirectURL := r.URL.Query().Get("redirect_uri") + var authURL string + if redirectURL != "" { + authURL = oauth.GetAuthURLWithRedirect(redirectURL) + } else { + authURL = oauth.GetAuthURL() + } + + resp := authURLResponse{ + AuthorizationURL: authURL, + State: oauth.State(), + } + respondWithJSON(w, http.StatusOK, resp) + } +} + +// oauth2CLICallbackHandler handles OAuth2 callbacks for CLI and returns JSON tokens. +func oauth2CLICallbackHandler(oauth oauth2.Provider, oauthSvc oauth2.Service) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + if !oauth.IsEnabled() { + respondWithJSON(w, http.StatusNotFound, errProviderDisabled) + return + } + var req struct { + Code string `json:"code"` + State string `json:"state"` + RedirectURL string `json:"redirect_url"` + } + + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + respondWithJSON(w, http.StatusBadRequest, errInvalidBody) + return + } + + if req.State != oauth.State() { + respondWithJSON(w, http.StatusBadRequest, errInvalidState) + return + } + + if req.Code == "" { + respondWithJSON(w, http.StatusBadRequest, errEmptyCode) + return + } + + jwt, err := oauthSvc.ProcessWebCallback(r.Context(), oauth, req.Code, req.RedirectURL) + if err != nil { + status := http.StatusInternalServerError + // Exchange errors and unauthorized errors should return 401 + errMsg := err.Error() + if errMsg == "unauthorized" || strings.Contains(errMsg, "failed to exchange code") { + status = http.StatusUnauthorized + } + respondWithJSON(w, status, newErrorResponse(errMsg)) + return + } + + jwt.AccessType = "" + respondWithJSON(w, http.StatusOK, jwt) + } +} + +// respondWithJSON writes a JSON response with the given status code and data. +func respondWithJSON(w http.ResponseWriter, status int, data any) { + w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(data); err != nil { + w.WriteHeader(http.StatusInternalServerError) + } +} + +// redirectWithError redirects to the baseURL with an error query parameter. +func redirectWithError(w http.ResponseWriter, r *http.Request, baseURL, errMsg string) { + redirectURL := fmt.Sprintf("%s?error=%s", baseURL, errMsg) + http.Redirect(w, r, redirectURL, http.StatusSeeOther) +} + +// setTokenCookies sets the access_token and refresh_token cookies in the response. +func setTokenCookies(w http.ResponseWriter, jwt *grpcTokenV1.Token) { + http.SetCookie(w, &http.Cookie{ + Name: "access_token", + Value: jwt.GetAccessToken(), + Path: "/", + HttpOnly: true, + Secure: true, + }) + http.SetCookie(w, &http.Cookie{ + Name: "refresh_token", + Value: jwt.GetRefreshToken(), + Path: "/", + HttpOnly: true, + Secure: true, + }) +} + +// handleDeviceFlowCallback processes OAuth callback for device authorization flow. +func handleDeviceFlowCallback(w http.ResponseWriter, r *http.Request, oauth oauth2.Provider, oauthSvc oauth2.Service) { + // Extract user code from state (format: "device:ABCD-EFGH") + state := r.FormValue("state") + userCode := oauth2.ExtractUserCodeFromState(state) + + // Get device code by user code to validate it exists + _, err := oauthSvc.GetDeviceCodeByUserCode(r.Context(), userCode) + if err != nil { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, strings.Replace(errorHTML, "{{ERROR_MESSAGE}}", "The device code is invalid or has expired.", 1)) + return + } + + // Get OAuth authorization code + code := r.FormValue("code") + if code == "" { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, strings.Replace(errorHTML, "{{ERROR_MESSAGE}}", "No authorization code received.", 1)) + return + } + + // Process the device callback + if err := oauthSvc.ProcessDeviceCallback(r.Context(), oauth, userCode, code); err != nil { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, strings.Replace(errorHTML, "{{ERROR_MESSAGE}}", fmt.Sprintf("Failed to process callback: %s.", err.Error()), 1)) + return + } + + // Show success page + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, successHTML) +} diff --git a/pkg/oauth2/http/oauth_device.go b/pkg/oauth2/http/oauth_device.go new file mode 100644 index 0000000000..4abbd10d32 --- /dev/null +++ b/pkg/oauth2/http/oauth_device.go @@ -0,0 +1,180 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package http + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1" + oauth2 "github.com/absmach/supermq/pkg/oauth2" + "github.com/go-chi/chi/v5" +) + +var ( + errDeviceCodeExpired = newErrorResponse("device code expired") + errDeviceCodePending = newErrorResponse("authorization pending") + errSlowDown = newErrorResponse("slow down") + errAccessDenied = newErrorResponse("access denied") +) + +// DeviceHandler registers device flow routes for OAuth2 providers. +func DeviceHandler(r *chi.Mux, tokenClient grpcTokenV1.TokenServiceClient, oauthSvc oauth2.Service, providers ...oauth2.Provider) *chi.Mux { + for _, provider := range providers { + r.Post("/oauth/device/code/"+provider.Name(), DeviceCodeHandler(provider, oauthSvc)) + r.Post("/oauth/device/token/"+provider.Name(), DeviceTokenHandler(provider, oauthSvc)) + } + // Register verify endpoints once (not per provider) + r.Get("/oauth/device/verify", DeviceVerifyPageHandler()) + r.Post("/oauth/device/verify", DeviceVerifyHandler(oauthSvc, providers...)) + return r +} + +// DeviceCodeHandler initiates the device authorization flow. +func DeviceCodeHandler(provider oauth2.Provider, oauthSvc oauth2.Service) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + if !provider.IsEnabled() { + errResp := errProviderDisabled + respondWithJSON(w, http.StatusNotFound, errResp) + return + } + + // Build verification URI with proper scheme + scheme := "http" + if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { + scheme = "https" + } + verificationURI := fmt.Sprintf("%s://%s/oauth/device/verify", scheme, r.Host) + + code, err := oauthSvc.CreateDeviceCode(r.Context(), provider, verificationURI) + if err != nil { + respondWithJSON(w, http.StatusInternalServerError, newErrorResponse(err.Error())) + return + } + + respondWithJSON(w, http.StatusOK, code) + } +} + +// DeviceTokenHandler polls for device authorization completion. +func DeviceTokenHandler(provider oauth2.Provider, oauthSvc oauth2.Service) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + if !provider.IsEnabled() { + respondWithJSON(w, http.StatusNotFound, errProviderDisabled) + return + } + + var req struct { + DeviceCode string `json:"device_code"` + } + + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + respondWithJSON(w, http.StatusBadRequest, errInvalidBody) + return + } + + jwt, err := oauthSvc.PollDeviceToken(r.Context(), provider, req.DeviceCode) + if err != nil { + // Map OAuth service errors to appropriate HTTP responses + switch { + case errors.Is(err, oauth2.ErrDeviceCodeNotFound): + respondWithJSON(w, http.StatusNotFound, newErrorResponse("invalid device code")) + case errors.Is(err, oauth2.ErrDeviceCodeExpired): + respondWithJSON(w, http.StatusBadRequest, errDeviceCodeExpired) + case errors.Is(err, oauth2.ErrSlowDown): + respondWithJSON(w, http.StatusBadRequest, errSlowDown) + case errors.Is(err, oauth2.ErrAccessDenied): + respondWithJSON(w, http.StatusUnauthorized, errAccessDenied) + case errors.Is(err, oauth2.ErrDeviceCodePending): + respondWithJSON(w, http.StatusAccepted, errDeviceCodePending) + default: + respondWithJSON(w, http.StatusInternalServerError, newErrorResponse(err.Error())) + } + return + } + + jwt.AccessType = "" + respondWithJSON(w, http.StatusOK, jwt) + } +} + +// DeviceVerifyPageHandler serves the HTML page for device verification. +func DeviceVerifyPageHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, deviceVerifyHTML) + } +} + +// DeviceVerifyHandler handles user verification of device codes. +func DeviceVerifyHandler(oauthSvc oauth2.Service, providers ...oauth2.Provider) http.HandlerFunc { + providerMap := make(map[string]oauth2.Provider) + for _, p := range providers { + providerMap[p.Name()] = p + } + + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + var req struct { + UserCode string `json:"user_code"` + Code string `json:"code"` + Approve bool `json:"approve"` + } + + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + respondWithJSON(w, http.StatusBadRequest, errInvalidBody) + return + } + + code, err := oauthSvc.GetDeviceCodeByUserCode(r.Context(), req.UserCode) + if err != nil { + if errors.Is(err, oauth2.ErrUserCodeNotFound) { + respondWithJSON(w, http.StatusNotFound, newErrorResponse("invalid user code")) + } else { + respondWithJSON(w, http.StatusInternalServerError, newErrorResponse(err.Error())) + } + return + } + + provider, ok := providerMap[code.Provider] + if !ok { + respondWithJSON(w, http.StatusBadRequest, newErrorResponse("invalid provider")) + return + } + + if !provider.IsEnabled() { + respondWithJSON(w, http.StatusNotFound, errProviderDisabled) + return + } + + if !req.Approve { + // User denied - pass empty code and approve=false + if err := oauthSvc.VerifyDevice(r.Context(), provider, req.UserCode, "", false); err != nil { + respondWithJSON(w, http.StatusInternalServerError, newErrorResponse(err.Error())) + return + } + respondWithJSON(w, http.StatusOK, map[string]string{"status": "denied"}) + return + } + + // User approved - verify with the OAuth code + if err := oauthSvc.VerifyDevice(r.Context(), provider, req.UserCode, req.Code, true); err != nil { + if errors.Is(err, oauth2.ErrDeviceCodeExpired) { + respondWithJSON(w, http.StatusBadRequest, errDeviceCodeExpired) + } else { + respondWithJSON(w, http.StatusUnauthorized, newErrorResponse(err.Error())) + } + return + } + + respondWithJSON(w, http.StatusOK, map[string]string{"status": "approved"}) + } +} diff --git a/pkg/oauth2/http/oauth_device_redis_test.go b/pkg/oauth2/http/oauth_device_redis_test.go new file mode 100644 index 0000000000..5e02cb0ffc --- /dev/null +++ b/pkg/oauth2/http/oauth_device_redis_test.go @@ -0,0 +1,426 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package http_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/absmach/supermq/pkg/oauth2" + "github.com/absmach/supermq/pkg/oauth2/store" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + deviceCodePrefix = "oauth:device:code:" + userCodePrefix = "oauth:device:user:" +) + +const ( + testRedisAddr = "localhost:6379" +) + +// setupRedisTest creates a test Redis client and clears test keys. +func setupRedisTest(t *testing.T) (*redis.Client, func()) { + t.Helper() + + client := redis.NewClient(&redis.Options{ + Addr: testRedisAddr, + DB: 1, // Use DB 1 for tests to avoid conflicts + }) + + ctx := context.Background() + + // Test connection + if err := client.Ping(ctx).Err(); err != nil { + t.Skip("Redis not available, skipping Redis tests") + } + + // Clear any existing test keys + pattern := "oauth:device:*" + iter := client.Scan(ctx, 0, pattern, 0).Iterator() + for iter.Next(ctx) { + client.Del(ctx, iter.Val()) + } + + cleanup := func() { + // Clear test keys after test + iter := client.Scan(ctx, 0, pattern, 0).Iterator() + for iter.Next(ctx) { + client.Del(ctx, iter.Val()) + } + client.Close() + } + + return client, cleanup +} + +func TestRedisDeviceCodeStore_Save(t *testing.T) { + client, cleanup := setupRedisTest(t) + defer cleanup() + + ctx := context.Background() + deviceStore := store.NewRedisDeviceCodeStore(ctx, client) + + code := oauth2.DeviceCode{ + DeviceCode: "test-device-code", + UserCode: "ABCD-EFGH", + VerificationURI: "http://localhost/verify", + ExpiresIn: 600, + Interval: 5, + Provider: "google", + CreatedAt: time.Now(), + State: "device:ABCD-EFGH", + } + + err := deviceStore.Save(code) + require.NoError(t, err) + + // Verify device code was saved + deviceKey := deviceCodePrefix + code.DeviceCode + exists, err := client.Exists(ctx, deviceKey).Result() + require.NoError(t, err) + assert.Equal(t, int64(1), exists) + + // Verify user code mapping was saved + userKey := userCodePrefix + code.UserCode + exists, err = client.Exists(ctx, userKey).Result() + require.NoError(t, err) + assert.Equal(t, int64(1), exists) + + // Verify TTL was set + ttl, err := client.TTL(ctx, deviceKey).Result() + require.NoError(t, err) + assert.Greater(t, ttl, time.Duration(0)) + assert.LessOrEqual(t, ttl, oauth2.DeviceCodeExpiry) +} + +func TestRedisDeviceCodeStore_Get(t *testing.T) { + client, cleanup := setupRedisTest(t) + defer cleanup() + + ctx := context.Background() + deviceStore := store.NewRedisDeviceCodeStore(ctx, client) + + originalCode := oauth2.DeviceCode{ + DeviceCode: "test-device-code-2", + UserCode: "XXXX-YYYY", + VerificationURI: "http://localhost/verify", + ExpiresIn: 600, + Interval: 5, + Provider: "google", + CreatedAt: time.Now(), + State: "device:XXXX-YYYY", + AccessToken: "test-access-token", + Approved: false, + } + + err := deviceStore.Save(originalCode) + require.NoError(t, err) + + // Retrieve the code + retrievedCode, err := deviceStore.Get(originalCode.DeviceCode) + require.NoError(t, err) + + assert.Equal(t, originalCode.DeviceCode, retrievedCode.DeviceCode) + assert.Equal(t, originalCode.UserCode, retrievedCode.UserCode) + assert.Equal(t, originalCode.Provider, retrievedCode.Provider) + assert.Equal(t, originalCode.Approved, retrievedCode.Approved) + assert.Equal(t, originalCode.AccessToken, retrievedCode.AccessToken) +} + +func TestRedisDeviceCodeStore_GetNotFound(t *testing.T) { + client, cleanup := setupRedisTest(t) + defer cleanup() + + ctx := context.Background() + deviceStore := store.NewRedisDeviceCodeStore(ctx, client) + + _, err := deviceStore.Get("non-existent-code") + require.Error(t, err) + assert.Contains(t, err.Error(), "device code not found") +} + +func TestRedisDeviceCodeStore_GetByUserCode(t *testing.T) { + client, cleanup := setupRedisTest(t) + defer cleanup() + + ctx := context.Background() + deviceStore := store.NewRedisDeviceCodeStore(ctx, client) + + originalCode := oauth2.DeviceCode{ + DeviceCode: "test-device-code-3", + UserCode: "ZZZZ-AAAA", + VerificationURI: "http://localhost/verify", + ExpiresIn: 600, + Interval: 5, + Provider: "google", + CreatedAt: time.Now(), + State: "device:ZZZZ-AAAA", + } + + err := deviceStore.Save(originalCode) + require.NoError(t, err) + + // Retrieve by user code + retrievedCode, err := deviceStore.GetByUserCode(originalCode.UserCode) + require.NoError(t, err) + + assert.Equal(t, originalCode.DeviceCode, retrievedCode.DeviceCode) + assert.Equal(t, originalCode.UserCode, retrievedCode.UserCode) + assert.Equal(t, originalCode.Provider, retrievedCode.Provider) +} + +func TestRedisDeviceCodeStore_GetByUserCodeNotFound(t *testing.T) { + client, cleanup := setupRedisTest(t) + defer cleanup() + + ctx := context.Background() + deviceStore := store.NewRedisDeviceCodeStore(ctx, client) + + _, err := deviceStore.GetByUserCode("non-existent-user-code") + require.Error(t, err) + assert.Contains(t, err.Error(), "user code not found") +} + +func TestRedisDeviceCodeStore_Update(t *testing.T) { + client, cleanup := setupRedisTest(t) + defer cleanup() + + ctx := context.Background() + deviceStore := store.NewRedisDeviceCodeStore(ctx, client) + + originalCode := oauth2.DeviceCode{ + DeviceCode: "test-device-code-4", + UserCode: "BBBB-CCCC", + VerificationURI: "http://localhost/verify", + ExpiresIn: 600, + Interval: 5, + Provider: "google", + CreatedAt: time.Now(), + State: "device:BBBB-CCCC", + Approved: false, + AccessToken: "", + } + + err := deviceStore.Save(originalCode) + require.NoError(t, err) + + // Update the code + originalCode.Approved = true + originalCode.AccessToken = "new-access-token" + + err = deviceStore.Update(originalCode) + require.NoError(t, err) + + // Retrieve and verify update + retrievedCode, err := deviceStore.Get(originalCode.DeviceCode) + require.NoError(t, err) + + assert.True(t, retrievedCode.Approved) + assert.Equal(t, "new-access-token", retrievedCode.AccessToken) +} + +func TestRedisDeviceCodeStore_UpdateNotFound(t *testing.T) { + client, cleanup := setupRedisTest(t) + defer cleanup() + + ctx := context.Background() + deviceStore := store.NewRedisDeviceCodeStore(ctx, client) + + code := oauth2.DeviceCode{ + DeviceCode: "non-existent-code", + } + + err := deviceStore.Update(code) + require.Error(t, err) +} + +func TestRedisDeviceCodeStore_Delete(t *testing.T) { + client, cleanup := setupRedisTest(t) + defer cleanup() + + ctx := context.Background() + deviceStore := store.NewRedisDeviceCodeStore(ctx, client) + + code := oauth2.DeviceCode{ + DeviceCode: "test-device-code-5", + UserCode: "DDDD-EEEE", + VerificationURI: "http://localhost/verify", + ExpiresIn: 600, + Interval: 5, + Provider: "google", + CreatedAt: time.Now(), + State: "device:DDDD-EEEE", + } + + err := deviceStore.Save(code) + require.NoError(t, err) + + // Verify it exists + _, err = deviceStore.Get(code.DeviceCode) + require.NoError(t, err) + + // Delete it + err = deviceStore.Delete(code.DeviceCode) + require.NoError(t, err) + + // Verify it's gone + _, err = deviceStore.Get(code.DeviceCode) + require.Error(t, err) + + // Verify user code mapping is also gone + deviceKey := deviceCodePrefix + code.DeviceCode + exists, err := client.Exists(ctx, deviceKey).Result() + require.NoError(t, err) + assert.Equal(t, int64(0), exists) + + userKey := userCodePrefix + code.UserCode + exists, err = client.Exists(ctx, userKey).Result() + require.NoError(t, err) + assert.Equal(t, int64(0), exists) +} + +func TestRedisDeviceCodeStore_DeleteNotFound(t *testing.T) { + client, cleanup := setupRedisTest(t) + defer cleanup() + + ctx := context.Background() + deviceStore := store.NewRedisDeviceCodeStore(ctx, client) + + err := deviceStore.Delete("non-existent-code") + require.Error(t, err) +} + +func TestRedisDeviceCodeStore_Expiry(t *testing.T) { + client, cleanup := setupRedisTest(t) + defer cleanup() + + ctx := context.Background() + deviceStore := store.NewRedisDeviceCodeStore(ctx, client) + + code := oauth2.DeviceCode{ + DeviceCode: "test-device-code-6", + UserCode: "FFFF-GGGG", + VerificationURI: "http://localhost/verify", + ExpiresIn: 1, + Interval: 5, + Provider: "google", + CreatedAt: time.Now(), + State: "device:FFFF-GGGG", + } + + err := deviceStore.Save(code) + require.NoError(t, err) + + // Manually set a very short TTL for testing + deviceKey := deviceCodePrefix + code.DeviceCode + userKey := userCodePrefix + code.UserCode + err = client.Expire(ctx, deviceKey, 1*time.Second).Err() + require.NoError(t, err) + err = client.Expire(ctx, userKey, 1*time.Second).Err() + require.NoError(t, err) + + // Wait for expiry + time.Sleep(2 * time.Second) + + // Verify it's expired + _, err = deviceStore.Get(code.DeviceCode) + require.Error(t, err) + assert.Contains(t, err.Error(), "device code not found") + + _, err = deviceStore.GetByUserCode(code.UserCode) + require.Error(t, err) + assert.Contains(t, err.Error(), "user code not found") +} + +func TestRedisDeviceCodeStore_MultipleInstances(t *testing.T) { + client, cleanup := setupRedisTest(t) + defer cleanup() + + ctx := context.Background() + + // Create two store instances (simulating two service instances) + deviceStore1 := store.NewRedisDeviceCodeStore(ctx, client) + deviceStore2 := store.NewRedisDeviceCodeStore(ctx, client) + + code := oauth2.DeviceCode{ + DeviceCode: "test-device-code-7", + UserCode: "HHHH-IIII", + VerificationURI: "http://localhost/verify", + ExpiresIn: 600, + Interval: 5, + Provider: "google", + CreatedAt: time.Now(), + State: "device:HHHH-IIII", + Approved: false, + } + + // Save from instance 1 + err := deviceStore1.Save(code) + require.NoError(t, err) + + // Retrieve from instance 2 + retrievedCode, err := deviceStore2.Get(code.DeviceCode) + require.NoError(t, err) + assert.Equal(t, code.DeviceCode, retrievedCode.DeviceCode) + assert.False(t, retrievedCode.Approved) + + // Update from instance 2 + retrievedCode.Approved = true + retrievedCode.AccessToken = "shared-token" + err = deviceStore2.Update(retrievedCode) + require.NoError(t, err) + + // Verify update from instance 1 + verifiedCode, err := deviceStore1.Get(code.DeviceCode) + require.NoError(t, err) + assert.True(t, verifiedCode.Approved) + assert.Equal(t, "shared-token", verifiedCode.AccessToken) +} + +func TestRedisDeviceCodeStore_ConcurrentAccess(t *testing.T) { + client, cleanup := setupRedisTest(t) + defer cleanup() + + ctx := context.Background() + deviceStore := store.NewRedisDeviceCodeStore(ctx, client) + + // Save multiple codes concurrently + numCodes := 10 + errChan := make(chan error, numCodes) + + for i := 0; i < numCodes; i++ { + go func(idx int) { + code := oauth2.DeviceCode{ + DeviceCode: fmt.Sprintf("concurrent-code-%d", idx), + UserCode: fmt.Sprintf("CODE-%04d", idx), + VerificationURI: "http://localhost/verify", + ExpiresIn: 600, + Interval: 5, + Provider: "google", + CreatedAt: time.Now(), + State: fmt.Sprintf("device:CODE-%04d", idx), + } + errChan <- deviceStore.Save(code) + }(i) + } + + // Wait for all saves + for i := 0; i < numCodes; i++ { + err := <-errChan + require.NoError(t, err) + } + + // Verify all codes can be retrieved + for i := 0; i < numCodes; i++ { + code, err := deviceStore.Get(fmt.Sprintf("concurrent-code-%d", i)) + require.NoError(t, err) + assert.Equal(t, fmt.Sprintf("CODE-%04d", i), code.UserCode) + } +} diff --git a/pkg/oauth2/http/oauth_device_test.go b/pkg/oauth2/http/oauth_device_test.go new file mode 100644 index 0000000000..e1e4406880 --- /dev/null +++ b/pkg/oauth2/http/oauth_device_test.go @@ -0,0 +1,348 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package http_test + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/absmach/supermq/pkg/oauth2" + useroauth "github.com/absmach/supermq/pkg/oauth2" + oauthhttp "github.com/absmach/supermq/pkg/oauth2/http" + + "github.com/absmach/supermq/pkg/oauth2/mocks" + "github.com/absmach/supermq/pkg/oauth2/store" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +const errProviderDisabled = "oauth provider is disabled" + +type errorResponse struct { + Error string `json:"error"` +} + +func TestInMemoryDeviceCodeStore(t *testing.T) { + deviceStore := store.NewInMemoryDeviceCodeStore() + + code := oauth2.DeviceCode{ + DeviceCode: "device123", + UserCode: "ABCD-EFGH", + VerificationURI: "http://example.com/verify", + ExpiresIn: 600, + Interval: 3, + Provider: "google", + CreatedAt: time.Now(), + State: "state123", + } + + t.Run("Save and Get", func(t *testing.T) { + err := deviceStore.Save(code) + assert.NoError(t, err) + + retrieved, err := deviceStore.Get(code.DeviceCode) + assert.NoError(t, err) + assert.Equal(t, code.DeviceCode, retrieved.DeviceCode) + assert.Equal(t, code.UserCode, retrieved.UserCode) + }) + + t.Run("GetByUserCode", func(t *testing.T) { + retrieved, err := deviceStore.GetByUserCode(code.UserCode) + assert.NoError(t, err) + assert.Equal(t, code.DeviceCode, retrieved.DeviceCode) + }) + + t.Run("Update", func(t *testing.T) { + code.Approved = true + code.AccessToken = "access_token_123" + err := deviceStore.Update(code) + assert.NoError(t, err) + + retrieved, err := deviceStore.Get(code.DeviceCode) + assert.NoError(t, err) + assert.True(t, retrieved.Approved) + assert.Equal(t, "access_token_123", retrieved.AccessToken) + }) + + t.Run("Delete", func(t *testing.T) { + err := deviceStore.Delete(code.DeviceCode) + assert.NoError(t, err) + + _, err = deviceStore.Get(code.DeviceCode) + assert.Error(t, err) + + _, err = deviceStore.GetByUserCode(code.UserCode) + assert.Error(t, err) + }) + + t.Run("Get non-existent", func(t *testing.T) { + _, err := deviceStore.Get("nonexistent") + assert.Error(t, err) + }) + + t.Run("Update non-existent", func(t *testing.T) { + err := deviceStore.Update(oauth2.DeviceCode{DeviceCode: "nonexistent"}) + assert.Error(t, err) + }) +} + +func TestDeviceCodeHandler(t *testing.T) { + tests := []struct { + name string + providerName string + enabled bool + setupMocks func(*mocks.Service, *mocks.Provider) + expectedStatus int + checkResponse func(*testing.T, *httptest.ResponseRecorder) + }{ + { + name: "successful device code generation", + providerName: "google", + enabled: true, + setupMocks: func(oauthSvc *mocks.Service, provider *mocks.Provider) { + provider.On("IsEnabled").Return(true) + mockCode := oauth2.DeviceCode{ + DeviceCode: "mock-device-code", + UserCode: "ABCD-EFGH", + VerificationURI: "http://example.com/verify", + ExpiresIn: 600, + Interval: 5, + } + oauthSvc.On("CreateDeviceCode", mock.Anything, provider, mock.Anything). + Return(mockCode, nil) + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, rec *httptest.ResponseRecorder) { + var deviceCode oauth2.DeviceCode + err := json.NewDecoder(rec.Body).Decode(&deviceCode) + assert.NoError(t, err) + assert.NotEmpty(t, deviceCode.DeviceCode) + assert.NotEmpty(t, deviceCode.UserCode) + assert.NotEmpty(t, deviceCode.VerificationURI) + assert.Greater(t, deviceCode.ExpiresIn, 0) + assert.Greater(t, deviceCode.Interval, 0) + }, + }, + { + name: "provider disabled", + providerName: "google", + enabled: false, + setupMocks: func(oauthSvc *mocks.Service, provider *mocks.Provider) { + provider.On("IsEnabled").Return(false) + }, + expectedStatus: http.StatusNotFound, + checkResponse: func(t *testing.T, rec *httptest.ResponseRecorder) { + var resp errorResponse + err := json.NewDecoder(rec.Body).Decode(&resp) + assert.NoError(t, err) + assert.Equal(t, errProviderDisabled, resp.Error) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + provider := new(mocks.Provider) + provider.On("Name").Return(tc.providerName) + oauthSvc := new(mocks.Service) + + tc.setupMocks(oauthSvc, provider) + + handler := oauthhttp.DeviceCodeHandler(provider, oauthSvc) + + req := httptest.NewRequest(http.MethodPost, "/oauth/device/code/google", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectedStatus, rec.Code) + tc.checkResponse(t, rec) + }) + } +} + +func TestDeviceTokenHandler(t *testing.T) { + tests := []struct { + name string + deviceCode string + setupMocks func(*mocks.Service, *mocks.Provider) + enabled bool + expectedStatus int + checkResponse func(*testing.T, *httptest.ResponseRecorder) + }{ + { + name: "authorization pending", + deviceCode: "device123", + setupMocks: func(oauthSvc *mocks.Service, provider *mocks.Provider) { + provider.On("IsEnabled").Return(true) + oauthSvc.On("PollDeviceToken", mock.Anything, provider, "device123"). + Return(nil, oauth2.ErrDeviceCodePending) + }, + enabled: true, + expectedStatus: http.StatusAccepted, + checkResponse: func(t *testing.T, rec *httptest.ResponseRecorder) { + var resp errorResponse + err := json.NewDecoder(rec.Body).Decode(&resp) + assert.NoError(t, err) + assert.Equal(t, "authorization pending", resp.Error) + }, + }, + { + name: "invalid device code", + deviceCode: "invalid", + setupMocks: func(oauthSvc *mocks.Service, provider *mocks.Provider) { + provider.On("IsEnabled").Return(true) + oauthSvc.On("PollDeviceToken", mock.Anything, provider, "invalid"). + Return(nil, oauth2.ErrDeviceCodeNotFound) + }, + enabled: true, + expectedStatus: http.StatusNotFound, + checkResponse: func(t *testing.T, rec *httptest.ResponseRecorder) { + var resp errorResponse + err := json.NewDecoder(rec.Body).Decode(&resp) + assert.NoError(t, err) + assert.Contains(t, resp.Error, "invalid device code") + }, + }, + { + name: "provider disabled", + deviceCode: "device123", + setupMocks: func(oauthSvc *mocks.Service, provider *mocks.Provider) { + provider.On("IsEnabled").Return(false) + }, + enabled: false, + expectedStatus: http.StatusNotFound, + checkResponse: func(t *testing.T, rec *httptest.ResponseRecorder) { + var resp errorResponse + err := json.NewDecoder(rec.Body).Decode(&resp) + assert.NoError(t, err) + assert.Equal(t, errProviderDisabled, resp.Error) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + provider := new(mocks.Provider) + provider.On("Name").Return("google") + oauthSvc := new(mocks.Service) + + tc.setupMocks(oauthSvc, provider) + + handler := oauthhttp.DeviceTokenHandler(provider, oauthSvc) + + reqBody, _ := json.Marshal(map[string]string{ + "device_code": tc.deviceCode, + }) + req := httptest.NewRequest(http.MethodPost, "/oauth/device/token/google", bytes.NewReader(reqBody)) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectedStatus, rec.Code) + tc.checkResponse(t, rec) + }) + } +} + +func TestDeviceVerifyHandler(t *testing.T) { + tests := []struct { + name string + userCode string + code string + approve bool + setupMocks func(*mocks.Service, *mocks.Provider) + enabled bool + expectedStatus int + checkResponse func(*testing.T, *httptest.ResponseRecorder) + }{ + { + name: "deny authorization", + userCode: "ABCD-EFGH", + code: "", + approve: false, + setupMocks: func(oauthSvc *mocks.Service, provider *mocks.Provider) { + oauthSvc.On("GetDeviceCodeByUserCode", mock.Anything, "ABCD-EFGH"). + Return(oauth2.DeviceCode{Provider: "google"}, nil) + provider.On("IsEnabled").Return(true) + oauthSvc.On("VerifyDevice", mock.Anything, provider, "ABCD-EFGH", "", false). + Return(nil) + }, + enabled: true, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, rec *httptest.ResponseRecorder) { + var resp map[string]string + err := json.NewDecoder(rec.Body).Decode(&resp) + assert.NoError(t, err) + assert.Equal(t, "denied", resp["status"]) + }, + }, + { + name: "invalid user code", + userCode: "INVALID", + code: "", + approve: false, + setupMocks: func(oauthSvc *mocks.Service, provider *mocks.Provider) { + oauthSvc.On("GetDeviceCodeByUserCode", mock.Anything, "INVALID"). + Return(oauth2.DeviceCode{}, useroauth.ErrUserCodeNotFound) + }, + enabled: true, + expectedStatus: http.StatusNotFound, + checkResponse: func(t *testing.T, rec *httptest.ResponseRecorder) { + var resp errorResponse + err := json.NewDecoder(rec.Body).Decode(&resp) + assert.NoError(t, err) + assert.Contains(t, resp.Error, "invalid user code") + }, + }, + { + name: "provider disabled", + userCode: "ABCD-EFGH", + code: "", + approve: false, + setupMocks: func(oauthSvc *mocks.Service, provider *mocks.Provider) { + oauthSvc.On("GetDeviceCodeByUserCode", mock.Anything, "ABCD-EFGH"). + Return(oauth2.DeviceCode{Provider: "google"}, nil) + provider.On("IsEnabled").Return(false) + }, + enabled: false, + expectedStatus: http.StatusNotFound, + checkResponse: func(t *testing.T, rec *httptest.ResponseRecorder) { + var resp errorResponse + err := json.NewDecoder(rec.Body).Decode(&resp) + assert.NoError(t, err) + assert.Equal(t, errProviderDisabled, resp.Error) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + provider := new(mocks.Provider) + provider.On("Name").Return("google") + oauthSvc := new(mocks.Service) + + tc.setupMocks(oauthSvc, provider) + + handler := oauthhttp.DeviceVerifyHandler(oauthSvc, provider) + + reqBody, _ := json.Marshal(map[string]interface{}{ + "user_code": tc.userCode, + "code": tc.code, + "approve": tc.approve, + }) + req := httptest.NewRequest(http.MethodPost, "/oauth/device/verify", bytes.NewReader(reqBody)) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectedStatus, rec.Code) + tc.checkResponse(t, rec) + }) + } +} diff --git a/pkg/oauth2/http/oauth_html.go b/pkg/oauth2/http/oauth_html.go new file mode 100644 index 0000000000..1c9db2f317 --- /dev/null +++ b/pkg/oauth2/http/oauth_html.go @@ -0,0 +1,666 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package http + +const successHTML = ` + + + + + Authentication Successful - Magistrala + + + +
+
+ +
+ +
+ + + +
+ +

Authentication Successful!

+

You have been successfully authenticated.

+

You can now close this window and return to the CLI.

+ + +
+ +` + +const errorHTML = ` + + + + + Authentication Failed - Magistrala + + + +
+
+ +
+ +
+
+
+ +

Authentication Failed

+

We encountered an error during authentication.

+ +
+ {{ERROR_MESSAGE}} +
+ +

Please close this window and try again.

+ + +
+ +` + +const deviceVerifyHTML = ` + + + + + Device Verification - Magistrala + + + +
+

Device Verification

+

Enter the code displayed on your device

+
+ + +
+
+
+ + + +` diff --git a/pkg/oauth2/http/oauth_test.go b/pkg/oauth2/http/oauth_test.go new file mode 100644 index 0000000000..e40d34424b --- /dev/null +++ b/pkg/oauth2/http/oauth_test.go @@ -0,0 +1,770 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package http_test + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1" + authmocks "github.com/absmach/supermq/auth/mocks" + "github.com/absmach/supermq/internal/testsutil" + "github.com/absmach/supermq/pkg/oauth2" + oauthhttp "github.com/absmach/supermq/pkg/oauth2/http" + oauth2mocks "github.com/absmach/supermq/pkg/oauth2/mocks" + "github.com/absmach/supermq/pkg/oauth2/store" + "github.com/absmach/supermq/users" + usermocks "github.com/absmach/supermq/users/mocks" + "github.com/go-chi/chi/v5" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + goauth2 "golang.org/x/oauth2" +) + +func TestOAuthAuthorizeEndpoint(t *testing.T) { + svc := new(usermocks.Service) + token := new(authmocks.TokenServiceClient) + + cases := []struct { + name string + provider string + redirectURI string + providerName string + providerEnabled bool + expectedStatus int + checkResponse func(t *testing.T, res *http.Response) + }{ + { + name: "get authorization URL successfully", + provider: "google", + redirectURI: "", + providerName: "google", + providerEnabled: true, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, res *http.Response) { + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + + var resp map[string]string + err = json.Unmarshal(body, &resp) + assert.NoError(t, err) + assert.Contains(t, resp, "authorization_url") + assert.Contains(t, resp, "state") + assert.NotEmpty(t, resp["authorization_url"]) + assert.NotEmpty(t, resp["state"]) + }, + }, + { + name: "get authorization URL with custom redirect", + provider: "google", + redirectURI: "http://localhost:9090/callback", + providerName: "google", + providerEnabled: true, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, res *http.Response) { + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + + var resp map[string]string + err = json.Unmarshal(body, &resp) + assert.NoError(t, err) + assert.Contains(t, resp, "authorization_url") + assert.Contains(t, resp["authorization_url"], "redirect_uri") + }, + }, + { + name: "provider disabled", + provider: "google", + redirectURI: "", + providerName: "google", + providerEnabled: false, + expectedStatus: http.StatusNotFound, + checkResponse: func(t *testing.T, res *http.Response) { + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + + var resp map[string]string + err = json.Unmarshal(body, &resp) + assert.NoError(t, err) + assert.Contains(t, resp, "error") + assert.Equal(t, "oauth provider is disabled", resp["error"]) + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + provider := new(oauth2mocks.Provider) + provider.On("Name").Return(tc.providerName) + provider.On("IsEnabled").Return(tc.providerEnabled) + provider.On("GetAuthURL").Return("https://accounts.google.com/o/oauth2/auth?client_id=test&state=test") + provider.On("GetAuthURLWithRedirect", mock.Anything).Return("https://accounts.google.com/o/oauth2/auth?client_id=test&state=test&redirect_uri=" + tc.redirectURI) + provider.On("State").Return("test-state") + + mux := chi.NewRouter() + redisClient := redis.NewClient(&redis.Options{Addr: "localhost:6379"}) + makeHandler(svc, token, mux, redisClient, provider) + + ts := httptest.NewServer(mux) + defer ts.Close() + + url := fmt.Sprintf("%s/oauth/authorize/%s", ts.URL, tc.provider) + if tc.redirectURI != "" { + url = fmt.Sprintf("%s?redirect_uri=%s", url, tc.redirectURI) + } + + req, err := http.NewRequest(http.MethodGet, url, nil) + assert.NoError(t, err) + + res, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer res.Body.Close() + + assert.Equal(t, tc.expectedStatus, res.StatusCode) + if tc.checkResponse != nil { + tc.checkResponse(t, res) + } + }) + } +} + +func TestOAuthCLICallbackEndpoint(t *testing.T) { + svc := new(usermocks.Service) + + validUserID := testsutil.GenerateUUID(t) + validUser := users.User{ + ID: validUserID, + Email: "test@example.com", + Credentials: users.Credentials{ + Username: "testuser", + }, + Status: users.EnabledStatus, + } + + cases := []struct { + name string + provider string + providerName string + providerEnabled bool + requestBody string + mockSetup func(*oauth2mocks.Provider, *usermocks.Service, *authmocks.TokenServiceClient) + expectedStatus int + checkResponse func(t *testing.T, res *http.Response) + }{ + { + name: "successful OAuth callback", + provider: "google", + providerName: "google", + providerEnabled: true, + requestBody: `{"code":"test-code","state":"test-state","redirect_url":"http://localhost:9090/callback"}`, + mockSetup: func(provider *oauth2mocks.Provider, svc *usermocks.Service, tokenClient *authmocks.TokenServiceClient) { + provider.On("ExchangeWithRedirect", mock.Anything, "test-code", "http://localhost:9090/callback"). + Return(goauth2.Token{AccessToken: "access-token"}, nil) + provider.On("UserInfo", "access-token").Return(validUser, nil) + svc.On("OAuthCallback", mock.Anything, mock.MatchedBy(func(u users.User) bool { + return u.Email == validUser.Email && u.AuthProvider == "google" + })).Return(validUser, nil) + svc.On("OAuthAddUserPolicy", mock.Anything, validUser).Return(nil) + refreshToken := "jwt-refresh-token" + tokenClient.On("Issue", mock.Anything, mock.Anything). + Return(&grpcTokenV1.Token{ + AccessToken: "jwt-access-token", + RefreshToken: &refreshToken, + }, nil) + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, res *http.Response) { + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + + var resp map[string]string + err = json.Unmarshal(body, &resp) + assert.NoError(t, err) + assert.Equal(t, "jwt-access-token", resp["access_token"]) + assert.Equal(t, "jwt-refresh-token", resp["refresh_token"]) + }, + }, + { + name: "OAuth callback without redirect URL", + provider: "google", + providerName: "google", + providerEnabled: true, + requestBody: `{"code":"test-code","state":"test-state"}`, + mockSetup: func(provider *oauth2mocks.Provider, svc *usermocks.Service, tokenClient *authmocks.TokenServiceClient) { + provider.On("Exchange", mock.Anything, "test-code"). + Return(goauth2.Token{AccessToken: "access-token"}, nil) + provider.On("UserInfo", "access-token").Return(validUser, nil) + svc.On("OAuthCallback", mock.Anything, mock.MatchedBy(func(u users.User) bool { + return u.Email == validUser.Email && u.AuthProvider == "google" + })).Return(validUser, nil) + svc.On("OAuthAddUserPolicy", mock.Anything, validUser).Return(nil) + refreshToken := "jwt-refresh-token" + tokenClient.On("Issue", mock.Anything, mock.Anything). + Return(&grpcTokenV1.Token{ + AccessToken: "jwt-access-token", + RefreshToken: &refreshToken, + }, nil) + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, res *http.Response) { + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + + var resp map[string]string + err = json.Unmarshal(body, &resp) + assert.NoError(t, err) + assert.Equal(t, "jwt-access-token", resp["access_token"]) + }, + }, + { + name: "provider disabled", + provider: "google", + providerName: "google", + providerEnabled: false, + requestBody: `{"code":"test-code","state":"test-state"}`, + mockSetup: func(provider *oauth2mocks.Provider, svc *usermocks.Service, tokenClient *authmocks.TokenServiceClient) { + }, + expectedStatus: http.StatusNotFound, + checkResponse: func(t *testing.T, res *http.Response) { + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + + var resp map[string]string + err = json.Unmarshal(body, &resp) + assert.NoError(t, err) + assert.Contains(t, resp, "error") + assert.Equal(t, "oauth provider is disabled", resp["error"]) + }, + }, + { + name: "invalid request body", + provider: "google", + providerName: "google", + providerEnabled: true, + requestBody: `invalid json`, + mockSetup: func(provider *oauth2mocks.Provider, svc *usermocks.Service, tokenClient *authmocks.TokenServiceClient) { + }, + expectedStatus: http.StatusBadRequest, + checkResponse: func(t *testing.T, res *http.Response) { + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + + var resp map[string]string + err = json.Unmarshal(body, &resp) + assert.NoError(t, err) + assert.Contains(t, resp, "error") + }, + }, + { + name: "invalid state", + provider: "google", + providerName: "google", + providerEnabled: true, + requestBody: `{"code":"test-code","state":"wrong-state"}`, + mockSetup: func(provider *oauth2mocks.Provider, svc *usermocks.Service, tokenClient *authmocks.TokenServiceClient) { + }, + expectedStatus: http.StatusBadRequest, + checkResponse: func(t *testing.T, res *http.Response) { + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + + var resp map[string]string + err = json.Unmarshal(body, &resp) + assert.NoError(t, err) + assert.Equal(t, "invalid state", resp["error"]) + }, + }, + { + name: "empty code", + provider: "google", + providerName: "google", + providerEnabled: true, + requestBody: `{"code":"","state":"test-state"}`, + mockSetup: func(provider *oauth2mocks.Provider, svc *usermocks.Service, tokenClient *authmocks.TokenServiceClient) { + }, + expectedStatus: http.StatusBadRequest, + checkResponse: func(t *testing.T, res *http.Response) { + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + + var resp map[string]string + err = json.Unmarshal(body, &resp) + assert.NoError(t, err) + assert.Equal(t, "empty code", resp["error"]) + }, + }, + { + name: "exchange token error", + provider: "google", + providerName: "google", + providerEnabled: true, + requestBody: `{"code":"test-code","state":"test-state"}`, + mockSetup: func(provider *oauth2mocks.Provider, svc *usermocks.Service, tokenClient *authmocks.TokenServiceClient) { + provider.On("Exchange", mock.Anything, "test-code"). + Return(goauth2.Token{}, fmt.Errorf("exchange failed")) + }, + expectedStatus: http.StatusUnauthorized, + checkResponse: func(t *testing.T, res *http.Response) { + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + + var resp map[string]string + err = json.Unmarshal(body, &resp) + assert.NoError(t, err) + assert.Contains(t, resp, "error") + }, + }, + { + name: "user info retrieval error", + provider: "google", + providerName: "google", + providerEnabled: true, + requestBody: `{"code":"test-code","state":"test-state"}`, + mockSetup: func(provider *oauth2mocks.Provider, svc *usermocks.Service, tokenClient *authmocks.TokenServiceClient) { + provider.On("Exchange", mock.Anything, "test-code"). + Return(goauth2.Token{AccessToken: "access-token"}, nil) + provider.On("UserInfo", "access-token").Return(users.User{}, fmt.Errorf("user info failed")) + }, + expectedStatus: http.StatusInternalServerError, + checkResponse: func(t *testing.T, res *http.Response) { + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + + var resp map[string]string + err = json.Unmarshal(body, &resp) + assert.NoError(t, err) + assert.Contains(t, resp, "error") + }, + }, + { + name: "OAuth callback service error", + provider: "google", + providerName: "google", + providerEnabled: true, + requestBody: `{"code":"test-code","state":"test-state"}`, + mockSetup: func(provider *oauth2mocks.Provider, svc *usermocks.Service, tokenClient *authmocks.TokenServiceClient) { + provider.On("Exchange", mock.Anything, "test-code"). + Return(goauth2.Token{AccessToken: "access-token"}, nil) + provider.On("UserInfo", "access-token").Return(validUser, nil) + svc.On("OAuthCallback", mock.Anything, mock.Anything). + Return(users.User{}, fmt.Errorf("service error")) + }, + expectedStatus: http.StatusInternalServerError, + checkResponse: func(t *testing.T, res *http.Response) { + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + + var resp map[string]string + err = json.Unmarshal(body, &resp) + assert.NoError(t, err) + assert.Contains(t, resp, "error") + }, + }, + { + name: "add user policy error", + provider: "google", + providerName: "google", + providerEnabled: true, + requestBody: `{"code":"test-code","state":"test-state"}`, + mockSetup: func(provider *oauth2mocks.Provider, svc *usermocks.Service, tokenClient *authmocks.TokenServiceClient) { + provider.On("Exchange", mock.Anything, "test-code"). + Return(goauth2.Token{AccessToken: "access-token"}, nil) + provider.On("UserInfo", "access-token").Return(validUser, nil) + svc.On("OAuthCallback", mock.Anything, mock.MatchedBy(func(u users.User) bool { + return u.Email == validUser.Email + })).Return(validUser, nil) + svc.On("OAuthAddUserPolicy", mock.Anything, validUser). + Return(fmt.Errorf("policy error")) + }, + expectedStatus: http.StatusInternalServerError, + checkResponse: func(t *testing.T, res *http.Response) { + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + + var resp map[string]string + err = json.Unmarshal(body, &resp) + assert.NoError(t, err) + assert.Contains(t, resp, "error") + }, + }, + { + name: "token issuance error", + provider: "google", + providerName: "google", + providerEnabled: true, + requestBody: `{"code":"test-code","state":"test-state"}`, + mockSetup: func(provider *oauth2mocks.Provider, svc *usermocks.Service, tokenClient *authmocks.TokenServiceClient) { + provider.On("Exchange", mock.Anything, "test-code"). + Return(goauth2.Token{AccessToken: "access-token"}, nil) + provider.On("UserInfo", "access-token").Return(validUser, nil) + svc.On("OAuthCallback", mock.Anything, mock.MatchedBy(func(u users.User) bool { + return u.Email == validUser.Email + })).Return(validUser, nil) + svc.On("OAuthAddUserPolicy", mock.Anything, validUser).Return(nil) + tokenClient.On("Issue", mock.Anything, mock.Anything). + Return(nil, fmt.Errorf("token issuance failed")) + }, + expectedStatus: http.StatusInternalServerError, + checkResponse: func(t *testing.T, res *http.Response) { + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + + var resp map[string]string + err = json.Unmarshal(body, &resp) + assert.NoError(t, err) + assert.Contains(t, resp, "error") + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + provider := new(oauth2mocks.Provider) + provider.On("Name").Return(tc.providerName) + provider.On("IsEnabled").Return(tc.providerEnabled) + provider.On("State").Return("test-state") + + tokenClient := new(authmocks.TokenServiceClient) + + tc.mockSetup(provider, svc, tokenClient) + + mux := chi.NewRouter() + redisClient := redis.NewClient(&redis.Options{Addr: "localhost:6379"}) + makeHandler(svc, tokenClient, mux, redisClient, provider) + + ts := httptest.NewServer(mux) + defer ts.Close() + + url := fmt.Sprintf("%s/oauth/cli/callback/%s", ts.URL, tc.provider) + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(tc.requestBody)) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + res, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer res.Body.Close() + + assert.Equal(t, tc.expectedStatus, res.StatusCode) + if tc.checkResponse != nil { + tc.checkResponse(t, res) + } + + // Reset mocks for next test + svc.ExpectedCalls = nil + tokenClient.ExpectedCalls = nil + }) + } +} + +func TestOAuthCallbackEndpoint(t *testing.T) { + svc := new(usermocks.Service) + + validUserID := testsutil.GenerateUUID(t) + validUser := users.User{ + ID: validUserID, + Email: "test@example.com", + Credentials: users.Credentials{ + Username: "testuser", + }, + Status: users.EnabledStatus, + } + + cases := []struct { + name string + provider string + providerName string + providerEnabled bool + queryParams map[string]string + mockSetup func(*oauth2mocks.Provider, *usermocks.Service, *authmocks.TokenServiceClient) + expectedStatus int + checkResponse func(t *testing.T, res *http.Response) + }{ + { + name: "successful OAuth web callback", + provider: "google", + providerName: "google", + providerEnabled: true, + queryParams: map[string]string{ + "state": "test-state", + "code": "test-code", + }, + mockSetup: func(provider *oauth2mocks.Provider, svc *usermocks.Service, tokenClient *authmocks.TokenServiceClient) { + provider.On("Exchange", mock.Anything, "test-code"). + Return(goauth2.Token{AccessToken: "access-token"}, nil) + provider.On("UserInfo", "access-token").Return(validUser, nil) + svc.On("OAuthCallback", mock.Anything, mock.MatchedBy(func(u users.User) bool { + return u.Email == validUser.Email && u.AuthProvider == "google" + })).Return(validUser, nil) + svc.On("OAuthAddUserPolicy", mock.Anything, validUser).Return(nil) + refreshToken := "jwt-refresh-token" + tokenClient.On("Issue", mock.Anything, mock.Anything). + Return(&grpcTokenV1.Token{ + AccessToken: "jwt-access-token", + RefreshToken: &refreshToken, + }, nil) + }, + expectedStatus: http.StatusFound, + checkResponse: func(t *testing.T, res *http.Response) { + // Check that it redirects to the redirect URL + assert.Contains(t, res.Header.Get("Location"), "http://localhost/ui") + // Check cookies are set + cookies := res.Cookies() + assert.NotEmpty(t, cookies) + foundAccessToken := false + foundRefreshToken := false + for _, cookie := range cookies { + if cookie.Name == "access_token" { + foundAccessToken = true + assert.Equal(t, "jwt-access-token", cookie.Value) + } + if cookie.Name == "refresh_token" { + foundRefreshToken = true + assert.Equal(t, "jwt-refresh-token", cookie.Value) + } + } + assert.True(t, foundAccessToken, "access_token cookie not found") + assert.True(t, foundRefreshToken, "refresh_token cookie not found") + }, + }, + { + name: "provider disabled", + provider: "google", + providerName: "google", + providerEnabled: false, + queryParams: map[string]string{ + "state": "test-state", + "code": "test-code", + }, + mockSetup: func(provider *oauth2mocks.Provider, svc *usermocks.Service, tokenClient *authmocks.TokenServiceClient) { + }, + expectedStatus: http.StatusSeeOther, + checkResponse: func(t *testing.T, res *http.Response) { + location := res.Header.Get("Location") + assert.Contains(t, location, "error=") + assert.Contains(t, location, "oauth") + }, + }, + { + name: "invalid state", + provider: "google", + providerName: "google", + providerEnabled: true, + queryParams: map[string]string{ + "state": "wrong-state", + "code": "test-code", + }, + mockSetup: func(provider *oauth2mocks.Provider, svc *usermocks.Service, tokenClient *authmocks.TokenServiceClient) { + }, + expectedStatus: http.StatusSeeOther, + checkResponse: func(t *testing.T, res *http.Response) { + location := res.Header.Get("Location") + assert.Contains(t, location, "error=") + assert.Contains(t, location, "state") + }, + }, + { + name: "empty code", + provider: "google", + providerName: "google", + providerEnabled: true, + queryParams: map[string]string{ + "state": "test-state", + "code": "", + }, + mockSetup: func(provider *oauth2mocks.Provider, svc *usermocks.Service, tokenClient *authmocks.TokenServiceClient) { + }, + expectedStatus: http.StatusSeeOther, + checkResponse: func(t *testing.T, res *http.Response) { + location := res.Header.Get("Location") + assert.Contains(t, location, "error=") + assert.Contains(t, location, "code") + }, + }, + { + name: "exchange token error", + provider: "google", + providerName: "google", + providerEnabled: true, + queryParams: map[string]string{ + "state": "test-state", + "code": "test-code", + }, + mockSetup: func(provider *oauth2mocks.Provider, svc *usermocks.Service, tokenClient *authmocks.TokenServiceClient) { + provider.On("Exchange", mock.Anything, "test-code"). + Return(goauth2.Token{}, fmt.Errorf("exchange failed")) + }, + expectedStatus: http.StatusSeeOther, + checkResponse: func(t *testing.T, res *http.Response) { + assert.Contains(t, res.Header.Get("Location"), "error=") + }, + }, + { + name: "user info retrieval error", + provider: "google", + providerName: "google", + providerEnabled: true, + queryParams: map[string]string{ + "state": "test-state", + "code": "test-code", + }, + mockSetup: func(provider *oauth2mocks.Provider, svc *usermocks.Service, tokenClient *authmocks.TokenServiceClient) { + provider.On("Exchange", mock.Anything, "test-code"). + Return(goauth2.Token{AccessToken: "access-token"}, nil) + provider.On("UserInfo", "access-token").Return(users.User{}, fmt.Errorf("user info failed")) + }, + expectedStatus: http.StatusSeeOther, + checkResponse: func(t *testing.T, res *http.Response) { + assert.Contains(t, res.Header.Get("Location"), "error=") + }, + }, + { + name: "OAuth callback service error", + provider: "google", + providerName: "google", + providerEnabled: true, + queryParams: map[string]string{ + "state": "test-state", + "code": "test-code", + }, + mockSetup: func(provider *oauth2mocks.Provider, svc *usermocks.Service, tokenClient *authmocks.TokenServiceClient) { + provider.On("Exchange", mock.Anything, "test-code"). + Return(goauth2.Token{AccessToken: "access-token"}, nil) + provider.On("UserInfo", "access-token").Return(validUser, nil) + svc.On("OAuthCallback", mock.Anything, mock.Anything). + Return(users.User{}, fmt.Errorf("service error")) + }, + expectedStatus: http.StatusSeeOther, + checkResponse: func(t *testing.T, res *http.Response) { + assert.Contains(t, res.Header.Get("Location"), "error=") + }, + }, + { + name: "add user policy error", + provider: "google", + providerName: "google", + providerEnabled: true, + queryParams: map[string]string{ + "state": "test-state", + "code": "test-code", + }, + mockSetup: func(provider *oauth2mocks.Provider, svc *usermocks.Service, tokenClient *authmocks.TokenServiceClient) { + provider.On("Exchange", mock.Anything, "test-code"). + Return(goauth2.Token{AccessToken: "access-token"}, nil) + provider.On("UserInfo", "access-token").Return(validUser, nil) + svc.On("OAuthCallback", mock.Anything, mock.MatchedBy(func(u users.User) bool { + return u.Email == validUser.Email + })).Return(validUser, nil) + svc.On("OAuthAddUserPolicy", mock.Anything, validUser). + Return(fmt.Errorf("policy error")) + }, + expectedStatus: http.StatusSeeOther, + checkResponse: func(t *testing.T, res *http.Response) { + assert.Contains(t, res.Header.Get("Location"), "error=") + }, + }, + { + name: "token issuance error", + provider: "google", + providerName: "google", + providerEnabled: true, + queryParams: map[string]string{ + "state": "test-state", + "code": "test-code", + }, + mockSetup: func(provider *oauth2mocks.Provider, svc *usermocks.Service, tokenClient *authmocks.TokenServiceClient) { + provider.On("Exchange", mock.Anything, "test-code"). + Return(goauth2.Token{AccessToken: "access-token"}, nil) + provider.On("UserInfo", "access-token").Return(validUser, nil) + svc.On("OAuthCallback", mock.Anything, mock.MatchedBy(func(u users.User) bool { + return u.Email == validUser.Email + })).Return(validUser, nil) + svc.On("OAuthAddUserPolicy", mock.Anything, validUser).Return(nil) + tokenClient.On("Issue", mock.Anything, mock.Anything). + Return(nil, fmt.Errorf("token issuance failed")) + }, + expectedStatus: http.StatusSeeOther, + checkResponse: func(t *testing.T, res *http.Response) { + assert.Contains(t, res.Header.Get("Location"), "error=") + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + provider := new(oauth2mocks.Provider) + provider.On("Name").Return(tc.providerName) + provider.On("IsEnabled").Return(tc.providerEnabled) + provider.On("State").Return("test-state") + provider.On("ErrorURL").Return("http://localhost/error") + provider.On("RedirectURL").Return("http://localhost/ui") + + tokenClient := new(authmocks.TokenServiceClient) + + tc.mockSetup(provider, svc, tokenClient) + + mux := chi.NewRouter() + redisClient := redis.NewClient(&redis.Options{Addr: "localhost:6379"}) + makeHandler(svc, tokenClient, mux, redisClient, provider) + + ts := httptest.NewServer(mux) + defer ts.Close() + + // Build query string + url := fmt.Sprintf("%s/oauth/callback/%s", ts.URL, tc.provider) + if len(tc.queryParams) > 0 { + url += "?" + first := true + for k, v := range tc.queryParams { + if !first { + url += "&" + } + url += fmt.Sprintf("%s=%s", k, v) + first = false + } + } + + req, err := http.NewRequest(http.MethodGet, url, nil) + assert.NoError(t, err) + + // Don't follow redirects to check the redirect URL + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + res, err := client.Do(req) + assert.NoError(t, err) + defer res.Body.Close() + + assert.Equal(t, tc.expectedStatus, res.StatusCode) + if tc.checkResponse != nil { + tc.checkResponse(t, res) + } + + // Reset mocks for next test + svc.ExpectedCalls = nil + tokenClient.ExpectedCalls = nil + }) + } +} + +func makeHandler(svc users.Service, tokensvc grpcTokenV1.TokenServiceClient, mux *chi.Mux, cacheClient *redis.Client, providers ...oauth2.Provider) http.Handler { + ctx := context.Background() + + deviceStore := store.NewRedisDeviceCodeStore(ctx, cacheClient) + oauthSvc := oauth2.NewOAuthService(deviceStore, svc, tokensvc) + + mux = oauthhttp.Handler(mux, tokensvc, oauthSvc, providers...) + mux = oauthhttp.DeviceHandler(mux, tokensvc, oauthSvc, providers...) + + return mux +} diff --git a/pkg/oauth2/mocks/device_code_store.go b/pkg/oauth2/mocks/device_code_store.go new file mode 100644 index 0000000000..055af014e1 --- /dev/null +++ b/pkg/oauth2/mocks/device_code_store.go @@ -0,0 +1,314 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "github.com/absmach/supermq/pkg/oauth2" + mock "github.com/stretchr/testify/mock" +) + +// NewDeviceCodeStore creates a new instance of DeviceCodeStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewDeviceCodeStore(t interface { + mock.TestingT + Cleanup(func()) +}) *DeviceCodeStore { + mock := &DeviceCodeStore{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// DeviceCodeStore is an autogenerated mock type for the DeviceCodeStore type +type DeviceCodeStore struct { + mock.Mock +} + +type DeviceCodeStore_Expecter struct { + mock *mock.Mock +} + +func (_m *DeviceCodeStore) EXPECT() *DeviceCodeStore_Expecter { + return &DeviceCodeStore_Expecter{mock: &_m.Mock} +} + +// Delete provides a mock function for the type DeviceCodeStore +func (_mock *DeviceCodeStore) Delete(deviceCode string) error { + ret := _mock.Called(deviceCode) + + if len(ret) == 0 { + panic("no return value specified for Delete") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(string) error); ok { + r0 = returnFunc(deviceCode) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// DeviceCodeStore_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' +type DeviceCodeStore_Delete_Call struct { + *mock.Call +} + +// Delete is a helper method to define mock.On call +// - deviceCode string +func (_e *DeviceCodeStore_Expecter) Delete(deviceCode interface{}) *DeviceCodeStore_Delete_Call { + return &DeviceCodeStore_Delete_Call{Call: _e.mock.On("Delete", deviceCode)} +} + +func (_c *DeviceCodeStore_Delete_Call) Run(run func(deviceCode string)) *DeviceCodeStore_Delete_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *DeviceCodeStore_Delete_Call) Return(err error) *DeviceCodeStore_Delete_Call { + _c.Call.Return(err) + return _c +} + +func (_c *DeviceCodeStore_Delete_Call) RunAndReturn(run func(deviceCode string) error) *DeviceCodeStore_Delete_Call { + _c.Call.Return(run) + return _c +} + +// Get provides a mock function for the type DeviceCodeStore +func (_mock *DeviceCodeStore) Get(deviceCode string) (oauth2.DeviceCode, error) { + ret := _mock.Called(deviceCode) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 oauth2.DeviceCode + var r1 error + if returnFunc, ok := ret.Get(0).(func(string) (oauth2.DeviceCode, error)); ok { + return returnFunc(deviceCode) + } + if returnFunc, ok := ret.Get(0).(func(string) oauth2.DeviceCode); ok { + r0 = returnFunc(deviceCode) + } else { + r0 = ret.Get(0).(oauth2.DeviceCode) + } + if returnFunc, ok := ret.Get(1).(func(string) error); ok { + r1 = returnFunc(deviceCode) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// DeviceCodeStore_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type DeviceCodeStore_Get_Call struct { + *mock.Call +} + +// Get is a helper method to define mock.On call +// - deviceCode string +func (_e *DeviceCodeStore_Expecter) Get(deviceCode interface{}) *DeviceCodeStore_Get_Call { + return &DeviceCodeStore_Get_Call{Call: _e.mock.On("Get", deviceCode)} +} + +func (_c *DeviceCodeStore_Get_Call) Run(run func(deviceCode string)) *DeviceCodeStore_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *DeviceCodeStore_Get_Call) Return(deviceCode1 oauth2.DeviceCode, err error) *DeviceCodeStore_Get_Call { + _c.Call.Return(deviceCode1, err) + return _c +} + +func (_c *DeviceCodeStore_Get_Call) RunAndReturn(run func(deviceCode string) (oauth2.DeviceCode, error)) *DeviceCodeStore_Get_Call { + _c.Call.Return(run) + return _c +} + +// GetByUserCode provides a mock function for the type DeviceCodeStore +func (_mock *DeviceCodeStore) GetByUserCode(userCode string) (oauth2.DeviceCode, error) { + ret := _mock.Called(userCode) + + if len(ret) == 0 { + panic("no return value specified for GetByUserCode") + } + + var r0 oauth2.DeviceCode + var r1 error + if returnFunc, ok := ret.Get(0).(func(string) (oauth2.DeviceCode, error)); ok { + return returnFunc(userCode) + } + if returnFunc, ok := ret.Get(0).(func(string) oauth2.DeviceCode); ok { + r0 = returnFunc(userCode) + } else { + r0 = ret.Get(0).(oauth2.DeviceCode) + } + if returnFunc, ok := ret.Get(1).(func(string) error); ok { + r1 = returnFunc(userCode) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// DeviceCodeStore_GetByUserCode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetByUserCode' +type DeviceCodeStore_GetByUserCode_Call struct { + *mock.Call +} + +// GetByUserCode is a helper method to define mock.On call +// - userCode string +func (_e *DeviceCodeStore_Expecter) GetByUserCode(userCode interface{}) *DeviceCodeStore_GetByUserCode_Call { + return &DeviceCodeStore_GetByUserCode_Call{Call: _e.mock.On("GetByUserCode", userCode)} +} + +func (_c *DeviceCodeStore_GetByUserCode_Call) Run(run func(userCode string)) *DeviceCodeStore_GetByUserCode_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *DeviceCodeStore_GetByUserCode_Call) Return(deviceCode oauth2.DeviceCode, err error) *DeviceCodeStore_GetByUserCode_Call { + _c.Call.Return(deviceCode, err) + return _c +} + +func (_c *DeviceCodeStore_GetByUserCode_Call) RunAndReturn(run func(userCode string) (oauth2.DeviceCode, error)) *DeviceCodeStore_GetByUserCode_Call { + _c.Call.Return(run) + return _c +} + +// Save provides a mock function for the type DeviceCodeStore +func (_mock *DeviceCodeStore) Save(code oauth2.DeviceCode) error { + ret := _mock.Called(code) + + if len(ret) == 0 { + panic("no return value specified for Save") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(oauth2.DeviceCode) error); ok { + r0 = returnFunc(code) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// DeviceCodeStore_Save_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Save' +type DeviceCodeStore_Save_Call struct { + *mock.Call +} + +// Save is a helper method to define mock.On call +// - code oauth2.DeviceCode +func (_e *DeviceCodeStore_Expecter) Save(code interface{}) *DeviceCodeStore_Save_Call { + return &DeviceCodeStore_Save_Call{Call: _e.mock.On("Save", code)} +} + +func (_c *DeviceCodeStore_Save_Call) Run(run func(code oauth2.DeviceCode)) *DeviceCodeStore_Save_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 oauth2.DeviceCode + if args[0] != nil { + arg0 = args[0].(oauth2.DeviceCode) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *DeviceCodeStore_Save_Call) Return(err error) *DeviceCodeStore_Save_Call { + _c.Call.Return(err) + return _c +} + +func (_c *DeviceCodeStore_Save_Call) RunAndReturn(run func(code oauth2.DeviceCode) error) *DeviceCodeStore_Save_Call { + _c.Call.Return(run) + return _c +} + +// Update provides a mock function for the type DeviceCodeStore +func (_mock *DeviceCodeStore) Update(code oauth2.DeviceCode) error { + ret := _mock.Called(code) + + if len(ret) == 0 { + panic("no return value specified for Update") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(oauth2.DeviceCode) error); ok { + r0 = returnFunc(code) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// DeviceCodeStore_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update' +type DeviceCodeStore_Update_Call struct { + *mock.Call +} + +// Update is a helper method to define mock.On call +// - code oauth2.DeviceCode +func (_e *DeviceCodeStore_Expecter) Update(code interface{}) *DeviceCodeStore_Update_Call { + return &DeviceCodeStore_Update_Call{Call: _e.mock.On("Update", code)} +} + +func (_c *DeviceCodeStore_Update_Call) Run(run func(code oauth2.DeviceCode)) *DeviceCodeStore_Update_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 oauth2.DeviceCode + if args[0] != nil { + arg0 = args[0].(oauth2.DeviceCode) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *DeviceCodeStore_Update_Call) Return(err error) *DeviceCodeStore_Update_Call { + _c.Call.Return(err) + return _c +} + +func (_c *DeviceCodeStore_Update_Call) RunAndReturn(run func(code oauth2.DeviceCode) error) *DeviceCodeStore_Update_Call { + _c.Call.Return(run) + return _c +} diff --git a/pkg/oauth2/mocks/provider.go b/pkg/oauth2/mocks/provider.go index f7f1678b27..7677741ff8 100644 --- a/pkg/oauth2/mocks/provider.go +++ b/pkg/oauth2/mocks/provider.go @@ -153,6 +153,173 @@ func (_c *Provider_Exchange_Call) RunAndReturn(run func(ctx context.Context, cod return _c } +// ExchangeWithRedirect provides a mock function for the type Provider +func (_mock *Provider) ExchangeWithRedirect(ctx context.Context, code string, redirectURL string) (oauth2.Token, error) { + ret := _mock.Called(ctx, code, redirectURL) + + if len(ret) == 0 { + panic("no return value specified for ExchangeWithRedirect") + } + + var r0 oauth2.Token + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (oauth2.Token, error)); ok { + return returnFunc(ctx, code, redirectURL) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) oauth2.Token); ok { + r0 = returnFunc(ctx, code, redirectURL) + } else { + r0 = ret.Get(0).(oauth2.Token) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = returnFunc(ctx, code, redirectURL) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Provider_ExchangeWithRedirect_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ExchangeWithRedirect' +type Provider_ExchangeWithRedirect_Call struct { + *mock.Call +} + +// ExchangeWithRedirect is a helper method to define mock.On call +// - ctx context.Context +// - code string +// - redirectURL string +func (_e *Provider_Expecter) ExchangeWithRedirect(ctx interface{}, code interface{}, redirectURL interface{}) *Provider_ExchangeWithRedirect_Call { + return &Provider_ExchangeWithRedirect_Call{Call: _e.mock.On("ExchangeWithRedirect", ctx, code, redirectURL)} +} + +func (_c *Provider_ExchangeWithRedirect_Call) Run(run func(ctx context.Context, code string, redirectURL string)) *Provider_ExchangeWithRedirect_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Provider_ExchangeWithRedirect_Call) Return(token oauth2.Token, err error) *Provider_ExchangeWithRedirect_Call { + _c.Call.Return(token, err) + return _c +} + +func (_c *Provider_ExchangeWithRedirect_Call) RunAndReturn(run func(ctx context.Context, code string, redirectURL string) (oauth2.Token, error)) *Provider_ExchangeWithRedirect_Call { + _c.Call.Return(run) + return _c +} + +// GetAuthURL provides a mock function for the type Provider +func (_mock *Provider) GetAuthURL() string { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for GetAuthURL") + } + + var r0 string + if returnFunc, ok := ret.Get(0).(func() string); ok { + r0 = returnFunc() + } else { + r0 = ret.Get(0).(string) + } + return r0 +} + +// Provider_GetAuthURL_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAuthURL' +type Provider_GetAuthURL_Call struct { + *mock.Call +} + +// GetAuthURL is a helper method to define mock.On call +func (_e *Provider_Expecter) GetAuthURL() *Provider_GetAuthURL_Call { + return &Provider_GetAuthURL_Call{Call: _e.mock.On("GetAuthURL")} +} + +func (_c *Provider_GetAuthURL_Call) Run(run func()) *Provider_GetAuthURL_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Provider_GetAuthURL_Call) Return(s string) *Provider_GetAuthURL_Call { + _c.Call.Return(s) + return _c +} + +func (_c *Provider_GetAuthURL_Call) RunAndReturn(run func() string) *Provider_GetAuthURL_Call { + _c.Call.Return(run) + return _c +} + +// GetAuthURLWithRedirect provides a mock function for the type Provider +func (_mock *Provider) GetAuthURLWithRedirect(redirectURL string) string { + ret := _mock.Called(redirectURL) + + if len(ret) == 0 { + panic("no return value specified for GetAuthURLWithRedirect") + } + + var r0 string + if returnFunc, ok := ret.Get(0).(func(string) string); ok { + r0 = returnFunc(redirectURL) + } else { + r0 = ret.Get(0).(string) + } + return r0 +} + +// Provider_GetAuthURLWithRedirect_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAuthURLWithRedirect' +type Provider_GetAuthURLWithRedirect_Call struct { + *mock.Call +} + +// GetAuthURLWithRedirect is a helper method to define mock.On call +// - redirectURL string +func (_e *Provider_Expecter) GetAuthURLWithRedirect(redirectURL interface{}) *Provider_GetAuthURLWithRedirect_Call { + return &Provider_GetAuthURLWithRedirect_Call{Call: _e.mock.On("GetAuthURLWithRedirect", redirectURL)} +} + +func (_c *Provider_GetAuthURLWithRedirect_Call) Run(run func(redirectURL string)) *Provider_GetAuthURLWithRedirect_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *Provider_GetAuthURLWithRedirect_Call) Return(s string) *Provider_GetAuthURLWithRedirect_Call { + _c.Call.Return(s) + return _c +} + +func (_c *Provider_GetAuthURLWithRedirect_Call) RunAndReturn(run func(redirectURL string) string) *Provider_GetAuthURLWithRedirect_Call { + _c.Call.Return(run) + return _c +} + // IsEnabled provides a mock function for the type Provider func (_mock *Provider) IsEnabled() bool { ret := _mock.Called() diff --git a/pkg/oauth2/mocks/service.go b/pkg/oauth2/mocks/service.go new file mode 100644 index 0000000000..5dab544781 --- /dev/null +++ b/pkg/oauth2/mocks/service.go @@ -0,0 +1,480 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + + "github.com/absmach/supermq/api/grpc/token/v1" + "github.com/absmach/supermq/pkg/oauth2" + mock "github.com/stretchr/testify/mock" +) + +// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewService(t interface { + mock.TestingT + Cleanup(func()) +}) *Service { + mock := &Service{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Service is an autogenerated mock type for the Service type +type Service struct { + mock.Mock +} + +type Service_Expecter struct { + mock *mock.Mock +} + +func (_m *Service) EXPECT() *Service_Expecter { + return &Service_Expecter{mock: &_m.Mock} +} + +// CreateDeviceCode provides a mock function for the type Service +func (_mock *Service) CreateDeviceCode(ctx context.Context, provider oauth2.Provider, verificationURI string) (oauth2.DeviceCode, error) { + ret := _mock.Called(ctx, provider, verificationURI) + + if len(ret) == 0 { + panic("no return value specified for CreateDeviceCode") + } + + var r0 oauth2.DeviceCode + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, oauth2.Provider, string) (oauth2.DeviceCode, error)); ok { + return returnFunc(ctx, provider, verificationURI) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, oauth2.Provider, string) oauth2.DeviceCode); ok { + r0 = returnFunc(ctx, provider, verificationURI) + } else { + r0 = ret.Get(0).(oauth2.DeviceCode) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, oauth2.Provider, string) error); ok { + r1 = returnFunc(ctx, provider, verificationURI) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_CreateDeviceCode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateDeviceCode' +type Service_CreateDeviceCode_Call struct { + *mock.Call +} + +// CreateDeviceCode is a helper method to define mock.On call +// - ctx context.Context +// - provider oauth2.Provider +// - verificationURI string +func (_e *Service_Expecter) CreateDeviceCode(ctx interface{}, provider interface{}, verificationURI interface{}) *Service_CreateDeviceCode_Call { + return &Service_CreateDeviceCode_Call{Call: _e.mock.On("CreateDeviceCode", ctx, provider, verificationURI)} +} + +func (_c *Service_CreateDeviceCode_Call) Run(run func(ctx context.Context, provider oauth2.Provider, verificationURI string)) *Service_CreateDeviceCode_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 oauth2.Provider + if args[1] != nil { + arg1 = args[1].(oauth2.Provider) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_CreateDeviceCode_Call) Return(deviceCode oauth2.DeviceCode, err error) *Service_CreateDeviceCode_Call { + _c.Call.Return(deviceCode, err) + return _c +} + +func (_c *Service_CreateDeviceCode_Call) RunAndReturn(run func(ctx context.Context, provider oauth2.Provider, verificationURI string) (oauth2.DeviceCode, error)) *Service_CreateDeviceCode_Call { + _c.Call.Return(run) + return _c +} + +// GetDeviceCodeByUserCode provides a mock function for the type Service +func (_mock *Service) GetDeviceCodeByUserCode(ctx context.Context, userCode string) (oauth2.DeviceCode, error) { + ret := _mock.Called(ctx, userCode) + + if len(ret) == 0 { + panic("no return value specified for GetDeviceCodeByUserCode") + } + + var r0 oauth2.DeviceCode + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (oauth2.DeviceCode, error)); ok { + return returnFunc(ctx, userCode) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) oauth2.DeviceCode); ok { + r0 = returnFunc(ctx, userCode) + } else { + r0 = ret.Get(0).(oauth2.DeviceCode) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, userCode) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_GetDeviceCodeByUserCode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDeviceCodeByUserCode' +type Service_GetDeviceCodeByUserCode_Call struct { + *mock.Call +} + +// GetDeviceCodeByUserCode is a helper method to define mock.On call +// - ctx context.Context +// - userCode string +func (_e *Service_Expecter) GetDeviceCodeByUserCode(ctx interface{}, userCode interface{}) *Service_GetDeviceCodeByUserCode_Call { + return &Service_GetDeviceCodeByUserCode_Call{Call: _e.mock.On("GetDeviceCodeByUserCode", ctx, userCode)} +} + +func (_c *Service_GetDeviceCodeByUserCode_Call) Run(run func(ctx context.Context, userCode string)) *Service_GetDeviceCodeByUserCode_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Service_GetDeviceCodeByUserCode_Call) Return(deviceCode oauth2.DeviceCode, err error) *Service_GetDeviceCodeByUserCode_Call { + _c.Call.Return(deviceCode, err) + return _c +} + +func (_c *Service_GetDeviceCodeByUserCode_Call) RunAndReturn(run func(ctx context.Context, userCode string) (oauth2.DeviceCode, error)) *Service_GetDeviceCodeByUserCode_Call { + _c.Call.Return(run) + return _c +} + +// PollDeviceToken provides a mock function for the type Service +func (_mock *Service) PollDeviceToken(ctx context.Context, provider oauth2.Provider, deviceCode string) (*v1.Token, error) { + ret := _mock.Called(ctx, provider, deviceCode) + + if len(ret) == 0 { + panic("no return value specified for PollDeviceToken") + } + + var r0 *v1.Token + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, oauth2.Provider, string) (*v1.Token, error)); ok { + return returnFunc(ctx, provider, deviceCode) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, oauth2.Provider, string) *v1.Token); ok { + r0 = returnFunc(ctx, provider, deviceCode) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Token) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, oauth2.Provider, string) error); ok { + r1 = returnFunc(ctx, provider, deviceCode) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_PollDeviceToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PollDeviceToken' +type Service_PollDeviceToken_Call struct { + *mock.Call +} + +// PollDeviceToken is a helper method to define mock.On call +// - ctx context.Context +// - provider oauth2.Provider +// - deviceCode string +func (_e *Service_Expecter) PollDeviceToken(ctx interface{}, provider interface{}, deviceCode interface{}) *Service_PollDeviceToken_Call { + return &Service_PollDeviceToken_Call{Call: _e.mock.On("PollDeviceToken", ctx, provider, deviceCode)} +} + +func (_c *Service_PollDeviceToken_Call) Run(run func(ctx context.Context, provider oauth2.Provider, deviceCode string)) *Service_PollDeviceToken_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 oauth2.Provider + if args[1] != nil { + arg1 = args[1].(oauth2.Provider) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_PollDeviceToken_Call) Return(token *v1.Token, err error) *Service_PollDeviceToken_Call { + _c.Call.Return(token, err) + return _c +} + +func (_c *Service_PollDeviceToken_Call) RunAndReturn(run func(ctx context.Context, provider oauth2.Provider, deviceCode string) (*v1.Token, error)) *Service_PollDeviceToken_Call { + _c.Call.Return(run) + return _c +} + +// ProcessDeviceCallback provides a mock function for the type Service +func (_mock *Service) ProcessDeviceCallback(ctx context.Context, provider oauth2.Provider, userCode string, oauthCode string) error { + ret := _mock.Called(ctx, provider, userCode, oauthCode) + + if len(ret) == 0 { + panic("no return value specified for ProcessDeviceCallback") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, oauth2.Provider, string, string) error); ok { + r0 = returnFunc(ctx, provider, userCode, oauthCode) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_ProcessDeviceCallback_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ProcessDeviceCallback' +type Service_ProcessDeviceCallback_Call struct { + *mock.Call +} + +// ProcessDeviceCallback is a helper method to define mock.On call +// - ctx context.Context +// - provider oauth2.Provider +// - userCode string +// - oauthCode string +func (_e *Service_Expecter) ProcessDeviceCallback(ctx interface{}, provider interface{}, userCode interface{}, oauthCode interface{}) *Service_ProcessDeviceCallback_Call { + return &Service_ProcessDeviceCallback_Call{Call: _e.mock.On("ProcessDeviceCallback", ctx, provider, userCode, oauthCode)} +} + +func (_c *Service_ProcessDeviceCallback_Call) Run(run func(ctx context.Context, provider oauth2.Provider, userCode string, oauthCode string)) *Service_ProcessDeviceCallback_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 oauth2.Provider + if args[1] != nil { + arg1 = args[1].(oauth2.Provider) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_ProcessDeviceCallback_Call) Return(err error) *Service_ProcessDeviceCallback_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_ProcessDeviceCallback_Call) RunAndReturn(run func(ctx context.Context, provider oauth2.Provider, userCode string, oauthCode string) error) *Service_ProcessDeviceCallback_Call { + _c.Call.Return(run) + return _c +} + +// ProcessWebCallback provides a mock function for the type Service +func (_mock *Service) ProcessWebCallback(ctx context.Context, provider oauth2.Provider, code string, redirectURL string) (*v1.Token, error) { + ret := _mock.Called(ctx, provider, code, redirectURL) + + if len(ret) == 0 { + panic("no return value specified for ProcessWebCallback") + } + + var r0 *v1.Token + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, oauth2.Provider, string, string) (*v1.Token, error)); ok { + return returnFunc(ctx, provider, code, redirectURL) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, oauth2.Provider, string, string) *v1.Token); ok { + r0 = returnFunc(ctx, provider, code, redirectURL) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Token) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, oauth2.Provider, string, string) error); ok { + r1 = returnFunc(ctx, provider, code, redirectURL) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ProcessWebCallback_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ProcessWebCallback' +type Service_ProcessWebCallback_Call struct { + *mock.Call +} + +// ProcessWebCallback is a helper method to define mock.On call +// - ctx context.Context +// - provider oauth2.Provider +// - code string +// - redirectURL string +func (_e *Service_Expecter) ProcessWebCallback(ctx interface{}, provider interface{}, code interface{}, redirectURL interface{}) *Service_ProcessWebCallback_Call { + return &Service_ProcessWebCallback_Call{Call: _e.mock.On("ProcessWebCallback", ctx, provider, code, redirectURL)} +} + +func (_c *Service_ProcessWebCallback_Call) Run(run func(ctx context.Context, provider oauth2.Provider, code string, redirectURL string)) *Service_ProcessWebCallback_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 oauth2.Provider + if args[1] != nil { + arg1 = args[1].(oauth2.Provider) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_ProcessWebCallback_Call) Return(token *v1.Token, err error) *Service_ProcessWebCallback_Call { + _c.Call.Return(token, err) + return _c +} + +func (_c *Service_ProcessWebCallback_Call) RunAndReturn(run func(ctx context.Context, provider oauth2.Provider, code string, redirectURL string) (*v1.Token, error)) *Service_ProcessWebCallback_Call { + _c.Call.Return(run) + return _c +} + +// VerifyDevice provides a mock function for the type Service +func (_mock *Service) VerifyDevice(ctx context.Context, provider oauth2.Provider, userCode string, oauthCode string, approve bool) error { + ret := _mock.Called(ctx, provider, userCode, oauthCode, approve) + + if len(ret) == 0 { + panic("no return value specified for VerifyDevice") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, oauth2.Provider, string, string, bool) error); ok { + r0 = returnFunc(ctx, provider, userCode, oauthCode, approve) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_VerifyDevice_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyDevice' +type Service_VerifyDevice_Call struct { + *mock.Call +} + +// VerifyDevice is a helper method to define mock.On call +// - ctx context.Context +// - provider oauth2.Provider +// - userCode string +// - oauthCode string +// - approve bool +func (_e *Service_Expecter) VerifyDevice(ctx interface{}, provider interface{}, userCode interface{}, oauthCode interface{}, approve interface{}) *Service_VerifyDevice_Call { + return &Service_VerifyDevice_Call{Call: _e.mock.On("VerifyDevice", ctx, provider, userCode, oauthCode, approve)} +} + +func (_c *Service_VerifyDevice_Call) Run(run func(ctx context.Context, provider oauth2.Provider, userCode string, oauthCode string, approve bool)) *Service_VerifyDevice_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 oauth2.Provider + if args[1] != nil { + arg1 = args[1].(oauth2.Provider) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 bool + if args[4] != nil { + arg4 = args[4].(bool) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_VerifyDevice_Call) Return(err error) *Service_VerifyDevice_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_VerifyDevice_Call) RunAndReturn(run func(ctx context.Context, provider oauth2.Provider, userCode string, oauthCode string, approve bool) error) *Service_VerifyDevice_Call { + _c.Call.Return(run) + return _c +} diff --git a/pkg/oauth2/normalize_test.go b/pkg/oauth2/normalize_test.go index 19423258f5..1fc8855250 100644 --- a/pkg/oauth2/normalize_test.go +++ b/pkg/oauth2/normalize_test.go @@ -1,11 +1,12 @@ // Copyright (c) Abstract Machines // SPDX-License-Identifier: Apache-2.0 -package oauth2 +package oauth2_test import ( "testing" + "github.com/absmach/supermq/pkg/oauth2" "github.com/absmach/supermq/users" "github.com/stretchr/testify/assert" ) @@ -19,7 +20,7 @@ func TestNormalizeUser(t *testing.T) { wantErrStr string }{ { - desc: "valid user with standard keys", + desc: "valid user with Google keys (given_name, family_name)", inputJSON: `{ "id": "123", "given_name": "Jane", @@ -39,7 +40,133 @@ func TestNormalizeUser(t *testing.T) { wantErrStr: "", }, { - desc: "missing required fields", + desc: "valid user with alternative key variants (givenName, familyName, emailAddress, profilePicture)", + inputJSON: `{ + "id": "456", + "givenName": "John", + "familyName": "Smith", + "user_name": "jsmith", + "emailAddress": "john@smith.com", + "profilePicture": "avatar.png" + }`, + provider: "github", + wantUser: users.User{ + ID: "456", + FirstName: "John", + LastName: "Smith", + Email: "john@smith.com", + ProfilePicture: "avatar.png", + Metadata: users.Metadata{"oauth_provider": "github"}, + }, + wantErrStr: "", + }, + { + desc: "valid user with snake_case variants (first_name, last_name, email_address, profile_picture)", + inputJSON: `{ + "id": "789", + "first_name": "Alice", + "last_name": "Brown", + "username": "abrown", + "email_address": "alice@brown.com", + "profile_picture": "photo.jpg" + }`, + provider: "custom", + wantUser: users.User{ + ID: "789", + FirstName: "Alice", + LastName: "Brown", + Email: "alice@brown.com", + ProfilePicture: "photo.jpg", + Metadata: users.Metadata{"oauth_provider": "custom"}, + }, + wantErrStr: "", + }, + { + desc: "valid user with lowercase variants (firstname, lastname, avatar)", + inputJSON: `{ + "id": "101112", + "firstname": "Bob", + "lastname": "Wilson", + "userName": "bwilson", + "email": "bob@wilson.com", + "avatar": "img.jpg" + }`, + provider: "oauth", + wantUser: users.User{ + ID: "101112", + FirstName: "Bob", + LastName: "Wilson", + Email: "bob@wilson.com", + ProfilePicture: "img.jpg", + Metadata: users.Metadata{"oauth_provider": "oauth"}, + }, + wantErrStr: "", + }, + { + desc: "valid user with minimal required fields only", + inputJSON: `{ + "id": "999", + "given_name": "Min", + "family_name": "Max", + "email": "min@max.com" + }`, + provider: "minimal", + wantUser: users.User{ + ID: "999", + FirstName: "Min", + LastName: "Max", + Email: "min@max.com", + ProfilePicture: "", + Metadata: users.Metadata{"oauth_provider": "minimal"}, + }, + wantErrStr: "", + }, + { + desc: "missing ID field", + inputJSON: `{ + "given_name": "Jane", + "family_name": "Doe", + "email": "jane@example.com" + }`, + provider: "google", + wantUser: users.User{}, + wantErrStr: "missing required fields: id", + }, + { + desc: "missing first_name field", + inputJSON: `{ + "id": "123", + "family_name": "Doe", + "email": "jane@example.com" + }`, + provider: "google", + wantUser: users.User{}, + wantErrStr: "missing required fields: first_name", + }, + { + desc: "missing last_name field", + inputJSON: `{ + "id": "123", + "given_name": "Jane", + "email": "jane@example.com" + }`, + provider: "google", + wantUser: users.User{}, + wantErrStr: "missing required fields: last_name", + }, + { + desc: "missing email field", + inputJSON: `{ + "id": "123", + "given_name": "Jane", + "family_name": "Doe" + }`, + provider: "google", + wantUser: users.User{}, + wantErrStr: "missing required fields: email", + }, + { + desc: "missing multiple required fields", inputJSON: `{ "given_name": "Jane" }`, @@ -48,109 +175,81 @@ func TestNormalizeUser(t *testing.T) { wantErrStr: "missing required fields: id, last_name, email", }, { - desc: "invalid JSON", - inputJSON: `{invalid json`, + desc: "missing all required fields", + inputJSON: `{}`, provider: "google", wantUser: users.User{}, - wantErrStr: "invalid character", + wantErrStr: "missing required fields: id, first_name, last_name, email", }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - user, err := NormalizeUser([]byte(tc.inputJSON), tc.provider) - if tc.wantErrStr != "" { - assert.Error(t, err) - assert.Contains(t, err.Error(), tc.wantErrStr) - assert.Equal(t, tc.wantUser, user) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.wantUser, user) - } - }) - } -} - -func TestNormalizeProfile(t *testing.T) { - cases := []struct { - desc string - raw map[string]any - expected map[string]any - }{ { - desc: "maps all variants to normalized keys", - raw: map[string]any{ - "id": "id123", - "givenName": "John", - "familyName": "Smith", - "user_name": "jsmith", - "emailAddress": "john@smith.com", - "profilePicture": "pic.png", - }, - expected: map[string]any{ - "id": "id123", - "first_name": "John", - "last_name": "Smith", - "username": "jsmith", - "email": "john@smith.com", - "picture": "pic.png", - }, + desc: "invalid JSON syntax", + inputJSON: `{invalid json`, + provider: "google", + wantUser: users.User{}, + wantErrStr: "invalid character", }, { - desc: "missing keys returns empty map", - raw: map[string]any{"foo": "bar"}, - expected: map[string]any{}, + desc: "empty JSON", + inputJSON: ``, + provider: "google", + wantUser: users.User{}, + wantErrStr: "unexpected end of JSON input", }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - got := normalizeProfile(tc.raw) - assert.Equal(t, tc.expected, got) - }) - } -} - -func TestValidateUser(t *testing.T) { - cases := []struct { - desc string - user normalizedUser - wantErr string - }{ { - desc: "valid user returns nil error", - user: normalizedUser{ - ID: "1", - FirstName: "F", - LastName: "L", - Email: "e@example.com", + desc: "unrecognized keys are ignored", + inputJSON: `{ + "id": "567", + "given_name": "Test", + "family_name": "User", + "email": "test@user.com", + "unrecognized_field": "ignored", + "another_field": 12345 + }`, + provider: "test", + wantUser: users.User{ + ID: "567", + FirstName: "Test", + LastName: "User", + Email: "test@user.com", + ProfilePicture: "", + Metadata: users.Metadata{"oauth_provider": "test"}, }, - wantErr: "", + wantErrStr: "", }, { - desc: "missing id returns error", - user: normalizedUser{ - FirstName: "F", - LastName: "L", - Email: "e@example.com", + desc: "key priority - first matching variant is used", + inputJSON: `{ + "id": "priority", + "given_name": "First", + "first_name": "Second", + "family_name": "Family1", + "last_name": "Family2", + "email": "email1@test.com", + "email_address": "email2@test.com" + }`, + provider: "priority", + wantUser: users.User{ + ID: "priority", + FirstName: "First", + LastName: "Family1", + Email: "email1@test.com", + ProfilePicture: "", + Metadata: users.Metadata{"oauth_provider": "priority"}, }, - wantErr: "missing required fields: id", - }, - { - desc: "multiple missing fields returns all in error", - user: normalizedUser{}, - wantErr: "missing required fields: id, first_name, last_name, email", + wantErrStr: "", }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - err := validateUser(tc.user) - if tc.wantErr == "" { - assert.NoError(t, err) - } else { + user, err := oauth2.NormalizeUser([]byte(tc.inputJSON), tc.provider) + if tc.wantErrStr != "" { assert.Error(t, err) - assert.Equal(t, tc.wantErr, err.Error()) + assert.Contains(t, err.Error(), tc.wantErrStr) + assert.Equal(t, tc.wantUser, user) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.wantUser, user) } }) } diff --git a/pkg/oauth2/oauth2.go b/pkg/oauth2/oauth2.go index 50967f2490..dc29135494 100644 --- a/pkg/oauth2/oauth2.go +++ b/pkg/oauth2/oauth2.go @@ -6,11 +6,14 @@ package oauth2 import ( "context" + grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1" "github.com/absmach/supermq/users" "golang.org/x/oauth2" ) // Config is the configuration for the OAuth2 provider. +// This is kept for backward compatibility but deprecated in favor of +// DeviceConfig and UserConfig. type Config struct { ClientID string `env:"CLIENT_ID" envDefault:""` ClientSecret string `env:"CLIENT_SECRET" envDefault:""` @@ -18,6 +21,42 @@ type Config struct { RedirectURL string `env:"REDIRECT_URL" envDefault:""` } +// DeviceConfig is the configuration for the OAuth2 device flow (CLI). +type DeviceConfig struct { + ClientID string `env:"DEVICE_CLIENT_ID" envDefault:""` + ClientSecret string `env:"DEVICE_CLIENT_SECRET" envDefault:""` + State string `env:"DEVICE_STATE" envDefault:""` + RedirectURL string `env:"DEVICE_REDIRECT_URL" envDefault:""` +} + +// UserConfig is the configuration for the OAuth2 user flow (web). +type UserConfig struct { + ClientID string `env:"USER_CLIENT_ID" envDefault:""` + ClientSecret string `env:"USER_CLIENT_SECRET" envDefault:""` + State string `env:"USER_STATE" envDefault:""` + RedirectURL string `env:"USER_REDIRECT_URL" envDefault:""` +} + +// ToConfig converts DeviceConfig to Config. +func (dc DeviceConfig) ToConfig() Config { + return Config{ + ClientID: dc.ClientID, + ClientSecret: dc.ClientSecret, + State: dc.State, + RedirectURL: dc.RedirectURL, + } +} + +// ToConfig converts UserConfig to Config. +func (uc UserConfig) ToConfig() Config { + return Config{ + ClientID: uc.ClientID, + ClientSecret: uc.ClientSecret, + State: uc.State, + RedirectURL: uc.RedirectURL, + } +} + // Provider is an interface that provides the OAuth2 flow for a specific provider // (e.g. Google, GitHub, etc.) type Provider interface { @@ -39,6 +78,45 @@ type Provider interface { // Exchange converts an authorization code into a token. Exchange(ctx context.Context, code string) (oauth2.Token, error) + // ExchangeWithRedirect converts an authorization code into a token using a custom redirect URL. + ExchangeWithRedirect(ctx context.Context, code, redirectURL string) (oauth2.Token, error) + // UserInfo retrieves the user's information using the access token. UserInfo(accessToken string) (users.User, error) + + // GetAuthURL returns the authorization URL for the OAuth2 flow. + GetAuthURL() string + + // GetAuthURLWithRedirect returns the authorization URL with a custom redirect URL. + GetAuthURLWithRedirect(redirectURL string) string +} + +// Service provides OAuth authentication operations for the users service. +type Service interface { + // Device flow operations + + // CreateDeviceCode initiates the device authorization flow. + // It generates device and user codes, and returns the verification URI. + CreateDeviceCode(ctx context.Context, provider Provider, verificationURI string) (DeviceCode, error) + + // PollDeviceToken polls for device authorization completion. + // Returns the JWT token once the user has authorized the device. + PollDeviceToken(ctx context.Context, provider Provider, deviceCode string) (*grpcTokenV1.Token, error) + + // VerifyDevice handles user verification of device codes. + // It exchanges the OAuth authorization code for a token and marks the device as approved. + VerifyDevice(ctx context.Context, provider Provider, userCode, oauthCode string, approve bool) error + + // GetDeviceCodeByUserCode retrieves a device code by its user code. + GetDeviceCodeByUserCode(ctx context.Context, userCode string) (DeviceCode, error) + + // Web flow operations + + // ProcessWebCallback handles OAuth callback for web flow. + // It exchanges the authorization code for a token and creates/updates the user. + ProcessWebCallback(ctx context.Context, provider Provider, code, redirectURL string) (*grpcTokenV1.Token, error) + + // ProcessDeviceCallback handles OAuth callback for device flow. + // It's called when a user authorizes a device through the web interface. + ProcessDeviceCallback(ctx context.Context, provider Provider, userCode, oauthCode string) error } diff --git a/pkg/oauth2/service.go b/pkg/oauth2/service.go new file mode 100644 index 0000000000..d64940a54a --- /dev/null +++ b/pkg/oauth2/service.go @@ -0,0 +1,261 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package oauth provides OAuth2 authentication implementation for users service. +// It handles both web-based OAuth flow and device authorization flow for CLI clients. +package oauth2 + +import ( + "context" + "crypto/rand" + "encoding/base32" + "fmt" + "strings" + "time" + + grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1" + smqauth "github.com/absmach/supermq/auth" + + "github.com/absmach/supermq/users" + "golang.org/x/oauth2" +) + +var _ Service = (*oauthService)(nil) + +// oauthService implements the OAuth Service interface. +type oauthService struct { + deviceStore DeviceCodeStore + userService users.Service + tokenClient grpcTokenV1.TokenServiceClient +} + +// NewOAuthService creates a new OAuth service instance. +func NewOAuthService(deviceStore DeviceCodeStore, userService users.Service, tokenClient grpcTokenV1.TokenServiceClient) Service { + return &oauthService{ + deviceStore: deviceStore, + userService: userService, + tokenClient: tokenClient, + } +} + +// CreateDeviceCode initiates the device authorization flow. +func (s *oauthService) CreateDeviceCode(ctx context.Context, provider Provider, verificationURI string) (DeviceCode, error) { + if !provider.IsEnabled() { + return DeviceCode{}, ErrInvalidProvider + } + + userCode, err := generateUserCode() + if err != nil { + return DeviceCode{}, fmt.Errorf("failed to generate user code: %w", err) + } + + deviceCode, err := generateDeviceCode() + if err != nil { + return DeviceCode{}, fmt.Errorf("failed to generate device code: %w", err) + } + + code := DeviceCode{ + DeviceCode: deviceCode, + UserCode: userCode, + VerificationURI: verificationURI, + ExpiresIn: int(DeviceCodeExpiry.Seconds()), + Interval: int(CodeCheckInterval.Seconds()), + Provider: provider.Name(), + CreatedAt: time.Now(), + State: provider.State(), + } + + if err := s.deviceStore.Save(code); err != nil { + return DeviceCode{}, fmt.Errorf("failed to save device code: %w", err) + } + + return code, nil +} + +// PollDeviceToken polls for device authorization completion. +func (s *oauthService) PollDeviceToken(ctx context.Context, provider Provider, deviceCode string) (*grpcTokenV1.Token, error) { + if !provider.IsEnabled() { + return nil, ErrInvalidProvider + } + + code, err := s.deviceStore.Get(deviceCode) + if err != nil { + return nil, ErrDeviceCodeNotFound + } + + // Check expiration + if time.Since(code.CreatedAt) > DeviceCodeExpiry { + s.deviceStore.Delete(deviceCode) + return nil, ErrDeviceCodeExpired + } + + // Check polling rate + if time.Since(code.LastPoll) < CodeCheckInterval { + return nil, ErrSlowDown + } + + // Update last poll time + code.LastPoll = time.Now() + s.deviceStore.Update(code) + + // Check if denied + if code.Denied { + s.deviceStore.Delete(deviceCode) + return nil, ErrAccessDenied + } + + // Check if approved + if !code.Approved || code.AccessToken == "" { + return nil, ErrDeviceCodePending + } + + // Process the OAuth user and issue tokens + jwt, err := s.processOAuthUser(ctx, provider, code.AccessToken) + if err != nil { + s.deviceStore.Delete(deviceCode) + return nil, fmt.Errorf("failed to process oauth user: %w", err) + } + + s.deviceStore.Delete(deviceCode) + jwt.AccessType = "" + return jwt, nil +} + +// VerifyDevice handles user verification of device codes. +func (s *oauthService) VerifyDevice(ctx context.Context, provider Provider, userCode, oauthCode string, approve bool) error { + if !provider.IsEnabled() { + return ErrInvalidProvider + } + + code, err := s.deviceStore.GetByUserCode(userCode) + if err != nil { + return err + } + + // Check expiration + if time.Since(code.CreatedAt) > DeviceCodeExpiry { + s.deviceStore.Delete(code.DeviceCode) + return ErrDeviceCodeExpired + } + + if !approve { + code.Denied = true + s.deviceStore.Update(code) + return nil + } + + // Exchange authorization code for access token + token, err := provider.Exchange(ctx, oauthCode) + if err != nil { + return fmt.Errorf("failed to exchange code: %w", err) + } + + code.Approved = true + code.AccessToken = token.AccessToken + if err := s.deviceStore.Update(code); err != nil { + return fmt.Errorf("failed to update device code: %w", err) + } + + return nil +} + +// GetDeviceCodeByUserCode retrieves a device code by its user code. +func (s *oauthService) GetDeviceCodeByUserCode(ctx context.Context, userCode string) (DeviceCode, error) { + return s.deviceStore.GetByUserCode(userCode) +} + +// ProcessWebCallback handles OAuth callback for web flow. +func (s *oauthService) ProcessWebCallback(ctx context.Context, provider Provider, code, redirectURL string) (*grpcTokenV1.Token, error) { + if !provider.IsEnabled() { + return nil, ErrInvalidProvider + } + + if code == "" { + return nil, ErrEmptyCode + } + + token, err := ExchangeCode(ctx, provider, code, redirectURL) + if err != nil { + return nil, fmt.Errorf("failed to exchange code: %w", err) + } + + return s.processOAuthUser(ctx, provider, token.AccessToken) +} + +// ProcessDeviceCallback handles OAuth callback for device flow. +func (s *oauthService) ProcessDeviceCallback(ctx context.Context, provider Provider, userCode, oauthCode string) error { + return s.VerifyDevice(ctx, provider, userCode, oauthCode, true) +} + +// processOAuthUser retrieves user info from the OAuth provider, creates or updates the user, +// adds user policies, and issues a JWT token. +func (s *oauthService) processOAuthUser(ctx context.Context, provider Provider, accessToken string) (*grpcTokenV1.Token, error) { + user, err := provider.UserInfo(accessToken) + if err != nil { + return nil, fmt.Errorf("failed to get user info: %w", err) + } + + user.AuthProvider = provider.Name() + if user.AuthProvider == "" { + user.AuthProvider = "oauth" + } + + user, err = s.userService.OAuthCallback(ctx, user) + if err != nil { + return nil, fmt.Errorf("failed to handle oauth callback: %w", err) + } + + if err := s.userService.OAuthAddUserPolicy(ctx, user); err != nil { + return nil, fmt.Errorf("failed to add user policy: %w", err) + } + + return s.tokenClient.Issue(ctx, &grpcTokenV1.IssueReq{ + UserId: user.ID, + Type: uint32(smqauth.AccessKey), + UserRole: uint32(smqauth.UserRole), + Verified: !user.VerifiedAt.IsZero(), + }) +} + +// generateUserCode generates a human-friendly code like "ABCD-EFGH". +func generateUserCode() (string, error) { + b := make([]byte, DeviceCodeLength) + if _, err := rand.Read(b); err != nil { + return "", err + } + code := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(b) + code = strings.ToUpper(code[:DeviceCodeLength]) + // Format as XXXX-XXXX + if len(code) >= 8 { + code = code[:4] + "-" + code[4:8] + } + return code, nil +} + +// generateDeviceCode generates a random device code. +func generateDeviceCode() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(b), nil +} + +// IsDeviceFlowState checks if the state parameter indicates a device flow. +func IsDeviceFlowState(state string) bool { + return strings.HasPrefix(state, DeviceStatePrefix) +} + +// ExtractUserCodeFromState extracts the user code from a device flow state. +func ExtractUserCodeFromState(state string) string { + return strings.TrimPrefix(state, DeviceStatePrefix) +} + +// ExchangeCode exchanges an authorization code for an access token. +// If redirectURL is provided, it uses ExchangeWithRedirect, otherwise uses Exchange. +func ExchangeCode(ctx context.Context, provider Provider, code, redirectURL string) (oauth2.Token, error) { + if redirectURL != "" { + return provider.ExchangeWithRedirect(ctx, code, redirectURL) + } + return provider.Exchange(ctx, code) +} diff --git a/pkg/oauth2/store/memory.go b/pkg/oauth2/store/memory.go new file mode 100644 index 0000000000..c10d6f4666 --- /dev/null +++ b/pkg/oauth2/store/memory.go @@ -0,0 +1,106 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package store + +import ( + "sync" + "time" + + "github.com/absmach/supermq/pkg/oauth2" +) + +// inMemoryDeviceCodeStore is an in-memory implementation of DeviceCodeStore. +type inMemoryDeviceCodeStore struct { + mu sync.RWMutex + codes map[string]oauth2.DeviceCode + userCodes map[string]string // maps user code to device code + cleanupDone chan struct{} +} + +// NewInMemoryDeviceCodeStore creates a new in-memory device code store. +// It automatically starts a cleanup goroutine to remove expired codes. +func NewInMemoryDeviceCodeStore() oauth2.DeviceCodeStore { + store := &inMemoryDeviceCodeStore{ + codes: make(map[string]oauth2.DeviceCode), + userCodes: make(map[string]string), + cleanupDone: make(chan struct{}), + } + go store.cleanup() + return store +} + +func (s *inMemoryDeviceCodeStore) Save(code oauth2.DeviceCode) error { + s.mu.Lock() + defer s.mu.Unlock() + s.codes[code.DeviceCode] = code + s.userCodes[code.UserCode] = code.DeviceCode + return nil +} + +func (s *inMemoryDeviceCodeStore) Get(deviceCode string) (oauth2.DeviceCode, error) { + s.mu.RLock() + defer s.mu.RUnlock() + code, ok := s.codes[deviceCode] + if !ok { + return oauth2.DeviceCode{}, oauth2.ErrDeviceCodeNotFound + } + return code, nil +} + +func (s *inMemoryDeviceCodeStore) GetByUserCode(userCode string) (oauth2.DeviceCode, error) { + s.mu.RLock() + defer s.mu.RUnlock() + deviceCode, ok := s.userCodes[userCode] + if !ok { + return oauth2.DeviceCode{}, oauth2.ErrUserCodeNotFound + } + code, ok := s.codes[deviceCode] + if !ok { + return oauth2.DeviceCode{}, oauth2.ErrDeviceCodeNotFound + } + return code, nil +} + +func (s *inMemoryDeviceCodeStore) Update(code oauth2.DeviceCode) error { + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.codes[code.DeviceCode]; !ok { + return oauth2.ErrDeviceCodeNotFound + } + s.codes[code.DeviceCode] = code + return nil +} + +func (s *inMemoryDeviceCodeStore) Delete(deviceCode string) error { + s.mu.Lock() + defer s.mu.Unlock() + if code, ok := s.codes[deviceCode]; ok { + delete(s.userCodes, code.UserCode) + } + delete(s.codes, deviceCode) + return nil +} + +// cleanup periodically removes expired device codes. +func (s *inMemoryDeviceCodeStore) cleanup() { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.mu.Lock() + now := time.Now() + for deviceCode, code := range s.codes { + if now.Sub(code.CreatedAt) > oauth2.DeviceCodeExpiry { + delete(s.codes, deviceCode) + delete(s.userCodes, code.UserCode) + } + } + s.mu.Unlock() + case <-s.cleanupDone: + return + } + } +} diff --git a/pkg/oauth2/store/redis.go b/pkg/oauth2/store/redis.go new file mode 100644 index 0000000000..cd6287bc20 --- /dev/null +++ b/pkg/oauth2/store/redis.go @@ -0,0 +1,144 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package store + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/absmach/supermq/pkg/oauth2" + "github.com/redis/go-redis/v9" +) + +const ( + deviceCodePrefix = "oauth:device:code:" + userCodePrefix = "oauth:device:user:" +) + +// redisDeviceCodeStore is a Redis-based implementation of DeviceCodeStore. +type redisDeviceCodeStore struct { + client *redis.Client + ctx context.Context +} + +// NewRedisDeviceCodeStore creates a new Redis-based device code store. +func NewRedisDeviceCodeStore(ctx context.Context, client *redis.Client) oauth2.DeviceCodeStore { + return &redisDeviceCodeStore{ + client: client, + ctx: ctx, + } +} + +func (s *redisDeviceCodeStore) Save(code oauth2.DeviceCode) error { + data, err := json.Marshal(code) + if err != nil { + return fmt.Errorf("failed to marshal device code: %w", err) + } + + // Store device code with expiry + deviceKey := deviceCodePrefix + code.DeviceCode + if err := s.client.Set(s.ctx, deviceKey, data, oauth2.DeviceCodeExpiry).Err(); err != nil { + return fmt.Errorf("failed to save device code: %w", err) + } + + // Store user code to device code mapping with expiry + userKey := userCodePrefix + code.UserCode + if err := s.client.Set(s.ctx, userKey, code.DeviceCode, oauth2.DeviceCodeExpiry).Err(); err != nil { + return fmt.Errorf("failed to save user code mapping: %w", err) + } + + return nil +} + +func (s *redisDeviceCodeStore) Get(deviceCode string) (oauth2.DeviceCode, error) { + deviceKey := deviceCodePrefix + deviceCode + data, err := s.client.Get(s.ctx, deviceKey).Bytes() + if err != nil { + if err == redis.Nil { + return oauth2.DeviceCode{}, oauth2.ErrDeviceCodeNotFound + } + return oauth2.DeviceCode{}, fmt.Errorf("failed to get device code: %w", err) + } + + var code oauth2.DeviceCode + if err := json.Unmarshal(data, &code); err != nil { + return oauth2.DeviceCode{}, fmt.Errorf("failed to unmarshal device code: %w", err) + } + + return code, nil +} + +func (s *redisDeviceCodeStore) GetByUserCode(userCode string) (oauth2.DeviceCode, error) { + // First, get the device code from user code mapping + userKey := userCodePrefix + userCode + deviceCode, err := s.client.Get(s.ctx, userKey).Result() + if err != nil { + if err == redis.Nil { + return oauth2.DeviceCode{}, oauth2.ErrUserCodeNotFound + } + return oauth2.DeviceCode{}, fmt.Errorf("failed to get device code by user code: %w", err) + } + + // Then, get the actual device code data + return s.Get(deviceCode) +} + +func (s *redisDeviceCodeStore) Update(code oauth2.DeviceCode) error { + // Get the existing code to check if it exists + existing, err := s.Get(code.DeviceCode) + if err != nil { + return err + } + + // Preserve the creation time and user code from existing + code.CreatedAt = existing.CreatedAt + code.UserCode = existing.UserCode + + data, err := json.Marshal(code) + if err != nil { + return fmt.Errorf("failed to marshal device code: %w", err) + } + + // Calculate remaining TTL + deviceKey := deviceCodePrefix + code.DeviceCode + ttl, err := s.client.TTL(s.ctx, deviceKey).Result() + if err != nil { + return fmt.Errorf("failed to get TTL: %w", err) + } + + // If TTL is negative (key doesn't exist or no expiry), use default + if ttl < 0 { + ttl = oauth2.DeviceCodeExpiry + } + + // Update the device code with remaining TTL + if err := s.client.Set(s.ctx, deviceKey, data, ttl).Err(); err != nil { + return fmt.Errorf("failed to update device code: %w", err) + } + + return nil +} + +func (s *redisDeviceCodeStore) Delete(deviceCode string) error { + // Get the code first to find the user code + code, err := s.Get(deviceCode) + if err != nil { + return err + } + + // Delete both device code and user code mapping + deviceKey := deviceCodePrefix + deviceCode + userKey := userCodePrefix + code.UserCode + + pipe := s.client.Pipeline() + pipe.Del(s.ctx, deviceKey) + pipe.Del(s.ctx, userKey) + + if _, err := pipe.Exec(s.ctx); err != nil { + return fmt.Errorf("failed to delete device code: %w", err) + } + + return nil +} diff --git a/pkg/sdk/mocks/sdk.go b/pkg/sdk/mocks/sdk.go index fcf7f9f686..7db8e7ba41 100644 --- a/pkg/sdk/mocks/sdk.go +++ b/pkg/sdk/mocks/sdk.go @@ -6004,6 +6004,314 @@ func (_c *SDK_ListGroupMembers_Call) RunAndReturn(run func(ctx context.Context, return _c } +// OAuthAuthorizationURL provides a mock function for the type SDK +func (_mock *SDK) OAuthAuthorizationURL(ctx context.Context, provider string, redirectURL string) (string, string, errors.SDKError) { + ret := _mock.Called(ctx, provider, redirectURL) + + if len(ret) == 0 { + panic("no return value specified for OAuthAuthorizationURL") + } + + var r0 string + var r1 string + var r2 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (string, string, errors.SDKError)); ok { + return returnFunc(ctx, provider, redirectURL) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) string); ok { + r0 = returnFunc(ctx, provider, redirectURL) + } else { + r0 = ret.Get(0).(string) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) string); ok { + r1 = returnFunc(ctx, provider, redirectURL) + } else { + r1 = ret.Get(1).(string) + } + if returnFunc, ok := ret.Get(2).(func(context.Context, string, string) errors.SDKError); ok { + r2 = returnFunc(ctx, provider, redirectURL) + } else { + if ret.Get(2) != nil { + r2 = ret.Get(2).(errors.SDKError) + } + } + return r0, r1, r2 +} + +// SDK_OAuthAuthorizationURL_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OAuthAuthorizationURL' +type SDK_OAuthAuthorizationURL_Call struct { + *mock.Call +} + +// OAuthAuthorizationURL is a helper method to define mock.On call +// - ctx context.Context +// - provider string +// - redirectURL string +func (_e *SDK_Expecter) OAuthAuthorizationURL(ctx interface{}, provider interface{}, redirectURL interface{}) *SDK_OAuthAuthorizationURL_Call { + return &SDK_OAuthAuthorizationURL_Call{Call: _e.mock.On("OAuthAuthorizationURL", ctx, provider, redirectURL)} +} + +func (_c *SDK_OAuthAuthorizationURL_Call) Run(run func(ctx context.Context, provider string, redirectURL string)) *SDK_OAuthAuthorizationURL_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *SDK_OAuthAuthorizationURL_Call) Return(s string, s1 string, sDKError errors.SDKError) *SDK_OAuthAuthorizationURL_Call { + _c.Call.Return(s, s1, sDKError) + return _c +} + +func (_c *SDK_OAuthAuthorizationURL_Call) RunAndReturn(run func(ctx context.Context, provider string, redirectURL string) (string, string, errors.SDKError)) *SDK_OAuthAuthorizationURL_Call { + _c.Call.Return(run) + return _c +} + +// OAuthCallback provides a mock function for the type SDK +func (_mock *SDK) OAuthCallback(ctx context.Context, provider string, code string, state string, redirectURL string) (sdk.Token, errors.SDKError) { + ret := _mock.Called(ctx, provider, code, state, redirectURL) + + if len(ret) == 0 { + panic("no return value specified for OAuthCallback") + } + + var r0 sdk.Token + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string) (sdk.Token, errors.SDKError)); ok { + return returnFunc(ctx, provider, code, state, redirectURL) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string) sdk.Token); ok { + r0 = returnFunc(ctx, provider, code, state, redirectURL) + } else { + r0 = ret.Get(0).(sdk.Token) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, provider, code, state, redirectURL) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_OAuthCallback_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OAuthCallback' +type SDK_OAuthCallback_Call struct { + *mock.Call +} + +// OAuthCallback is a helper method to define mock.On call +// - ctx context.Context +// - provider string +// - code string +// - state string +// - redirectURL string +func (_e *SDK_Expecter) OAuthCallback(ctx interface{}, provider interface{}, code interface{}, state interface{}, redirectURL interface{}) *SDK_OAuthCallback_Call { + return &SDK_OAuthCallback_Call{Call: _e.mock.On("OAuthCallback", ctx, provider, code, state, redirectURL)} +} + +func (_c *SDK_OAuthCallback_Call) Run(run func(ctx context.Context, provider string, code string, state string, redirectURL string)) *SDK_OAuthCallback_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 string + if args[4] != nil { + arg4 = args[4].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *SDK_OAuthCallback_Call) Return(token sdk.Token, sDKError errors.SDKError) *SDK_OAuthCallback_Call { + _c.Call.Return(token, sDKError) + return _c +} + +func (_c *SDK_OAuthCallback_Call) RunAndReturn(run func(ctx context.Context, provider string, code string, state string, redirectURL string) (sdk.Token, errors.SDKError)) *SDK_OAuthCallback_Call { + _c.Call.Return(run) + return _c +} + +// OAuthDeviceCode provides a mock function for the type SDK +func (_mock *SDK) OAuthDeviceCode(ctx context.Context, provider string) (sdk.DeviceCode, errors.SDKError) { + ret := _mock.Called(ctx, provider) + + if len(ret) == 0 { + panic("no return value specified for OAuthDeviceCode") + } + + var r0 sdk.DeviceCode + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (sdk.DeviceCode, errors.SDKError)); ok { + return returnFunc(ctx, provider) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) sdk.DeviceCode); ok { + r0 = returnFunc(ctx, provider) + } else { + r0 = ret.Get(0).(sdk.DeviceCode) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) errors.SDKError); ok { + r1 = returnFunc(ctx, provider) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_OAuthDeviceCode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OAuthDeviceCode' +type SDK_OAuthDeviceCode_Call struct { + *mock.Call +} + +// OAuthDeviceCode is a helper method to define mock.On call +// - ctx context.Context +// - provider string +func (_e *SDK_Expecter) OAuthDeviceCode(ctx interface{}, provider interface{}) *SDK_OAuthDeviceCode_Call { + return &SDK_OAuthDeviceCode_Call{Call: _e.mock.On("OAuthDeviceCode", ctx, provider)} +} + +func (_c *SDK_OAuthDeviceCode_Call) Run(run func(ctx context.Context, provider string)) *SDK_OAuthDeviceCode_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *SDK_OAuthDeviceCode_Call) Return(deviceCode sdk.DeviceCode, sDKError errors.SDKError) *SDK_OAuthDeviceCode_Call { + _c.Call.Return(deviceCode, sDKError) + return _c +} + +func (_c *SDK_OAuthDeviceCode_Call) RunAndReturn(run func(ctx context.Context, provider string) (sdk.DeviceCode, errors.SDKError)) *SDK_OAuthDeviceCode_Call { + _c.Call.Return(run) + return _c +} + +// OAuthDeviceToken provides a mock function for the type SDK +func (_mock *SDK) OAuthDeviceToken(ctx context.Context, provider string, deviceCode string) (sdk.Token, errors.SDKError) { + ret := _mock.Called(ctx, provider, deviceCode) + + if len(ret) == 0 { + panic("no return value specified for OAuthDeviceToken") + } + + var r0 sdk.Token + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (sdk.Token, errors.SDKError)); ok { + return returnFunc(ctx, provider, deviceCode) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) sdk.Token); ok { + r0 = returnFunc(ctx, provider, deviceCode) + } else { + r0 = ret.Get(0).(sdk.Token) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, provider, deviceCode) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_OAuthDeviceToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OAuthDeviceToken' +type SDK_OAuthDeviceToken_Call struct { + *mock.Call +} + +// OAuthDeviceToken is a helper method to define mock.On call +// - ctx context.Context +// - provider string +// - deviceCode string +func (_e *SDK_Expecter) OAuthDeviceToken(ctx interface{}, provider interface{}, deviceCode interface{}) *SDK_OAuthDeviceToken_Call { + return &SDK_OAuthDeviceToken_Call{Call: _e.mock.On("OAuthDeviceToken", ctx, provider, deviceCode)} +} + +func (_c *SDK_OAuthDeviceToken_Call) Run(run func(ctx context.Context, provider string, deviceCode string)) *SDK_OAuthDeviceToken_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *SDK_OAuthDeviceToken_Call) Return(token sdk.Token, sDKError errors.SDKError) *SDK_OAuthDeviceToken_Call { + _c.Call.Return(token, sDKError) + return _c +} + +func (_c *SDK_OAuthDeviceToken_Call) RunAndReturn(run func(ctx context.Context, provider string, deviceCode string) (sdk.Token, errors.SDKError)) *SDK_OAuthDeviceToken_Call { + _c.Call.Return(run) + return _c +} + // RefreshToken provides a mock function for the type SDK func (_mock *SDK) RefreshToken(ctx context.Context, token string) (sdk.Token, errors.SDKError) { ret := _mock.Called(ctx, token) diff --git a/pkg/sdk/oauth_device_test.go b/pkg/sdk/oauth_device_test.go new file mode 100644 index 0000000000..e74ea54561 --- /dev/null +++ b/pkg/sdk/oauth_device_test.go @@ -0,0 +1,264 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package sdk_test + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/absmach/supermq/pkg/sdk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOAuthDeviceCode(t *testing.T) { + tests := []struct { + name string + providerName string + serverResponse string + serverStatus int + expectedErr bool + checkResponse func(*testing.T, sdk.DeviceCode) + }{ + { + name: "successful device code request", + providerName: "google", + serverResponse: `{ + "device_code": "device123abc", + "user_code": "ABCD-EFGH", + "verification_uri": "https://example.com/device", + "expires_in": 600, + "interval": 3 + }`, + serverStatus: http.StatusOK, + expectedErr: false, + checkResponse: func(t *testing.T, deviceCode sdk.DeviceCode) { + assert.Equal(t, "device123abc", deviceCode.DeviceCode) + assert.Equal(t, "ABCD-EFGH", deviceCode.UserCode) + assert.Equal(t, "https://example.com/device", deviceCode.VerificationURI) + assert.Equal(t, 600, deviceCode.ExpiresIn) + assert.Equal(t, 3, deviceCode.Interval) + }, + }, + { + name: "provider not found", + providerName: "unknown", + serverResponse: `{"error": "oauth provider is disabled"}`, + serverStatus: http.StatusNotFound, + expectedErr: true, + checkResponse: func(t *testing.T, deviceCode sdk.DeviceCode) {}, + }, + { + name: "invalid json response", + providerName: "google", + serverResponse: `{invalid json}`, + serverStatus: http.StatusOK, + expectedErr: true, + checkResponse: func(t *testing.T, deviceCode sdk.DeviceCode) {}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, fmt.Sprintf("/oauth/device/code/%s", tc.providerName), r.URL.Path) + + w.WriteHeader(tc.serverStatus) + w.Write([]byte(tc.serverResponse)) + })) + defer server.Close() + + sdkConf := sdk.Config{ + UsersURL: server.URL, + } + mgsdk := sdk.NewSDK(sdkConf) + + deviceCode, err := mgsdk.OAuthDeviceCode(context.Background(), tc.providerName) + + if tc.expectedErr { + assert.Error(t, err) + return + } + + require.NoError(t, err) + tc.checkResponse(t, deviceCode) + }) + } +} + +func TestOAuthDeviceToken(t *testing.T) { + tests := []struct { + name string + providerName string + deviceCode string + serverResponse string + serverStatus int + expectedErr bool + checkResponse func(*testing.T, sdk.Token) + }{ + { + name: "successful token retrieval", + providerName: "google", + deviceCode: "device123", + serverResponse: `{ + "access_token": "access_token_123", + "refresh_token": "refresh_token_456" + }`, + serverStatus: http.StatusOK, + expectedErr: false, + checkResponse: func(t *testing.T, token sdk.Token) { + assert.Equal(t, "access_token_123", token.AccessToken) + assert.Equal(t, "refresh_token_456", token.RefreshToken) + }, + }, + { + name: "authorization pending", + providerName: "google", + deviceCode: "device123", + serverResponse: `{"error": "authorization pending"}`, + serverStatus: http.StatusAccepted, + expectedErr: true, + checkResponse: func(t *testing.T, token sdk.Token) {}, + }, + { + name: "device code expired", + providerName: "google", + deviceCode: "device123", + serverResponse: `{"error": "device code expired"}`, + serverStatus: http.StatusBadRequest, + expectedErr: true, + checkResponse: func(t *testing.T, token sdk.Token) {}, + }, + { + name: "slow down", + providerName: "google", + deviceCode: "device123", + serverResponse: `{"error": "slow down"}`, + serverStatus: http.StatusBadRequest, + expectedErr: true, + checkResponse: func(t *testing.T, token sdk.Token) {}, + }, + { + name: "access denied", + providerName: "google", + deviceCode: "device123", + serverResponse: `{"error": "access denied"}`, + serverStatus: http.StatusUnauthorized, + expectedErr: true, + checkResponse: func(t *testing.T, token sdk.Token) {}, + }, + { + name: "invalid device code", + providerName: "google", + deviceCode: "invalid", + serverResponse: `{"error": "invalid device code"}`, + serverStatus: http.StatusNotFound, + expectedErr: true, + checkResponse: func(t *testing.T, token sdk.Token) {}, + }, + { + name: "provider disabled", + providerName: "disabled", + deviceCode: "device123", + serverResponse: `{"error": "oauth provider is disabled"}`, + serverStatus: http.StatusNotFound, + expectedErr: true, + checkResponse: func(t *testing.T, token sdk.Token) {}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, fmt.Sprintf("/oauth/device/token/%s", tc.providerName), r.URL.Path) + + w.WriteHeader(tc.serverStatus) + w.Write([]byte(tc.serverResponse)) + })) + defer server.Close() + + sdkConf := sdk.Config{ + UsersURL: server.URL, + } + mgsdk := sdk.NewSDK(sdkConf) + + token, err := mgsdk.OAuthDeviceToken(context.Background(), tc.providerName, tc.deviceCode) + + if tc.expectedErr { + assert.Error(t, err) + return + } + + require.NoError(t, err) + tc.checkResponse(t, token) + }) + } +} + +func TestOAuthDeviceFlow(t *testing.T) { + t.Run("complete device flow integration", func(t *testing.T) { + var savedDeviceCode string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth/device/code/google": + response := `{ + "device_code": "device_code_123", + "user_code": "ABCD-EFGH", + "verification_uri": "https://example.com/device", + "expires_in": 600, + "interval": 3 + }` + w.WriteHeader(http.StatusOK) + w.Write([]byte(response)) + savedDeviceCode = "device_code_123" + + case "/oauth/device/token/google": + // Simulate polling: first call returns pending, second returns token + if savedDeviceCode == "device_code_123" { + // First call - pending + w.WriteHeader(http.StatusAccepted) + w.Write([]byte(`{"error": "authorization pending"}`)) + savedDeviceCode = "approved" // Mark as approved for next call + } else { + // Second call - success + response := `{ + "access_token": "access_token_123", + "refresh_token": "refresh_token_456" + }` + w.WriteHeader(http.StatusOK) + w.Write([]byte(response)) + } + } + })) + defer server.Close() + + sdkConf := sdk.Config{ + UsersURL: server.URL, + } + mgsdk := sdk.NewSDK(sdkConf) + + // Step 1: Get device code + deviceCode, err := mgsdk.OAuthDeviceCode(context.Background(), "google") + require.NoError(t, err) + assert.Equal(t, "device_code_123", deviceCode.DeviceCode) + assert.Equal(t, "ABCD-EFGH", deviceCode.UserCode) + + // Step 2: First poll - pending + _, err = mgsdk.OAuthDeviceToken(context.Background(), "google", deviceCode.DeviceCode) + assert.Error(t, err) + assert.Contains(t, err.Error(), "authorization pending") + + // Step 3: Second poll - success + token, err := mgsdk.OAuthDeviceToken(context.Background(), "google", deviceCode.DeviceCode) + require.NoError(t, err) + assert.Equal(t, "access_token_123", token.AccessToken) + assert.Equal(t, "refresh_token_456", token.RefreshToken) + }) +} diff --git a/pkg/sdk/oauth_test.go b/pkg/sdk/oauth_test.go new file mode 100644 index 0000000000..49ef061f4a --- /dev/null +++ b/pkg/sdk/oauth_test.go @@ -0,0 +1,350 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package sdk_test + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1" + authmocks "github.com/absmach/supermq/auth/mocks" + smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" + authnmocks "github.com/absmach/supermq/pkg/authn/mocks" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/pkg/oauth2" + oauth2mocks "github.com/absmach/supermq/pkg/oauth2/mocks" + sdk "github.com/absmach/supermq/pkg/sdk" + "github.com/absmach/supermq/pkg/uuid" + "github.com/absmach/supermq/users" + httpapi "github.com/absmach/supermq/users/api" + umocks "github.com/absmach/supermq/users/mocks" + "github.com/go-chi/chi/v5" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + goauth2 "golang.org/x/oauth2" +) + +const ( + googleProvider = "google" + jwtRefreshToken = "jwt-refresh-token" + jwtAccessToken = "jwt-access-token" + testAccessToken = "access-token" + testCode = "test-code" + testState = "test-state" + testCallbackURL = "http://localhost:9090/callback" + testAuthURLBase = "https://accounts.google.com/o/oauth2/auth" + testUserEmail = "test@example.com" + testUsername = "testuser" +) + +func setupOAuthServer() (*httptest.Server, *umocks.Service, *oauth2mocks.Provider, *authmocks.TokenServiceClient) { + usvc := new(umocks.Service) + logger := smqlog.NewMock() + mux := chi.NewRouter() + idp := uuid.NewMock() + provider := new(oauth2mocks.Provider) + provider.On("Name").Return(googleProvider) + authn := new(authnmocks.Authentication) + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithDomainCheck(false), smqauthn.WithAllowUnverifiedUser(true)) + token := new(authmocks.TokenServiceClient) + redisClient := redis.NewClient(&redis.Options{Addr: "localhost:6379"}) + p := []oauth2.Provider{provider} + httpapi.MakeHandler(usvc, am, token, true, mux, logger, "", passRegex, idp, redisClient, p, p) + + return httptest.NewServer(mux), usvc, provider, token +} + +func TestOAuthAuthorizationURL(t *testing.T) { + ts, _, provider, _ := setupOAuthServer() + defer ts.Close() + + conf := sdk.Config{ + UsersURL: ts.URL, + } + mgsdk := sdk.NewSDK(conf) + + cases := []struct { + desc string + providerName string + redirectURL string + providerEnabled bool + getAuthURL string + state string + err errors.SDKError + }{ + { + desc: "get authorization URL successfully", + providerName: googleProvider, + redirectURL: "", + providerEnabled: true, + getAuthURL: testAuthURLBase + "?client_id=test&state=" + testState, + state: testState, + err: nil, + }, + { + desc: "get authorization URL with custom redirect", + providerName: googleProvider, + redirectURL: testCallbackURL, + providerEnabled: true, + getAuthURL: testAuthURLBase + "?client_id=test&state=" + testState + "&redirect_uri=" + testCallbackURL, + state: testState, + err: nil, + }, + { + desc: "get authorization URL with disabled provider", + providerName: googleProvider, + redirectURL: "", + providerEnabled: false, + getAuthURL: "", + state: "", + err: errors.NewSDKErrorWithStatus(errors.Wrap(svcerr.ErrNotFound, errors.New("oauth provider is disabled")), http.StatusNotFound), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + provider.On("IsEnabled").Return(tc.providerEnabled) + if tc.providerEnabled { + if tc.redirectURL != "" { + provider.On("GetAuthURLWithRedirect", tc.redirectURL).Return(tc.getAuthURL) + } else { + provider.On("GetAuthURL").Return(tc.getAuthURL) + } + provider.On("State").Return(tc.state) + } + + authURL, state, err := mgsdk.OAuthAuthorizationURL(context.Background(), tc.providerName, tc.redirectURL) + + if tc.err == nil { + assert.NoError(t, err) + assert.Equal(t, tc.getAuthURL, authURL) + assert.Equal(t, tc.state, state) + } else { + assert.Error(t, err) + assert.Empty(t, authURL) + assert.Empty(t, state) + } + + // Reset mocks + provider.ExpectedCalls = nil + }) + } +} + +func TestOAuthCallback(t *testing.T) { + ts, svc, provider, tokenClient := setupOAuthServer() + defer ts.Close() + + conf := sdk.Config{ + UsersURL: ts.URL, + } + mgsdk := sdk.NewSDK(conf) + + validUser := users.User{ + ID: generateUUID(t), + Email: testUserEmail, + Credentials: users.Credentials{ + Username: testUsername, + }, + Status: users.EnabledStatus, + } + + cases := []struct { + desc string + providerName string + code string + state string + redirectURL string + providerEnabled bool + mockSetup func() + expectedToken sdk.Token + err errors.SDKError + }{ + { + desc: "successful OAuth callback", + providerName: googleProvider, + code: testCode, + state: testState, + redirectURL: testCallbackURL, + providerEnabled: true, + mockSetup: func() { + provider.On("IsEnabled").Return(true) + provider.On("State").Return(testState) + provider.On("ExchangeWithRedirect", mock.Anything, testCode, testCallbackURL). + Return(goauth2.Token{AccessToken: testAccessToken}, nil).Once() + provider.On("UserInfo", testAccessToken).Return(validUser, nil).Once() + svc.On("OAuthCallback", mock.Anything, mock.MatchedBy(func(u users.User) bool { + return u.Email == validUser.Email && u.AuthProvider == googleProvider + })).Return(validUser, nil).Once() + svc.On("OAuthAddUserPolicy", mock.Anything, validUser).Return(nil).Once() + refreshToken := jwtRefreshToken + tokenClient.On("Issue", mock.Anything, mock.Anything). + Return(&grpcTokenV1.Token{ + AccessToken: jwtAccessToken, + RefreshToken: &refreshToken, + }, nil).Once() + }, + expectedToken: sdk.Token{ + AccessToken: jwtAccessToken, + RefreshToken: jwtRefreshToken, + }, + err: nil, + }, + { + desc: "OAuth callback without redirect URL", + providerName: googleProvider, + code: testCode, + state: testState, + redirectURL: "", + providerEnabled: true, + mockSetup: func() { + provider.On("IsEnabled").Return(true) + provider.On("State").Return(testState) + provider.On("Exchange", mock.Anything, testCode). + Return(goauth2.Token{AccessToken: testAccessToken}, nil).Once() + provider.On("UserInfo", testAccessToken).Return(validUser, nil).Once() + svc.On("OAuthCallback", mock.Anything, mock.MatchedBy(func(u users.User) bool { + return u.Email == validUser.Email && u.AuthProvider == googleProvider + })).Return(validUser, nil).Once() + svc.On("OAuthAddUserPolicy", mock.Anything, validUser).Return(nil).Once() + refreshToken := jwtRefreshToken + tokenClient.On("Issue", mock.Anything, mock.Anything). + Return(&grpcTokenV1.Token{ + AccessToken: jwtAccessToken, + RefreshToken: &refreshToken, + }, nil).Once() + }, + expectedToken: sdk.Token{ + AccessToken: jwtAccessToken, + RefreshToken: jwtRefreshToken, + }, + err: nil, + }, + { + desc: "OAuth callback with disabled provider", + providerName: googleProvider, + code: testCode, + state: testState, + redirectURL: "", + providerEnabled: false, + mockSetup: func() { + provider.On("IsEnabled").Return(false) + }, + expectedToken: sdk.Token{}, + err: errors.NewSDKErrorWithStatus(errors.Wrap(svcerr.ErrNotFound, errors.New("oauth provider is disabled")), http.StatusNotFound), + }, + { + desc: "OAuth callback with invalid state", + providerName: googleProvider, + code: testCode, + state: "wrong-state", + redirectURL: "", + providerEnabled: true, + mockSetup: func() { + provider.On("IsEnabled").Return(true) + provider.On("State").Return(testState) + }, + expectedToken: sdk.Token{}, + err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.ErrMalformedEntity, errors.New("invalid state")), http.StatusBadRequest), + }, + { + desc: "OAuth callback with exchange error", + providerName: googleProvider, + code: testCode, + state: testState, + redirectURL: "", + providerEnabled: true, + mockSetup: func() { + provider.On("IsEnabled").Return(true) + provider.On("State").Return(testState) + provider.On("Exchange", mock.Anything, testCode). + Return(goauth2.Token{}, fmt.Errorf("exchange failed")).Once() + }, + expectedToken: sdk.Token{}, + err: errors.NewSDKErrorWithStatus(errors.New("exchange failed"), http.StatusUnauthorized), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + tc.mockSetup() + + token, err := mgsdk.OAuthCallback(context.Background(), tc.providerName, tc.code, tc.state, tc.redirectURL) + + if tc.err == nil { + assert.NoError(t, err) + assert.NotEmpty(t, token.AccessToken) + assert.NotEmpty(t, token.RefreshToken) + } else { + assert.Error(t, err) + assert.Empty(t, token.AccessToken) + assert.Empty(t, token.RefreshToken) + } + + // Reset mocks + svc.ExpectedCalls = nil + }) + } +} + +func TestOAuthIntegration(t *testing.T) { + ts, svc, provider, tokenClient := setupOAuthServer() + defer ts.Close() + + conf := sdk.Config{ + UsersURL: ts.URL, + } + mgsdk := sdk.NewSDK(conf) + + validUser := users.User{ + ID: generateUUID(t), + Email: testUserEmail, + Credentials: users.Credentials{ + Username: testUsername, + }, + Status: users.EnabledStatus, + } + + redirectURL := testCallbackURL + + // Setup mocks for authorization URL + provider.On("IsEnabled").Return(true) + provider.On("GetAuthURLWithRedirect", redirectURL). + Return(testAuthURLBase + "?redirect_uri=" + redirectURL) + provider.On("State").Return(testState) + + // Step 1: Get authorization URL + authURL, state, err := mgsdk.OAuthAuthorizationURL(context.Background(), googleProvider, redirectURL) + assert.NoError(t, err) + assert.NotEmpty(t, authURL) + assert.Equal(t, testState, state) + assert.Contains(t, authURL, redirectURL) + + // Setup mocks for callback + provider.On("ExchangeWithRedirect", mock.Anything, testCode, redirectURL). + Return(goauth2.Token{AccessToken: testAccessToken}, nil).Once() + provider.On("UserInfo", testAccessToken).Return(validUser, nil).Once() + svc.On("OAuthCallback", mock.Anything, mock.MatchedBy(func(u users.User) bool { + return u.Email == validUser.Email && u.AuthProvider == googleProvider + })).Return(validUser, nil).Once() + svc.On("OAuthAddUserPolicy", mock.Anything, validUser).Return(nil).Once() + refreshToken := jwtRefreshToken + tokenClient.On("Issue", mock.Anything, mock.Anything). + Return(&grpcTokenV1.Token{ + AccessToken: jwtAccessToken, + RefreshToken: &refreshToken, + }, nil).Once() + + // Step 2: Exchange code for token + token, err := mgsdk.OAuthCallback(context.Background(), googleProvider, testCode, state, redirectURL) + assert.NoError(t, err) + assert.NotEmpty(t, token.AccessToken) + assert.NotEmpty(t, token.RefreshToken) +} diff --git a/pkg/sdk/sdk.go b/pkg/sdk/sdk.go index dc49a72cc7..c1cbd19681 100644 --- a/pkg/sdk/sdk.go +++ b/pkg/sdk/sdk.go @@ -368,6 +368,39 @@ type SDK interface { // fmt.Println(token) RefreshToken(ctx context.Context, token string) (Token, errors.SDKError) + // OAuthAuthorizationURL returns the OAuth authorization URL for the given provider. + // + // example: + // ctx := context.Background() + // authURL, state, _ := sdk.OAuthAuthorizationURL(ctx, "google", "http://localhost:9090/callback") + // fmt.Println(authURL) + OAuthAuthorizationURL(ctx context.Context, provider, redirectURL string) (string, string, errors.SDKError) + + // OAuthCallback exchanges the OAuth authorization code for tokens. + // + // example: + // ctx := context.Background() + // token, _ := sdk.OAuthCallback(ctx, "google", "auth_code", "state", "http://localhost:9090/callback") + // fmt.Println(token) + OAuthCallback(ctx context.Context, provider, code, state, redirectURL string) (Token, errors.SDKError) + + // OAuthDeviceCode initiates the device authorization flow and returns device code information. + // + // example: + // ctx := context.Background() + // deviceCode, _ := sdk.OAuthDeviceCode(ctx, "google") + // fmt.Println("Go to:", deviceCode.VerificationURI) + // fmt.Println("Enter code:", deviceCode.UserCode) + OAuthDeviceCode(ctx context.Context, provider string) (DeviceCode, errors.SDKError) + + // OAuthDeviceToken polls for device authorization completion and returns tokens. + // + // example: + // ctx := context.Background() + // token, _ := sdk.OAuthDeviceToken(ctx, "google", deviceCode.DeviceCode) + // fmt.Println(token) + OAuthDeviceToken(ctx context.Context, provider, deviceCode string) (Token, errors.SDKError) + // SeachUsers filters users and returns a page result. // // example: diff --git a/pkg/sdk/tokens.go b/pkg/sdk/tokens.go index ccb0db7df3..120972be73 100644 --- a/pkg/sdk/tokens.go +++ b/pkg/sdk/tokens.go @@ -8,6 +8,7 @@ import ( "encoding/json" "fmt" "net/http" + "net/url" "github.com/absmach/supermq/pkg/errors" ) @@ -20,6 +21,15 @@ type Token struct { AccessType string `json:"access_type,omitempty"` } +// DeviceCode contains device authorization flow information. +type DeviceCode struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` +} + type Login struct { Username string `json:"username"` Password string `json:"password"` @@ -60,3 +70,106 @@ func (sdk mgSDK) RefreshToken(ctx context.Context, token string) (Token, errors. return t, nil } + +// OAuthAuthorizationURL returns the OAuth authorization URL for the given provider. +func (sdk mgSDK) OAuthAuthorizationURL(ctx context.Context, provider, redirectURL string) (string, string, errors.SDKError) { + reqURL := fmt.Sprintf("%s/oauth/authorize/%s", sdk.usersURL, provider) + if redirectURL != "" { + reqURL = fmt.Sprintf("%s?redirect_uri=%s", reqURL, url.QueryEscape(redirectURL)) + } + + _, body, sdkErr := sdk.processRequest(ctx, http.MethodGet, reqURL, "", nil, nil, http.StatusOK) + if sdkErr != nil { + return "", "", sdkErr + } + + var resp struct { + AuthorizationURL string `json:"authorization_url"` + State string `json:"state"` + } + if err := json.Unmarshal(body, &resp); err != nil { + return "", "", errors.NewSDKError(err) + } + + return resp.AuthorizationURL, resp.State, nil +} + +// OAuthCallback exchanges the OAuth authorization code for tokens. +func (sdk mgSDK) OAuthCallback(ctx context.Context, provider, code, state, redirectURL string) (Token, errors.SDKError) { + reqURL := fmt.Sprintf("%s/oauth/cli/callback/%s", sdk.usersURL, provider) + + data, err := json.Marshal(map[string]string{ + "code": code, + "state": state, + "redirect_url": redirectURL, + }) + if err != nil { + return Token{}, errors.NewSDKError(err) + } + + _, body, sdkErr := sdk.processRequest(ctx, http.MethodPost, reqURL, "", data, nil, http.StatusOK) + if sdkErr != nil { + return Token{}, sdkErr + } + + t := Token{} + if err := json.Unmarshal(body, &t); err != nil { + return Token{}, errors.NewSDKError(err) + } + + return t, nil +} + +// OAuthDeviceCode initiates the device authorization flow. +func (sdk mgSDK) OAuthDeviceCode(ctx context.Context, provider string) (DeviceCode, errors.SDKError) { + reqURL := fmt.Sprintf("%s/oauth/device/code/%s", sdk.usersURL, provider) + + _, body, sdkErr := sdk.processRequest(ctx, http.MethodPost, reqURL, "", nil, nil, http.StatusOK) + if sdkErr != nil { + return DeviceCode{}, sdkErr + } + + var deviceCode DeviceCode + if err := json.Unmarshal(body, &deviceCode); err != nil { + return DeviceCode{}, errors.NewSDKError(err) + } + + return deviceCode, nil +} + +// OAuthDeviceToken polls for device authorization completion. +func (sdk mgSDK) OAuthDeviceToken(ctx context.Context, provider, deviceCode string) (Token, errors.SDKError) { + reqURL := fmt.Sprintf("%s/oauth/device/token/%s", sdk.usersURL, provider) + + data, err := json.Marshal(map[string]string{ + "device_code": deviceCode, + }) + if err != nil { + return Token{}, errors.NewSDKError(err) + } + + // Accept both 200 (success) and 202 (pending) as valid responses + _, body, sdkErr := sdk.processRequest(ctx, http.MethodPost, reqURL, "", data, nil, http.StatusOK, http.StatusAccepted) + if sdkErr != nil { + return Token{}, sdkErr + } + + // Try to unmarshal as a token first + t := Token{} + if err := json.Unmarshal(body, &t); err == nil && t.AccessToken != "" { + // Successfully got a token + return t, nil + } + + // If no token, check if it's an error response (pending/slow down) + var errResp struct { + Error string `json:"error"` + } + if err := json.Unmarshal(body, &errResp); err == nil && errResp.Error != "" { + // Return an error that preserves the message for CLI to check + return Token{}, errors.NewSDKError(fmt.Errorf("%s", errResp.Error)) + } + + // Shouldn't reach here, but handle gracefully + return Token{}, errors.NewSDKError(fmt.Errorf("unexpected response")) +} diff --git a/pkg/sdk/users_test.go b/pkg/sdk/users_test.go index 56bbf7fefc..30a1b4c61c 100644 --- a/pkg/sdk/users_test.go +++ b/pkg/sdk/users_test.go @@ -20,6 +20,7 @@ import ( authnmocks "github.com/absmach/supermq/pkg/authn/mocks" "github.com/absmach/supermq/pkg/errors" svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/pkg/oauth2" oauth2mocks "github.com/absmach/supermq/pkg/oauth2/mocks" sdk "github.com/absmach/supermq/pkg/sdk" "github.com/absmach/supermq/pkg/uuid" @@ -27,6 +28,7 @@ import ( httpapi "github.com/absmach/supermq/users/api" umocks "github.com/absmach/supermq/users/mocks" "github.com/go-chi/chi/v5" + "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -46,7 +48,9 @@ func setupUsers() (*httptest.Server, *umocks.Service, *authnmocks.Authentication authn := new(authnmocks.Authentication) am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithDomainCheck(false), smqauthn.WithAllowUnverifiedUser(true)) token := new(authmocks.TokenServiceClient) - httpapi.MakeHandler(usvc, am, token, true, mux, logger, "", passRegex, idp, provider) + redisClient := redis.NewClient(&redis.Options{Addr: "localhost:6379"}) + p := []oauth2.Provider{provider} + httpapi.MakeHandler(usvc, am, token, true, mux, logger, "", passRegex, idp, redisClient, p, p) return httptest.NewServer(mux), usvc, authn } diff --git a/tools/config/.mockery.yaml b/tools/config/.mockery.yaml index 6e8228406f..c703bc4755 100644 --- a/tools/config/.mockery.yaml +++ b/tools/config/.mockery.yaml @@ -116,6 +116,8 @@ packages: github.com/absmach/supermq/pkg/oauth2: interfaces: Provider: + DeviceCodeStore: + Service: github.com/absmach/supermq/pkg/policies: interfaces: Evaluator: @@ -140,4 +142,3 @@ packages: github.com/absmach/supermq/notifications: interfaces: Notifier: - diff --git a/users/README.md b/users/README.md index aa9356e43e..da298a864a 100644 --- a/users/README.md +++ b/users/README.md @@ -57,6 +57,7 @@ The service is configured using the environment variables presented in the follo | `SMQ_JAEGER_TRACE_RATIO` | Jaeger sampling ratio | 1.0 | | `SMQ_SEND_TELEMETRY` | Send telemetry to supermq call home server. | true | | `SMQ_USERS_INSTANCE_ID` | SuperMQ instance ID | "" | +| `SMQ_USERS_CACHE_URL` | Cache database URL | redis://localhost:6379/0 | ## Deployment @@ -120,6 +121,7 @@ SMQ_OAUTH_UI_ERROR_URL=http://localhost:9095/error \ SMQ_USERS_DELETE_INTERVAL=24h \ SMQ_USERS_DELETE_AFTER=720h \ SMQ_USERS_INSTANCE_ID="" \ +SMQ_USERS_CACHE_URL=redis://localhost:6379/0 \ $GOBIN/supermq-users ``` diff --git a/users/api/endpoint_test.go b/users/api/endpoint_test.go index 6c5c53ea5a..4989378a5e 100644 --- a/users/api/endpoint_test.go +++ b/users/api/endpoint_test.go @@ -23,12 +23,14 @@ import ( authnmocks "github.com/absmach/supermq/pkg/authn/mocks" "github.com/absmach/supermq/pkg/errors" svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/pkg/oauth2" oauth2mocks "github.com/absmach/supermq/pkg/oauth2/mocks" "github.com/absmach/supermq/pkg/uuid" "github.com/absmach/supermq/users" usersapi "github.com/absmach/supermq/users/api" "github.com/absmach/supermq/users/mocks" "github.com/go-chi/chi/v5" + "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -97,7 +99,9 @@ func newUsersServer() (*httptest.Server, *mocks.Service, *authnmocks.Authenticat authn := new(authnmocks.Authentication) am := smqauthn.NewAuthNMiddleware(authn) token := new(authmocks.TokenServiceClient) - usersapi.MakeHandler(svc, am, token, true, mux, logger, "", passRegex, idp, provider) + // Create a mock Redis client for testing (won't be used in these tests) + redisClient := redis.NewClient(&redis.Options{Addr: "localhost:6379"}) + usersapi.MakeHandler(svc, am, token, true, mux, logger, "", passRegex, idp, redisClient, []oauth2.Provider{provider}, []oauth2.Provider{provider}) return httptest.NewServer(mux), svc, authn } diff --git a/users/api/transport.go b/users/api/transport.go index 1abd74fcc2..a65987f009 100644 --- a/users/api/transport.go +++ b/users/api/transport.go @@ -4,6 +4,7 @@ package api import ( + "context" "log/slog" "net/http" "regexp" @@ -12,14 +13,27 @@ import ( grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1" smqauthn "github.com/absmach/supermq/pkg/authn" "github.com/absmach/supermq/pkg/oauth2" + oauthhttp "github.com/absmach/supermq/pkg/oauth2/http" + "github.com/absmach/supermq/pkg/oauth2/store" "github.com/absmach/supermq/users" "github.com/go-chi/chi/v5" "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/redis/go-redis/v9" ) // MakeHandler returns a HTTP handler for Users and Groups API endpoints. -func MakeHandler(cls users.Service, authn smqauthn.AuthNMiddleware, tokensvc grpcTokenV1.TokenServiceClient, selfRegister bool, mux *chi.Mux, logger *slog.Logger, instanceID string, pr *regexp.Regexp, idp supermq.IDProvider, providers ...oauth2.Provider) http.Handler { - mux = usersHandler(cls, authn, tokensvc, selfRegister, mux, logger, pr, idp, providers...) +// It accepts separate providers for device flow and user flow. +// For backward compatibility, if only one provider is passed, it's used for both flows. +func MakeHandler(svc users.Service, authn smqauthn.AuthNMiddleware, tokensvc grpcTokenV1.TokenServiceClient, selfRegister bool, mux *chi.Mux, logger *slog.Logger, instanceID string, pr *regexp.Regexp, idp supermq.IDProvider, cacheClient *redis.Client, userProviders, deviceProviders []oauth2.Provider) http.Handler { + ctx := context.Background() + + mux = usersHandler(svc, authn, tokensvc, selfRegister, mux, logger, pr, idp) + + deviceStore := store.NewRedisDeviceCodeStore(ctx, cacheClient) + oauthSvc := oauth2.NewOAuthService(deviceStore, svc, tokensvc) + + mux = oauthhttp.Handler(mux, tokensvc, oauthSvc, userProviders...) + mux = oauthhttp.DeviceHandler(mux, tokensvc, oauthSvc, deviceProviders...) mux.Get("/health", supermq.Health("users", instanceID)) mux.Handle("/metrics", promhttp.Handler()) diff --git a/users/api/users.go b/users/api/users.go index 9dde471168..5229178ca4 100644 --- a/users/api/users.go +++ b/users/api/users.go @@ -15,10 +15,8 @@ import ( grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1" api "github.com/absmach/supermq/api/http" apiutil "github.com/absmach/supermq/api/http/util" - smqauth "github.com/absmach/supermq/auth" smqauthn "github.com/absmach/supermq/pkg/authn" "github.com/absmach/supermq/pkg/errors" - "github.com/absmach/supermq/pkg/oauth2" "github.com/absmach/supermq/users" "github.com/go-chi/chi/v5" kithttp "github.com/go-kit/kit/transport/http" @@ -28,7 +26,7 @@ import ( var passRegex = regexp.MustCompile("^.{8,}$") // usersHandler returns a HTTP handler for API endpoints. -func usersHandler(svc users.Service, authn smqauthn.AuthNMiddleware, tokenClient grpcTokenV1.TokenServiceClient, selfRegister bool, r *chi.Mux, logger *slog.Logger, pr *regexp.Regexp, idp supermq.IDProvider, providers ...oauth2.Provider) *chi.Mux { +func usersHandler(svc users.Service, authn smqauthn.AuthNMiddleware, tokenClient grpcTokenV1.TokenServiceClient, selfRegister bool, r *chi.Mux, logger *slog.Logger, pr *regexp.Regexp, idp supermq.IDProvider) *chi.Mux { passRegex = pr opts := []kithttp.ServerOption{ @@ -206,10 +204,6 @@ func usersHandler(svc users.Service, authn smqauthn.AuthNMiddleware, tokenClient opts..., ), "verify_email").ServeHTTP) - for _, provider := range providers { - r.HandleFunc("/oauth/callback/"+provider.Name(), oauth2CallbackHandler(provider, svc, tokenClient)) - } - return r } @@ -548,77 +542,3 @@ func decodeChangeUserStatus(_ context.Context, r *http.Request) (any, error) { return req, nil } - -// oauth2CallbackHandler is a http.HandlerFunc that handles OAuth2 callbacks. -func oauth2CallbackHandler(oauth oauth2.Provider, svc users.Service, tokenClient grpcTokenV1.TokenServiceClient) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if !oauth.IsEnabled() { - http.Redirect(w, r, oauth.ErrorURL()+"?error=oauth%20provider%20is%20disabled", http.StatusSeeOther) - return - } - state := r.FormValue("state") - if state != oauth.State() { - http.Redirect(w, r, oauth.ErrorURL()+"?error=invalid%20state", http.StatusSeeOther) - return - } - - if code := r.FormValue("code"); code != "" { - token, err := oauth.Exchange(r.Context(), code) - if err != nil { - http.Redirect(w, r, oauth.ErrorURL()+"?error="+err.Error(), http.StatusSeeOther) - return - } - - user, err := oauth.UserInfo(token.AccessToken) - if err != nil { - http.Redirect(w, r, oauth.ErrorURL()+"?error="+err.Error(), http.StatusSeeOther) - return - } - - user.AuthProvider = oauth.Name() - if user.AuthProvider == "" { - user.AuthProvider = "oauth" - } - user, err = svc.OAuthCallback(r.Context(), user) - if err != nil { - http.Redirect(w, r, oauth.ErrorURL()+"?error="+err.Error(), http.StatusSeeOther) - return - } - if err := svc.OAuthAddUserPolicy(r.Context(), user); err != nil { - http.Redirect(w, r, oauth.ErrorURL()+"?error="+err.Error(), http.StatusSeeOther) - return - } - - jwt, err := tokenClient.Issue(r.Context(), &grpcTokenV1.IssueReq{ - UserId: user.ID, - Type: uint32(smqauth.AccessKey), - UserRole: uint32(smqauth.UserRole), - Verified: !user.VerifiedAt.IsZero(), - }) - if err != nil { - http.Redirect(w, r, oauth.ErrorURL()+"?error="+err.Error(), http.StatusSeeOther) - return - } - - http.SetCookie(w, &http.Cookie{ - Name: "access_token", - Value: jwt.GetAccessToken(), - Path: "/", - HttpOnly: true, - Secure: true, - }) - http.SetCookie(w, &http.Cookie{ - Name: "refresh_token", - Value: jwt.GetRefreshToken(), - Path: "/", - HttpOnly: true, - Secure: true, - }) - - http.Redirect(w, r, oauth.RedirectURL(), http.StatusFound) - return - } - - http.Redirect(w, r, oauth.ErrorURL()+"?error=empty%20code", http.StatusSeeOther) - } -} diff --git a/users/oauth2/mocks/service.go b/users/oauth2/mocks/service.go new file mode 100644 index 0000000000..5dab544781 --- /dev/null +++ b/users/oauth2/mocks/service.go @@ -0,0 +1,480 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + + "github.com/absmach/supermq/api/grpc/token/v1" + "github.com/absmach/supermq/pkg/oauth2" + mock "github.com/stretchr/testify/mock" +) + +// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewService(t interface { + mock.TestingT + Cleanup(func()) +}) *Service { + mock := &Service{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Service is an autogenerated mock type for the Service type +type Service struct { + mock.Mock +} + +type Service_Expecter struct { + mock *mock.Mock +} + +func (_m *Service) EXPECT() *Service_Expecter { + return &Service_Expecter{mock: &_m.Mock} +} + +// CreateDeviceCode provides a mock function for the type Service +func (_mock *Service) CreateDeviceCode(ctx context.Context, provider oauth2.Provider, verificationURI string) (oauth2.DeviceCode, error) { + ret := _mock.Called(ctx, provider, verificationURI) + + if len(ret) == 0 { + panic("no return value specified for CreateDeviceCode") + } + + var r0 oauth2.DeviceCode + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, oauth2.Provider, string) (oauth2.DeviceCode, error)); ok { + return returnFunc(ctx, provider, verificationURI) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, oauth2.Provider, string) oauth2.DeviceCode); ok { + r0 = returnFunc(ctx, provider, verificationURI) + } else { + r0 = ret.Get(0).(oauth2.DeviceCode) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, oauth2.Provider, string) error); ok { + r1 = returnFunc(ctx, provider, verificationURI) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_CreateDeviceCode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateDeviceCode' +type Service_CreateDeviceCode_Call struct { + *mock.Call +} + +// CreateDeviceCode is a helper method to define mock.On call +// - ctx context.Context +// - provider oauth2.Provider +// - verificationURI string +func (_e *Service_Expecter) CreateDeviceCode(ctx interface{}, provider interface{}, verificationURI interface{}) *Service_CreateDeviceCode_Call { + return &Service_CreateDeviceCode_Call{Call: _e.mock.On("CreateDeviceCode", ctx, provider, verificationURI)} +} + +func (_c *Service_CreateDeviceCode_Call) Run(run func(ctx context.Context, provider oauth2.Provider, verificationURI string)) *Service_CreateDeviceCode_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 oauth2.Provider + if args[1] != nil { + arg1 = args[1].(oauth2.Provider) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_CreateDeviceCode_Call) Return(deviceCode oauth2.DeviceCode, err error) *Service_CreateDeviceCode_Call { + _c.Call.Return(deviceCode, err) + return _c +} + +func (_c *Service_CreateDeviceCode_Call) RunAndReturn(run func(ctx context.Context, provider oauth2.Provider, verificationURI string) (oauth2.DeviceCode, error)) *Service_CreateDeviceCode_Call { + _c.Call.Return(run) + return _c +} + +// GetDeviceCodeByUserCode provides a mock function for the type Service +func (_mock *Service) GetDeviceCodeByUserCode(ctx context.Context, userCode string) (oauth2.DeviceCode, error) { + ret := _mock.Called(ctx, userCode) + + if len(ret) == 0 { + panic("no return value specified for GetDeviceCodeByUserCode") + } + + var r0 oauth2.DeviceCode + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (oauth2.DeviceCode, error)); ok { + return returnFunc(ctx, userCode) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) oauth2.DeviceCode); ok { + r0 = returnFunc(ctx, userCode) + } else { + r0 = ret.Get(0).(oauth2.DeviceCode) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, userCode) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_GetDeviceCodeByUserCode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDeviceCodeByUserCode' +type Service_GetDeviceCodeByUserCode_Call struct { + *mock.Call +} + +// GetDeviceCodeByUserCode is a helper method to define mock.On call +// - ctx context.Context +// - userCode string +func (_e *Service_Expecter) GetDeviceCodeByUserCode(ctx interface{}, userCode interface{}) *Service_GetDeviceCodeByUserCode_Call { + return &Service_GetDeviceCodeByUserCode_Call{Call: _e.mock.On("GetDeviceCodeByUserCode", ctx, userCode)} +} + +func (_c *Service_GetDeviceCodeByUserCode_Call) Run(run func(ctx context.Context, userCode string)) *Service_GetDeviceCodeByUserCode_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Service_GetDeviceCodeByUserCode_Call) Return(deviceCode oauth2.DeviceCode, err error) *Service_GetDeviceCodeByUserCode_Call { + _c.Call.Return(deviceCode, err) + return _c +} + +func (_c *Service_GetDeviceCodeByUserCode_Call) RunAndReturn(run func(ctx context.Context, userCode string) (oauth2.DeviceCode, error)) *Service_GetDeviceCodeByUserCode_Call { + _c.Call.Return(run) + return _c +} + +// PollDeviceToken provides a mock function for the type Service +func (_mock *Service) PollDeviceToken(ctx context.Context, provider oauth2.Provider, deviceCode string) (*v1.Token, error) { + ret := _mock.Called(ctx, provider, deviceCode) + + if len(ret) == 0 { + panic("no return value specified for PollDeviceToken") + } + + var r0 *v1.Token + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, oauth2.Provider, string) (*v1.Token, error)); ok { + return returnFunc(ctx, provider, deviceCode) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, oauth2.Provider, string) *v1.Token); ok { + r0 = returnFunc(ctx, provider, deviceCode) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Token) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, oauth2.Provider, string) error); ok { + r1 = returnFunc(ctx, provider, deviceCode) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_PollDeviceToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PollDeviceToken' +type Service_PollDeviceToken_Call struct { + *mock.Call +} + +// PollDeviceToken is a helper method to define mock.On call +// - ctx context.Context +// - provider oauth2.Provider +// - deviceCode string +func (_e *Service_Expecter) PollDeviceToken(ctx interface{}, provider interface{}, deviceCode interface{}) *Service_PollDeviceToken_Call { + return &Service_PollDeviceToken_Call{Call: _e.mock.On("PollDeviceToken", ctx, provider, deviceCode)} +} + +func (_c *Service_PollDeviceToken_Call) Run(run func(ctx context.Context, provider oauth2.Provider, deviceCode string)) *Service_PollDeviceToken_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 oauth2.Provider + if args[1] != nil { + arg1 = args[1].(oauth2.Provider) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_PollDeviceToken_Call) Return(token *v1.Token, err error) *Service_PollDeviceToken_Call { + _c.Call.Return(token, err) + return _c +} + +func (_c *Service_PollDeviceToken_Call) RunAndReturn(run func(ctx context.Context, provider oauth2.Provider, deviceCode string) (*v1.Token, error)) *Service_PollDeviceToken_Call { + _c.Call.Return(run) + return _c +} + +// ProcessDeviceCallback provides a mock function for the type Service +func (_mock *Service) ProcessDeviceCallback(ctx context.Context, provider oauth2.Provider, userCode string, oauthCode string) error { + ret := _mock.Called(ctx, provider, userCode, oauthCode) + + if len(ret) == 0 { + panic("no return value specified for ProcessDeviceCallback") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, oauth2.Provider, string, string) error); ok { + r0 = returnFunc(ctx, provider, userCode, oauthCode) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_ProcessDeviceCallback_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ProcessDeviceCallback' +type Service_ProcessDeviceCallback_Call struct { + *mock.Call +} + +// ProcessDeviceCallback is a helper method to define mock.On call +// - ctx context.Context +// - provider oauth2.Provider +// - userCode string +// - oauthCode string +func (_e *Service_Expecter) ProcessDeviceCallback(ctx interface{}, provider interface{}, userCode interface{}, oauthCode interface{}) *Service_ProcessDeviceCallback_Call { + return &Service_ProcessDeviceCallback_Call{Call: _e.mock.On("ProcessDeviceCallback", ctx, provider, userCode, oauthCode)} +} + +func (_c *Service_ProcessDeviceCallback_Call) Run(run func(ctx context.Context, provider oauth2.Provider, userCode string, oauthCode string)) *Service_ProcessDeviceCallback_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 oauth2.Provider + if args[1] != nil { + arg1 = args[1].(oauth2.Provider) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_ProcessDeviceCallback_Call) Return(err error) *Service_ProcessDeviceCallback_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_ProcessDeviceCallback_Call) RunAndReturn(run func(ctx context.Context, provider oauth2.Provider, userCode string, oauthCode string) error) *Service_ProcessDeviceCallback_Call { + _c.Call.Return(run) + return _c +} + +// ProcessWebCallback provides a mock function for the type Service +func (_mock *Service) ProcessWebCallback(ctx context.Context, provider oauth2.Provider, code string, redirectURL string) (*v1.Token, error) { + ret := _mock.Called(ctx, provider, code, redirectURL) + + if len(ret) == 0 { + panic("no return value specified for ProcessWebCallback") + } + + var r0 *v1.Token + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, oauth2.Provider, string, string) (*v1.Token, error)); ok { + return returnFunc(ctx, provider, code, redirectURL) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, oauth2.Provider, string, string) *v1.Token); ok { + r0 = returnFunc(ctx, provider, code, redirectURL) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Token) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, oauth2.Provider, string, string) error); ok { + r1 = returnFunc(ctx, provider, code, redirectURL) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ProcessWebCallback_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ProcessWebCallback' +type Service_ProcessWebCallback_Call struct { + *mock.Call +} + +// ProcessWebCallback is a helper method to define mock.On call +// - ctx context.Context +// - provider oauth2.Provider +// - code string +// - redirectURL string +func (_e *Service_Expecter) ProcessWebCallback(ctx interface{}, provider interface{}, code interface{}, redirectURL interface{}) *Service_ProcessWebCallback_Call { + return &Service_ProcessWebCallback_Call{Call: _e.mock.On("ProcessWebCallback", ctx, provider, code, redirectURL)} +} + +func (_c *Service_ProcessWebCallback_Call) Run(run func(ctx context.Context, provider oauth2.Provider, code string, redirectURL string)) *Service_ProcessWebCallback_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 oauth2.Provider + if args[1] != nil { + arg1 = args[1].(oauth2.Provider) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_ProcessWebCallback_Call) Return(token *v1.Token, err error) *Service_ProcessWebCallback_Call { + _c.Call.Return(token, err) + return _c +} + +func (_c *Service_ProcessWebCallback_Call) RunAndReturn(run func(ctx context.Context, provider oauth2.Provider, code string, redirectURL string) (*v1.Token, error)) *Service_ProcessWebCallback_Call { + _c.Call.Return(run) + return _c +} + +// VerifyDevice provides a mock function for the type Service +func (_mock *Service) VerifyDevice(ctx context.Context, provider oauth2.Provider, userCode string, oauthCode string, approve bool) error { + ret := _mock.Called(ctx, provider, userCode, oauthCode, approve) + + if len(ret) == 0 { + panic("no return value specified for VerifyDevice") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, oauth2.Provider, string, string, bool) error); ok { + r0 = returnFunc(ctx, provider, userCode, oauthCode, approve) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_VerifyDevice_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyDevice' +type Service_VerifyDevice_Call struct { + *mock.Call +} + +// VerifyDevice is a helper method to define mock.On call +// - ctx context.Context +// - provider oauth2.Provider +// - userCode string +// - oauthCode string +// - approve bool +func (_e *Service_Expecter) VerifyDevice(ctx interface{}, provider interface{}, userCode interface{}, oauthCode interface{}, approve interface{}) *Service_VerifyDevice_Call { + return &Service_VerifyDevice_Call{Call: _e.mock.On("VerifyDevice", ctx, provider, userCode, oauthCode, approve)} +} + +func (_c *Service_VerifyDevice_Call) Run(run func(ctx context.Context, provider oauth2.Provider, userCode string, oauthCode string, approve bool)) *Service_VerifyDevice_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 oauth2.Provider + if args[1] != nil { + arg1 = args[1].(oauth2.Provider) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 bool + if args[4] != nil { + arg4 = args[4].(bool) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_VerifyDevice_Call) Return(err error) *Service_VerifyDevice_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_VerifyDevice_Call) RunAndReturn(run func(ctx context.Context, provider oauth2.Provider, userCode string, oauthCode string, approve bool) error) *Service_VerifyDevice_Call { + _c.Call.Return(run) + return _c +} diff --git a/users/oauth2/oauth.go b/users/oauth2/oauth.go new file mode 100644 index 0000000000..9996e2292e --- /dev/null +++ b/users/oauth2/oauth.go @@ -0,0 +1,111 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package oauth provides OAuth2 authentication implementation for users service. +// It handles both web-based OAuth flow and device authorization flow for CLI clients. +package oauth + +import ( + "context" + "crypto/rand" + "encoding/base32" + "strings" + + grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1" + "github.com/absmach/supermq/pkg/oauth2" + goauth2 "golang.org/x/oauth2" +) + +// Re-export constants from pkg/oauth2 for backward compatibility. +const ( + DeviceCodeLength = oauth2.DeviceCodeLength + DeviceCodePollTimeout = oauth2.DeviceCodePollTimeout + CodeCheckInterval = oauth2.CodeCheckInterval + DeviceStatePrefix = oauth2.DeviceStatePrefix +) + +// Re-export errors from pkg/oauth2 for backward compatibility. +var ( + ErrDeviceCodeExpired = oauth2.ErrDeviceCodeExpired + ErrDeviceCodePending = oauth2.ErrDeviceCodePending + ErrSlowDown = oauth2.ErrSlowDown + ErrAccessDenied = oauth2.ErrAccessDenied + ErrInvalidState = oauth2.ErrInvalidState + ErrEmptyCode = oauth2.ErrEmptyCode + ErrInvalidProvider = oauth2.ErrInvalidProvider + ErrDeviceCodeNotFound = oauth2.ErrDeviceCodeNotFound + ErrUserCodeNotFound = oauth2.ErrUserCodeNotFound +) + +// Service provides OAuth authentication operations for the users service. +type Service interface { + // Device flow operations + + // CreateDeviceCode initiates the device authorization flow. + // It generates device and user codes, and returns the verification URI. + CreateDeviceCode(ctx context.Context, provider oauth2.Provider, verificationURI string) (oauth2.DeviceCode, error) + + // PollDeviceToken polls for device authorization completion. + // Returns the JWT token once the user has authorized the device. + PollDeviceToken(ctx context.Context, provider oauth2.Provider, deviceCode string) (*grpcTokenV1.Token, error) + + // VerifyDevice handles user verification of device codes. + // It exchanges the OAuth authorization code for a token and marks the device as approved. + VerifyDevice(ctx context.Context, provider oauth2.Provider, userCode, oauthCode string, approve bool) error + + // GetDeviceCodeByUserCode retrieves a device code by its user code. + GetDeviceCodeByUserCode(ctx context.Context, userCode string) (oauth2.DeviceCode, error) + + // Web flow operations + + // ProcessWebCallback handles OAuth callback for web flow. + // It exchanges the authorization code for a token and creates/updates the user. + ProcessWebCallback(ctx context.Context, provider oauth2.Provider, code, redirectURL string) (*grpcTokenV1.Token, error) + + // ProcessDeviceCallback handles OAuth callback for device flow. + // It's called when a user authorizes a device through the web interface. + ProcessDeviceCallback(ctx context.Context, provider oauth2.Provider, userCode, oauthCode string) error +} + +// generateUserCode generates a human-friendly code like "ABCD-EFGH". +func generateUserCode() (string, error) { + b := make([]byte, DeviceCodeLength) + if _, err := rand.Read(b); err != nil { + return "", err + } + code := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(b) + code = strings.ToUpper(code[:DeviceCodeLength]) + // Format as XXXX-XXXX + if len(code) >= 8 { + code = code[:4] + "-" + code[4:8] + } + return code, nil +} + +// generateDeviceCode generates a random device code. +func generateDeviceCode() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(b), nil +} + +// IsDeviceFlowState checks if the state parameter indicates a device flow. +func IsDeviceFlowState(state string) bool { + return strings.HasPrefix(state, DeviceStatePrefix) +} + +// ExtractUserCodeFromState extracts the user code from a device flow state. +func ExtractUserCodeFromState(state string) string { + return strings.TrimPrefix(state, DeviceStatePrefix) +} + +// ExchangeCode exchanges an authorization code for an access token. +// If redirectURL is provided, it uses ExchangeWithRedirect, otherwise uses Exchange. +func ExchangeCode(ctx context.Context, provider oauth2.Provider, code, redirectURL string) (goauth2.Token, error) { + if redirectURL != "" { + return provider.ExchangeWithRedirect(ctx, code, redirectURL) + } + return provider.Exchange(ctx, code) +} diff --git a/users/oauth2/service.go b/users/oauth2/service.go new file mode 100644 index 0000000000..97f947e7e5 --- /dev/null +++ b/users/oauth2/service.go @@ -0,0 +1,212 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package oauth + +import ( + "context" + "fmt" + "time" + + grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1" + smqauth "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/pkg/oauth2" + "github.com/absmach/supermq/users" +) + +var _ Service = (*oauthService)(nil) + +// oauthService implements the OAuth Service interface. +type oauthService struct { + deviceStore oauth2.DeviceCodeStore + userService users.Service + tokenClient grpcTokenV1.TokenServiceClient +} + +// NewOAuthService creates a new OAuth service instance. +func NewOAuthService(deviceStore oauth2.DeviceCodeStore, userService users.Service, tokenClient grpcTokenV1.TokenServiceClient) Service { + return &oauthService{ + deviceStore: deviceStore, + userService: userService, + tokenClient: tokenClient, + } +} + +// CreateDeviceCode initiates the device authorization flow. +func (s *oauthService) CreateDeviceCode(ctx context.Context, provider oauth2.Provider, verificationURI string) (oauth2.DeviceCode, error) { + if !provider.IsEnabled() { + return oauth2.DeviceCode{}, ErrInvalidProvider + } + + userCode, err := generateUserCode() + if err != nil { + return oauth2.DeviceCode{}, fmt.Errorf("failed to generate user code: %w", err) + } + + deviceCode, err := generateDeviceCode() + if err != nil { + return oauth2.DeviceCode{}, fmt.Errorf("failed to generate device code: %w", err) + } + + code := oauth2.DeviceCode{ + DeviceCode: deviceCode, + UserCode: userCode, + VerificationURI: verificationURI, + ExpiresIn: int(oauth2.DeviceCodeExpiry.Seconds()), + Interval: int(CodeCheckInterval.Seconds()), + Provider: provider.Name(), + CreatedAt: time.Now(), + State: provider.State(), + } + + if err := s.deviceStore.Save(code); err != nil { + return oauth2.DeviceCode{}, fmt.Errorf("failed to save device code: %w", err) + } + + return code, nil +} + +// PollDeviceToken polls for device authorization completion. +func (s *oauthService) PollDeviceToken(ctx context.Context, provider oauth2.Provider, deviceCode string) (*grpcTokenV1.Token, error) { + if !provider.IsEnabled() { + return nil, ErrInvalidProvider + } + + code, err := s.deviceStore.Get(deviceCode) + if err != nil { + return nil, ErrDeviceCodeNotFound + } + + // Check expiration + if time.Since(code.CreatedAt) > oauth2.DeviceCodeExpiry { + s.deviceStore.Delete(deviceCode) + return nil, ErrDeviceCodeExpired + } + + // Check polling rate + if time.Since(code.LastPoll) < CodeCheckInterval { + return nil, ErrSlowDown + } + + // Update last poll time + code.LastPoll = time.Now() + s.deviceStore.Update(code) + + // Check if denied + if code.Denied { + s.deviceStore.Delete(deviceCode) + return nil, ErrAccessDenied + } + + // Check if approved + if !code.Approved || code.AccessToken == "" { + return nil, ErrDeviceCodePending + } + + // Process the OAuth user and issue tokens + jwt, err := s.processOAuthUser(ctx, provider, code.AccessToken) + if err != nil { + s.deviceStore.Delete(deviceCode) + return nil, fmt.Errorf("failed to process oauth user: %w", err) + } + + s.deviceStore.Delete(deviceCode) + jwt.AccessType = "" + return jwt, nil +} + +// VerifyDevice handles user verification of device codes. +func (s *oauthService) VerifyDevice(ctx context.Context, provider oauth2.Provider, userCode, oauthCode string, approve bool) error { + if !provider.IsEnabled() { + return ErrInvalidProvider + } + + code, err := s.deviceStore.GetByUserCode(userCode) + if err != nil { + return err + } + + // Check expiration + if time.Since(code.CreatedAt) > oauth2.DeviceCodeExpiry { + s.deviceStore.Delete(code.DeviceCode) + return ErrDeviceCodeExpired + } + + if !approve { + code.Denied = true + s.deviceStore.Update(code) + return nil + } + + // Exchange authorization code for access token + token, err := provider.Exchange(ctx, oauthCode) + if err != nil { + return fmt.Errorf("failed to exchange code: %w", err) + } + + code.Approved = true + code.AccessToken = token.AccessToken + if err := s.deviceStore.Update(code); err != nil { + return fmt.Errorf("failed to update device code: %w", err) + } + + return nil +} + +// GetDeviceCodeByUserCode retrieves a device code by its user code. +func (s *oauthService) GetDeviceCodeByUserCode(ctx context.Context, userCode string) (oauth2.DeviceCode, error) { + return s.deviceStore.GetByUserCode(userCode) +} + +// ProcessWebCallback handles OAuth callback for web flow. +func (s *oauthService) ProcessWebCallback(ctx context.Context, provider oauth2.Provider, code, redirectURL string) (*grpcTokenV1.Token, error) { + if !provider.IsEnabled() { + return nil, ErrInvalidProvider + } + + if code == "" { + return nil, ErrEmptyCode + } + + token, err := ExchangeCode(ctx, provider, code, redirectURL) + if err != nil { + return nil, fmt.Errorf("failed to exchange code: %w", err) + } + + return s.processOAuthUser(ctx, provider, token.AccessToken) +} + +// ProcessDeviceCallback handles OAuth callback for device flow. +func (s *oauthService) ProcessDeviceCallback(ctx context.Context, provider oauth2.Provider, userCode, oauthCode string) error { + return s.VerifyDevice(ctx, provider, userCode, oauthCode, true) +} + +// processOAuthUser retrieves user info from the OAuth provider, creates or updates the user, +// adds user policies, and issues a JWT token. +func (s *oauthService) processOAuthUser(ctx context.Context, provider oauth2.Provider, accessToken string) (*grpcTokenV1.Token, error) { + user, err := provider.UserInfo(accessToken) + if err != nil { + return nil, fmt.Errorf("failed to get user info: %w", err) + } + + user.AuthProvider = provider.Name() + if user.AuthProvider == "" { + user.AuthProvider = "oauth" + } + + user, err = s.userService.OAuthCallback(ctx, user) + if err != nil { + return nil, fmt.Errorf("failed to handle oauth callback: %w", err) + } + + if err := s.userService.OAuthAddUserPolicy(ctx, user); err != nil { + return nil, fmt.Errorf("failed to add user policy: %w", err) + } + + return s.tokenClient.Issue(ctx, &grpcTokenV1.IssueReq{ + UserId: user.ID, + Type: uint32(smqauth.AccessKey), + UserRole: uint32(smqauth.UserRole), + Verified: !user.VerifiedAt.IsZero(), + }) +} diff --git a/users/oauth2/store/memory.go b/users/oauth2/store/memory.go new file mode 100644 index 0000000000..2d195521c3 --- /dev/null +++ b/users/oauth2/store/memory.go @@ -0,0 +1,104 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package store + +import ( + "sync" + "time" +) + +// inMemoryDeviceCodeStore is an in-memory implementation of DeviceCodeStore. +type inMemoryDeviceCodeStore struct { + mu sync.RWMutex + codes map[string]DeviceCode + userCodes map[string]string // maps user code to device code + cleanupDone chan struct{} +} + +// NewInMemoryDeviceCodeStore creates a new in-memory device code store. +// It automatically starts a cleanup goroutine to remove expired codes. +func NewInMemoryDeviceCodeStore() DeviceCodeStore { + store := &inMemoryDeviceCodeStore{ + codes: make(map[string]DeviceCode), + userCodes: make(map[string]string), + cleanupDone: make(chan struct{}), + } + go store.cleanup() + return store +} + +func (s *inMemoryDeviceCodeStore) Save(code DeviceCode) error { + s.mu.Lock() + defer s.mu.Unlock() + s.codes[code.DeviceCode] = code + s.userCodes[code.UserCode] = code.DeviceCode + return nil +} + +func (s *inMemoryDeviceCodeStore) Get(deviceCode string) (DeviceCode, error) { + s.mu.RLock() + defer s.mu.RUnlock() + code, ok := s.codes[deviceCode] + if !ok { + return DeviceCode{}, ErrDeviceCodeNotFound + } + return code, nil +} + +func (s *inMemoryDeviceCodeStore) GetByUserCode(userCode string) (DeviceCode, error) { + s.mu.RLock() + defer s.mu.RUnlock() + deviceCode, ok := s.userCodes[userCode] + if !ok { + return DeviceCode{}, ErrUserCodeNotFound + } + code, ok := s.codes[deviceCode] + if !ok { + return DeviceCode{}, ErrDeviceCodeNotFound + } + return code, nil +} + +func (s *inMemoryDeviceCodeStore) Update(code DeviceCode) error { + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.codes[code.DeviceCode]; !ok { + return ErrDeviceCodeNotFound + } + s.codes[code.DeviceCode] = code + return nil +} + +func (s *inMemoryDeviceCodeStore) Delete(deviceCode string) error { + s.mu.Lock() + defer s.mu.Unlock() + if code, ok := s.codes[deviceCode]; ok { + delete(s.userCodes, code.UserCode) + } + delete(s.codes, deviceCode) + return nil +} + +// cleanup periodically removes expired device codes. +func (s *inMemoryDeviceCodeStore) cleanup() { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.mu.Lock() + now := time.Now() + for deviceCode, code := range s.codes { + if now.Sub(code.CreatedAt) > DeviceCodeExpiry { + delete(s.codes, deviceCode) + delete(s.userCodes, code.UserCode) + } + } + s.mu.Unlock() + case <-s.cleanupDone: + return + } + } +} diff --git a/users/oauth2/store/redis.go b/users/oauth2/store/redis.go new file mode 100644 index 0000000000..d5e7acd034 --- /dev/null +++ b/users/oauth2/store/redis.go @@ -0,0 +1,143 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package store + +// import ( +// "context" +// "encoding/json" +// "fmt" + +// "github.com/redis/go-redis/v9" +// ) + +// const ( +// deviceCodePrefix = "oauth:device:code:" +// userCodePrefix = "oauth:device:user:" +// ) + +// // redisDeviceCodeStore is a Redis-based implementation of DeviceCodeStore. +// type redisDeviceCodeStore struct { +// client *redis.Client +// ctx context.Context +// } + +// // NewRedisDeviceCodeStore creates a new Redis-based device code store. +// func NewRedisDeviceCodeStore(ctx context.Context, client *redis.Client) DeviceCodeStore { +// return &redisDeviceCodeStore{ +// client: client, +// ctx: ctx, +// } +// } + +// func (s *redisDeviceCodeStore) Save(code DeviceCode) error { +// data, err := json.Marshal(code) +// if err != nil { +// return fmt.Errorf("failed to marshal device code: %w", err) +// } + +// // Store device code with expiry +// deviceKey := deviceCodePrefix + code.DeviceCode +// if err := s.client.Set(s.ctx, deviceKey, data, DeviceCodeExpiry).Err(); err != nil { +// return fmt.Errorf("failed to save device code: %w", err) +// } + +// // Store user code to device code mapping with expiry +// userKey := userCodePrefix + code.UserCode +// if err := s.client.Set(s.ctx, userKey, code.DeviceCode, DeviceCodeExpiry).Err(); err != nil { +// return fmt.Errorf("failed to save user code mapping: %w", err) +// } + +// return nil +// } + +// func (s *redisDeviceCodeStore) Get(deviceCode string) (DeviceCode, error) { +// deviceKey := deviceCodePrefix + deviceCode +// data, err := s.client.Get(s.ctx, deviceKey).Bytes() +// if err != nil { +// if err == redis.Nil { +// return DeviceCode{}, ErrDeviceCodeNotFound +// } +// return DeviceCode{}, fmt.Errorf("failed to get device code: %w", err) +// } + +// var code DeviceCode +// if err := json.Unmarshal(data, &code); err != nil { +// return DeviceCode{}, fmt.Errorf("failed to unmarshal device code: %w", err) +// } + +// return code, nil +// } + +// func (s *redisDeviceCodeStore) GetByUserCode(userCode string) (DeviceCode, error) { +// // First, get the device code from user code mapping +// userKey := userCodePrefix + userCode +// deviceCode, err := s.client.Get(s.ctx, userKey).Result() +// if err != nil { +// if err == redis.Nil { +// return DeviceCode{}, ErrUserCodeNotFound +// } +// return DeviceCode{}, fmt.Errorf("failed to get device code by user code: %w", err) +// } + +// // Then, get the actual device code data +// return s.Get(deviceCode) +// } + +// func (s *redisDeviceCodeStore) Update(code DeviceCode) error { +// // Get the existing code to check if it exists +// existing, err := s.Get(code.DeviceCode) +// if err != nil { +// return err +// } + +// // Preserve the creation time and user code from existing +// code.CreatedAt = existing.CreatedAt +// code.UserCode = existing.UserCode + +// data, err := json.Marshal(code) +// if err != nil { +// return fmt.Errorf("failed to marshal device code: %w", err) +// } + +// // Calculate remaining TTL +// deviceKey := deviceCodePrefix + code.DeviceCode +// ttl, err := s.client.TTL(s.ctx, deviceKey).Result() +// if err != nil { +// return fmt.Errorf("failed to get TTL: %w", err) +// } + +// // If TTL is negative (key doesn't exist or no expiry), use default +// if ttl < 0 { +// ttl = DeviceCodeExpiry +// } + +// // Update the device code with remaining TTL +// if err := s.client.Set(s.ctx, deviceKey, data, ttl).Err(); err != nil { +// return fmt.Errorf("failed to update device code: %w", err) +// } + +// return nil +// } + +// func (s *redisDeviceCodeStore) Delete(deviceCode string) error { +// // Get the code first to find the user code +// code, err := s.Get(deviceCode) +// if err != nil { +// return err +// } + +// // Delete both device code and user code mapping +// deviceKey := deviceCodePrefix + deviceCode +// userKey := userCodePrefix + code.UserCode + +// pipe := s.client.Pipeline() +// pipe.Del(s.ctx, deviceKey) +// pipe.Del(s.ctx, userKey) + +// if _, err := pipe.Exec(s.ctx); err != nil { +// return fmt.Errorf("failed to delete device code: %w", err) +// } + +// return nil +// } diff --git a/users/oauth2/store/store.go b/users/oauth2/store/store.go new file mode 100644 index 0000000000..226136b295 --- /dev/null +++ b/users/oauth2/store/store.go @@ -0,0 +1,26 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package store provides storage implementations for OAuth device codes. +package store + +import ( + "github.com/absmach/supermq/pkg/oauth2" +) + +// Re-export constants from pkg/oauth2 for backward compatibility. +const ( + DeviceCodeExpiry = oauth2.DeviceCodeExpiry +) + +// Re-export errors from pkg/oauth2 for backward compatibility. +var ( + ErrDeviceCodeNotFound = oauth2.ErrDeviceCodeNotFound + ErrUserCodeNotFound = oauth2.ErrUserCodeNotFound +) + +// DeviceCode is an alias for oauth2.DeviceCode for backward compatibility. +type DeviceCode = oauth2.DeviceCode + +// DeviceCodeStore is an alias for oauth2.DeviceCodeStore for backward compatibility. +type DeviceCodeStore = oauth2.DeviceCodeStore