diff --git a/.cursor/rules/new-tool-from-docs.mdc b/.cursor/rules/new-tool-from-docs.mdc index ced8338..a248266 100644 --- a/.cursor/rules/new-tool-from-docs.mdc +++ b/.cursor/rules/new-tool-from-docs.mdc @@ -62,7 +62,7 @@ Before the implementation use the documentation URL provided to figure out the r Now follow the detailed implementation guide in [pkg/razorpay/README.md](mdc:../pkg/razorpay/README.md) for creating tools and start making code changes. Other guidelines: -1. [Razorpay Go SDK Constants](mdc:https:/github.com/razorpay/razorpay-go/blob/master/constants/url.go) - Use these constants for specifying the api endpoints while writing the tests. +1. [Razorpay Go SDK Constants](mdc:https:/github.com/razorpay/razorpay-go/v2/blob/master/constants/url.go) - Use these constants for specifying the api endpoints while writing the tests. 2. Use the payload and response from the docs provided to write the positive test case for the tool. STYLE: @@ -251,9 +251,9 @@ Contains the Model Context Protocol implementation: - `ToolParameter` types - Response handling utilities (`NewToolResultJSON`, etc.) -### `github.com/razorpay/razorpay-go` - Razorpay Go SDK +### `github.com/razorpay/razorpay-go/v2` - Razorpay Go SDK -**Imported as:** `rzpsdk "github.com/razorpay/razorpay-go"` +**Imported as:** `rzpsdk "github.com/razorpay/razorpay-go/v2"` Official Razorpay client library providing: - `Client` struct with resource-specific clients (Payment, Order, PaymentLink, etc.) diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index b074410..2471de7 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -32,7 +32,7 @@ jobs: IMAGE_TAG="${GITHUB_REF#refs/tags/}" echo "tags=razorpay/mcp:${IMAGE_TAG},razorpay/mcp:latest" >> $GITHUB_OUTPUT else - IMAGE_TAG="${GITHUB_SHA::7}" # short commit ID + IMAGE_TAG="${GITHUB_SHA}" # full commit SHA echo "tags=razorpay/mcp:${IMAGE_TAG}" >> $GITHUB_OUTPUT fi @@ -45,3 +45,5 @@ jobs: tags: ${{ steps.vars.outputs.tags }} build-args: | VERSION=${{ github.ref_name }} + COMMIT=${{ github.sha }} + BUILD_DATE=${{ github.event.head_commit.timestamp }} diff --git a/.gitignore b/.gitignore index 5768972..1c6abd1 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,4 @@ /.go /logs /vendor -razorpay-mcp-server \ No newline at end of file +/razorpay-mcp-server \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 8872ba5..fa442f1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,8 @@ FROM golang:1.24.2-alpine AS builder +# Install git +RUN apk add --no-cache git + WORKDIR /app COPY go.mod go.sum ./ @@ -9,8 +12,10 @@ RUN go mod download COPY . . ARG VERSION="dev" +ARG COMMIT="" +ARG BUILD_DATE="" -RUN CGO_ENABLED=0 GOOS=linux go build -ldflags "-X main.version=${VERSION} -X main.commit=$(git rev-parse HEAD) -X main.date=$(date -u +%Y-%m-%dT%H:%M:%SZ)" -o razorpay-mcp-server ./cmd/razorpay-mcp-server +RUN CGO_ENABLED=0 GOOS=linux go build -ldflags "-X main.version=${VERSION} -X main.commit=${COMMIT:-$(git rev-parse HEAD 2>/dev/null || echo 'unknown')} -X main.date=${BUILD_DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)}" -o razorpay-mcp-server ./cmd/razorpay-mcp-server FROM alpine:latest @@ -29,9 +34,23 @@ RUN chown -R rzp:rzpgroup /app ENV CONFIG="" \ RAZORPAY_KEY_ID="" \ RAZORPAY_KEY_SECRET="" \ - LOG_FILE="" + PORT="8090" \ + MODE="stdio" \ + LOG_FILE="" \ + ADDRESS="mcp.razorpay.com" # Switch to the non-root user USER rzp -ENTRYPOINT ["sh", "-c", "./razorpay-mcp-server stdio --key ${RAZORPAY_KEY_ID} --secret ${RAZORPAY_KEY_SECRET} ${CONFIG:+--config ${CONFIG}} ${LOG_FILE:+--log-file ${LOG_FILE}}"] +# Expose the SSE server port (used in SSE mode) +EXPOSE ${PORT} + +# Use shell form to allow variable substitution and conditional execution +ENTRYPOINT ["sh", "-c", "\ +if [ \"$MODE\" = \"sse\" ]; then \ + ./razorpay-mcp-server sse --port ${PORT} --address ${ADDRESS} ${CONFIG:+--config ${CONFIG}}; \ +elif [ \"$MODE\" = \"http\" ]; then \ + ./razorpay-mcp-server http --port ${PORT} ${CONFIG:+--config ${CONFIG}}; \ +else \ + ./razorpay-mcp-server stdio --key ${RAZORPAY_KEY_ID} --secret ${RAZORPAY_KEY_SECRET} ${CONFIG:+--config ${CONFIG}} ${LOG_FILE:+--log-file ${LOG_FILE}}; \ +fi"] diff --git a/README.md b/README.md index dc37390..71ad494 100644 --- a/README.md +++ b/README.md @@ -186,26 +186,85 @@ Once the build is ready, you need to specify the path to the binary executable i } ``` +## Usage with SSE Server + +The Razorpay MCP Server also supports the Server-Sent Events (SSE) transport protocol, allowing you to run it as a standalone service that clients can connect to. + +### Running the SSE Server with Docker + +To run the server in SSE mode using the same Docker image: + +```bash +# Run the SSE server on port 8090 (default) +docker run -p 8090:8090 \ + -e MODE=sse \ + -e PORT=8090 \ + razorpay-mcp-server:latest +``` + +You can customize the port by setting the `PORT` environment variable. + +### Testing with MCP Inspector + +You can test your SSE server using the [MCP Inspector](https://github.com/modelcontextprotocol/inspector) tool: + +```bash +# Install MCP Inspector +npm install -g @modelcontextprotocol/inspector + +# Open MCP Inspector tool +npx @modelcontextprotocol/inspector +``` + +This will open a browser interface where you can inspect and test the available tools on your SSE server. + ## Configuration The server requires the following configuration: -- `RAZORPAY_KEY_ID`: Your Razorpay API key ID -- `RAZORPAY_KEY_SECRET`: Your Razorpay API key secret -- `LOG_FILE` (optional): Path to log file for server logs +- `RAZORPAY_KEY_ID`: Your Razorpay API key ID (required for stdio mode) +- `RAZORPAY_KEY_SECRET`: Your Razorpay API key secret (required for stdio mode) +- `MODE`: Server mode ("stdio", "sse", or "http", default: "stdio") +- `PORT`: Port for SSE/HTTP server (default: "8090" for SSE, "8080" for HTTP) +- `ADDRESS`: Address to bind the server to (default: "localhost", used in SSE/HTTP mode) +- `LOG_FILE` (optional): Path to log file for stdio mode logs (SSE/HTTP modes always log to stdout) - `TOOLSETS` (optional): Comma-separated list of toolsets to enable (default: "all") - `READ_ONLY` (optional): Run server in read-only mode (default: false) ### Command Line Flags -The server supports the following command line flags: +The server supports different commands and flags: +#### stdio mode: +```bash +./razorpay-mcp-server stdio [flags] +``` - `--key` or `-k`: Your Razorpay API key ID - `--secret` or `-s`: Your Razorpay API key secret - `--log-file` or `-l`: Path to log file - `--toolsets` or `-t`: Comma-separated list of toolsets to enable - `--read-only`: Run server in read-only mode +#### SSE mode: +```bash +./razorpay-mcp-server sse [flags] +``` +- `--address` or `-a`: Address to bind the server to (default: "localhost") +- `--port` or `-p`: Port to run the SSE server on (default: 8080) +- `--toolsets` or `-t`: Comma-separated list of toolsets to enable +- `--read-only`: Run server in read-only mode + +#### HTTP mode: +```bash +./razorpay-mcp-server http [flags] +``` +- `--address` or `-a`: Address to bind the server to (default: "localhost") +- `--port` or `-p`: Port to run the HTTP server on (default: 8080) +- `--toolsets` or `-t`: Comma-separated list of toolsets to enable +- `--read-only`: Run server in read-only mode + +Note: SSE and HTTP modes log to stdout for better container integration. Authentication is handled via Bearer tokens in HTTP mode and SSE context in SSE mode. + ## Debugging the Server You can use the standard Go debugging tools to troubleshoot issues with the server. Log files can be specified using the `--log-file` flag (defaults to ./logs) diff --git a/cmd/razorpay-mcp-server/http.go b/cmd/razorpay-mcp-server/http.go new file mode 100644 index 0000000..9496063 --- /dev/null +++ b/cmd/razorpay-mcp-server/http.go @@ -0,0 +1,99 @@ +package main + +import ( + "context" + "fmt" + "log/slog" + "os" + "os/signal" + "syscall" + + "github.com/spf13/cobra" + "github.com/spf13/viper" + + rzpsdk "github.com/razorpay/razorpay-go/v2" + + "github.com/razorpay/razorpay-mcp-server/pkg/razorpay" +) + +// httpCmd starts the mcp server in http transport mode +var httpCmd = &cobra.Command{ + Use: "http", + Short: "start the http server for direct JSON-RPC calls", + Run: func(cmd *cobra.Command, args []string) { + // Create stdout logger + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: slog.LevelError, + })) + + // Get toolsets to enable from config + enabledToolsets := viper.GetStringSlice("toolsets") + + // Get read-only mode from config + readOnly := viper.GetBool("read_only") + + err := runHTTPServer(logger, nil, enabledToolsets, readOnly) + if err != nil { + logger.Error("error running http server", "error", err) + os.Exit(1) + } + }, +} + +func runHTTPServer( + log *slog.Logger, + client *rzpsdk.Client, + enabledToolsets []string, + readOnly bool, +) error { + ctx, stop := signal.NotifyContext( + context.Background(), + os.Interrupt, + syscall.SIGTERM, + ) + defer stop() + + srv, err := razorpay.NewServer( + log, + client, + "1.0.0", + enabledToolsets, + readOnly, + ) + if err != nil { + return fmt.Errorf("failed to create server: %w", err) + } + srv.RegisterTools() + + httpSrv, err := razorpay.NewHTTPServer( + srv, + razorpay.NewHTTPConfig( + razorpay.WithHTTPAddress("localhost"), + razorpay.WithHTTPPort(viper.GetInt("port")), + ), + ) + if err != nil { + return fmt.Errorf("failed to create http server: %w", err) + } + + errC := make(chan error, 1) + go func() { + log.Info("starting http server") + errC <- httpSrv.Start() + }() + + log.Info("Razorpay MCP Server running on http\n") + + // Wait for shutdown signal + select { + case <-ctx.Done(): + log.Info("shutting down server...") + return httpSrv.Shutdown(ctx) + case err := <-errC: + if err != nil { + log.Error("server error", "error", err) + return err + } + return nil + } +} diff --git a/cmd/razorpay-mcp-server/main.go b/cmd/razorpay-mcp-server/main.go index 49fc2e7..cf4b48b 100644 --- a/cmd/razorpay-mcp-server/main.go +++ b/cmd/razorpay-mcp-server/main.go @@ -40,6 +40,8 @@ func init() { rootCmd.PersistentFlags().StringP("log-file", "l", "", "path to the log file") rootCmd.PersistentFlags().StringSliceP("toolsets", "t", []string{}, "comma-separated list of toolsets to enable") rootCmd.PersistentFlags().Bool("read-only", false, "run server in read-only mode") + rootCmd.PersistentFlags().StringP("address", "a", "localhost", "address to bind the sse server to") + rootCmd.PersistentFlags().IntP("port", "p", 8080, "port to bind the sse server to") // bind flags to viper _ = viper.BindPFlag("key", rootCmd.PersistentFlags().Lookup("key")) @@ -47,6 +49,8 @@ func init() { _ = viper.BindPFlag("log_file", rootCmd.PersistentFlags().Lookup("log-file")) _ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets")) _ = viper.BindPFlag("read_only", rootCmd.PersistentFlags().Lookup("read-only")) + _ = viper.BindPFlag("address", rootCmd.PersistentFlags().Lookup("address")) + _ = viper.BindPFlag("port", rootCmd.PersistentFlags().Lookup("port")) // Set environment variable mappings _ = viper.BindEnv("key", "RAZORPAY_KEY_ID") // Maps RAZORPAY_KEY_ID to key @@ -57,6 +61,8 @@ func init() { // subcommands rootCmd.AddCommand(stdioCmd) + rootCmd.AddCommand(sseCmd) + rootCmd.AddCommand(httpCmd) } // initConfig reads in config file and ENV variables if set. diff --git a/cmd/razorpay-mcp-server/sse.go b/cmd/razorpay-mcp-server/sse.go new file mode 100644 index 0000000..bf20aa1 --- /dev/null +++ b/cmd/razorpay-mcp-server/sse.go @@ -0,0 +1,100 @@ +package main + +import ( + "context" + "fmt" + "log/slog" + "os" + "os/signal" + "syscall" + + "github.com/spf13/cobra" + "github.com/spf13/viper" + + rzpsdk "github.com/razorpay/razorpay-go/v2" + + "github.com/razorpay/razorpay-mcp-server/pkg/mcpgo" + "github.com/razorpay/razorpay-mcp-server/pkg/razorpay" +) + +// sseCmd starts the mcp server in sse transport mode +var sseCmd = &cobra.Command{ + Use: "sse", + Short: "start the sse server", + Run: func(cmd *cobra.Command, args []string) { + // Create stdout logger + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: slog.LevelError, + })) + + // Get toolsets to enable from config + enabledToolsets := viper.GetStringSlice("toolsets") + + // Get read-only mode from config + readOnly := viper.GetBool("read_only") + + err := runSseServer(logger, nil, enabledToolsets, readOnly) + if err != nil { + logger.Error("error running sse server", "error", err) + os.Exit(1) + } + }, +} + +func runSseServer( + log *slog.Logger, + client *rzpsdk.Client, + enabledToolsets []string, + readOnly bool, +) error { + ctx, stop := signal.NotifyContext( + context.Background(), + os.Interrupt, + syscall.SIGTERM, + ) + defer stop() + + srv, err := razorpay.NewServer( + log, + client, + "1.0.0", + enabledToolsets, + readOnly, + ) + if err != nil { + return fmt.Errorf("failed to create server: %w", err) + } + srv.RegisterTools() + + sseSrv, err := mcpgo.NewSSEServer( + srv.GetMCPServer(), + mcpgo.NewSSEConfig( + mcpgo.WithSSEAddress(viper.GetString("address")), + mcpgo.WithSSEPort(viper.GetInt("port")), + ), + ) + if err != nil { + return fmt.Errorf("failed to create sse server: %w", err) + } + + errC := make(chan error, 1) + go func() { + log.Info("starting server") + errC <- sseSrv.Start() + }() + + log.Info("Razorpay MCP Server running on sse\n") + + // Wait for shutdown signal + select { + case <-ctx.Done(): + log.Info("shutting down server...") + return nil + case err := <-errC: + if err != nil { + log.Error("server error", "error", err) + return err + } + return nil + } +} diff --git a/cmd/razorpay-mcp-server/stdio.go b/cmd/razorpay-mcp-server/stdio.go index 1f051cb..83f7d12 100644 --- a/cmd/razorpay-mcp-server/stdio.go +++ b/cmd/razorpay-mcp-server/stdio.go @@ -13,7 +13,7 @@ import ( "github.com/spf13/cobra" "github.com/spf13/viper" - rzpsdk "github.com/razorpay/razorpay-go" + rzpsdk "github.com/razorpay/razorpay-go/v2" "github.com/razorpay/razorpay-mcp-server/pkg/log" "github.com/razorpay/razorpay-mcp-server/pkg/mcpgo" diff --git a/go.mod b/go.mod index 1348842..7488367 100644 --- a/go.mod +++ b/go.mod @@ -5,8 +5,8 @@ go 1.24.2 require ( github.com/go-test/deep v1.1.1 github.com/gorilla/mux v1.8.1 - github.com/mark3labs/mcp-go v0.23.1 - github.com/razorpay/razorpay-go v1.3.4 + github.com/mark3labs/mcp-go v0.23.0 + github.com/razorpay/razorpay-go/v2 v2.0.0-20250603191311-f233476cd336 github.com/spf13/cobra v1.9.1 github.com/spf13/viper v1.20.1 github.com/stretchr/testify v1.10.0 diff --git a/go.sum b/go.sum index 82304f6..5425ae3 100644 --- a/go.sum +++ b/go.sum @@ -22,14 +22,14 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/mark3labs/mcp-go v0.23.1 h1:RzTzZ5kJ+HxwnutKA4rll8N/pKV6Wh5dhCmiJUu5S9I= -github.com/mark3labs/mcp-go v0.23.1/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/mark3labs/mcp-go v0.23.0 h1:NmtoPx4jf7if7bfAynocpVdpXqX5U8X/18c1gddK/QA= +github.com/mark3labs/mcp-go v0.23.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/razorpay/razorpay-go v1.3.4 h1:A9DZ18GZDn/bGRjQ9SesTGUNIAEw+IB27512l3I81aI= -github.com/razorpay/razorpay-go v1.3.4/go.mod h1:VcljkUylUJAUEvFfGVv/d5ht1to1dUgF4H1+3nv7i+Q= +github.com/razorpay/razorpay-go/v2 v2.0.0-20250603191311-f233476cd336 h1:8Wz7yDP6dyX8V3D8pPVl4RYx+VqKzb4QIXjq7+VNk5Q= +github.com/razorpay/razorpay-go/v2 v2.0.0-20250603191311-f233476cd336/go.mod h1:4vdw1ydjK5PdDV3KE3YOb30BWLDE0IpcXZTA4iW80Ew= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= diff --git a/pkg/mcpgo/auth.go b/pkg/mcpgo/auth.go new file mode 100644 index 0000000..51a3f8e --- /dev/null +++ b/pkg/mcpgo/auth.go @@ -0,0 +1,50 @@ +package mcpgo + +import ( + "context" + "encoding/base64" + "fmt" + "strings" + + rzpsdk "github.com/razorpay/razorpay-go/v2" +) + +// AuthenticateRequest handles authentication for a request context. +// If client is provided, it returns the context as-is (stdio mode). +// Otherwise, it validates the auth token and creates a new client (SSE mode). +func AuthenticateRequest( + ctx context.Context, + client *rzpsdk.Client, +) (context.Context, error) { + // If client is provided, this is the stdio mcp server + if client != nil { + return ctx, nil + } + + // Check if auth token is provided + auth := AuthTokenFromContext(ctx) + if auth == "" { + return nil, fmt.Errorf("unauthorized: no auth token provided") + } + + // Base64 decode the auth token + token, err := base64.StdEncoding.DecodeString(auth) + if err != nil { + return nil, fmt.Errorf("unauthorized: invalid auth token") + } + + // Split token into key:secret + parts := strings.Split(string(token), ":") + if len(parts) != 2 { + return nil, fmt.Errorf("unauthorized: invalid auth token") + } + + // Create a new client with the auth credentials + newClient := rzpsdk.NewClient(parts[0], parts[1]) + newClient.SetUserAgent("razorpay-mcp/" + "/sse") + + // Store the client in context + ctx = WithClient(ctx, newClient) + + return ctx, nil +} diff --git a/pkg/mcpgo/context_key.go b/pkg/mcpgo/context_key.go new file mode 100644 index 0000000..b60d8f6 --- /dev/null +++ b/pkg/mcpgo/context_key.go @@ -0,0 +1,46 @@ +package mcpgo + +import ( + "context" +) + +// contextKey is a type used for context value keys to avoid key collisions. +type contextKey string + +// Context keys for storing various values. +const ( + authTokenKey contextKey = "auth_token" + clientKey contextKey = "client" +) + +// WithAuthToken returns a new context with the authentication token attached. +func WithAuthToken(ctx context.Context, token string) context.Context { + return context.WithValue(ctx, authTokenKey, token) +} + +// AuthTokenFromContext extracts the authentication token from the context. +// Returns an empty string if no token is found or if the value is not a string. +func AuthTokenFromContext(ctx context.Context) string { + value := ctx.Value(authTokenKey) + if value == nil { + return "" + } + + token, ok := value.(string) + if !ok { + return "" + } + + return token +} + +// WithClient returns a new context with the client instance attached. +func WithClient(ctx context.Context, client interface{}) context.Context { + return context.WithValue(ctx, clientKey, client) +} + +// ClientFromContext extracts the client instance from the context. +// Returns nil if no client is found. +func ClientFromContext(ctx context.Context) interface{} { + return ctx.Value(clientKey) +} diff --git a/pkg/mcpgo/server.go b/pkg/mcpgo/server.go index ca43075..f7b1d20 100644 --- a/pkg/mcpgo/server.go +++ b/pkg/mcpgo/server.go @@ -1,7 +1,12 @@ package mcpgo import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" + + rzpsdk "github.com/razorpay/razorpay-go/v2" ) // Server defines the minimal MCP server interface needed by the application @@ -94,3 +99,25 @@ func WithToolCapabilities(enabled bool) ServerOption { return s.SetOption(server.WithToolCapabilities(enabled)) } } + +// WithAuthenticationMiddleware returns a server option that adds an +// authentication middleware to the server. +func WithAuthenticationMiddleware(client *rzpsdk.Client) ServerOption { + return func(s OptionSetter) error { + return s.SetOption(server.WithToolHandlerMiddleware( + func(next server.ToolHandlerFunc) server.ToolHandlerFunc { + return func( + ctx context.Context, + request mcp.CallToolRequest, + ) (result *mcp.CallToolResult, err error) { + authenticatedCtx, err := AuthenticateRequest(ctx, client) + if err != nil { + return nil, err + } + + return next(authenticatedCtx, request) + } + }), + ) + } +} diff --git a/pkg/mcpgo/sse.go b/pkg/mcpgo/sse.go new file mode 100644 index 0000000..d48ca48 --- /dev/null +++ b/pkg/mcpgo/sse.go @@ -0,0 +1,134 @@ +package mcpgo + +import ( + "context" + "fmt" + "net/http" + "strings" + + "github.com/mark3labs/mcp-go/server" +) + +type SSEConfig struct { + // address is the address to bind the server to + address string + // port is the port to bind the server to + port int +} + +// getDefaultSSEConfig returns a default configuration for the SSE server +func getDefaultSSEConfig() *SSEConfig { + return &SSEConfig{ + address: "localhost", + port: 8080, + } +} + +// SSEConfigOpts defines a function type for applying configuration options +type SSEConfigOpts func(*SSEConfig) + +// WithSSEAddress returns an option to set the server address +func WithSSEAddress(address string) SSEConfigOpts { + return func(config *SSEConfig) { + config.address = address + } +} + +// WithSSEPort returns an option to set the server port +func WithSSEPort(port int) SSEConfigOpts { + return func(config *SSEConfig) { + config.port = port + } +} + +// NewSSEConfig creates a new SSE server configuration with the provided options +func NewSSEConfig(opts ...SSEConfigOpts) *SSEConfig { + config := getDefaultSSEConfig() + + for _, opt := range opts { + opt(config) + } + + return config +} + +// NewSSEServer creates a new sse transport server +func NewSSEServer( + mcpServer Server, + config *SSEConfig, +) (*mark3labsSseImpl, error) { + sImpl, ok := mcpServer.(*mark3labsImpl) + if !ok { + return nil, fmt.Errorf("%w: expected *mark3labsImpl, got %T", + ErrInvalidServerImplementation, mcpServer) + } + + // Create a new SSE server with the base options + sseServer := server.NewSSEServer( + sImpl.mcpServer, + server.WithBaseURL(config.address), + server.WithSSEContextFunc(authFromRequest), + ) + + // Wrap the server with a recovery handler + impl := &mark3labsSseImpl{ + mcpSseServer: sseServer, + SSEConfig: config, + } + + return impl, nil +} + +// mark3labsSseImpl implements the TransportServer +// interface for sse transport +type mark3labsSseImpl struct { + mcpSseServer *server.SSEServer + SSEConfig *SSEConfig + httpServer *http.Server + mux *http.ServeMux +} + +// Start implements the TransportServer interface +func (s *mark3labsSseImpl) Start() error { + s.mux = http.NewServeMux() + + // Register health check endpoints + s.mux.HandleFunc("/live", s.handleLiveness) + s.mux.HandleFunc("/ready", s.handleReadiness) + + // Register SSE server as default handler for all other routes + s.mux.Handle("/", s.mcpSseServer) + + // Create HTTP server with our custom mux + s.httpServer = &http.Server{ + Addr: fmt.Sprintf(":%d", s.SSEConfig.port), + Handler: s.mux, + } + + // Start the HTTP server + return s.httpServer.ListenAndServe() +} + +// handleLiveness handles /live endpoint +func (s *mark3labsSseImpl) handleLiveness(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) +} + +// handleReadiness handles /ready endpoint +func (s *mark3labsSseImpl) handleReadiness(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) +} + +// authFromRequest extracts the auth token from the request headers. +func authFromRequest(ctx context.Context, r *http.Request) context.Context { + authHeader := r.Header.Get("Authorization") + + parts := strings.SplitN(authHeader, " ", 2) + if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { + return ctx + } + + return WithAuthToken(ctx, parts[1]) +} diff --git a/pkg/mcpgo/tool.go b/pkg/mcpgo/tool.go index 39f8190..12352bf 100644 --- a/pkg/mcpgo/tool.go +++ b/pkg/mcpgo/tool.go @@ -3,6 +3,7 @@ package mcpgo import ( "context" "encoding/json" + "fmt" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" @@ -33,6 +34,18 @@ type Tool interface { // GetHandler internal method for fetching the underlying handler GetHandler() ToolHandler + + // GetName returns the name of the tool + GetName() string + + // GetDescription returns the description of the tool + GetDescription() string + + // GetInputSchema returns the input schema for the tool + GetInputSchema() map[string]interface{} + + // Call executes the tool with the given context and request + Call(ctx context.Context, request CallToolRequest) (interface{}, error) } // PropertyOption represents a customization option for @@ -389,6 +402,62 @@ func (t *mark3labsToolImpl) GetHandler() ToolHandler { return t.handler } +// GetName returns the name of the tool +func (t *mark3labsToolImpl) GetName() string { + return t.name +} + +// GetDescription returns the description of the tool +func (t *mark3labsToolImpl) GetDescription() string { + return t.description +} + +// GetInputSchema returns the input schema for the tool +func (t *mark3labsToolImpl) GetInputSchema() map[string]interface{} { + properties := make(map[string]interface{}) + required := make([]string, 0) + + for _, param := range t.parameters { + properties[param.Name] = param.Schema + if isRequired, ok := param.Schema["required"].(bool); ok && isRequired { + required = append(required, param.Name) + } + } + + schema := map[string]interface{}{ + "type": "object", + "properties": properties, + } + + if len(required) > 0 { + schema["required"] = required + } + + return schema +} + +// Call executes the tool with the given context and request +func (t *mark3labsToolImpl) Call(ctx context.Context, request CallToolRequest) (interface{}, error) { + result, err := t.handler(ctx, request) + if err != nil { + return nil, err + } + + // If the result is an error, return it as an error + if result.IsError { + return nil, fmt.Errorf("%s", result.Text) + } + + // Try to parse the result as JSON first + var jsonResult interface{} + if err := json.Unmarshal([]byte(result.Text), &jsonResult); err == nil { + return jsonResult, nil + } + + // If JSON parsing fails, return the text as is + return result.Text, nil +} + // toMCPServerTool converts our Tool to mcp's ServerTool func (t *mark3labsToolImpl) toMCPServerTool() server.ServerTool { // Create the mcp tool with appropriate options diff --git a/pkg/razorpay/http.go b/pkg/razorpay/http.go new file mode 100644 index 0000000..6fac60c --- /dev/null +++ b/pkg/razorpay/http.go @@ -0,0 +1,305 @@ +package razorpay + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + rzpsdk "github.com/razorpay/razorpay-go/v2" + + "github.com/razorpay/razorpay-mcp-server/pkg/mcpgo" +) + +// HTTPConfig holds the configuration for the HTTP server +type HTTPConfig struct { + address string + port int +} + +// HTTPConfigOpts defines a function type for applying configuration options +type HTTPConfigOpts func(*HTTPConfig) + +// WithHTTPAddress returns an option to set the server address +func WithHTTPAddress(address string) HTTPConfigOpts { + return func(config *HTTPConfig) { + config.address = address + } +} + +// WithHTTPPort returns an option to set the server port +func WithHTTPPort(port int) HTTPConfigOpts { + return func(config *HTTPConfig) { + config.port = port + } +} + +// NewHTTPConfig creates a new HTTP server configuration with the provided options +func NewHTTPConfig(opts ...HTTPConfigOpts) *HTTPConfig { + config := &HTTPConfig{ + address: "localhost", + port: 8080, + } + + for _, opt := range opts { + opt(config) + } + + return config +} + +// HTTPServer implements a JSON-RPC HTTP server +type HTTPServer struct { + server *Server + config *HTTPConfig + httpServer *http.Server + mux *http.ServeMux +} + +// JSONRPCRequest represents a JSON-RPC 2.0 request +type JSONRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID interface{} `json:"id"` + Method string `json:"method"` + Params interface{} `json:"params,omitempty"` +} + +// JSONRPCResponse represents a JSON-RPC 2.0 response +type JSONRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID interface{} `json:"id"` + Result interface{} `json:"result,omitempty"` + Error *JSONRPCError `json:"error,omitempty"` +} + +// JSONRPCError represents a JSON-RPC 2.0 error +type JSONRPCError struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +// Tool represents a tool that can be called +type Tool struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema map[string]interface{} `json:"inputSchema"` +} + +// ToolsListResponse represents the response to tools/list +type ToolsListResponse struct { + Tools []Tool `json:"tools"` +} + +// ToolCallParams represents the parameters for tools/call +type ToolCallParams struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments"` +} + +// NewHTTPServer creates a new HTTP server instance +func NewHTTPServer(server *Server, config *HTTPConfig) (*HTTPServer, error) { + return &HTTPServer{ + server: server, + config: config, + }, nil +} + +// Start starts the HTTP server +func (h *HTTPServer) Start() error { + h.mux = http.NewServeMux() + + // Register health check endpoints + h.mux.HandleFunc("/live", h.handleLiveness) + h.mux.HandleFunc("/ready", h.handleReadiness) + + // Register JSON-RPC endpoint as default handler + h.mux.HandleFunc("/", h.handleJSONRPC) + + // Create HTTP server + h.httpServer = &http.Server{ + Addr: fmt.Sprintf(":%d", h.config.port), + Handler: h.mux, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, + } + + // Start the HTTP server + return h.httpServer.ListenAndServe() +} + +// Shutdown gracefully shuts down the HTTP server +func (h *HTTPServer) Shutdown(ctx context.Context) error { + if h.httpServer != nil { + return h.httpServer.Shutdown(ctx) + } + return nil +} + +// handleLiveness returns 200 OK for liveness probe +func (h *HTTPServer) handleLiveness(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) +} + +// handleReadiness returns 200 OK for readiness probe +func (h *HTTPServer) handleReadiness(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) +} + +// handleJSONRPC handles JSON-RPC requests +func (h *HTTPServer) handleJSONRPC(w http.ResponseWriter, r *http.Request) { + // Set response headers + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Accel-Buffering", "no") // Prevent buffering for SSE compatibility + + // Parse JSON-RPC request + var req JSONRPCRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + h.writeErrorResponse(w, nil, -32700, "Parse error", err.Error()) + return + } + + // Validate JSON-RPC version + if req.JSONRPC != "2.0" { + h.writeErrorResponse(w, req.ID, -32600, "Invalid Request", "JSON-RPC version must be 2.0") + return + } + + // Extract and validate authentication + ctx, err := h.authenticateRequest(r.Context(), r) + if err != nil { + h.writeErrorResponse(w, req.ID, -32603, "Authentication failed", err.Error()) + return + } + + // Route based on method + switch req.Method { + case "tools/list": + h.handleToolsList(ctx, w, req) + case "tools/call": + h.handleToolsCall(ctx, w, req) + default: + h.writeErrorResponse(w, req.ID, -32601, "Method not found", fmt.Sprintf("Method %s not found", req.Method)) + } +} + +// authenticateRequest extracts and validates the Bearer token +func (h *HTTPServer) authenticateRequest(ctx context.Context, r *http.Request) (context.Context, error) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + return nil, fmt.Errorf("authorization header required") + } + + parts := strings.SplitN(authHeader, " ", 2) + if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { + return nil, fmt.Errorf("invalid authorization header format") + } + + token := parts[1] + + // Decode base64 token to get key:secret + decoded, err := base64.StdEncoding.DecodeString(token) + if err != nil { + return nil, fmt.Errorf("invalid token encoding") + } + + // Split into key and secret + credentials := strings.SplitN(string(decoded), ":", 2) + if len(credentials) != 2 { + return nil, fmt.Errorf("invalid credentials format") + } + + keyID := credentials[0] + keySecret := credentials[1] + + // Create Razorpay client + client := rzpsdk.NewClient(keyID, keySecret) + + // Add client to context + return mcpgo.WithClient(ctx, client), nil +} + +// handleToolsList handles the tools/list method +func (h *HTTPServer) handleToolsList(ctx context.Context, w http.ResponseWriter, req JSONRPCRequest) { + // Get all tools from the server + tools := h.server.GetAllTools() + + // Convert to the expected format + var toolsList []Tool + for _, tool := range tools { + toolsList = append(toolsList, Tool{ + Name: tool.GetName(), + Description: tool.GetDescription(), + InputSchema: tool.GetInputSchema(), + }) + } + + response := JSONRPCResponse{ + JSONRPC: "2.0", + ID: req.ID, + Result: ToolsListResponse{Tools: toolsList}, + } + + h.writeJSONResponse(w, response) +} + +// handleToolsCall handles the tools/call method +func (h *HTTPServer) handleToolsCall(ctx context.Context, w http.ResponseWriter, req JSONRPCRequest) { + // Parse parameters + var params ToolCallParams + if req.Params != nil { + paramsBytes, err := json.Marshal(req.Params) + if err != nil { + h.writeErrorResponse(w, req.ID, -32602, "Invalid params", err.Error()) + return + } + + if err := json.Unmarshal(paramsBytes, ¶ms); err != nil { + h.writeErrorResponse(w, req.ID, -32602, "Invalid params", err.Error()) + return + } + } + + // Call the tool + result, err := h.server.CallTool(ctx, params.Name, params.Arguments) + if err != nil { + h.writeErrorResponse(w, req.ID, -32603, "Tool execution failed", err.Error()) + return + } + + response := JSONRPCResponse{ + JSONRPC: "2.0", + ID: req.ID, + Result: result, + } + + h.writeJSONResponse(w, response) +} + +// writeJSONResponse writes a JSON response +func (h *HTTPServer) writeJSONResponse(w http.ResponseWriter, response JSONRPCResponse) { + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + } +} + +// writeErrorResponse writes a JSON-RPC error response +func (h *HTTPServer) writeErrorResponse(w http.ResponseWriter, id interface{}, code int, message, data string) { + response := JSONRPCResponse{ + JSONRPC: "2.0", + ID: id, + Error: &JSONRPCError{ + Code: code, + Message: message, + Data: data, + }, + } + + h.writeJSONResponse(w, response) +} diff --git a/pkg/razorpay/http_test.go b/pkg/razorpay/http_test.go new file mode 100644 index 0000000..58f94f2 --- /dev/null +++ b/pkg/razorpay/http_test.go @@ -0,0 +1,192 @@ +package razorpay + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHTTPServer_ToolsList(t *testing.T) { + // Create a test server + server, err := NewServer(nil, nil, "1.0.0", []string{}, false) + require.NoError(t, err) + + httpServer, err := NewHTTPServer(server, NewHTTPConfig()) + require.NoError(t, err) + + // Create test request + requestBody := map[string]interface{}{ + "jsonrpc": "2.0", + "id": "test-request", + "method": "tools/list", + "params": map[string]interface{}{}, + } + + body, err := json.Marshal(requestBody) + require.NoError(t, err) + + req := httptest.NewRequest("POST", "/", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + // Create test credentials (test_key:test_secret) + credentials := base64.StdEncoding.EncodeToString([]byte("test_key:test_secret")) + req.Header.Set("Authorization", "Bearer "+credentials) + + w := httptest.NewRecorder() + + // Call the handler + httpServer.handleJSONRPC(w, req) + + // Check response + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "application/json", w.Header().Get("Content-Type")) + assert.Equal(t, "no", w.Header().Get("X-Accel-Buffering")) + + var response JSONRPCResponse + err = json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(t, err) + + assert.Equal(t, "2.0", response.JSONRPC) + assert.Equal(t, "test-request", response.ID) + assert.Nil(t, response.Error) + assert.NotNil(t, response.Result) + + // Check that we have tools in the response + result, ok := response.Result.(map[string]interface{}) + require.True(t, ok) + + tools, ok := result["tools"].([]interface{}) + require.True(t, ok) + assert.Greater(t, len(tools), 0) +} + +func TestHTTPServer_InvalidJSONRPC(t *testing.T) { + server, err := NewServer(nil, nil, "1.0.0", []string{}, false) + require.NoError(t, err) + + httpServer, err := NewHTTPServer(server, NewHTTPConfig()) + require.NoError(t, err) + + // Test invalid JSON-RPC version + requestBody := map[string]interface{}{ + "jsonrpc": "1.0", // Invalid version + "id": "test-request", + "method": "tools/list", + } + + body, err := json.Marshal(requestBody) + require.NoError(t, err) + + req := httptest.NewRequest("POST", "/", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + credentials := base64.StdEncoding.EncodeToString([]byte("test_key:test_secret")) + req.Header.Set("Authorization", "Bearer "+credentials) + + w := httptest.NewRecorder() + httpServer.handleJSONRPC(w, req) + + var response JSONRPCResponse + err = json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(t, err) + + assert.Equal(t, "2.0", response.JSONRPC) + assert.NotNil(t, response.Error) + assert.Equal(t, -32600, response.Error.Code) + assert.Equal(t, "Invalid Request", response.Error.Message) +} + +func TestHTTPServer_MissingAuth(t *testing.T) { + server, err := NewServer(nil, nil, "1.0.0", []string{}, false) + require.NoError(t, err) + + httpServer, err := NewHTTPServer(server, NewHTTPConfig()) + require.NoError(t, err) + + requestBody := map[string]interface{}{ + "jsonrpc": "2.0", + "id": "test-request", + "method": "tools/list", + } + + body, err := json.Marshal(requestBody) + require.NoError(t, err) + + req := httptest.NewRequest("POST", "/", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + // No Authorization header + + w := httptest.NewRecorder() + httpServer.handleJSONRPC(w, req) + + var response JSONRPCResponse + err = json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(t, err) + + assert.NotNil(t, response.Error) + assert.Equal(t, -32603, response.Error.Code) + assert.Equal(t, "Authentication failed", response.Error.Message) +} + +func TestHTTPServer_HealthChecks(t *testing.T) { + server, err := NewServer(nil, nil, "1.0.0", []string{}, false) + require.NoError(t, err) + + httpServer, err := NewHTTPServer(server, NewHTTPConfig()) + require.NoError(t, err) + + // Test liveness endpoint + req := httptest.NewRequest("GET", "/live", nil) + w := httptest.NewRecorder() + httpServer.handleLiveness(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "OK", w.Body.String()) + + // Test readiness endpoint + req = httptest.NewRequest("GET", "/ready", nil) + w = httptest.NewRecorder() + httpServer.handleReadiness(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "OK", w.Body.String()) +} + +func TestHTTPServer_AuthenticateRequest(t *testing.T) { + server, err := NewServer(nil, nil, "1.0.0", []string{}, false) + require.NoError(t, err) + + httpServer, err := NewHTTPServer(server, NewHTTPConfig()) + require.NoError(t, err) + + // Test valid Bearer token + req := httptest.NewRequest("POST", "/", nil) + credentials := base64.StdEncoding.EncodeToString([]byte("test_key:test_secret")) + req.Header.Set("Authorization", "Bearer "+credentials) + + ctx, err := httpServer.authenticateRequest(context.Background(), req) + assert.NoError(t, err) + assert.NotNil(t, ctx) + + // Test invalid Bearer token format + req = httptest.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer invalid_token") + + _, err = httpServer.authenticateRequest(context.Background(), req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid token encoding") + + // Test missing Authorization header + req = httptest.NewRequest("POST", "/", nil) + + _, err = httpServer.authenticateRequest(context.Background(), req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "authorization header required") +} diff --git a/pkg/razorpay/orders.go b/pkg/razorpay/orders.go index 750f74b..9e13400 100644 --- a/pkg/razorpay/orders.go +++ b/pkg/razorpay/orders.go @@ -5,7 +5,7 @@ import ( "fmt" "log/slog" - rzpsdk "github.com/razorpay/razorpay-go" + rzpsdk "github.com/razorpay/razorpay-go/v2" "github.com/razorpay/razorpay-mcp-server/pkg/mcpgo" ) @@ -59,6 +59,12 @@ func CreateOrder( ctx context.Context, r mcpgo.CallToolRequest, ) (*mcpgo.ToolResult, error) { + // Get client from context or use default + client, err := getClientFromContextOrDefault(ctx, client) + if err != nil { + return mcpgo.NewToolResultError(err.Error()), nil + } + payload := make(map[string]interface{}) validator := NewValidator(&r). @@ -112,6 +118,12 @@ func FetchOrder( ctx context.Context, r mcpgo.CallToolRequest, ) (*mcpgo.ToolResult, error) { + // Get client from context or use default + client, err := getClientFromContextOrDefault(ctx, client) + if err != nil { + return mcpgo.NewToolResultError(err.Error()), nil + } + payload := make(map[string]interface{}) validator := NewValidator(&r). @@ -193,6 +205,12 @@ func FetchAllOrders( ctx context.Context, r mcpgo.CallToolRequest, ) (*mcpgo.ToolResult, error) { + // Get client from context or use default + client, err := getClientFromContextOrDefault(ctx, client) + if err != nil { + return mcpgo.NewToolResultError(err.Error()), nil + } + queryParams := make(map[string]interface{}) validator := NewValidator(&r). diff --git a/pkg/razorpay/orders_test.go b/pkg/razorpay/orders_test.go index 978827a..c35f128 100644 --- a/pkg/razorpay/orders_test.go +++ b/pkg/razorpay/orders_test.go @@ -6,7 +6,7 @@ import ( "net/http/httptest" "testing" - "github.com/razorpay/razorpay-go/constants" + "github.com/razorpay/razorpay-go/v2/constants" "github.com/razorpay/razorpay-mcp-server/pkg/razorpay/mock" ) diff --git a/pkg/razorpay/payment_links.go b/pkg/razorpay/payment_links.go index d378c2d..5e53a56 100644 --- a/pkg/razorpay/payment_links.go +++ b/pkg/razorpay/payment_links.go @@ -5,7 +5,7 @@ import ( "fmt" "log/slog" - rzpsdk "github.com/razorpay/razorpay-go" + rzpsdk "github.com/razorpay/razorpay-go/v2" "github.com/razorpay/razorpay-mcp-server/pkg/mcpgo" ) @@ -91,6 +91,12 @@ func CreatePaymentLink( ctx context.Context, r mcpgo.CallToolRequest, ) (*mcpgo.ToolResult, error) { + // Get client from context or use default + client, err := getClientFromContextOrDefault(ctx, client) + if err != nil { + return mcpgo.NewToolResultError(err.Error()), nil + } + // Create a parameters map to collect validated parameters plCreateReq := make(map[string]interface{}) customer := make(map[string]interface{}) @@ -303,6 +309,12 @@ func FetchPaymentLink( ctx context.Context, r mcpgo.CallToolRequest, ) (*mcpgo.ToolResult, error) { + // Get client from context or use default + client, err := getClientFromContextOrDefault(ctx, client) + if err != nil { + return mcpgo.NewToolResultError(err.Error()), nil + } + fields := make(map[string]interface{}) validator := NewValidator(&r). diff --git a/pkg/razorpay/payment_links_test.go b/pkg/razorpay/payment_links_test.go index b7c1a0b..dc33f17 100644 --- a/pkg/razorpay/payment_links_test.go +++ b/pkg/razorpay/payment_links_test.go @@ -6,7 +6,7 @@ import ( "net/http/httptest" "testing" - "github.com/razorpay/razorpay-go/constants" + "github.com/razorpay/razorpay-go/v2/constants" "github.com/razorpay/razorpay-mcp-server/pkg/razorpay/mock" ) diff --git a/pkg/razorpay/payments.go b/pkg/razorpay/payments.go index 148a23d..ff248e4 100644 --- a/pkg/razorpay/payments.go +++ b/pkg/razorpay/payments.go @@ -5,7 +5,7 @@ import ( "fmt" "log/slog" - rzpsdk "github.com/razorpay/razorpay-go" + rzpsdk "github.com/razorpay/razorpay-go/v2" "github.com/razorpay/razorpay-mcp-server/pkg/mcpgo" ) @@ -28,6 +28,12 @@ func FetchPayment( ctx context.Context, r mcpgo.CallToolRequest, ) (*mcpgo.ToolResult, error) { + // Get client from context or use default + client, err := getClientFromContextOrDefault(ctx, client) + if err != nil { + return mcpgo.NewToolResultError(err.Error()), nil + } + params := make(map[string]interface{}) validator := NewValidator(&r). @@ -76,6 +82,12 @@ func FetchPaymentCardDetails( ctx context.Context, r mcpgo.CallToolRequest, ) (*mcpgo.ToolResult, error) { + // Get client from context or use default + client, err := getClientFromContextOrDefault(ctx, client) + if err != nil { + return mcpgo.NewToolResultError(err.Error()), nil + } + params := make(map[string]interface{}) validator := NewValidator(&r). @@ -131,6 +143,12 @@ func UpdatePayment( ctx context.Context, r mcpgo.CallToolRequest, ) (*mcpgo.ToolResult, error) { + // Get client from context or use default + client, err := getClientFromContextOrDefault(ctx, client) + if err != nil { + return mcpgo.NewToolResultError(err.Error()), nil + } + params := make(map[string]interface{}) paymentUpdateReq := make(map[string]interface{}) @@ -192,6 +210,12 @@ func CapturePayment( ctx context.Context, r mcpgo.CallToolRequest, ) (*mcpgo.ToolResult, error) { + // Get client from context or use default + client, err := getClientFromContextOrDefault(ctx, client) + if err != nil { + return mcpgo.NewToolResultError(err.Error()), nil + } + params := make(map[string]interface{}) paymentCaptureReq := make(map[string]interface{}) @@ -270,6 +294,12 @@ func FetchAllPayments( ctx context.Context, r mcpgo.CallToolRequest, ) (*mcpgo.ToolResult, error) { + // Get client from context or use default + client, err := getClientFromContextOrDefault(ctx, client) + if err != nil { + return mcpgo.NewToolResultError(err.Error()), nil + } + // Create query parameters map paymentListOptions := make(map[string]interface{}) diff --git a/pkg/razorpay/payments_test.go b/pkg/razorpay/payments_test.go index 6283eeb..63b4c2c 100644 --- a/pkg/razorpay/payments_test.go +++ b/pkg/razorpay/payments_test.go @@ -6,7 +6,7 @@ import ( "net/http/httptest" "testing" - "github.com/razorpay/razorpay-go/constants" + "github.com/razorpay/razorpay-go/v2/constants" "github.com/razorpay/razorpay-mcp-server/pkg/razorpay/mock" ) diff --git a/pkg/razorpay/payouts.go b/pkg/razorpay/payouts.go index b163137..7b3d191 100644 --- a/pkg/razorpay/payouts.go +++ b/pkg/razorpay/payouts.go @@ -5,7 +5,7 @@ import ( "fmt" "log/slog" - rzpsdk "github.com/razorpay/razorpay-go" + rzpsdk "github.com/razorpay/razorpay-go/v2" "github.com/razorpay/razorpay-mcp-server/pkg/mcpgo" ) diff --git a/pkg/razorpay/payouts_test.go b/pkg/razorpay/payouts_test.go index 42d1ab8..0cdafc8 100644 --- a/pkg/razorpay/payouts_test.go +++ b/pkg/razorpay/payouts_test.go @@ -6,7 +6,7 @@ import ( "net/http/httptest" "testing" - "github.com/razorpay/razorpay-go/constants" + "github.com/razorpay/razorpay-go/v2/constants" "github.com/razorpay/razorpay-mcp-server/pkg/razorpay/mock" ) diff --git a/pkg/razorpay/qr_codes.go b/pkg/razorpay/qr_codes.go index 155a53f..7c3ffd2 100644 --- a/pkg/razorpay/qr_codes.go +++ b/pkg/razorpay/qr_codes.go @@ -5,7 +5,7 @@ import ( "fmt" "log/slog" - rzpsdk "github.com/razorpay/razorpay-go" + rzpsdk "github.com/razorpay/razorpay-go/v2" "github.com/razorpay/razorpay-mcp-server/pkg/mcpgo" ) diff --git a/pkg/razorpay/qr_codes_test.go b/pkg/razorpay/qr_codes_test.go index 5b0c0a6..b0e771f 100644 --- a/pkg/razorpay/qr_codes_test.go +++ b/pkg/razorpay/qr_codes_test.go @@ -6,7 +6,7 @@ import ( "net/http/httptest" "testing" - "github.com/razorpay/razorpay-go/constants" + "github.com/razorpay/razorpay-go/v2/constants" "github.com/razorpay/razorpay-mcp-server/pkg/razorpay/mock" ) diff --git a/pkg/razorpay/refunds.go b/pkg/razorpay/refunds.go index 0ba6bd0..6581ef1 100644 --- a/pkg/razorpay/refunds.go +++ b/pkg/razorpay/refunds.go @@ -5,7 +5,7 @@ import ( "fmt" "log/slog" - rzpsdk "github.com/razorpay/razorpay-go" + rzpsdk "github.com/razorpay/razorpay-go/v2" "github.com/razorpay/razorpay-mcp-server/pkg/mcpgo" ) @@ -51,6 +51,12 @@ func CreateRefund( ctx context.Context, r mcpgo.CallToolRequest, ) (*mcpgo.ToolResult, error) { + // Get client from context or use default + client, err := getClientFromContextOrDefault(ctx, client) + if err != nil { + return mcpgo.NewToolResultError(err.Error()), nil + } + payload := make(map[string]interface{}) data := make(map[string]interface{}) @@ -105,6 +111,12 @@ func FetchRefund( ctx context.Context, r mcpgo.CallToolRequest, ) (*mcpgo.ToolResult, error) { + // Get client from context or use default + client, err := getClientFromContextOrDefault(ctx, client) + if err != nil { + return mcpgo.NewToolResultError(err.Error()), nil + } + payload := make(map[string]interface{}) validator := NewValidator(&r). @@ -156,6 +168,12 @@ func UpdateRefund( ctx context.Context, r mcpgo.CallToolRequest, ) (*mcpgo.ToolResult, error) { + // Get client from context or use default + client, err := getClientFromContextOrDefault(ctx, client) + if err != nil { + return mcpgo.NewToolResultError(err.Error()), nil + } + payload := make(map[string]interface{}) data := make(map[string]interface{}) diff --git a/pkg/razorpay/refunds_test.go b/pkg/razorpay/refunds_test.go index 3891e3b..327f4f0 100644 --- a/pkg/razorpay/refunds_test.go +++ b/pkg/razorpay/refunds_test.go @@ -6,7 +6,7 @@ import ( "net/http/httptest" "testing" - "github.com/razorpay/razorpay-go/constants" + "github.com/razorpay/razorpay-go/v2/constants" "github.com/razorpay/razorpay-mcp-server/pkg/razorpay/mock" ) diff --git a/pkg/razorpay/server.go b/pkg/razorpay/server.go index 0c4c596..073307f 100644 --- a/pkg/razorpay/server.go +++ b/pkg/razorpay/server.go @@ -1,9 +1,11 @@ package razorpay import ( + "context" + "fmt" "log/slog" - rzpsdk "github.com/razorpay/razorpay-go" + rzpsdk "github.com/razorpay/razorpay-go/v2" "github.com/razorpay/razorpay-mcp-server/pkg/mcpgo" "github.com/razorpay/razorpay-mcp-server/pkg/toolsets" @@ -30,6 +32,9 @@ func NewServer( mcpgo.WithLogging(), mcpgo.WithResourceCapabilities(true, true), mcpgo.WithToolCapabilities(true), + + // Add Tool Middlewares + mcpgo.WithAuthenticationMiddleware(client), } // Create the mcpgo server @@ -68,3 +73,75 @@ func (s *Server) RegisterTools() { func (s *Server) GetMCPServer() mcpgo.Server { return s.server } + +// GetAllTools returns all registered tools +func (s *Server) GetAllTools() []mcpgo.Tool { + var allTools []mcpgo.Tool + + // Iterate through all toolsets and collect their tools + for _, toolset := range s.toolsets.Toolsets { + if toolset.Enabled { + allTools = append(allTools, toolset.ReadTools()...) + if !s.toolsets.ReadOnly() { + allTools = append(allTools, toolset.WriteTools()...) + } + } + } + + return allTools +} + +// CallTool calls a specific tool by name with the provided arguments +func (s *Server) CallTool(ctx context.Context, name string, arguments map[string]interface{}) (interface{}, error) { + // Find the tool by name + tools := s.GetAllTools() + var targetTool mcpgo.Tool + + for _, tool := range tools { + if tool.GetName() == name { + targetTool = tool + break + } + } + + if targetTool == nil { + return nil, fmt.Errorf("tool '%s' not found", name) + } + + // Create a call tool request + request := mcpgo.CallToolRequest{ + Name: name, + Arguments: arguments, + } + + // Call the tool + result, err := targetTool.Call(ctx, request) + if err != nil { + return nil, err + } + + return result, nil +} + +// getClientFromContextOrDefault returns either the provided default +// client or gets one from context. +func getClientFromContextOrDefault( + ctx context.Context, + defaultClient *rzpsdk.Client, +) (*rzpsdk.Client, error) { + if defaultClient != nil { + return defaultClient, nil + } + + clientInterface := mcpgo.ClientFromContext(ctx) + if clientInterface == nil { + return nil, fmt.Errorf("no client found in context") + } + + client, ok := clientInterface.(*rzpsdk.Client) + if !ok { + return nil, fmt.Errorf("invalid client type in context") + } + + return client, nil +} diff --git a/pkg/razorpay/server_test.go b/pkg/razorpay/server_test.go new file mode 100644 index 0000000..59fa88b --- /dev/null +++ b/pkg/razorpay/server_test.go @@ -0,0 +1,153 @@ +package razorpay + +import ( + "context" + "encoding/base64" + "fmt" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + + rzpsdk "github.com/razorpay/razorpay-go/v2" + + "github.com/razorpay/razorpay-mcp-server/pkg/mcpgo" +) + +// TestConcurrentRequestHandling tests concurrent requests with different +// auth tokens +func TestConcurrentRequestHandling(t *testing.T) { + var capturedClients sync.Map + + testTool := mcpgo.NewTool( + "test_tool", + "Test tool for concurrent testing", + []mcpgo.ToolParameter{ + mcpgo.WithString("test_param", + mcpgo.Description("Test parameter"), mcpgo.Required()), + }, + func( + ctx context.Context, + request mcpgo.CallToolRequest, + ) (*mcpgo.ToolResult, error) { + client := mcpgo.ClientFromContext(ctx) + if client == nil { + return mcpgo.NewToolResultError("no client found in context"), nil + } + + requestID := request.Arguments["test_param"].(string) + capturedClients.Store(requestID, client) + + return mcpgo.NewToolResultText(fmt.Sprintf("processed: %s", requestID)), nil + }, + ) + + tokens := []string{ + base64.StdEncoding.EncodeToString([]byte("key1:secret1")), + base64.StdEncoding.EncodeToString([]byte("key2:secret2")), + base64.StdEncoding.EncodeToString([]byte("key3:secret3")), + } + + middleware := createAuthMiddleware(nil) + + var wg sync.WaitGroup + var errors sync.Map + numRequests := len(tokens) + + for i, token := range tokens { + wg.Add(1) + go func(requestID int, authToken string) { + defer wg.Done() + + ctx := context.Background() + ctx = mcpgo.WithAuthToken(ctx, authToken) + + request := createMCPRequest(map[string]interface{}{ + "test_param": fmt.Sprintf("request_%d", requestID), + }) + + handler := middleware(testTool.GetHandler()) + _, err := handler(ctx, request) + + if err != nil { + errors.Store(requestID, err.Error()) + } + }(i, token) + } + + wg.Wait() + + errors.Range(func(key, value interface{}) bool { + t.Errorf("Request %v failed with error: %v", key, value) + return true + }) + + clientCount := 0 + capturedClients.Range(func(key, value interface{}) bool { + clientCount++ + assert.NotNil(t, value, "Client should not be nil for request %v", key) + return true + }) + assert.Equal(t, numRequests, clientCount, + "Should have captured clients for all requests") +} + +// TestStdioModeWithClient tests that stdio mode works with provided client +func TestStdioModeWithClient(t *testing.T) { + testClient := rzpsdk.NewClient("test_key", "test_secret") + + var capturedClient *rzpsdk.Client + + testTool := mcpgo.NewTool( + "stdio_test", + "Test tool for stdio mode", + []mcpgo.ToolParameter{}, + func( + ctx context.Context, + request mcpgo.CallToolRequest, + ) (*mcpgo.ToolResult, error) { + client, err := getClientFromContextOrDefault(ctx, testClient) + if err != nil { + return mcpgo.NewToolResultError(err.Error()), nil + } + capturedClient = client + + return mcpgo.NewToolResultText("stdio test completed"), nil + }, + ) + + middleware := createAuthMiddleware(testClient) + + ctx := context.Background() + + request := createMCPRequest(map[string]interface{}{}) + + handler := middleware(testTool.GetHandler()) + result, err := handler(ctx, request) + + assert.NoError(t, err) + assert.NotNil(t, result) + assert.False(t, result.IsError) + + assert.Equal(t, testClient, capturedClient, + "Should use the provided client in stdio mode") +} + +// createAuthMiddleware creates a middleware function for testing purposes +// that uses the shared authentication logic from mcpgo.AuthenticateRequest +func createAuthMiddleware( + client *rzpsdk.Client, +) func(mcpgo.ToolHandler) mcpgo.ToolHandler { + return func(next mcpgo.ToolHandler) mcpgo.ToolHandler { + return func( + ctx context.Context, + request mcpgo.CallToolRequest, + ) (*mcpgo.ToolResult, error) { + authenticatedCtx, err := mcpgo.AuthenticateRequest(ctx, client) + if err != nil { + return nil, err + } + return next(authenticatedCtx, request) + } + } +} diff --git a/pkg/razorpay/settlements.go b/pkg/razorpay/settlements.go index 1e38029..749bf5b 100644 --- a/pkg/razorpay/settlements.go +++ b/pkg/razorpay/settlements.go @@ -5,7 +5,7 @@ import ( "fmt" "log/slog" - rzpsdk "github.com/razorpay/razorpay-go" + rzpsdk "github.com/razorpay/razorpay-go/v2" "github.com/razorpay/razorpay-mcp-server/pkg/mcpgo" ) diff --git a/pkg/razorpay/settlements_test.go b/pkg/razorpay/settlements_test.go index 92becc4..fe7afab 100644 --- a/pkg/razorpay/settlements_test.go +++ b/pkg/razorpay/settlements_test.go @@ -6,7 +6,7 @@ import ( "net/http/httptest" "testing" - "github.com/razorpay/razorpay-go/constants" + "github.com/razorpay/razorpay-go/v2/constants" "github.com/razorpay/razorpay-mcp-server/pkg/razorpay/mock" ) diff --git a/pkg/razorpay/test_helpers.go b/pkg/razorpay/test_helpers.go index 9245c59..ee22e4b 100644 --- a/pkg/razorpay/test_helpers.go +++ b/pkg/razorpay/test_helpers.go @@ -12,7 +12,7 @@ import ( "github.com/go-test/deep" "github.com/stretchr/testify/assert" - rzpsdk "github.com/razorpay/razorpay-go" + rzpsdk "github.com/razorpay/razorpay-go/v2" "github.com/razorpay/razorpay-mcp-server/pkg/mcpgo" ) diff --git a/pkg/razorpay/tools.go b/pkg/razorpay/tools.go index 99fdca4..7ffb485 100644 --- a/pkg/razorpay/tools.go +++ b/pkg/razorpay/tools.go @@ -3,7 +3,7 @@ package razorpay import ( "log/slog" - rzpsdk "github.com/razorpay/razorpay-go" + rzpsdk "github.com/razorpay/razorpay-go/v2" "github.com/razorpay/razorpay-mcp-server/pkg/toolsets" ) diff --git a/pkg/toolsets/toolsets.go b/pkg/toolsets/toolsets.go index 60c8dcf..f2dc668 100644 --- a/pkg/toolsets/toolsets.go +++ b/pkg/toolsets/toolsets.go @@ -121,3 +121,18 @@ func (tg *ToolsetGroup) RegisterTools(s mcpgo.Server) { toolset.RegisterTools(s) } } + +// ReadOnly returns whether the toolset group is in read-only mode +func (tg *ToolsetGroup) ReadOnly() bool { + return tg.readOnly +} + +// ReadTools returns the read tools in this toolset +func (t *Toolset) ReadTools() []mcpgo.Tool { + return t.readTools +} + +// WriteTools returns the write tools in this toolset +func (t *Toolset) WriteTools() []mcpgo.Tool { + return t.writeTools +}