Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions cli/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,70 @@ supermq-cli users create <user_name> <user_email> <user_password> <user_token>
supermq-cli users token <user_email> <user_password>
```

#### OAuth Authentication

Authenticate using OAuth providers (e.g., Google) to obtain access tokens using the device authorization flow.

```bash
supermq-cli users oauth <provider>
```

Example with Google:

```bash
supermq-cli users oauth google
```

**How it works (Device Authorization Flow):**

The CLI uses the OAuth 2.0 Device Authorization Flow, which is specifically designed for command-line applications and devices with limited input capabilities:

1. The CLI requests a device code from the OAuth provider
2. You receive a short user code (e.g., `ABCD-EFGH`) and a verification URL
3. Visit the verification URL in any browser (on any device)
4. Enter the user code when prompted
5. Authorize the application in your browser
6. The CLI automatically detects the authorization and retrieves your tokens
7. Tokens are displayed in JSON format

**Example Output:**

```
=== OAuth Device Authorization ===

Please complete authentication in your browser:

1. Visit: https://auth.example.com/device
2. Enter code: ABCD-EFGH

Waiting for authorization...

✓ Authentication successful!

{
"access_token": "eyJhbGc...",
"refresh_token": "eyJhbGc..."
}
```

**Benefits of Device Flow:**

- **No local server needed** - Works in any environment (containers, SSH, etc.)
- **Cross-device authentication** - Authenticate on your phone/tablet while using CLI
- **Better security** - No need to open ports or handle callbacks
- **Works everywhere** - Headless servers, restricted networks, Docker containers

**Prerequisites:**

The OAuth provider must be configured on the SuperMQ server. For Google OAuth, the following environment variables must be set:

- `SMQ_GOOGLE_CLIENT_ID` - Google OAuth client ID
- `SMQ_GOOGLE_CLIENT_SECRET` - Google OAuth client secret
- `SMQ_GOOGLE_REDIRECT_URL` - OAuth redirect URL (e.g., `http://localhost:9002/oauth/callback/google`)
- `SMQ_GOOGLE_STATE` - OAuth state parameter for security

See the [Users Service README](../users/README.md#oauth-configuration) for more details.

#### Get User

```bash
Expand Down
191 changes: 191 additions & 0 deletions cli/oauth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0

package cli

import (
"context"
"fmt"
"net"
"net/http"
"os/exec"
"runtime"
"strings"
"sync"
"time"

"github.com/spf13/cobra"
)

const (
callbackPath = "/callback"
localServerPort = "9090"
callbackTimeout = 5 * time.Minute
shutdownTimeout = 5 * time.Second
)

type oauthCallbackResult struct {
code string
state string
err error
}

type browserOpener interface {
Open(url string) error
}

type defaultBrowserOpener struct{}

func (defaultBrowserOpener) Open(url string) error {
return openBrowser(url)
}

type callbackServer struct {
listener net.Listener
server *http.Server
}

func newCallbackServer(resultChan chan<- oauthCallbackResult) (*callbackServer, error) {
listener, err := net.Listen("tcp", "127.0.0.1:"+localServerPort)
if err != nil {
return nil, fmt.Errorf("failed to start local server: %w", err)
}

mux := http.NewServeMux()
server := &http.Server{Handler: mux}

var once sync.Once
mux.HandleFunc(callbackPath, func(w http.ResponseWriter, r *http.Request) {
handleOAuthCallback(w, r, resultChan, &once)
})

go func() {
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
once.Do(func() {
resultChan <- oauthCallbackResult{err: fmt.Errorf("server error: %w", err)}
})
}
}()

return &callbackServer{
listener: listener,
server: server,
}, nil
}

func (cs *callbackServer) Shutdown(cmd *cobra.Command) {
ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer cancel()
if err := cs.server.Shutdown(ctx); err != nil {
logErrorCmd(*cmd, err)
}
}

func performOAuthLogin(cmd *cobra.Command, provider string) error {
return performOAuthLoginWithBrowser(cmd, provider, defaultBrowserOpener{})
}

func performOAuthLoginWithBrowser(cmd *cobra.Command, provider string, browser browserOpener) error {
ctx := cmd.Context()
callbackChan := make(chan oauthCallbackResult, 1)

server, err := newCallbackServer(callbackChan)
if err != nil {
return err
}
defer server.Shutdown(cmd)

callbackURL := fmt.Sprintf("http://127.0.0.1:%s%s", localServerPort, callbackPath)
authURL, state, err := sdk.OAuthAuthorizationURL(ctx, provider, callbackURL)
if err != nil {
return fmt.Errorf("failed to get authorization URL: %w", err)
}

printAuthInstructions(authURL)
if err := browser.Open(authURL); err != nil {
fmt.Printf("Failed to open browser automatically: %v\n", err)
}

fmt.Println("Waiting for authentication callback...")

result, err := waitForCallback(callbackChan)
if err != nil {
return err
}

fmt.Println("Exchanging authorization code for tokens...")
token, err := sdk.OAuthCallback(ctx, provider, result.code, state, callbackURL)
if err != nil {
return fmt.Errorf("failed to exchange code for token: %w", err)
}

logJSONCmd(*cmd, token)
fmt.Println("\nAuthentication successful! You can now use the access_token for API requests.")

return nil
}

func handleOAuthCallback(w http.ResponseWriter, r *http.Request, resultChan chan<- oauthCallbackResult, once *sync.Once) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
errParam := r.URL.Query().Get("error")

w.Header().Set("Content-Type", "text/html")

if errParam != "" {
html := strings.Replace(errorHTML, "{{ERROR_MESSAGE}}", errParam, 1)
fmt.Fprint(w, html)
once.Do(func() {
resultChan <- oauthCallbackResult{err: fmt.Errorf("oauth error: %s", errParam)}
})
return
}

if code == "" {
html := strings.Replace(errorHTML, "{{ERROR_MESSAGE}}", "missing authorization code", 1)
fmt.Fprint(w, html)
once.Do(func() {
resultChan <- oauthCallbackResult{err: fmt.Errorf("missing authorization code")}
})
return
}

fmt.Fprint(w, successHTML)
once.Do(func() {
resultChan <- oauthCallbackResult{code: code, state: state}
})
}

func printAuthInstructions(authURL string) {
fmt.Printf("Opening browser for authentication...\n")
fmt.Printf("If the browser doesn't open automatically, please visit:\n%s\n\n", authURL)
}

func waitForCallback(callbackChan <-chan oauthCallbackResult) (oauthCallbackResult, error) {
select {
case result := <-callbackChan:
if result.err != nil {
return oauthCallbackResult{}, fmt.Errorf("callback error: %w", result.err)
}
return result, nil
case <-time.After(callbackTimeout):
return oauthCallbackResult{}, fmt.Errorf("authentication timeout after %v", callbackTimeout)
}
}

func openBrowser(url string) error {
var cmd *exec.Cmd

switch runtime.GOOS {
case "linux":
cmd = exec.Command("xdg-open", url)
case "darwin":
cmd = exec.Command("open", url)
case "windows":
cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url)
default:
return fmt.Errorf("unsupported platform")
}

return cmd.Start()
}
100 changes: 100 additions & 0 deletions cli/oauth_device.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0

package cli

import (
"context"
"fmt"
"strings"
"time"

"github.com/fatih/color"
"github.com/spf13/cobra"
)

const (
pollInterval = 3 * time.Second
pollTimeout = 10 * time.Minute
)

func performOAuthDeviceLogin(cmd *cobra.Command, provider string) error {
ctx := cmd.Context()

// Step 1: Get device code
deviceCode, err := sdk.OAuthDeviceCode(ctx, provider)
if err != nil {
return fmt.Errorf("failed to get device code: %w", err)
}

// Step 2: Display instructions to user
printDeviceInstructions(deviceCode.VerificationURI, deviceCode.UserCode)

// Step 3: Poll for authorization
token, pollErr := pollForAuthorization(ctx, provider, deviceCode.DeviceCode, deviceCode.Interval)
if pollErr != nil {
return fmt.Errorf("authorization failed: %w", pollErr)
}

// Step 4: Display success message
logJSONCmd(*cmd, token)
successMsg := color.New(color.FgGreen, color.Bold).SprintFunc()
fmt.Printf("\n%s\n", successMsg("✓ Authentication successful!"))
fmt.Println("You can now use the access_token for API requests.")

return nil
}

func printDeviceInstructions(verificationURI, userCode string) {
fmt.Println()
fmt.Println(color.New(color.FgCyan, color.Bold).Sprint("=== OAuth Device Authorization ==="))
fmt.Println()
fmt.Println(color.New(color.FgYellow).Sprint("Please complete authentication in your browser:"))
fmt.Println()
fmt.Printf(" 1. Visit: %s\n", color.New(color.FgBlue, color.Underline).Sprint(verificationURI))
fmt.Printf(" 2. Enter code: %s\n", color.New(color.FgGreen, color.Bold).Sprint(userCode))
fmt.Println()
fmt.Println(color.New(color.FgWhite).Sprint("Waiting for authorization..."))
fmt.Println()
}

func pollForAuthorization(ctx context.Context, provider, deviceCode string, interval int) (interface{}, error) {
pollDuration := time.Duration(interval) * time.Second
if pollDuration < pollInterval {
pollDuration = pollInterval
}

ticker := time.NewTicker(pollDuration)
defer ticker.Stop()

timeout := time.After(pollTimeout)
spinner := []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"}
spinnerIdx := 0

for {
select {
case <-timeout:
return nil, fmt.Errorf("authentication timeout after %v", pollTimeout)

case <-ticker.C:
// Show spinner
fmt.Printf("\r%s Polling for authorization...", color.CyanString(spinner[spinnerIdx]))
spinnerIdx = (spinnerIdx + 1) % len(spinner)

token, err := sdk.OAuthDeviceToken(ctx, provider, deviceCode)
if err != nil {
errMsg := err.Error()
// Check if it's a pending error (expected during polling)
if strings.Contains(errMsg, "authorization pending") || strings.Contains(errMsg, "slow down") {
continue
}
// Any other error is a real failure
return nil, fmt.Errorf("failed to get token: %w", err)
}

// Clear the spinner line
fmt.Print("\r" + string(make([]byte, 50)) + "\r")
return token, nil
}
}
}
Loading
Loading