diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 739c67559..f547197f7 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.7.0 +current_version = 0.6.0 commit = False tag = False sign-tags = True diff --git a/.dockerignore b/.dockerignore index 21031a9df..96bc08a95 100644 --- a/.dockerignore +++ b/.dockerignore @@ -241,7 +241,6 @@ log/ # Certificates and secrets certs/ -jwt/ *.pem *.key *.crt diff --git a/.env.example b/.env.example index d75fcb00e..bf2f8616a 100644 --- a/.env.example +++ b/.env.example @@ -88,6 +88,10 @@ PROTOCOL_VERSION=2025-03-26 BASIC_AUTH_USER=admin BASIC_AUTH_PASSWORD=changeme AUTH_REQUIRED=true + + +# Secret used to sign JWTs (use long random value in prod) + # Content type for outgoing requests to Forge FORGE_CONTENT_TYPE=application/json @@ -100,29 +104,18 @@ JWT_ALGORITHM=HS256 # === HMAC (Symmetric) Configuration - Default for Development === # Secret used to sign JWTs (required for HMAC algorithms: HS256, HS384, HS512) + # PRODUCTION: Use a strong, random secret (minimum 32 characters) -# Generate with: openssl rand -base64 32 JWT_SECRET_KEY=my-test-key -# === RSA/ECDSA (Asymmetric) Configuration - Recommended for Production === -# Public and private key paths (required for asymmetric algorithms: RS*, ES*) -# Generate RSA keys with: make certs-jwt -# (creates certs/jwt/private.pem and certs/jwt/public.pem with proper permissions) -# Generate ECDSA keys with: make certs-jwt-ecdsa -# (creates certs/jwt/ec_private.pem and certs/jwt/ec_public.pem with proper permissions) -# Generate both SSL and JWT keys: make certs-all -#JWT_PUBLIC_KEY_PATH=certs/jwt/public.pem -#JWT_PRIVATE_KEY_PATH=certs/jwt/private.pem - -# JWT Claims Configuration +# Algorithm used to sign JWTs (e.g., HS256) +JWT_ALGORITHM=HS256 + +# JWT Audience and Issuer claims for token validation # PRODUCTION: Set these to your service-specific values JWT_AUDIENCE=mcpgateway-api JWT_ISSUER=mcpgateway -# JWT Validation Options -# Set to false for Dynamic Client Registration (DCR) scenarios where audience varies -JWT_AUDIENCE_VERIFICATION=true - # Expiry time for generated JWT tokens (in minutes; e.g. 7 days) TOKEN_EXPIRY=10080 REQUIRE_TOKEN_EXPIRATION=false diff --git a/.flake8 b/.flake8 index 4e23db7d5..ee0e3514f 100644 --- a/.flake8 +++ b/.flake8 @@ -1,6 +1,5 @@ [flake8] max-line-length = 600 -exclude = mcp-servers extend-ignore = E203, W503 diff --git a/.github/tools/cleanup-ghcr-versions.sh b/.github/tools/cleanup-ghcr-versions.sh index 5e348b83b..e97a6169c 100755 --- a/.github/tools/cleanup-ghcr-versions.sh +++ b/.github/tools/cleanup-ghcr-versions.sh @@ -92,7 +92,7 @@ fi ############################################################################## ORG="ibm" PKG="mcp-context-forge" -KEEP_TAGS=( "0.1.0" "v0.1.0" "0.1.1" "v0.1.1" "0.2.0" "v0.2.0" "0.3.0" "v0.3.0" "0.4.0" "v0.4.0" "0.5.0" "v0.5.0" "0.6.0" "v0.6.0" "0.7.0" "v0.7.0" "latest" ) +KEEP_TAGS=( "0.1.0" "v0.1.0" "0.1.1" "v0.1.1" "0.2.0" "v0.2.0" "0.3.0" "v0.3.0" "0.4.0" "v0.4.0" "0.5.0" "v0.5.0" "0.6.0" "v0.6.0" "latest" ) PER_PAGE=100 DRY_RUN=${DRY_RUN:-true} # default safe diff --git a/.github/workflows/docker-release.yml b/.github/workflows/docker-release.yml index 12af4b0e7..42bd0f6e8 100644 --- a/.github/workflows/docker-release.yml +++ b/.github/workflows/docker-release.yml @@ -4,12 +4,12 @@ # # This workflow re-tags a Docker image (built by a previous workflow) # when a GitHub Release is published, giving it a semantic version tag -# like `v0.7.0`. It assumes the CI build has already pushed an image +# like `v0.6.0`. It assumes the CI build has already pushed an image # tagged with the commit SHA, and that all checks on that commit passed. # # ➤ Trigger: Release published (e.g. from GitHub UI or `gh release` CLI) # ➤ Assumes: Existing image tagged with the commit SHA is available -# ➤ Result: Image re-tagged as `ghcr.io/OWNER/REPO:v0.7.0` +# ➤ Result: Image re-tagged as `ghcr.io/OWNER/REPO:v0.6.0` # # ====================================================================== @@ -25,7 +25,7 @@ on: workflow_dispatch: inputs: tag: - description: 'Release tag (e.g., v0.7.0)' + description: 'Release tag (e.g., v0.6.0)' required: true type: string diff --git a/.github/workflows/release-chart.yml.inactive b/.github/workflows/release-chart.yml.inactive index 4d3db0f19..339dd2dde 100644 --- a/.github/workflows/release-chart.yml.inactive +++ b/.github/workflows/release-chart.yml.inactive @@ -2,7 +2,7 @@ name: Release Helm Chart on: release: - types: [published] # tag repo, ex: v0.7.0 to trigger + types: [published] # tag repo, ex: v0.6.0 to trigger permissions: contents: read packages: write diff --git a/.gitignore b/.gitignore index bf93c4d29..8fe29570c 100644 --- a/.gitignore +++ b/.gitignore @@ -46,7 +46,6 @@ node_modules/ .tmp* mcp.db-journal certs/ -jwt/ FIXMEs *.old logs/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9bef2d710..625ee2ce3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,7 +18,7 @@ # report issues (linters). Modified files will need to be staged again. # ----------------------------------------------------------------------------- -exclude: '(^|/)(\.pre-commit-config\.yaml|normalize_special_characters\.py|test_input_validation\.py)$|(^|/)mcp-servers/templates/|.*\.(jinja|j2)$' # ignore these files, all templates, and jinja files +exclude: '(^|/)(\.pre-commit-config\.yaml|normalize_special_characters\.py|test_input_validation\.py)$|.*\.(jinja|j2)$' # ignore these files and jinja templates repos: # ----------------------------------------------------------------------------- diff --git a/AGENTS.md b/AGENTS.md index 114d2957e..a61c5d3fa 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,7 +1,5 @@ # Repository Guidelines -For specific tasks, see also: llms/api.md llms/helm.md llms/mcpgateway.md llms/mcp-server-go.md llms/mcp-server-python.md llms/mkdocs.md llms/plugins-llms.md llms/testing.md - ## Project Structure & Module Organization - `mcpgateway/`: FastAPI gateway source (entry `main.py`, `cli.py`, services, transports, templates/static, Alembic). - Services: `mcpgateway/services/` (gateway, server, tool, resource, prompt logic). diff --git a/DEVELOPING.md b/DEVELOPING.md index 2e1ea39ed..ae508fd42 100644 --- a/DEVELOPING.md +++ b/DEVELOPING.md @@ -6,7 +6,7 @@ # Gateway & auth export MCP_GATEWAY_BASE_URL=http://localhost:4444 export MCP_SERVER_URL=http://localhost:4444/servers/UUID_OF_SERVER_1/mcp -export MCP_AUTH="Bearer " +export MCP_AUTH="" ``` | Mode | Command | Notes | diff --git a/README.md b/README.md index 57c8785b1..2266a37d2 100644 --- a/README.md +++ b/README.md @@ -411,7 +411,7 @@ npx -y @modelcontextprotocol/inspector 🖧 Using the stdio wrapper (mcpgateway-wrapper) ```bash -export MCP_AUTH="Bearer ${MCPGATEWAY_BEARER_TOKEN}" +export MCP_AUTH=$MCPGATEWAY_BEARER_TOKEN export MCP_SERVER_URL=http://localhost:4444/servers/UUID_OF_SERVER_1/mcp python3 -m mcpgateway.wrapper # Ctrl-C to exit ``` @@ -423,7 +423,7 @@ In MCP Inspector, define `MCP_AUTH` and `MCP_SERVER_URL` env variables, and sele ```bash echo $PWD/.venv/bin/python3 # Using the Python3 full path ensures you have a working venv export MCP_SERVER_URL='http://localhost:4444/servers/UUID_OF_SERVER_1/mcp' -export MCP_AUTH="Bearer ${MCPGATEWAY_BEARER_TOKEN}" +export MCP_AUTH=${MCPGATEWAY_BEARER_TOKEN} npx -y @modelcontextprotocol/inspector ``` @@ -446,7 +446,7 @@ When using a MCP Client such as Claude with stdio: "command": "python", "args": ["-m", "mcpgateway.wrapper"], "env": { - "MCP_AUTH": "Bearer your-token-here", + "MCP_AUTH": "your-token-here", "MCP_SERVER_URL": "http://localhost:4444/servers/UUID_OF_SERVER_1", "MCP_TOOL_CALL_TIMEOUT": "120" } @@ -645,13 +645,13 @@ The `mcpgateway.wrapper` lets you connect to the gateway over **stdio** while ke ```bash # Set environment variables export MCPGATEWAY_BEARER_TOKEN=$(python3 -m mcpgateway.utils.create_jwt_token --username admin@example.com --exp 10080 --secret my-test-key) -export MCP_AUTH="Bearer ${MCPGATEWAY_BEARER_TOKEN}" +export MCP_AUTH=${MCPGATEWAY_BEARER_TOKEN} export MCP_SERVER_URL='http://localhost:4444/servers/UUID_OF_SERVER_1/mcp' export MCP_TOOL_CALL_TIMEOUT=120 export MCP_WRAPPER_LOG_LEVEL=DEBUG # or OFF to disable logging docker run --rm -i \ - -e MCP_AUTH=$MCP_AUTH \ + -e MCP_AUTH=$MCPGATEWAY_BEARER_TOKEN \ -e MCP_SERVER_URL=http://host.docker.internal:4444/servers/UUID_OF_SERVER_1/mcp \ -e MCP_TOOL_CALL_TIMEOUT=120 \ -e MCP_WRAPPER_LOG_LEVEL=DEBUG \ @@ -669,7 +669,7 @@ Because the wrapper speaks JSON-RPC over stdin/stdout, you can interact with it ```bash # Start the MCP Gateway Wrapper -export MCP_AUTH="Bearer ${MCPGATEWAY_BEARER_TOKEN}" +export MCP_AUTH=${MCPGATEWAY_BEARER_TOKEN} export MCP_SERVER_URL=http://localhost:4444/servers/YOUR_SERVER_UUID python3 -m mcpgateway.wrapper ``` @@ -730,12 +730,10 @@ The `mcpgateway.wrapper` exposes everything your Gateway knows about over **stdi 🐳 Docker / Podman ```bash -export MCP_AUTH="Bearer $MCPGATEWAY_BEARER_TOKEN" - docker run -i --rm \ --network=host \ -e MCP_SERVER_URL=http://localhost:4444/servers/UUID_OF_SERVER_1/mcp \ - -e MCP_AUTH=${MCP_AUTH} \ + -e MCP_AUTH=${MCPGATEWAY_BEARER_TOKEN} \ -e MCP_TOOL_CALL_TIMEOUT=120 \ ghcr.io/ibm/mcp-context-forge:0.7.0 \ python3 -m mcpgateway.wrapper @@ -753,7 +751,7 @@ docker run -i --rm \ pipx install --include-deps mcp-contextforge-gateway # Run the stdio wrapper -MCP_AUTH="Bearer ${MCPGATEWAY_BEARER_TOKEN}" \ +MCP_AUTH=${MCPGATEWAY_BEARER_TOKEN} \ MCP_SERVER_URL=http://localhost:4444/servers/UUID_OF_SERVER_1/mcp \ python3 -m mcpgateway.wrapper # Alternatively with uv @@ -769,7 +767,7 @@ uv run --directory . -m mcpgateway.wrapper "command": "python3", "args": ["-m", "mcpgateway.wrapper"], "env": { - "MCP_AUTH": "Bearer ", + "MCP_AUTH": "", "MCP_SERVER_URL": "http://localhost:4444/servers/UUID_OF_SERVER_1/mcp", "MCP_TOOL_CALL_TIMEOUT": "120" } @@ -806,7 +804,7 @@ source ~/.venv/mcpgateway/bin/activate uv pip install mcp-contextforge-gateway # Launch wrapper -MCP_AUTH="Bearer ${MCPGATEWAY_BEARER_TOKEN}" \ +MCP_AUTH=${MCPGATEWAY_BEARER_TOKEN} \ MCP_SERVER_URL=http://localhost:4444/servers/UUID_OF_SERVER_1/mcp \ uv run --directory . -m mcpgateway.wrapper # Use this just for testing, as the Client will run the uv command ``` @@ -826,7 +824,7 @@ uv run --directory . -m mcpgateway.wrapper # Use this just for testing, as the C "mcpgateway.wrapper" ], "env": { - "MCP_AUTH": "Bearer ", + "MCP_AUTH": "", "MCP_SERVER_URL": "http://localhost:4444/servers/UUID_OF_SERVER_1/mcp" } } diff --git a/check_todays_metrics.py b/check_todays_metrics.py new file mode 100755 index 000000000..8e6897309 --- /dev/null +++ b/check_todays_metrics.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +import sqlite3 +from datetime import datetime + +def check_todays_metrics(): + conn = sqlite3.connect('mcp.db') + cursor = conn.cursor() + + print("=== TOOL METRICS FROM TODAY (2025-08-27) ===") + cursor.execute(""" + SELECT t.name, tm.timestamp, tm.is_success, tm.response_time + FROM tool_metrics tm + JOIN tools t ON t.id = tm.tool_id + WHERE tm.timestamp LIKE '2025-08-27%' + ORDER BY tm.timestamp DESC + """) + results = cursor.fetchall() + if results: + for name, timestamp, success, response_time in results: + print(f" {name}: {timestamp} (success: {success}, {response_time:.3f}s)") + else: + print(" No tool metrics recorded today!") + + print(f"\nTotal tool metrics today: {len(results)}") + + print("\n=== CHECKING TOOLS THAT APPEAR IN UI ===") + cursor.execute("SELECT id, name, enabled FROM tools WHERE name LIKE 'json%' OR name LIKE 'test%' OR name LIKE 'book%' ORDER BY name") + tools = cursor.fetchall() + for tool_id, name, enabled in tools: + cursor.execute("SELECT COUNT(*) FROM tool_metrics WHERE tool_id = ?", (tool_id,)) + metric_count = cursor.fetchone()[0] + cursor.execute("SELECT MAX(timestamp) FROM tool_metrics WHERE tool_id = ?", (tool_id,)) + last_exec = cursor.fetchone()[0] + print(f" {name}:") + print(f" ID: {tool_id}") + print(f" Enabled: {enabled}") + print(f" Total metrics: {metric_count}") + print(f" Last execution: {last_exec}") + print() + + conn.close() + +if __name__ == "__main__": + check_todays_metrics() diff --git a/docs/docs/development/mcp-developer-guide-json-rpc.md b/docs/docs/development/mcp-developer-guide-json-rpc.md index 55f599b7f..d1174a02e 100644 --- a/docs/docs/development/mcp-developer-guide-json-rpc.md +++ b/docs/docs/development/mcp-developer-guide-json-rpc.md @@ -474,8 +474,8 @@ For command-line integration and desktop client compatibility, use the STDIO wra ```bash # Configure environment variables -export MCP_AUTH="Bearer ${MCPGATEWAY_BEARER_TOKEN}" -export MCP_SERVER_URL="http://localhost:4444/servers/your-server-id" +export MCP_AUTH_TOKEN=${MCPGATEWAY_BEARER_TOKEN} +export MCP_SERVER_CATALOG_URLS="http://localhost:4444/servers/your-server-id" export MCP_TOOL_CALL_TIMEOUT=120 export MCP_WRAPPER_LOG_LEVEL=INFO diff --git a/docs/docs/faq/index.md b/docs/docs/faq/index.md index 167e47d7d..056effc91 100644 --- a/docs/docs/faq/index.md +++ b/docs/docs/faq/index.md @@ -293,7 +293,7 @@ "command": "python3", "args": ["-m", "mcpgateway.wrapper"], "env": { - "MCP_AUTH": "Bearer ", + "MCP_AUTH": "", "MCP_SERVER_URL": "http://localhost:4444/servers/UUID_OF_SERVER_1/mcp", "MCP_TOOL_CALL_TIMEOUT": "120" } diff --git a/docs/docs/index.md b/docs/docs/index.md index b42d2ace6..ddfb01ad3 100644 --- a/docs/docs/index.md +++ b/docs/docs/index.md @@ -298,7 +298,7 @@ npx -y @modelcontextprotocol/inspector 🖧 Using the stdio wrapper (mcpgateway-wrapper) ```bash -export MCP_AUTH="Bearer ${MCPGATEWAY_BEARER_TOKEN}" +export MCP_AUTH=$MCPGATEWAY_BEARER_TOKEN export MCP_SERVER_URL=http://localhost:4444/servers/UUID_OF_SERVER_1/mcp python3 -m mcpgateway.wrapper # Ctrl-C to exit ``` @@ -310,7 +310,7 @@ In MCP Inspector, define `MCP_AUTH` and `MCP_SERVER_URL` env variables, and sele ```bash echo $PWD/.venv/bin/python3 # Using the Python3 full path ensures you have a working venv export MCP_SERVER_URL='http://localhost:4444/servers/UUID_OF_SERVER_1/mcp' -export MCP_AUTH="Bearer ${MCPGATEWAY_BEARER_TOKEN}" +export MCP_AUTH=${MCPGATEWAY_BEARER_TOKEN} npx -y @modelcontextprotocol/inspector ``` @@ -333,7 +333,7 @@ When using a MCP Client such as Claude with stdio: "command": "python", "args": ["-m", "mcpgateway.wrapper"], "env": { - "MCP_AUTH": "Beare ", + "MCP_AUTH": "your-token-here", "MCP_SERVER_URL": "http://localhost:4444/servers/UUID_OF_SERVER_1", "MCP_TOOL_CALL_TIMEOUT": "120" } @@ -532,13 +532,13 @@ The `mcpgateway.wrapper` lets you connect to the gateway over **stdio** while ke ```bash # Set environment variables export MCPGATEWAY_BEARER_TOKEN=$(python3 -m mcpgateway.utils.create_jwt_token --username admin@example.com --exp 10080 --secret my-test-key) -export MCP_AUTH="Bearer ${MCPGATEWAY_BEARER_TOKEN}" +export MCP_AUTH=${MCPGATEWAY_BEARER_TOKEN} export MCP_SERVER_URL='http://localhost:4444/servers/UUID_OF_SERVER_1/mcp' export MCP_TOOL_CALL_TIMEOUT=120 export MCP_WRAPPER_LOG_LEVEL=DEBUG # or OFF to disable logging docker run --rm -i \ - -e MCP_AUTH=$MCP_AUTH \ + -e MCP_AUTH=$MCPGATEWAY_BEARER_TOKEN \ -e MCP_SERVER_URL=http://host.docker.internal:4444/servers/UUID_OF_SERVER_1/mcp \ -e MCP_TOOL_CALL_TIMEOUT=120 \ -e MCP_WRAPPER_LOG_LEVEL=DEBUG \ @@ -556,7 +556,7 @@ Because the wrapper speaks JSON-RPC over stdin/stdout, you can interact with it ```bash # Start the MCP Gateway Wrapper -export MCP_AUTH="Bearer ${MCPGATEWAY_BEARER_TOKEN}" +export MCP_AUTH=${MCPGATEWAY_BEARER_TOKEN} export MCP_SERVER_URL=http://localhost:4444/servers/YOUR_SERVER_UUID python3 -m mcpgateway.wrapper ``` @@ -617,12 +617,10 @@ The `mcpgateway.wrapper` exposes everything your Gateway knows about over **stdi 🐳 Docker / Podman ```bash -export MCP_AUTH="Bearer ${MCPGATEWAY_BEARER_TOKEN}" - docker run -i --rm \ --network=host \ -e MCP_SERVER_URL=http://localhost:4444/servers/UUID_OF_SERVER_1/mcp \ - -e MCP_AUTH=${MCP_AUTH} \ + -e MCP_AUTH=${MCPGATEWAY_BEARER_TOKEN} \ -e MCP_TOOL_CALL_TIMEOUT=120 \ ghcr.io/ibm/mcp-context-forge:0.7.0 \ python3 -m mcpgateway.wrapper @@ -640,7 +638,7 @@ docker run -i --rm \ pipx install --include-deps mcp-contextforge-gateway # Run the stdio wrapper -MCP_AUTH="Bearer ${MCPGATEWAY_BEARER_TOKEN}" \ +MCP_AUTH=${MCPGATEWAY_BEARER_TOKEN} \ MCP_SERVER_URL=http://localhost:4444/servers/UUID_OF_SERVER_1/mcp \ python3 -m mcpgateway.wrapper # Alternatively with uv @@ -656,7 +654,7 @@ uv run --directory . -m mcpgateway.wrapper "command": "python3", "args": ["-m", "mcpgateway.wrapper"], "env": { - "MCP_AUTH": "Bearer ", + "MCP_AUTH": "", "MCP_SERVER_URL": "http://localhost:4444/servers/UUID_OF_SERVER_1/mcp", "MCP_TOOL_CALL_TIMEOUT": "120" } @@ -693,7 +691,7 @@ source ~/.venv/mcpgateway/bin/activate uv pip install mcp-contextforge-gateway # Launch wrapper -MCP_AUTH="Bearer ${MCPGATEWAY_BEARER_TOKEN}" \ +MCP_AUTH=${MCPGATEWAY_BEARER_TOKEN} \ MCP_SERVER_URL=http://localhost:4444/servers/UUID_OF_SERVER_1/mcp \ uv run --directory . -m mcpgateway.wrapper # Use this just for testing, as the Client will run the uv command ``` @@ -713,7 +711,7 @@ uv run --directory . -m mcpgateway.wrapper # Use this just for testing, as the C "mcpgateway.wrapper" ], "env": { - "MCP_AUTH": "Bearer ", + "MCP_AUTH": "", "MCP_SERVER_URL": "http://localhost:4444/servers/UUID_OF_SERVER_1/mcp" } } diff --git a/docs/docs/overview/quick_start.md b/docs/docs/overview/quick_start.md index 6ee37dd7f..f6b766f18 100644 --- a/docs/docs/overview/quick_start.md +++ b/docs/docs/overview/quick_start.md @@ -291,7 +291,7 @@ npx -y @modelcontextprotocol/inspector ## Connect via `mcpgateway-wrapper` (stdio) ```bash -export MCP_AUTH="Bearer ${MCPGATEWAY_BEARER_TOKEN}" +export MCP_AUTH=$MCPGATEWAY_BEARER_TOKEN export MCP_SERVER_URL=http://localhost:4444/servers/UUID_OF_SERVER_1/mcp python3 -m mcpgateway.wrapper # behaves as a local MCP stdio server - run from MCP client ``` @@ -306,7 +306,7 @@ Use this in GUI clients (Claude Desktop, Continue, etc.) that prefer stdio. Exam "args": ["-m", "mcpgateway.wrapper"], "env": { "MCP_SERVER_URL": "http://localhost:4444/servers/UUID_OF_SERVER_1/mcp", - "MCP_AUTH": "Bearer ", + "MCP_AUTH": "", "MCP_TOOL_CALL_TIMEOUT": "120" } } diff --git a/docs/docs/using/agents/bee.md b/docs/docs/using/agents/bee.md index 907162c43..7f5a6d5be 100644 --- a/docs/docs/using/agents/bee.md +++ b/docs/docs/using/agents/bee.md @@ -36,7 +36,7 @@ To use MCP tools in the Bee Agent Framework, follow these steps: ```bash export MCP_GATEWAY_BASE_URL=http://localhost:4444 - export MCP_AUTH="Bearer " + export MCP_AUTH="your_bearer_token" ``` --- diff --git a/docs/docs/using/clients/claude-desktop.md b/docs/docs/using/clients/claude-desktop.md index 147301da7..c85804b25 100644 --- a/docs/docs/using/clients/claude-desktop.md +++ b/docs/docs/using/clients/claude-desktop.md @@ -27,7 +27,7 @@ prompt and resource registered in your Gateway. "args": ["-m", "mcpgateway.wrapper"], "env": { "MCP_SERVER_URL": "http://localhost:4444/servers/UUID_OF_SERVER_1", - "MCP_AUTH": "Bearer ", + "MCP_AUTH": "", "MCP_TOOL_CALL_TIMEOUT": "120" } } @@ -47,7 +47,7 @@ prompt and resource registered in your Gateway. "args": [ "run", "--rm", "--network=host", "-i", "-e", "MCP_SERVER_URL=http://localhost:4444/servers/UUID_OF_SERVER_1", - "-e", "MCP_AUTH=", + "-e", "MCP_AUTH=", "ghcr.io/ibm/mcp-context-forge:0.7.0", "python3", "-m", "mcpgateway.wrapper" ] @@ -68,7 +68,7 @@ If you installed the package globally: "args": ["run", "python3", "-m", "mcpgateway.wrapper"], "env": { "MCP_SERVER_URL": "http://localhost:4444/servers/UUID_OF_SERVER_1", - "MCP_AUTH": "Bearer " + "MCP_AUTH": "" } } ``` diff --git a/docs/docs/using/clients/continue.md b/docs/docs/using/clients/continue.md index 2741b1dd5..e6f313e31 100644 --- a/docs/docs/using/clients/continue.md +++ b/docs/docs/using/clients/continue.md @@ -80,7 +80,7 @@ pipx install --include-deps mcp-contextforge-gateway "args": ["-m", "mcpgateway.wrapper"], "env": { "MCP_SERVER_URL": "http://localhost:4444/servers/UUID_OF_SERVER_1", - "MCP_AUTH": "Bearer ${env:MCP_AUTH}", + "MCP_AUTH": "${env:MCP_AUTH}", "MCP_TOOL_CALL_TIMEOUT": "120" } } diff --git a/docs/docs/using/clients/copilot.md b/docs/docs/using/clients/copilot.md index 7894ee1ae..b1f0c601e 100644 --- a/docs/docs/using/clients/copilot.md +++ b/docs/docs/using/clients/copilot.md @@ -112,7 +112,7 @@ That's it - VS Code spawns the stdio process, pipes JSON-RPC, and you're ready t "args": [ "run", "--rm", "--network=host", "-i", "-e", "MCP_SERVER_URL=http://localhost:4444/servers/UUID_OF_SERVER_1", - "-e", "MCP_AUTH=", + "-e", "MCP_AUTH=", "ghcr.io/ibm/mcp-context-forge:0.7.0", "python3", "-m", "mcpgateway.wrapper" ] diff --git a/docs/docs/using/clients/mcp-cli.md b/docs/docs/using/clients/mcp-cli.md index a27561c02..13d19fa97 100644 --- a/docs/docs/using/clients/mcp-cli.md +++ b/docs/docs/using/clients/mcp-cli.md @@ -80,7 +80,7 @@ Create a `server_config.json` file to define your MCP Context Forge Gateway conn "command": "/path/to/mcp-context-forge/.venv/bin/python", "args": ["-m", "mcpgateway.wrapper"], "env": { - "MCP_AUTH": "Bearer ", + "MCP_AUTH": "", "MCP_SERVER_URL": "http://localhost:4444", "MCP_TOOL_CALL_TIMEOUT": "120" } @@ -320,7 +320,7 @@ In interactive mode, use these commands: ```bash # MCP Context Forge Gateway connection -export MCP_AUTH="Bearer your-jwt-token" +export MCP_AUTH="your-jwt-token" export MCP_SERVER_URL="http://localhost:4444" # LLM Provider API keys @@ -461,7 +461,7 @@ The mcp-cli integrates with MCP Context Forge Gateway through multiple connectio "command": "/path/to/mcp-context-forge/.venv/bin/python", "args": ["-m", "mcpgateway.wrapper"], "env": { - "MCP_AUTH": "Bearer ", + "MCP_AUTH": "your-jwt-token", "MCP_SERVER_URL": "http://localhost:4444" } } diff --git a/docs/docs/using/mcpgateway-wrapper.md b/docs/docs/using/mcpgateway-wrapper.md index 29647f1c0..89a52cdc6 100644 --- a/docs/docs/using/mcpgateway-wrapper.md +++ b/docs/docs/using/mcpgateway-wrapper.md @@ -30,7 +30,7 @@ export MCPGATEWAY_BEARER_TOKEN=$(python3 -m mcpgateway.utils.create_jwt_token \ Configure the wrapper via ENV variables: ```bash -export MCP_AUTH="Bearer ${MCPGATEWAY_BEARER_TOKEN}" +export MCP_AUTH=${MCPGATEWAY_BEARER_TOKEN} export MCP_SERVER_URL='http://localhost:4444/servers/UUID_OF_SERVER_1/mcp' # select a virtual server export MCP_TOOL_CALL_TIMEOUT=120 # tool call timeout in seconds (optional - default 90) export MCP_WRAPPER_LOG_LEVEL=INFO # DEBUG | INFO | OFF @@ -105,7 +105,7 @@ The MCP Client calls the entrypoint, which needs to have the `mcp-contextforge-g "command": "python3", "args": ["-m", "mcpgateway.wrapper"], "env": { - "MCP_AUTH": "Bearer ", + "MCP_AUTH": "", "MCP_SERVER_URL": "http://localhost:4444/servers/UUID_OF_SERVER_1" } } @@ -132,7 +132,7 @@ The MCP Client calls the entrypoint, which needs to have the `mcp-contextforge-g "mcpgateway.wrapper" ], "env": { - "MCP_AUTH": "Bearer ", + "MCP_AUTH": "", "MCP_SERVER_URL": "http://localhost:4444/servers/UUID_OF_SERVER_1" } } @@ -150,7 +150,7 @@ The MCP Client calls the entrypoint, which needs to have the `mcp-contextforge-g "command": "/path/to/python", "args": ["-m", "mcpgateway.wrapper"], "env": { - "MCP_AUTH": "Bearer ", + "MCP_AUTH": "", "MCP_SERVER_URL": "http://localhost:4444/servers/UUID_OF_SERVER_1" } } @@ -178,7 +178,7 @@ The MCP Client calls the entrypoint, which needs to have the `mcp-contextforge-g ], "env": { "MCP_SERVER_URL": "http://localhost:4444/servers/UUID_OF_SERVER_1", - "MCP_AUTH": "Bearer REPLACE_WITH_MCPGATEWAY_BEARER_TOKEN", + "MCP_AUTH": "REPLACE_WITH_MCPGATEWAY_BEARER_TOKEN", "MCP_WRAPPER_LOG_LEVEL": "OFF" } } diff --git a/docs/docs/using/plugins/index.md b/docs/docs/using/plugins/index.md index b936e77af..b43eafed7 100644 --- a/docs/docs/using/plugins/index.md +++ b/docs/docs/using/plugins/index.md @@ -243,7 +243,7 @@ Each plugin can operate in one of four modes: | Mode | Description | Use Case | |------|-------------|----------| | **enforce** | Blocks requests on policy violations and plugin errors | Production guardrails | -| **enforce_ignore_errors** | Blocks requests on policy violations but only logs errors | Production guardrails | +| **enforce_ignore_errors** | Blocks requests on policy violations; logs errors and continues | Guardrails with fault tolerance | | **permissive** | Logs violations but allows requests | Testing and monitoring | | **disabled** | Plugin loaded but not executed | Temporary deactivation | @@ -308,44 +308,6 @@ The plugin framework provides comprehensive hook coverage across the entire MCP | `federation_pre_sync` | Gateway federation validation and filtering | v0.8.0 | | `federation_post_sync` | Post-federation data processing and reconciliation | v0.8.0 | -### Prompt Hooks Details - -The prompt hooks allow plugins to intercept and modify prompt retrieval and rendering: - -- **`prompt_pre_fetch`**: Receives the prompt name and arguments before prompt template retrieval. Can modify the arguments. -- **`prompt_post_fetch`**: Receives the completed prompt after rendering. Can modify the prompt text or block it from being returned. - -Example Use Cases: -- Detect prompt injection attacks -- Sanitize or anonymize prompts -- Search and replace - -#### Prompt Hook Payloads - -**PromptPrehookPayload**: Payload for prompt pre-fetch hooks. - -```python -class PromptPrehookPayload(BaseModel): - name: str # Prompt template name - args: Optional[dict[str, str]] = Field(default_factory=dict) # Template arguments -``` - -**Example**: -```python -payload = PromptPrehookPayload( - name="user_greeting", - args={"user_name": "Alice", "time_of_day": "morning"} -) -``` - -**PromptPosthookPayload**: Payload for prompt post-fetch hooks. - -```python -class PromptPosthookPayload(BaseModel): - name: str # Prompt name - result: PromptResult # Rendered prompt result -``` - ### Tool Hooks Details The tool hooks enable plugins to intercept and modify tool invocations: @@ -360,42 +322,6 @@ Example use cases: - Input validation and sanitization - Output filtering and transformation -#### Tool Hook Payloads - -**ToolPreInvokePayload**: Payload for tool pre-invoke hooks. - -```python -class ToolPreInvokePayload(BaseModel): - name: str # Tool name - args: Optional[dict[str, Any]] = Field(default_factory=dict) # Tool arguments - headers: Optional[HttpHeaderPayload] = None # HTTP pass-through headers -``` - -**ToolPostInvokePayload**: Payload for tool post-invoke hooks. - -```python -class ToolPostInvokePayload(BaseModel): - name: str # Tool name - result: Any # Tool execution result -``` - -The associated `HttpHeaderPayload` object for the `ToolPreInvokePayload` is as follows: - -Special payload for HTTP header manipulation. - -```python -class HttpHeaderPayload(RootModel[dict[str, str]]): - # Provides dictionary-like access to HTTP headers - # Supports: __iter__, __getitem__, __setitem__, __len__ -``` - -**Usage**: -```python -headers = HttpHeaderPayload({"Authorization": "Bearer token", "Content-Type": "application/json"}) -headers["X-Custom-Header"] = "custom_value" -auth_header = headers["Authorization"] -``` - ### Resource Hooks Details The resource hooks enable plugins to intercept and modify resource fetching: @@ -411,24 +337,6 @@ Example use cases: - Content transformation and filtering - Resource caching metadata -#### Resource Hook Payloads - -**ResourcePreFetchPayload**: Payload for resource pre-fetch hooks. - -```python -class ResourcePreFetchPayload(BaseModel): - uri: str # Resource URI - metadata: Optional[dict[str, Any]] = Field(default_factory=dict) # Request metadata -``` - -**ResourcePostFetchPayload**: Payload for resource post-fetch hooks. - -```python -class ResourcePostFetchPayload(BaseModel): - uri: str # Resource URI - content: Any # Fetched resource content -``` - Planned hooks (not yet implemented): - `server_pre_register` / `server_post_register` - Server validation @@ -703,32 +611,6 @@ async def prompt_post_fetch(self, payload, context): return PromptPosthookResult() ``` -#### Tool and Gateway Metadata - -Currently, the tool pre/post hooks have access to tool and gateway metadata through the global context metadata dictionary. They are accessible as follows: - -It can be accessed inside of the tool hooks through: - -```python -from mcpgateway.plugins.framework.constants import GATEWAY_METADATA, TOOL_METADATA - -tool_meta = context.global_context.metadata[TOOL_METADATA] -assert tool_meta.original_name == "test_tool" -assert tool_meta.url.host == "example.com" -assert tool_meta.integration_type == "REST" or tool_meta.integration_type == "MCP" -``` - -Note, if the integration type is `MCP` the gateway information may also be available as follows. - -```python -gateway_meta = context.global_context.metadata[GATEWAY_METADATA] -assert gateway_meta.name == "test_gateway" -assert gateway_meta.transport == "sse" -assert gateway_meta.url.host == "example.com" -``` - -Metadata for other entities such as prompts and resources will be added in future versions of the gateway. - ### External Service Plugin Example ```python diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index 749f84cd3..6865b5f39 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -20,34 +20,29 @@ # Standard from collections import defaultdict import csv -from datetime import datetime, timedelta, timezone +from datetime import datetime from functools import wraps -import html import io +from io import StringIO import json -import logging from pathlib import Path import time -from typing import Any, cast, Dict, List, Optional, Union -import urllib.parse +from typing import Any, Callable, cast, Dict, List, Optional, Union import uuid # Third-Party from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response -from fastapi.encoders import jsonable_encoder -from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse, StreamingResponse +from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse, StreamingResponse import httpx from pydantic import ValidationError from pydantic_core import ValidationError as CoreValidationError from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session -from starlette.datastructures import UploadFile as StarletteUploadFile # First-Party from mcpgateway.config import settings from mcpgateway.db import get_db, GlobalConfig from mcpgateway.db import Tool as DbTool -from mcpgateway.middleware.rbac import get_current_user_with_permissions, require_permission from mcpgateway.models import LogLevel from mcpgateway.schemas import ( A2AAgentCreate, @@ -77,29 +72,30 @@ ) from mcpgateway.services.a2a_service import A2AAgentError, A2AAgentNameConflictError, A2AAgentNotFoundError, A2AAgentService from mcpgateway.services.export_service import ExportError, ExportService -from mcpgateway.services.gateway_service import GatewayConnectionError, GatewayNameConflictError, GatewayNotFoundError, GatewayService, GatewayUrlConflictError +from mcpgateway.services.gateway_service import GatewayConnectionError, GatewayNotFoundError, GatewayService from mcpgateway.services.import_service import ConflictStrategy from mcpgateway.services.import_service import ImportError as ImportServiceError -from mcpgateway.services.import_service import ImportService, ImportValidationError +from mcpgateway.services.import_service import ImportService from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.prompt_service import PromptNotFoundError, PromptService from mcpgateway.services.resource_service import ResourceNotFoundError, ResourceService from mcpgateway.services.root_service import RootService from mcpgateway.services.server_service import ServerError, ServerNameConflictError, ServerNotFoundError, ServerService from mcpgateway.services.tag_service import TagService -from mcpgateway.services.team_management_service import TeamManagementService -from mcpgateway.services.tool_service import ToolError, ToolNameConflictError, ToolNotFoundError, ToolService -from mcpgateway.utils.create_jwt_token import create_jwt_token, get_jwt_token +from mcpgateway.services.tool_service import ToolError, ToolNotFoundError, ToolService +from mcpgateway.utils.create_jwt_token import get_jwt_token from mcpgateway.utils.error_formatter import ErrorFormatter from mcpgateway.utils.metadata_capture import MetadataCapture from mcpgateway.utils.oauth_encryption import get_oauth_encryption from mcpgateway.utils.passthrough_headers import PassthroughHeadersError from mcpgateway.utils.retry_manager import ResilientHttpClient +from mcpgateway.utils.security_cookies import set_auth_cookie +from mcpgateway.utils.verify_credentials import require_auth, require_basic_auth # Import the shared logging service from main # This will be set by main.py when it imports admin_router logging_service: Optional[LoggingService] = None -LOGGER: logging.Logger = logging.getLogger("mcpgateway.admin") +logger = None def set_logging_service(service: LoggingService): @@ -109,44 +105,32 @@ def set_logging_service(service: LoggingService): Args: service: The LoggingService instance to use - - Examples: - >>> from mcpgateway.services.logging_service import LoggingService - >>> from mcpgateway import admin - >>> logging_svc = LoggingService() - >>> admin.set_logging_service(logging_svc) - >>> admin.logging_service is not None - True - >>> admin.LOGGER is not None - True - - Test with different service instance: - >>> new_svc = LoggingService() - >>> admin.set_logging_service(new_svc) - >>> admin.logging_service == new_svc - True - >>> admin.LOGGER.name - 'mcpgateway.admin' - - Test that global variables are properly set: - >>> admin.set_logging_service(logging_svc) - >>> hasattr(admin, 'logging_service') - True - >>> hasattr(admin, 'LOGGER') - True """ - global logging_service, LOGGER # pylint: disable=global-statement + global logging_service, logger # pylint: disable=global-statement logging_service = service - LOGGER = logging_service.get_logger("mcpgateway.admin") + logger = logging_service.get_logger("mcpgateway.admin") # Fallback for testing - create a temporary instance if not set if logging_service is None: logging_service = LoggingService() - LOGGER = logging_service.get_logger("mcpgateway.admin") + logger = logging_service.get_logger("mcpgateway.admin") -# Removed duplicate function definition - using the more comprehensive version below +def extract_user_email(user: Union[str, Dict[str, Any]]) -> str: + """Extract user email from authentication result. + + Args: + user: Result from require_auth, either email string or user dict + + Returns: + str: User email address or 'anonymous' for unauthenticated users + """ + if isinstance(user, str): + return user # Already an email or 'anonymous' + else: + # JWT payload typically has 'sub' (subject) or 'email' field + return user.get('email') or user.get('sub') or 'anonymous' # Initialize services @@ -164,61 +148,20 @@ def set_logging_service(service: LoggingService): # Set up basic authentication # Rate limiting storage -rate_limit_storage = defaultdict(list) +rate_limit_storage: Dict[str, List[float]] = defaultdict(list) -def rate_limit(requests_per_minute: Optional[int] = None): +def rate_limit(requests_per_minute: Optional[int] = None) -> Callable[..., Any]: """Apply rate limiting to admin endpoints. Args: requests_per_minute: Maximum requests per minute (uses config default if None) Returns: - Decorator function that enforces rate limiting - - Examples: - Test basic decorator creation: - >>> from mcpgateway import admin - >>> decorator = admin.rate_limit(10) - >>> callable(decorator) - True - - Test with None parameter (uses default): - >>> default_decorator = admin.rate_limit(None) - >>> callable(default_decorator) - True - - Test with specific limit: - >>> limited_decorator = admin.rate_limit(5) - >>> callable(limited_decorator) - True - - Test decorator returns wrapper: - >>> async def dummy_func(): - ... return "success" - >>> decorated_func = decorator(dummy_func) - >>> callable(decorated_func) - True - - Test rate limit storage structure: - >>> isinstance(admin.rate_limit_storage, dict) - True - >>> from collections import defaultdict - >>> isinstance(admin.rate_limit_storage, defaultdict) - True - - Test decorator with zero limit: - >>> zero_limit_decorator = admin.rate_limit(0) - >>> callable(zero_limit_decorator) - True - - Test decorator with high limit: - >>> high_limit_decorator = admin.rate_limit(1000) - >>> callable(high_limit_decorator) - True + Decorator function that enforces per-IP request limits """ - def decorator(func): + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: """Decorator that wraps the function with rate limiting logic. Args: @@ -229,7 +172,7 @@ def decorator(func): """ @wraps(func) - async def wrapper(*args, request: Optional[Request] = None, **kwargs): + async def wrapper(*args: Any, request: Optional[Request] = None, **kwargs: Any) -> Any: """Execute the wrapped function with rate limiting enforcement. Args: @@ -256,7 +199,8 @@ async def wrapper(*args, request: Optional[Request] = None, **kwargs): # enforce if len(rate_limit_storage[client_ip]) >= limit: - LOGGER.warning(f"Rate limit exceeded for IP {client_ip} on endpoint {func.__name__}") + if logger: + logger.warning(f"Rate limit exceeded for IP {client_ip} on endpoint {func.__name__}") raise HTTPException( status_code=429, detail=f"Rate limit exceeded. Maximum {limit} requests per minute.", @@ -272,143 +216,6 @@ async def wrapper(*args, request: Optional[Request] = None, **kwargs): return decorator -def get_user_email(user) -> str: - """Extract user email from JWT payload consistently. - - Args: - user: User object from JWT token (from get_current_user_with_permissions) - - Returns: - str: User email address - - Examples: - Test with dictionary user (JWT payload) with 'sub': - >>> from mcpgateway import admin - >>> user_dict = {'sub': 'alice@example.com', 'iat': 1234567890} - >>> admin.get_user_email(user_dict) - 'alice@example.com' - - Test with dictionary user with 'email' field: - >>> user_dict = {'email': 'bob@company.com', 'role': 'admin'} - >>> admin.get_user_email(user_dict) - 'bob@company.com' - - Test with dictionary user with both 'sub' and 'email' (sub takes precedence): - >>> user_dict = {'sub': 'charlie@primary.com', 'email': 'charlie@secondary.com'} - >>> admin.get_user_email(user_dict) - 'charlie@primary.com' - - Test with dictionary user with no email fields: - >>> user_dict = {'username': 'dave', 'role': 'user'} - >>> admin.get_user_email(user_dict) - 'unknown' - - Test with user object having email attribute: - >>> class MockUser: - ... def __init__(self, email): - ... self.email = email - >>> user_obj = MockUser('eve@test.com') - >>> admin.get_user_email(user_obj) - 'eve@test.com' - - Test with user object without email attribute: - >>> class BasicUser: - ... def __init__(self, name): - ... self.name = name - ... def __str__(self): - ... return self.name - >>> user_obj = BasicUser('frank') - >>> admin.get_user_email(user_obj) - 'frank' - - Test with None user: - >>> admin.get_user_email(None) - 'unknown' - - Test with string user: - >>> admin.get_user_email('grace@example.org') - 'grace@example.org' - - Test with empty dictionary: - >>> admin.get_user_email({}) - 'unknown' - - Test with non-string, non-dict, non-object values: - >>> admin.get_user_email(12345) - '12345' - """ - if isinstance(user, dict): - # Standard JWT format - try 'sub' first, then 'email' - return user.get("sub") or user.get("email", "unknown") - if hasattr(user, "email"): - # User object with email attribute - return user.email - # Fallback to string representation - return str(user) if user else "unknown" - - -def serialize_datetime(obj): - """Convert datetime objects to ISO format strings for JSON serialization. - - Args: - obj: Object to serialize, potentially a datetime - - Returns: - str: ISO format string if obj is datetime, otherwise returns obj unchanged - - Examples: - Test with datetime object: - >>> from mcpgateway import admin - >>> from datetime import datetime, timezone - >>> dt = datetime(2025, 1, 15, 10, 30, 45, tzinfo=timezone.utc) - >>> admin.serialize_datetime(dt) - '2025-01-15T10:30:45+00:00' - - Test with naive datetime: - >>> dt_naive = datetime(2025, 3, 20, 14, 15, 30) - >>> result = admin.serialize_datetime(dt_naive) - >>> '2025-03-20T14:15:30' in result - True - - Test with datetime with microseconds: - >>> dt_micro = datetime(2025, 6, 10, 9, 25, 12, 500000) - >>> result = admin.serialize_datetime(dt_micro) - >>> '2025-06-10T09:25:12.500000' in result - True - - Test with non-datetime objects (should return unchanged): - >>> admin.serialize_datetime("2025-01-15T10:30:45") - '2025-01-15T10:30:45' - >>> admin.serialize_datetime(12345) - 12345 - >>> admin.serialize_datetime(['a', 'list']) - ['a', 'list'] - >>> admin.serialize_datetime({'key': 'value'}) - {'key': 'value'} - >>> admin.serialize_datetime(None) - >>> admin.serialize_datetime(True) - True - - Test with current datetime: - >>> import datetime as dt_module - >>> now = dt_module.datetime.now() - >>> result = admin.serialize_datetime(now) - >>> isinstance(result, str) - True - >>> 'T' in result # ISO format contains 'T' separator - True - - Test edge case with datetime min/max: - >>> dt_min = datetime.min - >>> result = admin.serialize_datetime(dt_min) - >>> result.startswith('0001-01-01T') - True - """ - if isinstance(obj, datetime): - return obj.isoformat() - return obj - - admin_router = APIRouter(prefix="/admin", tags=["Admin UI"]) #################### @@ -420,7 +227,7 @@ def serialize_datetime(obj): @rate_limit(requests_per_minute=30) # Lower limit for config endpoints async def get_global_passthrough_headers( db: Session = Depends(get_db), - _user=Depends(get_current_user_with_permissions), + _user: Union[str, Dict[str, Any]] = Depends(require_auth), ) -> GlobalConfigRead: """Get the global passthrough headers configuration. @@ -455,7 +262,7 @@ async def update_global_passthrough_headers( request: Request, # pylint: disable=unused-argument config_update: GlobalConfigUpdate, db: Session = Depends(get_db), - _user=Depends(get_current_user_with_permissions), + _user: Union[str, Dict[str, Any]] = Depends(require_auth), ) -> GlobalConfigRead: """Update the global passthrough headers configuration. @@ -490,22 +297,21 @@ async def update_global_passthrough_headers( config.passthrough_headers = config_update.passthrough_headers db.commit() return GlobalConfigRead(passthrough_headers=config.passthrough_headers) - except (IntegrityError, ValidationError, PassthroughHeadersError) as e: + except Exception as e: db.rollback() if isinstance(e, IntegrityError): raise HTTPException(status_code=409, detail="Passthrough headers conflict") - if isinstance(e, ValidationError): + elif isinstance(e, ValidationError): raise HTTPException(status_code=422, detail="Invalid passthrough headers format") - if isinstance(e, PassthroughHeadersError): + else: raise HTTPException(status_code=500, detail=str(e)) - raise HTTPException(status_code=500, detail="Unknown error occurred") @admin_router.get("/servers", response_model=List[ServerRead]) async def admin_list_servers( include_inactive: bool = False, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), + user: str = Depends(require_auth), ) -> List[Dict[str, Any]]: """ List servers for the admin UI with an option to include inactive servers. @@ -525,7 +331,7 @@ async def admin_list_servers( >>> >>> # Mock dependencies >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> >>> # Mock server service >>> from datetime import datetime, timezone @@ -553,9 +359,9 @@ async def admin_list_servers( ... metrics=mock_metrics ... ) >>> - >>> # Mock the server_service.list_servers_for_user method - >>> original_list_servers_for_user = server_service.list_servers_for_user - >>> server_service.list_servers_for_user = AsyncMock(return_value=[mock_server]) + >>> # Mock the server_service.list_servers method + >>> original_list_servers = server_service.list_servers + >>> server_service.list_servers = AsyncMock(return_value=[mock_server]) >>> >>> # Test the function >>> async def test_admin_list_servers(): @@ -571,10 +377,10 @@ async def admin_list_servers( True >>> >>> # Restore original method - >>> server_service.list_servers_for_user = original_list_servers_for_user + >>> server_service.list_servers = original_list_servers >>> >>> # Additional test for empty server list - >>> server_service.list_servers_for_user = AsyncMock(return_value=[]) + >>> server_service.list_servers = AsyncMock(return_value=[]) >>> async def test_admin_list_servers_empty(): ... result = await admin_list_servers( ... include_inactive=True, @@ -584,13 +390,13 @@ async def admin_list_servers( ... return result == [] >>> asyncio.run(test_admin_list_servers_empty()) True - >>> server_service.list_servers_for_user = original_list_servers_for_user + >>> server_service.list_servers = original_list_servers >>> >>> # Additional test for exception handling >>> import pytest >>> from fastapi import HTTPException >>> async def test_admin_list_servers_exception(): - ... server_service.list_servers_for_user = AsyncMock(side_effect=Exception("Test error")) + ... server_service.list_servers = AsyncMock(side_effect=Exception("Test error")) ... try: ... await admin_list_servers(False, mock_db, mock_user) ... except Exception as e: @@ -598,14 +404,13 @@ async def admin_list_servers( >>> asyncio.run(test_admin_list_servers_exception()) True """ - LOGGER.debug(f"User {get_user_email(user)} requested server list") - user_email = get_user_email(user) - servers = await server_service.list_servers_for_user(db, user_email, include_inactive=include_inactive) + if logger: logger.debug(f"User {user} requested server list") + servers = await server_service.list_servers(db, include_inactive=include_inactive) return [server.model_dump(by_alias=True) for server in servers] @admin_router.get("/servers/{server_id}", response_model=ServerRead) -async def admin_get_server(server_id: str, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> Dict[str, Any]: +async def admin_get_server(server_id: str, db: Session = Depends(get_db), user: str = Depends(require_auth)) -> Dict[str, Any]: """ Retrieve server details for the admin UI. @@ -630,7 +435,7 @@ async def admin_get_server(server_id: str, db: Session = Depends(get_db), user=D >>> >>> # Mock dependencies >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> server_id = "test-server-1" >>> >>> # Mock server response @@ -698,18 +503,18 @@ async def admin_get_server(server_id: str, db: Session = Depends(get_db), user=D >>> server_service.get_server = original_get_server """ try: - LOGGER.debug(f"User {get_user_email(user)} requested details for server ID {server_id}") + if logger: logger.debug(f"User {user} requested details for server ID {server_id}") server = await server_service.get_server(db, server_id) return server.model_dump(by_alias=True) except ServerNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: - LOGGER.error(f"Error getting gateway {server_id}: {e}") + if logger: logger.error(f"Error getting gateway {server_id}: {e}") raise e @admin_router.post("/servers", response_model=ServerRead) -async def admin_add_server(request: Request, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> JSONResponse: +async def admin_add_server(request: Request, db: Session = Depends(get_db), user: str = Depends(require_auth)) -> JSONResponse: """ Add a new server via the admin UI. @@ -721,9 +526,9 @@ async def admin_add_server(request: Request, db: Session = Depends(get_db), user - name (required): The name of the server - description (optional): A description of the server's purpose - icon (optional): URL or path to the server's icon - - associatedTools (optional, multiple values): Tools associated with this server - - associatedResources (optional, multiple values): Resources associated with this server - - associatedPrompts (optional, multiple values): Prompts associated with this server + - associatedTools (optional, comma-separated): Tools associated with this server + - associatedResources (optional, comma-separated): Resources associated with this server + - associatedPrompts (optional, comma-separated): Prompts associated with this server Args: request (Request): FastAPI request containing form data. @@ -747,7 +552,7 @@ async def admin_add_server(request: Request, db: Session = Depends(get_db), user >>> timestamp = datetime.now().strftime("%Y%m%d%H%M%S") >>> short_uuid = str(uuid.uuid4())[:8] >>> unq_ext = f"{timestamp}-{short_uuid}" - >>> mock_user = {"email": "test_user_" + unq_ext, "db": mock_db} + >>> mock_user = "test_user_" + unq_ext >>> # Mock form data for successful server creation >>> form_data = FormData([ ... ("name", "Test-Server-"+unq_ext ), @@ -756,9 +561,7 @@ async def admin_add_server(request: Request, db: Session = Depends(get_db), user ... ("associatedTools", "tool1"), ... ("associatedTools", "tool2"), ... ("associatedResources", "resource1"), - ... ("associatedResources", "resource2"), ... ("associatedPrompts", "prompt1"), - ... ("associatedPrompts", "prompt2"), ... ("is_inactive_checked", "false") ... ]) >>> @@ -834,48 +637,26 @@ async def admin_add_server(request: Request, db: Session = Depends(get_db), user tags: list[str] = [tag.strip() for tag in tags_str.split(",") if tag.strip()] if tags_str else [] try: - LOGGER.debug(f"User {get_user_email(user)} is adding a new server with name: {form['name']}") - server_id = form.get("id") - visibility = str(form.get("visibility", "private")) - LOGGER.info(f" user input id::{server_id}") + if logger: logger.debug(f"User {user} is adding a new server with name: {form['name']}") server = ServerCreate( - id=form.get("id") or None, - name=form.get("name"), - description=form.get("description"), - icon=form.get("icon"), - associated_tools=",".join(str(x) for x in form.getlist("associatedTools")), - associated_resources=",".join(str(x) for x in form.getlist("associatedResources")), - associated_prompts=",".join(str(x) for x in form.getlist("associatedPrompts")), + id=None, + name=str(form.get("name") or ""), + description=str(form.get("description") or "") if form.get("description") else None, # nullable + icon=str(form.get("icon") or "") if form.get("icon") else None, + associated_tools=[str(x) for x in form.getlist("associatedTools") if isinstance(x, str)], + associated_resources=str(form.get("associatedResources") or "").split(",") if form.get("associatedResources") else None, + associated_prompts=str(form.get("associatedPrompts") or "").split(",") if form.get("associatedPrompts") else None, + associated_a2a_agents=str(form.get("associatedA2AAgents") or "").split(",") if form.get("associatedA2AAgents") else None, tags=tags, - visibility=visibility, + team_id=None, + owner_email=None, ) except KeyError as e: # Convert KeyError to ValidationError-like response return JSONResponse(content={"message": f"Missing required field: {e}", "success": False}, status_code=422) - try: - user_email = get_user_email(user) - # Determine personal team for default assignment - team_id_raw = form.get("team_id", None) - team_id = str(team_id_raw) if team_id_raw is not None else None - team_service = TeamManagementService(db) - team_id = await team_service.verify_team_for_user(user_email, team_id) - - # Extract metadata for server creation - creation_metadata = MetadataCapture.extract_creation_metadata(request, user) - - # Ensure default visibility is private and assign to personal team when available - team_id_cast = cast(Optional[str], team_id) - await server_service.register_server( - db, - server, - created_by=user_email, # Use the consistent user_email - created_from_ip=creation_metadata["created_from_ip"], - created_via=creation_metadata["created_via"], - created_user_agent=creation_metadata["created_user_agent"], - team_id=team_id_cast, - visibility=visibility, - ) + try: + await server_service.register_server(db, server) return JSONResponse( content={"message": "Server created successfully!", "success": True}, status_code=200, @@ -883,14 +664,10 @@ async def admin_add_server(request: Request, db: Session = Depends(get_db), user except CoreValidationError as ex: return JSONResponse(content={"message": str(ex), "success": False}, status_code=422) - except ServerNameConflictError as ex: - return JSONResponse(content={"message": str(ex), "success": False}, status_code=409) except ServerError as ex: return JSONResponse(content={"message": str(ex), "success": False}, status_code=500) except ValueError as ex: return JSONResponse(content={"message": str(ex), "success": False}, status_code=400) - except ValidationError as ex: - return JSONResponse(content=ErrorFormatter.format_validation_error(ex), status_code=422) except IntegrityError as ex: return JSONResponse(content=ErrorFormatter.format_database_error(ex), status_code=409) except Exception as ex: @@ -902,7 +679,7 @@ async def admin_edit_server( server_id: str, request: Request, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), + user: str = Depends(require_auth), ) -> JSONResponse: """ Edit an existing server via the admin UI. @@ -912,13 +689,12 @@ async def admin_edit_server( update operation. Expects form fields: - - id (optional): Updated UUID for the server - name (optional): The updated name of the server - description (optional): An updated description of the server's purpose - icon (optional): Updated URL or path to the server's icon - - associatedTools (optional, multiple values): Updated list of tools associated with this server - - associatedResources (optional, multiple values): Updated list of resources associated with this server - - associatedPrompts (optional, multiple values): Updated list of prompts associated with this server + - associatedTools (optional, comma-separated): Updated list of tools associated with this server + - associatedResources (optional, comma-separated): Updated list of resources associated with this server + - associatedPrompts (optional, comma-separated): Updated list of prompts associated with this server Args: server_id (str): The ID of the server to edit @@ -937,7 +713,7 @@ async def admin_edit_server( >>> from starlette.datastructures import FormData >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> server_id = "server-to-edit" >>> >>> # Happy path: Edit server with new name @@ -1019,41 +795,23 @@ async def admin_edit_server( tags_str = str(form.get("tags", "")) tags: list[str] = [tag.strip() for tag in tags_str.split(",") if tag.strip()] if tags_str else [] try: - LOGGER.debug(f"User {get_user_email(user)} is editing server ID {server_id} with name: {form.get('name')}") - visibility = str(form.get("visibility", "private")) - user_email = get_user_email(user) - team_id_raw = form.get("team_id", None) - team_id = str(team_id_raw) if team_id_raw is not None else None - - team_service = TeamManagementService(db) - team_id = await team_service.verify_team_for_user(user_email, team_id) - - mod_metadata = MetadataCapture.extract_modification_metadata(request, user, 0) - + if logger: logger.debug(f"User {user} is editing server ID {server_id} with name: {form.get('name')}") server = ServerUpdate( - id=form.get("id"), - name=form.get("name"), - description=form.get("description"), - icon=form.get("icon"), - associated_tools=",".join(str(x) for x in form.getlist("associatedTools")), - associated_resources=",".join(str(x) for x in form.getlist("associatedResources")), - associated_prompts=",".join(str(x) for x in form.getlist("associatedPrompts")), + id=None, + name=str(form.get("name") or "") if form.get("name") else None, + description=str(form.get("description") or "") if form.get("description") else None, + icon=str(form.get("icon") or "") if form.get("icon") else None, + associated_tools=[str(x) for x in form.getlist("associatedTools") if isinstance(x, str)], + associated_resources=str(form.get("associatedResources") or "").split(",") if form.get("associatedResources") else None, + associated_prompts=str(form.get("associatedPrompts") or "").split(",") if form.get("associatedPrompts") else None, + associated_a2a_agents=str(form.get("associatedA2AAgents") or "").split(",") if form.get("associatedA2AAgents") else None, tags=tags, - visibility=visibility, - team_id=team_id, - owner_email=user_email, - ) - - await server_service.update_server( - db, - server_id, - server, - user_email, - modified_by=mod_metadata["modified_by"], - modified_from_ip=mod_metadata["modified_from_ip"], - modified_via=mod_metadata["modified_via"], - modified_user_agent=mod_metadata["modified_user_agent"], + team_id=None, + owner_email=None, + visibility=None, ) + user_email = extract_user_email(user) + await server_service.update_server(db, server_id, server, user_email) return JSONResponse( content={"message": "Server updated successfully!", "success": True}, @@ -1081,8 +839,8 @@ async def admin_toggle_server( server_id: str, request: Request, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), -) -> Response: + user: str = Depends(require_auth), +) -> RedirectResponse: """ Toggle a server's active status via the admin UI. @@ -1098,7 +856,7 @@ async def admin_toggle_server( user (str): Authenticated user dependency. Returns: - Response: A redirect to the admin dashboard catalog section with a + RedirectResponse: A redirect to the admin dashboard catalog section with a status code of 303 (See Other). Examples: @@ -1109,7 +867,7 @@ async def admin_toggle_server( >>> from starlette.datastructures import FormData >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> server_id = "server-to-toggle" >>> >>> # Happy path: Activate server @@ -1167,13 +925,14 @@ async def admin_toggle_server( >>> server_service.toggle_server_status = original_toggle_server_status """ form = await request.form() - LOGGER.debug(f"User {get_user_email(user)} is toggling server ID {server_id} with activate: {form.get('activate')}") + if logger: logger.debug(f"User {user} is toggling server ID {server_id} with activate: {form.get('activate')}") activate = str(form.get("activate", "true")).lower() == "true" is_inactive_checked = str(form.get("is_inactive_checked", "false")) try: await server_service.toggle_server_status(db, server_id, activate) except Exception as e: - LOGGER.error(f"Error toggling server status: {e}") + if logger: + logger.error(f"Error toggling server status: {e}") root_path = request.scope.get("root_path", "") if is_inactive_checked.lower() == "true": @@ -1182,7 +941,7 @@ async def admin_toggle_server( @admin_router.post("/servers/{server_id}/delete") -async def admin_delete_server(server_id: str, request: Request, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> RedirectResponse: +async def admin_delete_server(server_id: str, request: Request, db: Session = Depends(get_db), user: str = Depends(require_auth)) -> RedirectResponse: """ Delete a server via the admin UI. @@ -1207,7 +966,7 @@ async def admin_delete_server(server_id: str, request: Request, db: Session = De >>> from starlette.datastructures import FormData >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> server_id = "server-to-delete" >>> >>> # Happy path: Delete server @@ -1253,10 +1012,10 @@ async def admin_delete_server(server_id: str, request: Request, db: Session = De >>> server_service.delete_server = original_delete_server """ try: - LOGGER.debug(f"User {get_user_email(user)} is deleting server ID {server_id}") + if logger: logger.debug(f"User {user} is deleting server ID {server_id}") await server_service.delete_server(db, server_id) except Exception as e: - LOGGER.error(f"Error deleting server: {e}") + if logger: logger.error(f"Error deleting server: {e}") form = await request.form() is_inactive_checked = str(form.get("is_inactive_checked", "false")) @@ -1271,7 +1030,7 @@ async def admin_delete_server(server_id: str, request: Request, db: Session = De async def admin_list_resources( include_inactive: bool = False, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), + user: str = Depends(require_auth), ) -> List[Dict[str, Any]]: """ List resources for the admin UI with an option to include inactive resources. @@ -1295,7 +1054,7 @@ async def admin_list_resources( >>> from datetime import datetime, timezone >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> >>> # Mock resource data >>> mock_resource = ResourceRead( @@ -1316,9 +1075,9 @@ async def admin_list_resources( ... tags=[] ... ) >>> - >>> # Mock the resource_service.list_resources_for_user method - >>> original_list_resources_for_user = resource_service.list_resources_for_user - >>> resource_service.list_resources_for_user = AsyncMock(return_value=[mock_resource]) + >>> # Mock the resource_service.list_resources method + >>> original_list_resources = resource_service.list_resources + >>> resource_service.list_resources = AsyncMock(return_value=[mock_resource]) >>> >>> # Test listing active resources >>> async def test_admin_list_resources_active(): @@ -1339,7 +1098,7 @@ async def admin_list_resources( ... avg_response_time=0.0, last_execution_time=None), ... tags=[] ... ) - >>> resource_service.list_resources_for_user = AsyncMock(return_value=[mock_resource, mock_inactive_resource]) + >>> resource_service.list_resources = AsyncMock(return_value=[mock_resource, mock_inactive_resource]) >>> async def test_admin_list_resources_all(): ... result = await admin_list_resources(include_inactive=True, db=mock_db, user=mock_user) ... return len(result) == 2 and not result[1]['isActive'] @@ -1348,7 +1107,7 @@ async def admin_list_resources( True >>> >>> # Test empty list - >>> resource_service.list_resources_for_user = AsyncMock(return_value=[]) + >>> resource_service.list_resources = AsyncMock(return_value=[]) >>> async def test_admin_list_resources_empty(): ... result = await admin_list_resources(include_inactive=False, db=mock_db, user=mock_user) ... return result == [] @@ -1357,7 +1116,7 @@ async def admin_list_resources( True >>> >>> # Test exception handling - >>> resource_service.list_resources_for_user = AsyncMock(side_effect=Exception("Resource list error")) + >>> resource_service.list_resources = AsyncMock(side_effect=Exception("Resource list error")) >>> async def test_admin_list_resources_exception(): ... try: ... await admin_list_resources(False, mock_db, mock_user) @@ -1369,11 +1128,10 @@ async def admin_list_resources( True >>> >>> # Restore original method - >>> resource_service.list_resources_for_user = original_list_resources_for_user + >>> resource_service.list_resources = original_list_resources """ - LOGGER.debug(f"User {get_user_email(user)} requested resource list") - user_email = get_user_email(user) - resources = await resource_service.list_resources_for_user(db, user_email, include_inactive=include_inactive) + if logger: logger.debug(f"User {user} requested resource list") + resources = await resource_service.list_resources(db, include_inactive=include_inactive) return [resource.model_dump(by_alias=True) for resource in resources] @@ -1381,7 +1139,7 @@ async def admin_list_resources( async def admin_list_prompts( include_inactive: bool = False, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), + user: str = Depends(require_auth), ) -> List[Dict[str, Any]]: """ List prompts for the admin UI with an option to include inactive prompts. @@ -1405,7 +1163,7 @@ async def admin_list_prompts( >>> from datetime import datetime, timezone >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> >>> # Mock prompt data >>> mock_prompt = PromptRead( @@ -1425,9 +1183,9 @@ async def admin_list_prompts( ... tags=[] ... ) >>> - >>> # Mock the prompt_service.list_prompts_for_user method - >>> original_list_prompts_for_user = prompt_service.list_prompts_for_user - >>> prompt_service.list_prompts_for_user = AsyncMock(return_value=[mock_prompt]) + >>> # Mock the prompt_service.list_prompts method + >>> original_list_prompts = prompt_service.list_prompts + >>> prompt_service.list_prompts = AsyncMock(return_value=[mock_prompt]) >>> >>> # Test listing active prompts >>> async def test_admin_list_prompts_active(): @@ -1448,7 +1206,7 @@ async def admin_list_prompts( ... ), ... tags=[] ... ) - >>> prompt_service.list_prompts_for_user = AsyncMock(return_value=[mock_prompt, mock_inactive_prompt]) + >>> prompt_service.list_prompts = AsyncMock(return_value=[mock_prompt, mock_inactive_prompt]) >>> async def test_admin_list_prompts_all(): ... result = await admin_list_prompts(include_inactive=True, db=mock_db, user=mock_user) ... return len(result) == 2 and not result[1]['isActive'] @@ -1457,7 +1215,7 @@ async def admin_list_prompts( True >>> >>> # Test empty list - >>> prompt_service.list_prompts_for_user = AsyncMock(return_value=[]) + >>> prompt_service.list_prompts = AsyncMock(return_value=[]) >>> async def test_admin_list_prompts_empty(): ... result = await admin_list_prompts(include_inactive=False, db=mock_db, user=mock_user) ... return result == [] @@ -1466,7 +1224,7 @@ async def admin_list_prompts( True >>> >>> # Test exception handling - >>> prompt_service.list_prompts_for_user = AsyncMock(side_effect=Exception("Prompt list error")) + >>> prompt_service.list_prompts = AsyncMock(side_effect=Exception("Prompt list error")) >>> async def test_admin_list_prompts_exception(): ... try: ... await admin_list_prompts(False, mock_db, mock_user) @@ -1478,11 +1236,10 @@ async def admin_list_prompts( True >>> >>> # Restore original method - >>> prompt_service.list_prompts_for_user = original_list_prompts_for_user + >>> prompt_service.list_prompts = original_list_prompts """ - LOGGER.debug(f"User {get_user_email(user)} requested prompt list") - user_email = get_user_email(user) - prompts = await prompt_service.list_prompts_for_user(db, user_email, include_inactive=include_inactive) + if logger: logger.debug(f"User {user} requested prompt list") + prompts = await prompt_service.list_prompts(db, include_inactive=include_inactive) return [prompt.model_dump(by_alias=True) for prompt in prompts] @@ -1490,7 +1247,7 @@ async def admin_list_prompts( async def admin_list_gateways( include_inactive: bool = False, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), + user: str = Depends(require_auth), ) -> List[Dict[str, Any]]: """ List gateways for the admin UI with an option to include inactive gateways. @@ -1514,7 +1271,7 @@ async def admin_list_gateways( >>> from datetime import datetime, timezone >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> >>> # Mock gateway data >>> mock_gateway = GatewayRead( @@ -1587,7 +1344,7 @@ async def admin_list_gateways( >>> # Restore original method >>> gateway_service.list_gateways = original_list_gateways """ - LOGGER.debug(f"User {get_user_email(user)} requested gateway list") + if logger: logger.debug(f"User {user} requested gateway list") gateways = await gateway_service.list_gateways(db, include_inactive=include_inactive) return [gateway.model_dump(by_alias=True) for gateway in gateways] @@ -1597,7 +1354,7 @@ async def admin_toggle_gateway( gateway_id: str, request: Request, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), + user: str = Depends(require_auth), ) -> RedirectResponse: """ Toggle the active status of a gateway via the admin UI. @@ -1624,7 +1381,7 @@ async def admin_toggle_gateway( >>> from starlette.datastructures import FormData >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> gateway_id = "gateway-to-toggle" >>> >>> # Happy path: Activate gateway @@ -1681,7 +1438,7 @@ async def admin_toggle_gateway( >>> # Restore original method >>> gateway_service.toggle_gateway_status = original_toggle_gateway_status """ - LOGGER.debug(f"User {get_user_email(user)} is toggling gateway ID {gateway_id}") + if logger: logger.debug(f"User {user} is toggling gateway ID {gateway_id}") form = await request.form() activate = str(form.get("activate", "true")).lower() == "true" is_inactive_checked = str(form.get("is_inactive_checked", "false")) @@ -1689,7 +1446,7 @@ async def admin_toggle_gateway( try: await gateway_service.toggle_gateway_status(db, gateway_id, activate) except Exception as e: - LOGGER.error(f"Error toggling gateway status: {e}") + if logger: logger.error(f"Error toggling gateway status: {e}") root_path = request.scope.get("root_path", "") if is_inactive_checked.lower() == "true": @@ -1700,11 +1457,10 @@ async def admin_toggle_gateway( @admin_router.get("/", name="admin_home", response_class=HTMLResponse) async def admin_ui( request: Request, - team_id: Optional[str] = Query(None), include_inactive: bool = False, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), - _jwt_token: str = Depends(get_jwt_token), + user: str = Depends(require_basic_auth), + jwt_token: str = Depends(get_jwt_token), ) -> Any: """ Render the admin dashboard HTML page. @@ -1713,21 +1469,15 @@ async def admin_ui( servers, tools, resources, prompts, gateways, and roots from their respective services, then renders the admin dashboard template with this data. - Supports optional `team_id` query param to scope the returned data to a team. - If `team_id` is provided and email-based team management is enabled, we - validate the user is a member of that team. We attempt to pass team_id into - service listing functions (preferred). If the service API does not accept a - team_id parameter we fall back to post-filtering the returned items. - The endpoint also sets a JWT token as a cookie for authentication in subsequent requests. This token is HTTP-only for security reasons. Args: request (Request): FastAPI request object. - team_id (Optional[str]): Optional team ID to filter data by team. include_inactive (bool): Whether to include inactive items in all listings. db (Session): Database session dependency. - user (dict): Authenticated user context with permissions. + user (str): Authenticated user from basic auth dependency. + jwt_token (str): JWT token for authentication. Returns: Any: Rendered HTML template for the admin dashboard. @@ -1741,20 +1491,21 @@ async def admin_ui( >>> from datetime import datetime, timezone >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "admin_user", "db": mock_db} + >>> mock_user = "admin_user" + >>> mock_jwt = "fake.jwt.token" >>> >>> # Mock services to return empty lists for simplicity in doctest - >>> original_list_servers_for_user = server_service.list_servers_for_user - >>> original_list_tools_for_user = tool_service.list_tools_for_user - >>> original_list_resources_for_user = resource_service.list_resources_for_user - >>> original_list_prompts_for_user = prompt_service.list_prompts_for_user + >>> original_list_servers = server_service.list_servers + >>> original_list_tools = tool_service.list_tools + >>> original_list_resources = resource_service.list_resources + >>> original_list_prompts = prompt_service.list_prompts >>> original_list_gateways = gateway_service.list_gateways >>> original_list_roots = root_service.list_roots >>> - >>> server_service.list_servers_for_user = AsyncMock(return_value=[]) - >>> tool_service.list_tools_for_user = AsyncMock(return_value=[]) - >>> resource_service.list_resources_for_user = AsyncMock(return_value=[]) - >>> prompt_service.list_prompts_for_user = AsyncMock(return_value=[]) + >>> server_service.list_servers = AsyncMock(return_value=[]) + >>> tool_service.list_tools = AsyncMock(return_value=[]) + >>> resource_service.list_resources = AsyncMock(return_value=[]) + >>> prompt_service.list_prompts = AsyncMock(return_value=[]) >>> gateway_service.list_gateways = AsyncMock(return_value=[]) >>> root_service.list_roots = AsyncMock(return_value=[]) >>> @@ -1766,17 +1517,17 @@ async def admin_ui( >>> >>> # Test basic rendering >>> async def test_admin_ui_basic_render(): - ... response = await admin_ui(mock_request, None, False, mock_db, mock_user) - ... return isinstance(response, HTMLResponse) and response.status_code == 200 + ... response = await admin_ui(mock_request, False, mock_db, mock_user, mock_jwt) + ... return isinstance(response, HTMLResponse) and response.status_code == 200 and "jwt_token" in response.headers.get("set-cookie", "") >>> >>> asyncio.run(test_admin_ui_basic_render()) True >>> >>> # Test with include_inactive=True >>> async def test_admin_ui_include_inactive(): - ... response = await admin_ui(mock_request, None, True, mock_db, mock_user) + ... response = await admin_ui(mock_request, True, mock_db, mock_user, mock_jwt) ... # Verify list methods were called with include_inactive=True - ... server_service.list_servers_for_user.assert_called_with(mock_db, mock_user["email"], include_inactive=True) + ... server_service.list_servers.assert_called_with(mock_db, include_inactive=True) ... return isinstance(response, HTMLResponse) >>> >>> asyncio.run(test_admin_ui_include_inactive()) @@ -1799,11 +1550,19 @@ async def admin_ui( ... customName="T1", ... tags=[] ... ) + + >>> server_service.list_servers = AsyncMock(return_value=[mock_server]) + >>> tool_service.list_tools = AsyncMock(return_value=[mock_tool]) + >>> + >>> async def test_admin_ui_with_data(): + ... response = await admin_ui(mock_request, False, mock_db, mock_user, mock_jwt) + >>> server_service.list_servers_for_user = AsyncMock(return_value=[mock_server]) >>> tool_service.list_tools_for_user = AsyncMock(return_value=[mock_tool]) >>> >>> async def test_admin_ui_with_data(): ... response = await admin_ui(mock_request, None, False, mock_db, mock_user) + ... # Check if template context was populated (indirectly via mock calls) ... assert mock_request.app.state.templates.TemplateResponse.call_count >= 1 ... context = mock_request.app.state.templates.TemplateResponse.call_args[0][2] @@ -1812,6 +1571,16 @@ async def admin_ui( >>> asyncio.run(test_admin_ui_with_data()) True >>> + + >>> # Test exception handling during data fetching + >>> server_service.list_servers = AsyncMock(side_effect=Exception("DB error")) + >>> async def test_admin_ui_exception_handled(): + ... try: + ... response = await admin_ui(mock_request, False, mock_db, mock_user, mock_jwt) + ... return False # Should not reach here if exception is properly raised + ... except Exception as e: + ... return str(e) == "DB error" + >>> from unittest.mock import AsyncMock, patch >>> import logging >>> @@ -1832,11 +1601,18 @@ async def admin_ui( ... log_called = mock_log.called ... # Optionally, you can even inspect the message if you want ... return ok_response and log_called + >>> >>> asyncio.run(test_admin_ui_exception_handled()) True >>> >>> # Restore original methods + >>> server_service.list_servers = original_list_servers + >>> tool_service.list_tools = original_list_tools + >>> resource_service.list_resources = original_list_resources + >>> prompt_service.list_prompts = original_list_prompts + >>> gateway_service.list_gateways = original_list_gateways + >>> root_service.list_roots = original_list_roots >>> server_service.list_servers_for_user = original_list_servers_for_user >>> tool_service.list_tools_for_user = original_list_tools_for_user >>> resource_service.list_resources_for_user = original_list_resources_for_user @@ -4425,45 +4201,58 @@ async def admin_delete_user( Returns: HTMLResponse: Success/error message """ - if not settings.email_auth_enabled: - return HTMLResponse(content='
Email authentication is disabled
', status_code=403) - - try: - # First-Party - from mcpgateway.services.email_auth_service import EmailAuthService # pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel + if logger: logger.debug(f"User {user} accessed the admin UI") + tools = [ + tool.model_dump(by_alias=True) for tool in sorted(await tool_service.list_tools(db, include_inactive=include_inactive), key=lambda t: ((t.url or "").lower(), (t.original_name or "").lower())) + ] + servers = [server.model_dump(by_alias=True) for server in await server_service.list_servers(db, include_inactive=include_inactive)] + resources = [resource.model_dump(by_alias=True) for resource in await resource_service.list_resources(db, include_inactive=include_inactive)] + prompts = [prompt.model_dump(by_alias=True) for prompt in await prompt_service.list_prompts(db, include_inactive=include_inactive)] + gateways_raw = await gateway_service.list_gateways(db, include_inactive=include_inactive) + gateways = [gateway.model_dump(by_alias=True) for gateway in gateways_raw] - auth_service = EmailAuthService(db) - - # URL decode the email - - decoded_email = urllib.parse.unquote(user_email) - - # Get current user email from JWT - current_user_email = get_user_email(user) - - # Prevent self-deletion - if decoded_email == current_user_email: - return HTMLResponse(content='
Cannot delete your own account
', status_code=400) - - # Prevent deleting the last active admin user - if await auth_service.is_last_active_admin(decoded_email): - return HTMLResponse(content='
Cannot delete the last remaining admin user
', status_code=400) + roots = [root.model_dump(by_alias=True) for root in await root_service.list_roots()] - await auth_service.delete_user(decoded_email) + # Load A2A agents if enabled + a2a_agents = [] + if a2a_service and settings.mcpgateway_a2a_enabled: + a2a_agents_raw = await a2a_service.list_agents(db, include_inactive=include_inactive) + a2a_agents = [agent.model_dump(by_alias=True) for agent in a2a_agents_raw] - # Return empty content to remove the user from the list - return HTMLResponse(content="", status_code=200) + root_path = settings.app_root_path + max_name_length = settings.validation_max_name_length + response = request.app.state.templates.TemplateResponse( + request, + "admin.html", + { + "request": request, + "servers": servers, + "tools": tools, + "resources": resources, + "prompts": prompts, + "gateways": gateways, + "a2a_agents": a2a_agents, + "roots": roots, + "include_inactive": include_inactive, + "root_path": root_path, + "max_name_length": max_name_length, + "gateway_tool_name_separator": settings.gateway_tool_name_separator, + "bulk_import_max_tools": settings.mcpgateway_bulk_import_max_tools, + "a2a_enabled": settings.mcpgateway_a2a_enabled, + "mcpgateway_ui_tool_test_timeout": settings.mcpgateway_ui_tool_test_timeout, + }, + ) - except Exception as e: - LOGGER.error(f"Error deleting user {user_email}: {e}") - return HTMLResponse(content=f'
Error deleting user: {str(e)}
', status_code=400) + # Use secure cookie utility for proper security attributes + set_auth_cookie(response, jwt_token, remember_me=False) + return response @admin_router.get("/tools", response_model=List[ToolRead]) async def admin_list_tools( include_inactive: bool = False, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), + user: str = Depends(require_auth), ) -> List[Dict[str, Any]]: """ List tools for the admin UI with an option to include inactive tools. @@ -4487,7 +4276,7 @@ async def admin_list_tools( >>> from datetime import datetime, timezone >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> >>> # Mock tool data >>> mock_tool = ToolRead( @@ -4520,9 +4309,9 @@ async def admin_list_tools( ... tags=[] ... ) # Added gateway_id=None >>> - >>> # Mock the tool_service.list_tools_for_user method - >>> original_list_tools_for_user = tool_service.list_tools_for_user - >>> tool_service.list_tools_for_user = AsyncMock(return_value=[mock_tool]) + >>> # Mock the tool_service.list_tools method + >>> original_list_tools = tool_service.list_tools + >>> tool_service.list_tools = AsyncMock(return_value=[mock_tool]) >>> >>> # Test listing active tools >>> async def test_admin_list_tools_active(): @@ -4548,7 +4337,7 @@ async def admin_list_tools( ... customName="Inactive Tool", ... tags=[] ... ) - >>> tool_service.list_tools_for_user = AsyncMock(return_value=[mock_tool, mock_inactive_tool]) + >>> tool_service.list_tools = AsyncMock(return_value=[mock_tool, mock_inactive_tool]) >>> async def test_admin_list_tools_all(): ... result = await admin_list_tools(include_inactive=True, db=mock_db, user=mock_user) ... return len(result) == 2 and not result[1]['enabled'] @@ -4557,7 +4346,7 @@ async def admin_list_tools( True >>> >>> # Test empty list - >>> tool_service.list_tools_for_user = AsyncMock(return_value=[]) + >>> tool_service.list_tools = AsyncMock(return_value=[]) >>> async def test_admin_list_tools_empty(): ... result = await admin_list_tools(include_inactive=False, db=mock_db, user=mock_user) ... return result == [] @@ -4566,7 +4355,7 @@ async def admin_list_tools( True >>> >>> # Test exception handling - >>> tool_service.list_tools_for_user = AsyncMock(side_effect=Exception("Tool list error")) + >>> tool_service.list_tools = AsyncMock(side_effect=Exception("Tool list error")) >>> async def test_admin_list_tools_exception(): ... try: ... await admin_list_tools(False, mock_db, mock_user) @@ -4578,17 +4367,16 @@ async def admin_list_tools( True >>> >>> # Restore original method - >>> tool_service.list_tools_for_user = original_list_tools_for_user + >>> tool_service.list_tools = original_list_tools """ - LOGGER.debug(f"User {get_user_email(user)} requested tool list") - user_email = get_user_email(user) - tools = await tool_service.list_tools_for_user(db, user_email, include_inactive=include_inactive) + if logger: logger.debug(f"User {user} requested tool list") + tools = await tool_service.list_tools(db, include_inactive=include_inactive) return [tool.model_dump(by_alias=True) for tool in tools] @admin_router.get("/tools/{tool_id}", response_model=ToolRead) -async def admin_get_tool(tool_id: str, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> Dict[str, Any]: +async def admin_get_tool(tool_id: str, db: Session = Depends(get_db), user: str = Depends(require_auth)) -> Dict[str, Any]: """ Retrieve specific tool details for the admin UI. @@ -4617,7 +4405,7 @@ async def admin_get_tool(tool_id: str, db: Session = Depends(get_db), user=Depen >>> from fastapi import HTTPException >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> tool_id = "test-tool-id" >>> >>> # Mock tool data @@ -4676,7 +4464,7 @@ async def admin_get_tool(tool_id: str, db: Session = Depends(get_db), user=Depen >>> # Restore original method >>> tool_service.get_tool = original_get_tool """ - LOGGER.debug(f"User {get_user_email(user)} requested details for tool ID {tool_id}") + if logger: logger.debug(f"User {user} requested details for tool ID {tool_id}") try: tool = await tool_service.get_tool(db, tool_id) return tool.model_dump(by_alias=True) @@ -4684,7 +4472,7 @@ async def admin_get_tool(tool_id: str, db: Session = Depends(get_db), user=Depen raise HTTPException(status_code=404, detail=str(e)) except Exception as e: # Catch any other unexpected errors and re-raise or log as needed - LOGGER.error(f"Error getting tool {tool_id}: {e}") + if logger: logger.error(f"Error getting tool {tool_id}: {e}") raise e # Re-raise for now, or return a 500 JSONResponse if preferred for API consistency @@ -4693,7 +4481,7 @@ async def admin_get_tool(tool_id: str, db: Session = Depends(get_db), user=Depen async def admin_add_tool( request: Request, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), + user: str = Depends(require_auth), ) -> JSONResponse: """ Add a tool via the admin UI with error handling. @@ -4736,7 +4524,7 @@ async def admin_add_tool( >>> import json >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> # Happy path: Add a new tool successfully >>> form_data_success = FormData([ @@ -4815,13 +4603,14 @@ async def admin_add_tool( >>> tool_service.register_tool = original_register_tool """ - LOGGER.debug(f"User {get_user_email(user)} is adding a new tool") + if logger: logger.debug(f"User {user} is adding a new tool") form = await request.form() - LOGGER.debug(f"Received form data: {dict(form)}") + if logger: logger.debug(f"Received form data: {dict(form)}") + integration_type = form.get("integrationType", "REST") request_type = form.get("requestType") - visibility = str(form.get("visibility", "private")) + # Map UI fields to internal defaults if missing (REST -> GET, MCP -> SSE) if request_type is None: if integration_type == "REST": request_type = "GET" # or any valid REST method default @@ -4830,28 +4619,19 @@ async def admin_add_tool( else: request_type = "GET" - user_email = get_user_email(user) - # Determine personal team for default assignment - team_id = form.get("team_id", None) - team_service = TeamManagementService(db) - team_id = await team_service.verify_team_for_user(user_email, team_id) # Parse tags from comma-separated string tags_str = str(form.get("tags", "")) tags: list[str] = [tag.strip() for tag in tags_str.split(",") if tag.strip()] if tags_str else [] - # Safely parse potential JSON strings from form - headers_raw = form.get("headers") - input_schema_raw = form.get("input_schema") - annotations_raw = form.get("annotations") + tool_data: dict[str, Any] = { "name": form.get("name"), - "displayName": form.get("displayName"), "url": form.get("url"), "description": form.get("description"), "request_type": request_type, "integration_type": integration_type, - "headers": json.loads(headers_raw if isinstance(headers_raw, str) and headers_raw else "{}"), - "input_schema": json.loads(input_schema_raw if isinstance(input_schema_raw, str) and input_schema_raw else "{}"), - "annotations": json.loads(annotations_raw if isinstance(annotations_raw, str) and annotations_raw else "{}"), + "headers": json.loads(str(form.get("headers") or "{}")), + "input_schema": json.loads(str(form.get("input_schema") or "{}")), + "annotations": json.loads(str(form.get("annotations") or "{}")), "jsonpath_filter": form.get("jsonpath_filter", ""), "auth_type": form.get("auth_type", ""), "auth_username": form.get("auth_username", ""), @@ -4860,14 +4640,11 @@ async def admin_add_tool( "auth_header_key": form.get("auth_header_key", ""), "auth_header_value": form.get("auth_header_value", ""), "tags": tags, - "visibility": visibility, - "team_id": team_id, - "owner_email": user_email, } - LOGGER.debug(f"Tool data built: {tool_data}") + if logger: logger.debug(f"Tool data built: {tool_data}") try: - tool = ToolCreate(**tool_data) - LOGGER.debug(f"Validated tool data: {tool.model_dump(by_alias=True)}") + tool = ToolCreate(**tool_data) # Pydantic validation happens here; raises ValidationError on bad input + if logger: logger.debug(f"Validated tool data: {tool.model_dump(by_alias=True)}") # Extract creation metadata metadata = MetadataCapture.extract_creation_metadata(request, user) @@ -4888,18 +4665,15 @@ async def admin_add_tool( ) except IntegrityError as ex: error_message = ErrorFormatter.format_database_error(ex) - LOGGER.error(f"IntegrityError in admin_add_tool: {error_message}") + if logger: logger.error(f"IntegrityError in admin_add_resource: {error_message}") return JSONResponse(status_code=409, content=error_message) - except ToolNameConflictError as ex: - LOGGER.error(f"ToolNameConflictError in admin_add_tool: {str(ex)}") - return JSONResponse(content={"message": str(ex), "success": False}, status_code=409) except ToolError as ex: return JSONResponse(content={"message": str(ex), "success": False}, status_code=500) except ValidationError as ex: # This block should catch ValidationError - LOGGER.error(f"ValidationError in admin_add_tool: {str(ex)}") + if logger: logger.error(f"ValidationError in admin_add_tool: {str(ex)}") return JSONResponse(content=ErrorFormatter.format_validation_error(ex), status_code=422) except Exception as ex: - LOGGER.error(f"Unexpected error in admin_add_tool: {str(ex)}") + if logger: logger.error(f"Unexpected error in admin_add_tool: {str(ex)}") return JSONResponse(content={"message": str(ex), "success": False}, status_code=500) @@ -4909,14 +4683,13 @@ async def admin_edit_tool( tool_id: str, request: Request, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), + user: str = Depends(require_auth), ) -> Response: """ Edit a tool via the admin UI. Expects form fields: - name - - displayName (optional) - url - description (optional) - requestType (to be mapped to request_type) @@ -4959,7 +4732,7 @@ async def admin_edit_tool( >>> import json >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> tool_id = "tool-to-edit" >>> # Happy path: Edit tool successfully @@ -5083,32 +4856,21 @@ async def admin_edit_tool( >>> tool_service.update_tool = original_update_tool """ - LOGGER.debug(f"User {get_user_email(user)} is editing tool ID {tool_id}") + if logger: logger.debug(f"User {user} is editing tool ID {tool_id}") form = await request.form() # Parse tags from comma-separated string tags_str = str(form.get("tags", "")) tags: list[str] = [tag.strip() for tag in tags_str.split(",") if tag.strip()] if tags_str else [] - visibility = str(form.get("visibility", "private")) - - user_email = get_user_email(user) - # Determine personal team for default assignment - team_id = form.get("team_id", None) - team_service = TeamManagementService(db) - team_id = await team_service.verify_team_for_user(user_email, team_id) - - headers_raw2 = form.get("headers") - input_schema_raw2 = form.get("input_schema") - annotations_raw2 = form.get("annotations") +# Assemble snake_case keys expected by Pydantic schemas from form inputs tool_data: dict[str, Any] = { "name": form.get("name"), - "displayName": form.get("displayName"), "custom_name": form.get("customName"), "url": form.get("url"), "description": form.get("description"), - "headers": json.loads(headers_raw2 if isinstance(headers_raw2, str) and headers_raw2 else "{}"), - "input_schema": json.loads(input_schema_raw2 if isinstance(input_schema_raw2, str) and input_schema_raw2 else "{}"), - "annotations": json.loads(annotations_raw2 if isinstance(annotations_raw2, str) and annotations_raw2 else "{}"), + "headers": json.loads(str(form.get("headers") or "{}")), + "input_schema": json.loads(str(form.get("input_schema") or "{}")), + "annotations": json.loads(str(form.get("annotations") or "{}")), "jsonpath_filter": form.get("jsonpathFilter", ""), "auth_type": form.get("auth_type", ""), "auth_username": form.get("auth_username", ""), @@ -5117,9 +4879,6 @@ async def admin_edit_tool( "auth_header_key": form.get("auth_header_key", ""), "auth_header_value": form.get("auth_header_value", ""), "tags": tags, - "visibility": visibility, - "owner_email": user_email, - "team_id": team_id, } # Only include integration_type if it's provided (not disabled in form) if "integrationType" in form: @@ -5127,9 +4886,9 @@ async def admin_edit_tool( # Only include request_type if it's provided (not disabled in form) if "requestType" in form: tool_data["request_type"] = form.get("requestType") - LOGGER.debug(f"Tool update data built: {tool_data}") + if logger: logger.debug(f"Tool update data built: {tool_data}") try: - tool = ToolUpdate(**tool_data) # Pydantic validation happens here + tool = ToolUpdate(**tool_data) # Pydantic validation happens here; raises ValidationError on bad input # Get current tool to extract current version current_tool = db.get(DbTool, tool_id) @@ -5150,24 +4909,21 @@ async def admin_edit_tool( return JSONResponse(content={"message": "Edit tool successfully", "success": True}, status_code=200) except IntegrityError as ex: error_message = ErrorFormatter.format_database_error(ex) - LOGGER.error(f"IntegrityError in admin_tool_resource: {error_message}") + if logger: logger.error(f"IntegrityError in admin_tool_resource: {error_message}") return JSONResponse(status_code=409, content=error_message) - except ToolNameConflictError as ex: - LOGGER.error(f"ToolNameConflictError in admin_edit_tool: {str(ex)}") - return JSONResponse(content={"message": str(ex), "success": False}, status_code=409) except ToolError as ex: - LOGGER.error(f"ToolError in admin_edit_tool: {str(ex)}") + if logger: logger.error(f"ToolError in admin_edit_tool: {str(ex)}") return JSONResponse(content={"message": str(ex), "success": False}, status_code=500) except ValidationError as ex: # Catch Pydantic validation errors - LOGGER.error(f"ValidationError in admin_edit_tool: {str(ex)}") + if logger: logger.error(f"ValidationError in admin_edit_tool: {str(ex)}") return JSONResponse(content=ErrorFormatter.format_validation_error(ex), status_code=422) except Exception as ex: # Generic catch-all for unexpected errors - LOGGER.error(f"Unexpected error in admin_edit_tool: {str(ex)}") + if logger: logger.error(f"Unexpected error in admin_edit_tool: {str(ex)}") return JSONResponse(content={"message": str(ex), "success": False}, status_code=500) @admin_router.post("/tools/{tool_id}/delete") -async def admin_delete_tool(tool_id: str, request: Request, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> RedirectResponse: +async def admin_delete_tool(tool_id: str, request: Request, db: Session = Depends(get_db), user: str = Depends(require_auth)) -> RedirectResponse: """ Delete a tool via the admin UI. @@ -5193,7 +4949,7 @@ async def admin_delete_tool(tool_id: str, request: Request, db: Session = Depend >>> from starlette.datastructures import FormData >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> tool_id = "tool-to-delete" >>> >>> # Happy path: Delete tool @@ -5238,11 +4994,11 @@ async def admin_delete_tool(tool_id: str, request: Request, db: Session = Depend >>> # Restore original method >>> tool_service.delete_tool = original_delete_tool """ - LOGGER.debug(f"User {get_user_email(user)} is deleting tool ID {tool_id}") + if logger: logger.debug(f"User {user} is deleting tool ID {tool_id}") try: await tool_service.delete_tool(db, tool_id) except Exception as e: - LOGGER.error(f"Error deleting tool: {e}") + if logger: logger.error(f"Error deleting tool: {e}") form = await request.form() is_inactive_checked = str(form.get("is_inactive_checked", "false")) @@ -5258,7 +5014,7 @@ async def admin_toggle_tool( tool_id: str, request: Request, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), + user: str = Depends(require_auth), ) -> RedirectResponse: """ Toggle a tool's active status via the admin UI. @@ -5286,7 +5042,7 @@ async def admin_toggle_tool( >>> from starlette.datastructures import FormData >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> tool_id = "tool-to-toggle" >>> >>> # Happy path: Activate tool @@ -5343,14 +5099,14 @@ async def admin_toggle_tool( >>> # Restore original method >>> tool_service.toggle_tool_status = original_toggle_tool_status """ - LOGGER.debug(f"User {get_user_email(user)} is toggling tool ID {tool_id}") + if logger: logger.debug(f"User {user} is toggling tool ID {tool_id}") form = await request.form() activate = str(form.get("activate", "true")).lower() == "true" is_inactive_checked = str(form.get("is_inactive_checked", "false")) try: await tool_service.toggle_tool_status(db, tool_id, activate, reachable=activate) except Exception as e: - LOGGER.error(f"Error toggling tool status: {e}") + if logger: logger.error(f"Error toggling tool status: {e}") root_path = request.scope.get("root_path", "") if is_inactive_checked.lower() == "true": @@ -5359,7 +5115,7 @@ async def admin_toggle_tool( @admin_router.get("/gateways/{gateway_id}", response_model=GatewayRead) -async def admin_get_gateway(gateway_id: str, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> Dict[str, Any]: +async def admin_get_gateway(gateway_id: str, db: Session = Depends(get_db), user: str = Depends(require_auth)) -> Dict[str, Any]: """Get gateway details for the admin UI. Args: @@ -5383,7 +5139,7 @@ async def admin_get_gateway(gateway_id: str, db: Session = Depends(get_db), user >>> from fastapi import HTTPException >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> gateway_id = "test-gateway-id" >>> >>> # Mock gateway data @@ -5435,19 +5191,19 @@ async def admin_get_gateway(gateway_id: str, db: Session = Depends(get_db), user >>> # Restore original method >>> gateway_service.get_gateway = original_get_gateway """ - LOGGER.debug(f"User {get_user_email(user)} requested details for gateway ID {gateway_id}") + if logger: logger.debug(f"User {user} requested details for gateway ID {gateway_id}") try: gateway = await gateway_service.get_gateway(db, gateway_id) return gateway.model_dump(by_alias=True) except GatewayNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: - LOGGER.error(f"Error getting gateway {gateway_id}: {e}") + if logger: logger.error(f"Error getting gateway {gateway_id}: {e}") raise e @admin_router.post("/gateways") -async def admin_add_gateway(request: Request, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> JSONResponse: +async def admin_add_gateway(request: Request, db: Session = Depends(get_db), user: str = Depends(require_auth)) -> JSONResponse: """Add a gateway via the admin UI. Expects form fields: @@ -5477,7 +5233,7 @@ async def admin_add_gateway(request: Request, db: Session = Depends(get_db), use >>> import json # Added import for json.loads >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> >>> # Happy path: Add a new gateway successfully with basic auth details >>> form_data_success = FormData([ @@ -5557,7 +5313,7 @@ async def admin_add_gateway(request: Request, db: Session = Depends(get_db), use >>> # Restore original method >>> gateway_service.register_gateway = original_register_gateway """ - LOGGER.debug(f"User {get_user_email(user)} is adding a new gateway") + if logger: logger.debug(f"User {user} is adding a new gateway") form = await request.form() try: # Parse tags from comma-separated string @@ -5584,11 +5340,9 @@ async def admin_add_gateway(request: Request, db: Session = Depends(get_db), use encryption = get_oauth_encryption(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_config["client_secret"]) except (json.JSONDecodeError, ValueError) as e: - LOGGER.error(f"Failed to parse OAuth config: {e}") + if logger: logger.error(f"Failed to parse OAuth config: {e}") oauth_config = None - visibility = str(form.get("visibility", "private")) - # Handle passthrough_headers passthrough_headers = str(form.get("passthrough_headers")) if passthrough_headers and passthrough_headers.strip(): @@ -5612,10 +5366,12 @@ async def admin_add_gateway(request: Request, db: Session = Depends(get_db), use auth_token=str(form.get("auth_token", "")), auth_header_key=str(form.get("auth_header_key", "")), auth_header_value=str(form.get("auth_header_value", "")), + auth_value=str(form.get("auth_value", "")), auth_headers=auth_headers if auth_headers else None, oauth_config=oauth_config, passthrough_headers=passthrough_headers, - visibility=visibility, + team_id=None, + owner_email=None, ) except KeyError as e: # Convert KeyError to ValidationError-like response @@ -5623,20 +5379,13 @@ async def admin_add_gateway(request: Request, db: Session = Depends(get_db), use except ValidationError as ex: # --- Getting only the custom message from the ValueError --- - error_ctx = [str(err["ctx"]["error"]) for err in ex.errors()] + error_ctx = [str(err.get("ctx", {}).get("error", str(err))) for err in ex.errors()] return JSONResponse(content={"success": False, "message": "; ".join(error_ctx)}, status_code=422) - user_email = get_user_email(user) - team_id = form.get("team_id", None) - - team_service = TeamManagementService(db) - team_id = await team_service.verify_team_for_user(user_email, team_id) - try: # Extract creation metadata metadata = MetadataCapture.extract_creation_metadata(request, user) - team_id_cast = cast(Optional[str], team_id) await gateway_service.register_gateway( db, gateway, @@ -5644,14 +5393,11 @@ async def admin_add_gateway(request: Request, db: Session = Depends(get_db), use created_from_ip=metadata["created_from_ip"], created_via=metadata["created_via"], created_user_agent=metadata["created_user_agent"], - visibility=visibility, - team_id=team_id_cast, - owner_email=user_email, ) # Provide specific guidance for OAuth Authorization Code flow message = "Gateway registered successfully!" - if oauth_config and isinstance(oauth_config, dict) and oauth_config.get("grant_type") == "authorization_code": + if oauth_config and oauth_config.get("grant_type") == "authorization_code": message = ( "Gateway registered successfully! 🎉\n\n" "⚠️ IMPORTANT: This gateway uses OAuth Authorization Code flow.\n" @@ -5669,16 +5415,10 @@ async def admin_add_gateway(request: Request, db: Session = Depends(get_db), use except GatewayConnectionError as ex: return JSONResponse(content={"message": str(ex), "success": False}, status_code=502) - except GatewayUrlConflictError as ex: - return JSONResponse(content={"message": str(ex), "success": False}, status_code=409) - except GatewayNameConflictError as ex: - return JSONResponse(content={"message": str(ex), "success": False}, status_code=409) except ValueError as ex: return JSONResponse(content={"message": str(ex), "success": False}, status_code=400) except RuntimeError as ex: return JSONResponse(content={"message": str(ex), "success": False}, status_code=500) - except ValidationError as ex: - return JSONResponse(content=ErrorFormatter.format_validation_error(ex), status_code=422) except IntegrityError as ex: return JSONResponse(content=ErrorFormatter.format_database_error(ex), status_code=409) except Exception as ex: @@ -5694,7 +5434,7 @@ async def admin_edit_gateway( gateway_id: str, request: Request, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), + user: str = Depends(require_auth), ) -> JSONResponse: """Edit a gateway via the admin UI. @@ -5722,7 +5462,7 @@ async def admin_edit_gateway( >>> from pydantic import ValidationError >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> gateway_id = "gateway-to-edit" >>> >>> # Happy path: Edit gateway successfully @@ -5795,15 +5535,13 @@ async def admin_edit_gateway( >>> # Restore original method >>> gateway_service.update_gateway = original_update_gateway """ - LOGGER.debug(f"User {get_user_email(user)} is editing gateway ID {gateway_id}") + if logger: logger.debug(f"User {user} is editing gateway ID {gateway_id}") form = await request.form() try: # Parse tags from comma-separated string tags_str = str(form.get("tags", "")) tags: List[str] = [tag.strip() for tag in tags_str.split(",") if tag.strip()] if tags_str else [] - visibility = str(form.get("visibility", "private")) - # Parse auth_headers JSON if present auth_headers_json = str(form.get("auth_headers")) auth_headers = [] @@ -5835,17 +5573,9 @@ async def admin_edit_gateway( encryption = get_oauth_encryption(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_config["client_secret"]) except (json.JSONDecodeError, ValueError) as e: - LOGGER.error(f"Failed to parse OAuth config: {e}") + if logger: logger.error(f"Failed to parse OAuth config: {e}") oauth_config = None - user_email = get_user_email(user) - # Determine personal team for default assignment - team_id_raw = form.get("team_id", None) - team_id = str(team_id_raw) if team_id_raw is not None else None - - team_service = TeamManagementService(db) - team_id = await team_service.verify_team_for_user(user_email, team_id) - gateway = GatewayUpdate( # Pydantic validation happens here name=str(form.get("name")), url=str(form["url"]), @@ -5862,21 +5592,11 @@ async def admin_edit_gateway( auth_headers=auth_headers if auth_headers else None, passthrough_headers=passthrough_headers, oauth_config=oauth_config, - visibility=visibility, - owner_email=user_email, - team_id=team_id, - ) - - mod_metadata = MetadataCapture.extract_modification_metadata(request, user, 0) - await gateway_service.update_gateway( - db, - gateway_id, - gateway, - modified_by=mod_metadata["modified_by"], - modified_from_ip=mod_metadata["modified_from_ip"], - modified_via=mod_metadata["modified_via"], - modified_user_agent=mod_metadata["modified_user_agent"], + team_id=None, + owner_email=None, + visibility=None, ) + await gateway_service.update_gateway(db, gateway_id, gateway) return JSONResponse( content={"message": "Gateway updated successfully!", "success": True}, status_code=200, @@ -5896,7 +5616,7 @@ async def admin_edit_gateway( @admin_router.post("/gateways/{gateway_id}/delete") -async def admin_delete_gateway(gateway_id: str, request: Request, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> RedirectResponse: +async def admin_delete_gateway(gateway_id: str, request: Request, db: Session = Depends(get_db), user: str = Depends(require_auth)) -> RedirectResponse: """ Delete a gateway via the admin UI. @@ -5922,7 +5642,7 @@ async def admin_delete_gateway(gateway_id: str, request: Request, db: Session = >>> from starlette.datastructures import FormData >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> gateway_id = "gateway-to-delete" >>> >>> # Happy path: Delete gateway @@ -5967,11 +5687,11 @@ async def admin_delete_gateway(gateway_id: str, request: Request, db: Session = >>> # Restore original method >>> gateway_service.delete_gateway = original_delete_gateway """ - LOGGER.debug(f"User {get_user_email(user)} is deleting gateway ID {gateway_id}") + if logger: logger.debug(f"User {user} is deleting gateway ID {gateway_id}") try: await gateway_service.delete_gateway(db, gateway_id) except Exception as e: - LOGGER.error(f"Error deleting gateway: {e}") + if logger: logger.error(f"Error deleting gateway: {e}") form = await request.form() is_inactive_checked = str(form.get("is_inactive_checked", "false")) @@ -5983,7 +5703,7 @@ async def admin_delete_gateway(gateway_id: str, request: Request, db: Session = @admin_router.get("/resources/{uri:path}") -async def admin_get_resource(uri: str, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> Dict[str, Any]: +async def admin_get_resource(uri: str, db: Session = Depends(get_db), user: str = Depends(require_auth)) -> Dict[str, Any]: """Get resource details for the admin UI. Args: @@ -6007,7 +5727,7 @@ async def admin_get_resource(uri: str, db: Session = Depends(get_db), user=Depen >>> from fastapi import HTTPException >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> resource_uri = "test://resource/get" >>> >>> # Mock resource data @@ -6066,7 +5786,7 @@ async def admin_get_resource(uri: str, db: Session = Depends(get_db), user=Depen >>> resource_service.get_resource_by_uri = original_get_resource_by_uri >>> resource_service.read_resource = original_read_resource """ - LOGGER.debug(f"User {get_user_email(user)} requested details for resource URI {uri}") + if logger: logger.debug(f"User {user} requested details for resource URI {uri}") try: resource = await resource_service.get_resource_by_uri(db, uri) content = await resource_service.read_resource(db, uri) @@ -6074,12 +5794,12 @@ async def admin_get_resource(uri: str, db: Session = Depends(get_db), user=Depen except ResourceNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: - LOGGER.error(f"Error getting resource {uri}: {e}") + if logger: logger.error(f"Error getting resource {uri}: {e}") raise e @admin_router.post("/resources") -async def admin_add_resource(request: Request, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> Response: +async def admin_add_resource(request: Request, db: Session = Depends(get_db), user: str = Depends(require_auth)) -> Response: """ Add a resource via the admin UI. @@ -6106,7 +5826,7 @@ async def admin_add_resource(request: Request, db: Session = Depends(get_db), us >>> from starlette.datastructures import FormData >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> form_data = FormData([ ... ("uri", "test://resource1"), ... ("name", "Test Resource"), @@ -6129,7 +5849,7 @@ async def admin_add_resource(request: Request, db: Session = Depends(get_db), us True >>> resource_service.register_resource = original_register_resource """ - LOGGER.debug(f"User {get_user_email(user)} is adding a new resource") + if logger: logger.debug(f"User {user} is adding a new resource") form = await request.form() # Parse tags from comma-separated string @@ -6145,6 +5865,8 @@ async def admin_add_resource(request: Request, db: Session = Depends(get_db), us template=cast(str | None, form.get("template")), content=str(form["content"]), tags=tags, + team_id=None, + owner_email=None, ) metadata = MetadataCapture.extract_creation_metadata(request, user) @@ -6165,14 +5887,14 @@ async def admin_add_resource(request: Request, db: Session = Depends(get_db), us ) except Exception as ex: if isinstance(ex, ValidationError): - LOGGER.error(f"ValidationError in admin_add_resource: {ErrorFormatter.format_validation_error(ex)}") + if logger: logger.error(f"ValidationError in admin_add_resource: {ErrorFormatter.format_validation_error(ex)}") return JSONResponse(content=ErrorFormatter.format_validation_error(ex), status_code=422) if isinstance(ex, IntegrityError): error_message = ErrorFormatter.format_database_error(ex) - LOGGER.error(f"IntegrityError in admin_add_resource: {error_message}") + if logger: logger.error(f"IntegrityError in admin_add_resource: {error_message}") return JSONResponse(status_code=409, content=error_message) - LOGGER.error(f"Error in admin_add_resource: {ex}") + if logger: logger.error(f"Error in admin_add_resource: {ex}") return JSONResponse(content={"message": str(ex), "success": False}, status_code=500) @@ -6181,7 +5903,7 @@ async def admin_edit_resource( uri: str, request: Request, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), + user: str = Depends(require_auth), ) -> JSONResponse: """ Edit a resource via the admin UI. @@ -6209,7 +5931,7 @@ async def admin_edit_resource( >>> from starlette.datastructures import FormData >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> form_data = FormData([ ... ("name", "Updated Resource"), ... ("description", "Updated description"), @@ -6267,7 +5989,7 @@ async def admin_edit_resource( >>> # Reset mock >>> resource_service.update_resource = original_update_resource """ - LOGGER.debug(f"User {get_user_email(user)} is editing resource URI {uri}") + if logger: logger.debug(f"User {user} is editing resource URI {uri}") form = await request.form() # Parse tags from comma-separated string @@ -6275,7 +5997,6 @@ async def admin_edit_resource( tags: List[str] = [tag.strip() for tag in tags_str.split(",") if tag.strip()] if tags_str else [] try: - mod_metadata = MetadataCapture.extract_modification_metadata(request, user, 0) resource = ResourceUpdate( name=str(form["name"]), description=str(form.get("description")), @@ -6283,34 +6004,29 @@ async def admin_edit_resource( content=str(form["content"]), template=str(form.get("template")), tags=tags, + team_id=None, + owner_email=None, + visibility=None, ) - await resource_service.update_resource( - db, - uri, - resource, - modified_by=mod_metadata["modified_by"], - modified_from_ip=mod_metadata["modified_from_ip"], - modified_via=mod_metadata["modified_via"], - modified_user_agent=mod_metadata["modified_user_agent"], - ) + await resource_service.update_resource(db, uri, resource) return JSONResponse( content={"message": "Resource updated successfully!", "success": True}, status_code=200, ) except Exception as ex: if isinstance(ex, ValidationError): - LOGGER.error(f"ValidationError in admin_edit_resource: {ErrorFormatter.format_validation_error(ex)}") + if logger: logger.error(f"ValidationError in admin_edit_resource: {ErrorFormatter.format_validation_error(ex)}") return JSONResponse(content=ErrorFormatter.format_validation_error(ex), status_code=422) if isinstance(ex, IntegrityError): error_message = ErrorFormatter.format_database_error(ex) - LOGGER.error(f"IntegrityError in admin_edit_resource: {error_message}") + if logger: logger.error(f"IntegrityError in admin_edit_resource: {error_message}") return JSONResponse(status_code=409, content=error_message) - LOGGER.error(f"Error in admin_edit_resource: {ex}") + if logger: logger.error(f"Error in admin_edit_resource: {ex}") return JSONResponse(content={"message": str(ex), "success": False}, status_code=500) @admin_router.post("/resources/{uri:path}/delete") -async def admin_delete_resource(uri: str, request: Request, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> RedirectResponse: +async def admin_delete_resource(uri: str, request: Request, db: Session = Depends(get_db), user: str = Depends(require_auth)) -> RedirectResponse: """ Delete a resource via the admin UI. @@ -6336,7 +6052,7 @@ async def admin_delete_resource(uri: str, request: Request, db: Session = Depend >>> from starlette.datastructures import FormData >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> mock_request = MagicMock(spec=Request) >>> form_data = FormData([("is_inactive_checked", "false")]) >>> mock_request.form = AsyncMock(return_value=form_data) @@ -6357,15 +6073,15 @@ async def admin_delete_resource(uri: str, request: Request, db: Session = Depend >>> mock_request.form = AsyncMock(return_value=form_data_inactive) >>> >>> async def test_admin_delete_resource_inactive(): - ... response = await admin_delete_resource("test://resource1", mock_request, mock_db, mock_user) + ... response = await admin_delete_resource("test://resource1", mock_request, mock_user) ... return isinstance(response, RedirectResponse) and "include_inactive=true" in response.headers["location"] >>> >>> asyncio.run(test_admin_delete_resource_inactive()) True >>> resource_service.delete_resource = original_delete_resource """ - LOGGER.debug(f"User {get_user_email(user)} is deleting resource URI {uri}") - await resource_service.delete_resource(user["db"] if isinstance(user, dict) else db, uri) + if logger: logger.debug(f"User {user} is deleting resource URI {uri}") + await resource_service.delete_resource(db, uri) form = await request.form() is_inactive_checked: str = str(form.get("is_inactive_checked", "false")) root_path = request.scope.get("root_path", "") @@ -6379,7 +6095,7 @@ async def admin_toggle_resource( resource_id: int, request: Request, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), + user: str = Depends(require_auth), ) -> RedirectResponse: """ Toggle a resource's active status via the admin UI. @@ -6407,7 +6123,7 @@ async def admin_toggle_resource( >>> from starlette.datastructures import FormData >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> mock_request = MagicMock(spec=Request) >>> form_data = FormData([ ... ("activate", "true"), @@ -6470,14 +6186,14 @@ async def admin_toggle_resource( True >>> resource_service.toggle_resource_status = original_toggle_resource_status """ - LOGGER.debug(f"User {get_user_email(user)} is toggling resource ID {resource_id}") + if logger: logger.debug(f"User {user} is toggling resource ID {resource_id}") form = await request.form() activate = str(form.get("activate", "true")).lower() == "true" is_inactive_checked = str(form.get("is_inactive_checked", "false")) try: await resource_service.toggle_resource_status(db, resource_id, activate) except Exception as e: - LOGGER.error(f"Error toggling resource status: {e}") + if logger: logger.error(f"Error toggling resource status: {e}") root_path = request.scope.get("root_path", "") if is_inactive_checked.lower() == "true": @@ -6486,7 +6202,7 @@ async def admin_toggle_resource( @admin_router.get("/prompts/{name}") -async def admin_get_prompt(name: str, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> Dict[str, Any]: +async def admin_get_prompt(name: str, db: Session = Depends(get_db), user: str = Depends(require_auth)) -> Dict[str, Any]: """Get prompt details for the admin UI. Args: @@ -6510,7 +6226,7 @@ async def admin_get_prompt(name: str, db: Session = Depends(get_db), user=Depend >>> from fastapi import HTTPException >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> prompt_name = "test-prompt" >>> >>> # Mock prompt details @@ -6573,7 +6289,7 @@ async def admin_get_prompt(name: str, db: Session = Depends(get_db), user=Depend >>> >>> prompt_service.get_prompt_details = original_get_prompt_details """ - LOGGER.debug(f"User {get_user_email(user)} requested details for prompt name {name}") + if logger: logger.debug(f"User {user} requested details for prompt name {name}") try: prompt_details = await prompt_service.get_prompt_details(db, name) prompt = PromptRead.model_validate(prompt_details) @@ -6581,12 +6297,12 @@ async def admin_get_prompt(name: str, db: Session = Depends(get_db), user=Depend except PromptNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: - LOGGER.error(f"Error getting prompt {name}: {e}") + if logger: logger.error(f"Error getting prompt {name}: {e}") raise e @admin_router.post("/prompts") -async def admin_add_prompt(request: Request, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> JSONResponse: +async def admin_add_prompt(request: Request, db: Session = Depends(get_db), user: str = Depends(require_auth)) -> JSONResponse: """Add a prompt via the admin UI. Expects form fields: @@ -6611,7 +6327,7 @@ async def admin_add_prompt(request: Request, db: Session = Depends(get_db), user >>> from starlette.datastructures import FormData >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> form_data = FormData([ ... ("name", "Test Prompt"), ... ("description", "A test prompt"), @@ -6634,7 +6350,7 @@ async def admin_add_prompt(request: Request, db: Session = Depends(get_db), user >>> prompt_service.register_prompt = original_register_prompt """ - LOGGER.debug(f"User {get_user_email(user)} is adding a new prompt") + if logger: logger.debug(f"User {user} is adding a new prompt") form = await request.form() # Parse tags from comma-separated string @@ -6653,6 +6369,8 @@ async def admin_add_prompt(request: Request, db: Session = Depends(get_db), user template=str(form["template"]), arguments=arguments, tags=tags, + team_id=None, + owner_email=None, ) # Extract creation metadata metadata = MetadataCapture.extract_creation_metadata(request, user) @@ -6673,13 +6391,13 @@ async def admin_add_prompt(request: Request, db: Session = Depends(get_db), user ) except Exception as ex: if isinstance(ex, ValidationError): - LOGGER.error(f"ValidationError in admin_add_prompt: {ErrorFormatter.format_validation_error(ex)}") + if logger: logger.error(f"ValidationError in admin_add_prompt: {ErrorFormatter.format_validation_error(ex)}") return JSONResponse(content=ErrorFormatter.format_validation_error(ex), status_code=422) if isinstance(ex, IntegrityError): error_message = ErrorFormatter.format_database_error(ex) - LOGGER.error(f"IntegrityError in admin_add_prompt: {error_message}") + if logger: logger.error(f"IntegrityError in admin_add_prompt: {error_message}") return JSONResponse(status_code=409, content=error_message) - LOGGER.error(f"Error in admin_add_prompt: {ex}") + if logger: logger.error(f"Error in admin_add_prompt: {ex}") return JSONResponse(content={"message": str(ex), "success": False}, status_code=500) @@ -6688,7 +6406,7 @@ async def admin_edit_prompt( name: str, request: Request, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), + user: str = Depends(require_auth), ) -> Response: """Edit a prompt via the admin UI. @@ -6715,7 +6433,7 @@ async def admin_edit_prompt( >>> from starlette.datastructures import FormData >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> prompt_name = "test-prompt" >>> form_data = FormData([ ... ("name", "Updated Prompt"), @@ -6755,7 +6473,7 @@ async def admin_edit_prompt( True >>> prompt_service.update_prompt = original_update_prompt """ - LOGGER.debug(f"User {get_user_email(user)} is editing prompt name {name}") + if logger: logger.debug(f"User {user} is editing prompt name {name}") form = await request.form() args_json: str = str(form.get("arguments")) or "[]" @@ -6764,23 +6482,17 @@ async def admin_edit_prompt( tags_str = str(form.get("tags", "")) tags: List[str] = [tag.strip() for tag in tags_str.split(",") if tag.strip()] if tags_str else [] try: - mod_metadata = MetadataCapture.extract_modification_metadata(request, user, 0) prompt = PromptUpdate( name=str(form["name"]), description=str(form.get("description")), template=str(form["template"]), arguments=arguments, tags=tags, + team_id=None, + owner_email=None, + visibility=None, ) - await prompt_service.update_prompt( - db, - name, - prompt, - modified_by=mod_metadata["modified_by"], - modified_from_ip=mod_metadata["modified_from_ip"], - modified_via=mod_metadata["modified_via"], - modified_user_agent=mod_metadata["modified_user_agent"], - ) + await prompt_service.update_prompt(db, name, prompt) root_path = request.scope.get("root_path", "") is_inactive_checked: str = str(form.get("is_inactive_checked", "false")) @@ -6793,18 +6505,18 @@ async def admin_edit_prompt( ) except Exception as ex: if isinstance(ex, ValidationError): - LOGGER.error(f"ValidationError in admin_edit_prompt: {ErrorFormatter.format_validation_error(ex)}") + if logger: logger.error(f"ValidationError in admin_edit_prompt: {ErrorFormatter.format_validation_error(ex)}") return JSONResponse(content=ErrorFormatter.format_validation_error(ex), status_code=422) if isinstance(ex, IntegrityError): error_message = ErrorFormatter.format_database_error(ex) - LOGGER.error(f"IntegrityError in admin_edit_prompt: {error_message}") + if logger: logger.error(f"IntegrityError in admin_edit_prompt: {error_message}") return JSONResponse(status_code=409, content=error_message) - LOGGER.error(f"Error in admin_edit_prompt: {ex}") + if logger: logger.error(f"Error in admin_edit_prompt: {ex}") return JSONResponse(content={"message": str(ex), "success": False}, status_code=500) @admin_router.post("/prompts/{name}/delete") -async def admin_delete_prompt(name: str, request: Request, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> RedirectResponse: +async def admin_delete_prompt(name: str, request: Request, db: Session = Depends(get_db), user: str = Depends(require_auth)) -> RedirectResponse: """ Delete a prompt via the admin UI. @@ -6830,7 +6542,7 @@ async def admin_delete_prompt(name: str, request: Request, db: Session = Depends >>> from starlette.datastructures import FormData >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> mock_request = MagicMock(spec=Request) >>> form_data = FormData([("is_inactive_checked", "false")]) >>> mock_request.form = AsyncMock(return_value=form_data) @@ -6858,7 +6570,7 @@ async def admin_delete_prompt(name: str, request: Request, db: Session = Depends True >>> prompt_service.delete_prompt = original_delete_prompt """ - LOGGER.debug(f"User {get_user_email(user)} is deleting prompt name {name}") + if logger: logger.debug(f"User {user} is deleting prompt name {name}") await prompt_service.delete_prompt(db, name) form = await request.form() is_inactive_checked: str = str(form.get("is_inactive_checked", "false")) @@ -6873,7 +6585,7 @@ async def admin_toggle_prompt( prompt_id: int, request: Request, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), + user: str = Depends(require_auth), ) -> RedirectResponse: """ Toggle a prompt's active status via the admin UI. @@ -6901,7 +6613,7 @@ async def admin_toggle_prompt( >>> from starlette.datastructures import FormData >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> mock_request = MagicMock(spec=Request) >>> form_data = FormData([ ... ("activate", "true"), @@ -6964,14 +6676,14 @@ async def admin_toggle_prompt( True >>> prompt_service.toggle_prompt_status = original_toggle_prompt_status """ - LOGGER.debug(f"User {get_user_email(user)} is toggling prompt ID {prompt_id}") + if logger: logger.debug(f"User {user} is toggling prompt ID {prompt_id}") form = await request.form() activate: bool = str(form.get("activate", "true")).lower() == "true" is_inactive_checked: str = str(form.get("is_inactive_checked", "false")) try: await prompt_service.toggle_prompt_status(db, prompt_id, activate) except Exception as e: - LOGGER.error(f"Error toggling prompt status: {e}") + if logger: logger.error(f"Error toggling prompt status: {e}") root_path = request.scope.get("root_path", "") if is_inactive_checked.lower() == "true": @@ -6980,7 +6692,7 @@ async def admin_toggle_prompt( @admin_router.post("/roots") -async def admin_add_root(request: Request, user=Depends(get_current_user_with_permissions)) -> RedirectResponse: +async def admin_add_root(request: Request, user: str = Depends(require_auth)) -> RedirectResponse: """Add a new root via the admin UI. Expects form fields: @@ -7001,8 +6713,7 @@ async def admin_add_root(request: Request, user=Depends(get_current_user_with_pe >>> from fastapi.responses import RedirectResponse >>> from starlette.datastructures import FormData >>> - >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> mock_request = MagicMock(spec=Request) >>> form_data = FormData([ ... ("uri", "test://root1"), @@ -7022,7 +6733,7 @@ async def admin_add_root(request: Request, user=Depends(get_current_user_with_pe True >>> root_service.add_root = original_add_root """ - LOGGER.debug(f"User {get_user_email(user)} is adding a new root") + if logger: logger.debug(f"User {user} is adding a new root") form = await request.form() uri = str(form["uri"]) name_value = form.get("name") @@ -7035,7 +6746,7 @@ async def admin_add_root(request: Request, user=Depends(get_current_user_with_pe @admin_router.post("/roots/{uri:path}/delete") -async def admin_delete_root(uri: str, request: Request, user=Depends(get_current_user_with_permissions)) -> RedirectResponse: +async def admin_delete_root(uri: str, request: Request, user: str = Depends(require_auth)) -> RedirectResponse: """ Delete a root via the admin UI. @@ -7059,8 +6770,7 @@ async def admin_delete_root(uri: str, request: Request, user=Depends(get_current >>> from fastapi.responses import RedirectResponse >>> from starlette.datastructures import FormData >>> - >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> mock_request = MagicMock(spec=Request) >>> form_data = FormData([("is_inactive_checked", "false")]) >>> mock_request.form = AsyncMock(return_value=form_data) @@ -7088,7 +6798,7 @@ async def admin_delete_root(uri: str, request: Request, user=Depends(get_current True >>> root_service.remove_root = original_remove_root """ - LOGGER.debug(f"User {get_user_email(user)} is deleting root URI {uri}") + if logger: logger.debug(f"User {user} is deleting root URI {uri}") await root_service.remove_root(uri) form = await request.form() root_path = request.scope.get("root_path", "") @@ -7101,11 +6811,13 @@ async def admin_delete_root(uri: str, request: Request, user=Depends(get_current # Metrics MetricsDict = Dict[str, Union[ToolMetrics, ResourceMetrics, ServerMetrics, PromptMetrics]] +# Import the response time formatting function +from mcpgateway.utils.metrics_common import format_response_time # @admin_router.get("/metrics", response_model=MetricsDict) # async def admin_get_metrics( # db: Session = Depends(get_db), -# user=Depends(get_current_user_with_permissions), +# user: str = Depends(require_auth), # ) -> MetricsDict: # """ # Retrieve aggregate metrics for all entity types via the admin UI. @@ -7125,7 +6837,7 @@ async def admin_delete_root(uri: str, request: Request, user=Depends(get_current # resources, servers, and prompts. Each value is a Pydantic model instance # specific to the entity type. # """ -# LOGGER.debug(f"User {get_user_email(user)} requested aggregate metrics") +# if logger: logger.debug(f"User {user} requested aggregate metrics") # tool_metrics = await tool_service.aggregate_metrics(db) # resource_metrics = await resource_service.aggregate_metrics(db) # server_metrics = await server_service.aggregate_metrics(db) @@ -7143,7 +6855,7 @@ async def admin_delete_root(uri: str, request: Request, user=Depends(get_current @admin_router.get("/metrics") async def get_aggregated_metrics( db: Session = Depends(get_db), - _user=Depends(get_current_user_with_permissions), + _user: Union[str, Dict[str, Any]] = Depends(require_auth), ) -> Dict[str, Any]: """Retrieve aggregated metrics and top performers for all entity types. @@ -7164,23 +6876,336 @@ async def get_aggregated_metrics( - 'topPerformers': A nested dictionary with top 5 tools, resources, prompts, and servers. """ + # Get ALL entities with metrics for UI display (same logic as CSV export) + from sqlalchemy import func, case, Float + from sqlalchemy.sql import desc + from mcpgateway.db import Tool, ToolMetric, Resource, ResourceMetric, Prompt, PromptMetric, Server, ServerMetric + from mcpgateway.utils.metrics_common import build_top_performers + + # Get ALL tools (including those with 0 metrics) + tools_query = ( + db.query( + Tool.id, + Tool.name, + func.coalesce(func.count(ToolMetric.id), 0).label("execution_count"), + func.avg(ToolMetric.response_time).label("avg_response_time"), + case( + ( + func.count(ToolMetric.id) > 0, + func.sum(case((ToolMetric.is_success.is_(True), 1), else_=0)).cast(Float) / func.count(ToolMetric.id) * 100, + ), + else_=None, + ).label("success_rate"), + func.max(ToolMetric.timestamp).label("last_execution"), + ) + .outerjoin(ToolMetric, Tool.id == ToolMetric.tool_id) + .group_by(Tool.id, Tool.name) + .order_by(desc("execution_count"), Tool.name) # Order by exec count, then name + ) + all_tools = build_top_performers(tools_query.all()) + + # Get ALL resources + resources_query = ( + db.query( + Resource.id, + Resource.uri.label("name"), + func.coalesce(func.count(ResourceMetric.id), 0).label("execution_count"), + func.avg(ResourceMetric.response_time).label("avg_response_time"), + case( + ( + func.count(ResourceMetric.id) > 0, + func.sum(case((ResourceMetric.is_success.is_(True), 1), else_=0)).cast(Float) / func.count(ResourceMetric.id) * 100, + ), + else_=None, + ).label("success_rate"), + func.max(ResourceMetric.timestamp).label("last_execution"), + ) + .outerjoin(ResourceMetric, Resource.id == ResourceMetric.resource_id) + .group_by(Resource.id, Resource.uri) + .order_by(desc("execution_count"), Resource.uri) + ) + all_resources = build_top_performers(resources_query.all()) + + # Get ALL prompts + prompts_query = ( + db.query( + Prompt.id, + Prompt.name, + func.coalesce(func.count(PromptMetric.id), 0).label("execution_count"), + func.avg(PromptMetric.response_time).label("avg_response_time"), + case( + ( + func.count(PromptMetric.id) > 0, + func.sum(case((PromptMetric.is_success.is_(True), 1), else_=0)).cast(Float) / func.count(PromptMetric.id) * 100, + ), + else_=None, + ).label("success_rate"), + func.max(PromptMetric.timestamp).label("last_execution"), + ) + .outerjoin(PromptMetric, Prompt.id == PromptMetric.prompt_id) + .group_by(Prompt.id, Prompt.name) + .order_by(desc("execution_count"), Prompt.name) + ) + all_prompts = build_top_performers(prompts_query.all()) + + # Get ALL servers + servers_query = ( + db.query( + Server.id, + Server.name, + func.coalesce(func.count(ServerMetric.id), 0).label("execution_count"), + func.avg(ServerMetric.response_time).label("avg_response_time"), + case( + ( + func.count(ServerMetric.id) > 0, + func.sum(case((ServerMetric.is_success.is_(True), 1), else_=0)).cast(Float) / func.count(ServerMetric.id) * 100, + ), + else_=None, + ).label("success_rate"), + func.max(ServerMetric.timestamp).label("last_execution"), + ) + .outerjoin(ServerMetric, Server.id == ServerMetric.server_id) + .group_by(Server.id, Server.name) + .order_by(desc("execution_count"), Server.name) + ) + all_servers = build_top_performers(servers_query.all()) + metrics = { "tools": await tool_service.aggregate_metrics(db), "resources": await resource_service.aggregate_metrics(db), "prompts": await prompt_service.aggregate_metrics(db), "servers": await server_service.aggregate_metrics(db), "topPerformers": { - "tools": await tool_service.get_top_tools(db, limit=5), - "resources": await resource_service.get_top_resources(db, limit=5), - "prompts": await prompt_service.get_top_prompts(db, limit=5), - "servers": await server_service.get_top_servers(db, limit=5), + "tools": all_tools, # Now includes ALL tools + "resources": all_resources, # Now includes ALL resources + "prompts": all_prompts, # Now includes ALL prompts + "servers": all_servers, # Now includes ALL servers }, } return metrics +@admin_router.get("/metrics/export", response_class=Response) +async def export_metrics_csv( + db: Session = Depends(get_db), + entity_type: str = Query(..., description="Entity type to export (tools, resources, prompts, servers)"), + limit: Optional[int] = Query(None, description="Maximum number of results to return. If not provided, all results are returned."), + user: str = Depends(require_auth), +) -> Response: + """Export metrics for a specific entity type to CSV format. + + This endpoint retrieves ALL entities of the specified type from the database and + exports them to CSV format with their performance metrics for download. + Entities without metrics will show 0 executions and N/A for response times. + Response times are formatted to 3 decimal places. + + Args: + db (Session): Database session dependency for querying metrics. + entity_type (str): Type of entity to export (tools, resources, prompts, servers). + limit (Optional[int]): Maximum number of results to return. If None, all results are returned. + user (str): Authenticated user. + + Returns: + Response: CSV file download response containing the metrics data for ALL entities. + + Raises: + HTTPException: If the entity type is invalid. + """ + if logger: logger.debug(f"User {user} requested CSV export of {entity_type} metrics") + + # Validate entity type + valid_types = ["tools", "resources", "prompts", "servers"] + if entity_type not in valid_types: + raise HTTPException(status_code=400, detail=f"Invalid entity type. Must be one of: {', '.join(valid_types)}") + + # Get ALL entities with their metrics data for CSV export (including those with 0 executions) + try: + if entity_type == "tools": + # Import required SQLAlchemy functions and models + from sqlalchemy import func, case, desc, Float + from mcpgateway.db import Tool, ToolMetric + from mcpgateway.utils.metrics_common import build_top_performers + + query = ( + db.query( + Tool.id, + Tool.name, + func.coalesce(func.count(ToolMetric.id), 0).label("execution_count"), + func.avg(ToolMetric.response_time).label("avg_response_time"), + case( + ( + func.count(ToolMetric.id) > 0, + func.sum(case((ToolMetric.is_success.is_(True), 1), else_=0)).cast(Float) / func.count(ToolMetric.id) * 100, + ), + else_=None, + ).label("success_rate"), + func.max(ToolMetric.timestamp).label("last_execution"), + ) + .outerjoin(ToolMetric, Tool.id == ToolMetric.tool_id) + .group_by(Tool.id, Tool.name) + .order_by(Tool.name) # Order by name for consistent CSV output + ) + + if limit is not None: + query = query.limit(limit) + + results = query.all() + performers = build_top_performers(results) + + elif entity_type == "resources": + from sqlalchemy import func, case, Float + from mcpgateway.db import Resource, ResourceMetric + from mcpgateway.utils.metrics_common import build_top_performers + + query = ( + db.query( + Resource.id, + Resource.uri.label("name"), # Use URI as name for resources + func.coalesce(func.count(ResourceMetric.id), 0).label("execution_count"), + func.avg(ResourceMetric.response_time).label("avg_response_time"), + case( + ( + func.count(ResourceMetric.id) > 0, + func.sum(case((ResourceMetric.is_success.is_(True), 1), else_=0)).cast(Float) / func.count(ResourceMetric.id) * 100, + ), + else_=None, + ).label("success_rate"), + func.max(ResourceMetric.timestamp).label("last_execution"), + ) + .outerjoin(ResourceMetric, Resource.id == ResourceMetric.resource_id) + .group_by(Resource.id, Resource.uri) + .order_by(Resource.uri) + ) + + if limit is not None: + query = query.limit(limit) + + results = query.all() + performers = build_top_performers(results) + + elif entity_type == "prompts": + from sqlalchemy import func, case, Float + from mcpgateway.db import Prompt, PromptMetric + from mcpgateway.utils.metrics_common import build_top_performers + + query = ( + db.query( + Prompt.id, + Prompt.name, + func.coalesce(func.count(PromptMetric.id), 0).label("execution_count"), + func.avg(PromptMetric.response_time).label("avg_response_time"), + case( + ( + func.count(PromptMetric.id) > 0, + func.sum(case((PromptMetric.is_success.is_(True), 1), else_=0)).cast(Float) / func.count(PromptMetric.id) * 100, + ), + else_=None, + ).label("success_rate"), + func.max(PromptMetric.timestamp).label("last_execution"), + ) + .outerjoin(PromptMetric, Prompt.id == PromptMetric.prompt_id) + .group_by(Prompt.id, Prompt.name) + .order_by(Prompt.name) + ) + + if limit is not None: + query = query.limit(limit) + + results = query.all() + performers = build_top_performers(results) + + elif entity_type == "servers": + from sqlalchemy import func, case, Float + from mcpgateway.db import Server, ServerMetric + from mcpgateway.utils.metrics_common import build_top_performers + + query = ( + db.query( + Server.id, + Server.name, + func.coalesce(func.count(ServerMetric.id), 0).label("execution_count"), + func.avg(ServerMetric.response_time).label("avg_response_time"), + case( + ( + func.count(ServerMetric.id) > 0, + func.sum(case((ServerMetric.is_success.is_(True), 1), else_=0)).cast(Float) / func.count(ServerMetric.id) * 100, + ), + else_=None, + ).label("success_rate"), + func.max(ServerMetric.timestamp).label("last_execution"), + ) + .outerjoin(ServerMetric, Server.id == ServerMetric.server_id) + .group_by(Server.id, Server.name) + .order_by(Server.name) + ) + + if limit is not None: + query = query.limit(limit) + + results = query.all() + performers = build_top_performers(results) + except Exception as e: + if logger: logger.error(f"Error exporting {entity_type} metrics to CSV: {str(e)}") + raise HTTPException(status_code=500, detail=f"Failed to export metrics: {str(e)}") + + # Handle empty data case + if not performers: + # Return empty CSV with headers + csv_content = "ID,Name,Execution Count,Average Response Time (s),Success Rate (%),Last Execution\n" + return Response( + content=csv_content, + media_type="text/csv", + headers={"Content-Disposition": f"attachment; filename={entity_type}_metrics.csv"} + ) + + # Create CSV content + output = StringIO() + writer = csv.writer(output) + + # Write header row + writer.writerow([ + "ID", + "Name", + "Execution Count", + "Average Response Time (s)", + "Success Rate (%)", + "Last Execution" + ]) + + # Write data rows with formatted values + for performer in performers: + # Format response time to 3 decimal places + formatted_response_time = format_response_time(performer.avg_response_time) if performer.avg_response_time is not None else "N/A" + + # Format success rate + success_rate = f"{performer.success_rate:.1f}" if performer.success_rate is not None else "N/A" + + # Format timestamp + last_execution = performer.last_execution.isoformat() if performer.last_execution else "N/A" + + writer.writerow([ + performer.id, + performer.name, + performer.execution_count, + formatted_response_time, + success_rate, + last_execution + ]) + + # Get the CSV content as a string + csv_content = output.getvalue() + output.close() + + # Return CSV response + return Response( + content=csv_content, + media_type="text/csv", + headers={"Content-Disposition": f"attachment; filename={entity_type}_metrics.csv"} + ) + + @admin_router.post("/metrics/reset", response_model=Dict[str, object]) -async def admin_reset_metrics(db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> Dict[str, object]: +async def admin_reset_metrics(db: Session = Depends(get_db), user: str = Depends(require_auth)) -> Dict[str, object]: """ Reset all metrics for tools, resources, servers, and prompts. Each service must implement its own reset_metrics method. @@ -7197,7 +7222,7 @@ async def admin_reset_metrics(db: Session = Depends(get_db), user=Depends(get_cu >>> from unittest.mock import AsyncMock, MagicMock >>> >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> >>> original_reset_metrics_tool = tool_service.reset_metrics >>> original_reset_metrics_resource = resource_service.reset_metrics @@ -7221,7 +7246,7 @@ async def admin_reset_metrics(db: Session = Depends(get_db), user=Depends(get_cu >>> server_service.reset_metrics = original_reset_metrics_server >>> prompt_service.reset_metrics = original_reset_metrics_prompt """ - LOGGER.debug(f"User {get_user_email(user)} requested to reset all metrics") + if logger: logger.debug(f"User {user} requested to reset all metrics") await tool_service.reset_metrics(db) await resource_service.reset_metrics(db) await server_service.reset_metrics(db) @@ -7230,7 +7255,7 @@ async def admin_reset_metrics(db: Session = Depends(get_db), user=Depends(get_cu @admin_router.post("/gateways/test", response_model=GatewayTestResponse) -async def admin_test_gateway(request: GatewayTestRequest, user=Depends(get_current_user_with_permissions)) -> GatewayTestResponse: +async def admin_test_gateway(request: GatewayTestRequest, user: str = Depends(require_auth)) -> GatewayTestResponse: """ Test a gateway by sending a request to its URL. This endpoint allows administrators to test the connectivity and response @@ -7249,8 +7274,7 @@ async def admin_test_gateway(request: GatewayTestRequest, user=Depends(get_curre >>> from fastapi import Request >>> import httpx >>> - >>> mock_db = MagicMock() - >>> mock_user = {"email": "test_user", "db": mock_db} + >>> mock_user = "test_user" >>> mock_request = GatewayTestRequest( ... base_url="https://api.example.com", ... path="/test", @@ -7371,9 +7395,9 @@ async def admin_test_gateway(request: GatewayTestRequest, user=Depends(get_curre """ full_url = str(request.base_url).rstrip("/") + "/" + request.path.lstrip("/") full_url = full_url.rstrip("/") - LOGGER.debug(f"User {get_user_email(user)} testing server at {request.base_url}.") - start_time: float = time.monotonic() + if logger: logger.debug(f"User {user} testing server at {request.base_url}.") try: + start_time: float = time.monotonic() async with ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": not settings.skip_ssl_verify}) as client: response: httpx.Response = await client.request(method=request.method.upper(), url=full_url, headers=request.headers, json=request.body) latency_ms = int((time.monotonic() - start_time) * 1000) @@ -7385,7 +7409,7 @@ async def admin_test_gateway(request: GatewayTestRequest, user=Depends(get_curre return GatewayTestResponse(status_code=response.status_code, latency_ms=latency_ms, body=response_body) except httpx.RequestError as e: - LOGGER.warning(f"Gateway test failed: {e}") + if logger: logger.warning(f"Gateway test failed: {e}") latency_ms = int((time.monotonic() - start_time) * 1000) return GatewayTestResponse(status_code=502, latency_ms=latency_ms, body={"error": "Request failed", "details": str(e)}) @@ -7400,7 +7424,7 @@ async def admin_list_tags( entity_types: Optional[str] = None, include_entities: bool = False, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), + user: str = Depends(require_auth), ) -> List[Dict[str, Any]]: """ List all unique tags with statistics for the admin UI. @@ -7436,7 +7460,7 @@ async def admin_list_tags( if entity_types: entity_types_list = [et.strip().lower() for et in entity_types.split(",") if et.strip()] - LOGGER.debug(f"Admin user {user} is retrieving tags for entity types: {entity_types_list}, include_entities: {include_entities}") + if logger: logger.debug(f"Admin user {user} is retrieving tags for entity types: {entity_types_list}, include_entities: {include_entities}") try: tags = await tag_service.get_all_tags(db, entity_types=entity_types_list, include_entities=include_entities) @@ -7470,7 +7494,7 @@ async def admin_list_tags( return result except Exception as e: - LOGGER.error(f"Failed to retrieve tags for admin: {str(e)}") + if logger: logger.error(f"Failed to retrieve tags for admin: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to retrieve tags: {str(e)}") @@ -7480,7 +7504,7 @@ async def admin_list_tags( async def admin_import_tools( request: Request, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), + user: str = Depends(require_auth), ) -> JSONResponse: """Bulk import multiple tools in a single request. @@ -7500,10 +7524,10 @@ async def admin_import_tools( """ # Check if bulk import is enabled if not settings.mcpgateway_bulk_import_enabled: - LOGGER.warning("Bulk import attempted but feature is disabled") + if logger: logger.warning("Bulk import attempted but feature is disabled") raise HTTPException(status_code=403, detail="Bulk import feature is disabled. Enable MCPGATEWAY_BULK_IMPORT_ENABLED to use this endpoint.") - LOGGER.debug("bulk tool import: user=%s", user) + if logger: logger.debug("bulk tool import: user=%s", user) try: # ---------- robust payload parsing ---------- ctype = (request.headers.get("content-type") or "").lower() @@ -7511,36 +7535,35 @@ async def admin_import_tools( try: payload = await request.json() except Exception as ex: - LOGGER.exception("Invalid JSON body") + logger.exception("Invalid JSON body") return JSONResponse({"success": False, "message": f"Invalid JSON: {ex}"}, status_code=422) else: try: form = await request.form() except Exception as ex: - LOGGER.exception("Invalid form body") + logger.exception("Invalid form body") return JSONResponse({"success": False, "message": f"Invalid form data: {ex}"}, status_code=422) # Check for file upload first if "tools_file" in form: file = form["tools_file"] - if isinstance(file, StarletteUploadFile): + if hasattr(file, "file"): content = await file.read() try: payload = json.loads(content.decode("utf-8")) except (json.JSONDecodeError, UnicodeDecodeError) as ex: - LOGGER.exception("Invalid JSON file") + logger.exception("Invalid JSON file") return JSONResponse({"success": False, "message": f"Invalid JSON file: {ex}"}, status_code=422) else: return JSONResponse({"success": False, "message": "Invalid file upload"}, status_code=422) else: # Check for JSON in form fields - raw_val = form.get("tools") or form.get("tools_json") or form.get("json") or form.get("payload") - raw = raw_val if isinstance(raw_val, str) else None + raw = form.get("tools") or form.get("tools_json") or form.get("json") or form.get("payload") if not raw: return JSONResponse({"success": False, "message": "Missing tools/tools_json/json/payload form field."}, status_code=422) try: payload = json.loads(raw) except Exception as ex: - LOGGER.exception("Invalid JSON in form field") + logger.exception("Invalid JSON in form field") return JSONResponse({"success": False, "message": f"Invalid JSON: {ex}"}, status_code=422) if not isinstance(payload, list): @@ -7590,7 +7613,7 @@ async def admin_import_tools( except ToolError as ex: errors.append({"index": i, "name": name, "error": {"message": str(ex)}}) except Exception as ex: - LOGGER.exception("Unexpected error importing tool %r at index %d", name, i) + logger.exception("Unexpected error importing tool %r at index %d", name, i) errors.append({"index": i, "name": name, "error": {"message": str(ex)}}) # Format response to match both frontend and test expectations @@ -7612,11 +7635,10 @@ async def admin_import_tools( }, } - rd = cast(Dict[str, Any], response_data) if len(errors) == 0: - rd["message"] = f"Successfully imported all {len(created)} tools" + response_data["message"] = f"Successfully imported all {len(created)} tools" else: - rd["message"] = f"Imported {len(created)} of {len(payload)} tools. {len(errors)} failed." + response_data["message"] = f"Imported {len(created)} of {len(payload)} tools. {len(errors)} failed." return JSONResponse( response_data, @@ -7628,7 +7650,7 @@ async def admin_import_tools( raise except Exception as ex: # absolute catch-all: report instead of crashing - LOGGER.exception("Fatal error in admin_import_tools") + logger.exception("Fatal error in admin_import_tools") return JSONResponse({"success": False, "message": str(ex)}, status_code=500) @@ -7649,7 +7671,7 @@ async def admin_get_logs( limit: int = 100, offset: int = 0, order: str = "desc", - user=Depends(get_current_user_with_permissions), # pylint: disable=unused-argument + user: str = Depends(require_auth), # pylint: disable=unused-argument ) -> Dict[str, Any]: """Get filtered log entries from the in-memory buffer. @@ -7673,7 +7695,7 @@ async def admin_get_logs( HTTPException: If validation fails or service unavailable """ # Get log storage from logging service - storage = cast(Any, logging_service).get_storage() + storage = logging_service.get_storage() if not storage: return {"logs": [], "total": 0, "stats": {}} @@ -7733,7 +7755,7 @@ async def admin_stream_logs( entity_type: Optional[str] = None, entity_id: Optional[str] = None, level: Optional[str] = None, - user=Depends(get_current_user_with_permissions), # pylint: disable=unused-argument + user: str = Depends(require_auth), # pylint: disable=unused-argument ): """Stream real-time log updates via Server-Sent Events. @@ -7751,7 +7773,7 @@ async def admin_stream_logs( HTTPException: If log level is invalid or service unavailable """ # Get log storage from logging service - storage = cast(Any, logging_service).get_storage() + storage = logging_service.get_storage() if not storage: raise HTTPException(503, "Log storage not available") @@ -7800,7 +7822,7 @@ async def generate(): yield f"data: {json.dumps(event)}\n\n" except Exception as e: - LOGGER.error(f"Error in log streaming: {e}") + if logger: logger.error(f"Error in log streaming: {e}") yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n" return StreamingResponse( @@ -7816,7 +7838,7 @@ async def generate(): @admin_router.get("/logs/file") async def admin_get_log_file( filename: Optional[str] = None, - user=Depends(get_current_user_with_permissions), # pylint: disable=unused-argument + user: str = Depends(require_auth), # pylint: disable=unused-argument ): """Download log file. @@ -7858,21 +7880,12 @@ async def admin_get_log_file( if not (file_path.suffix in [".log", ".jsonl", ".json"] or file_path.stem.startswith(Path(settings.log_file).stem)): raise HTTPException(403, "Not a log file") - # Return file for download using Response with file content - try: - with open(file_path, "rb") as f: - file_content = f.read() - - return Response( - content=file_content, - media_type="application/octet-stream", - headers={ - "Content-Disposition": f'attachment; filename="{file_path.name}"', - }, - ) - except Exception as e: - LOGGER.error(f"Error reading file for download: {e}") - raise HTTPException(500, f"Error reading file for download: {e}") + # Return file for download + return FileResponse( + path=file_path, + filename=file_path.name, + media_type="application/octet-stream", + ) # List available log files log_files = [] @@ -7895,7 +7908,7 @@ async def admin_get_log_file( if settings.log_rotation_enabled: pattern = f"{Path(settings.log_file).stem}.*" for file in log_dir.glob(pattern): - if file.is_file() and file.name != main_log.name: # Exclude main log file + if file.is_file(): stat = file.stat() log_files.append( { @@ -7923,7 +7936,7 @@ async def admin_get_log_file( log_files.sort(key=lambda x: x["modified"], reverse=True) except Exception as e: - LOGGER.error(f"Error listing log files: {e}") + if logger: logger.error(f"Error listing log files: {e}") raise HTTPException(500, f"Error listing log files: {e}") return { @@ -7935,7 +7948,7 @@ async def admin_get_log_file( @admin_router.get("/logs/export") async def admin_export_logs( - export_format: str = Query("json", alias="format"), + export_format: str = "json", entity_type: Optional[str] = None, entity_id: Optional[str] = None, level: Optional[str] = None, @@ -7943,7 +7956,7 @@ async def admin_export_logs( end_time: Optional[str] = None, request_id: Optional[str] = None, search: Optional[str] = None, - user=Depends(get_current_user_with_permissions), # pylint: disable=unused-argument + user: str = Depends(require_auth), # pylint: disable=unused-argument ): """Export filtered logs in JSON or CSV format. @@ -7970,7 +7983,7 @@ async def admin_export_logs( raise HTTPException(400, f"Invalid format: {export_format}. Use 'json' or 'csv'") # Get log storage from logging service - storage = cast(Any, logging_service).get_storage() + storage = logging_service.get_storage() if not storage: raise HTTPException(503, "Log storage not available") @@ -8064,20 +8077,18 @@ async def admin_export_logs( @admin_router.get("/export/configuration") async def admin_export_configuration( - request: Request, types: Optional[str] = None, exclude_types: Optional[str] = None, tags: Optional[str] = None, include_inactive: bool = False, include_dependencies: bool = True, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), + user: str = Depends(require_auth), ): """ Export gateway configuration via Admin UI. Args: - request: FastAPI request object for extracting root path types: Comma-separated entity types to include exclude_types: Comma-separated entity types to exclude tags: Comma-separated tags to filter by @@ -8093,7 +8104,7 @@ async def admin_export_configuration( HTTPException: If export fails """ try: - LOGGER.info(f"Admin user {user} requested configuration export") + if logger: logger.info(f"Admin user {user} requested configuration export") # Parse parameters include_types = None @@ -8111,19 +8122,9 @@ async def admin_export_configuration( # Extract username from user (which could be string or dict with token) username = user if isinstance(user, str) else user.get("username", "unknown") - # Get root path for URL construction - root_path = request.scope.get("root_path", "") if request else "" - # Perform export export_data = await export_service.export_configuration( - db=db, - include_types=include_types, - exclude_types=exclude_types_list, - tags=tags_list, - include_inactive=include_inactive, - include_dependencies=include_dependencies, - exported_by=username, - root_path=root_path, + db=db, include_types=include_types, exclude_types=exclude_types_list, tags=tags_list, include_inactive=include_inactive, include_dependencies=include_dependencies, exported_by=username ) # Generate filename @@ -8141,15 +8142,15 @@ async def admin_export_configuration( ) except ExportError as e: - LOGGER.error(f"Admin export failed for user {user}: {str(e)}") + if logger: logger.error(f"Admin export failed for user {user}: {str(e)}") raise HTTPException(status_code=400, detail=str(e)) except Exception as e: - LOGGER.error(f"Unexpected admin export error for user {user}: {str(e)}") + if logger: logger.error(f"Unexpected admin export error for user {user}: {str(e)}") raise HTTPException(status_code=500, detail=f"Export failed: {str(e)}") @admin_router.post("/export/selective") -async def admin_export_selective(request: Request, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)): +async def admin_export_selective(request: Request, db: Session = Depends(get_db), user: str = Depends(require_auth)): """ Export selected entities via Admin UI with entity selection. @@ -8174,7 +8175,7 @@ async def admin_export_selective(request: Request, db: Session = Depends(get_db) } """ try: - LOGGER.info(f"Admin user {user} requested selective configuration export") + if logger: logger.info(f"Admin user {user} requested selective configuration export") body = await request.json() entity_selections = body.get("entity_selections", {}) @@ -8201,68 +8202,15 @@ async def admin_export_selective(request: Request, db: Session = Depends(get_db) ) except ExportError as e: - LOGGER.error(f"Admin selective export failed for user {user}: {str(e)}") + if logger: logger.error(f"Admin selective export failed for user {user}: {str(e)}") raise HTTPException(status_code=400, detail=str(e)) except Exception as e: - LOGGER.error(f"Unexpected admin selective export error for user {user}: {str(e)}") + if logger: logger.error(f"Unexpected admin selective export error for user {user}: {str(e)}") raise HTTPException(status_code=500, detail=f"Export failed: {str(e)}") -@admin_router.post("/import/preview") -async def admin_import_preview(request: Request, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)): - """ - Preview import file to show available items for selective import. - - Args: - request: FastAPI request object with import file data - db: Database session - user: Authenticated user - - Returns: - JSON response with categorized import preview data - - Raises: - HTTPException: 400 for invalid JSON or missing data field, validation errors; - 500 for unexpected preview failures - - Expects JSON body: - { - "data": { ... } // The import file content - } - """ - try: - LOGGER.info(f"Admin import preview requested by user: {user}") - - # Parse request data - try: - data = await request.json() - except ValueError as e: - raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}") - - # Extract import data - import_data = data.get("data") - if not import_data: - raise HTTPException(status_code=400, detail="Missing 'data' field with import content") - - # Validate user permissions for import preview - username = user if isinstance(user, str) else user.get("username", "unknown") - LOGGER.info(f"Processing import preview for user: {username}") - - # Generate preview - preview_data = await import_service.preview_import(db=db, import_data=import_data) - - return JSONResponse(content={"success": True, "preview": preview_data, "message": f"Import preview generated. Found {preview_data['summary']['total_items']} total items."}) - - except ImportValidationError as e: - LOGGER.error(f"Import validation failed for user {user}: {str(e)}") - raise HTTPException(status_code=400, detail=f"Invalid import data: {str(e)}") - except Exception as e: - LOGGER.error(f"Import preview failed for user {user}: {str(e)}") - raise HTTPException(status_code=500, detail=f"Preview failed: {str(e)}") - - @admin_router.post("/import/configuration") -async def admin_import_configuration(request: Request, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)): +async def admin_import_configuration(request: Request, db: Session = Depends(get_db), user: str = Depends(require_auth)): """ Import configuration via Admin UI. @@ -8287,7 +8235,7 @@ async def admin_import_configuration(request: Request, db: Session = Depends(get } """ try: - LOGGER.info(f"Admin user {user} requested configuration import") + if logger: logger.info(f"Admin user {user} requested configuration import") body = await request.json() import_data = body.get("import_data") @@ -8303,8 +8251,7 @@ async def admin_import_configuration(request: Request, db: Session = Depends(get try: conflict_strategy = ConflictStrategy(conflict_strategy_str.lower()) except ValueError: - allowed = [s.value for s in ConflictStrategy.__members__.values()] - raise HTTPException(status_code=400, detail=f"Invalid conflict strategy. Must be one of: {allowed}") + raise HTTPException(status_code=400, detail=f"Invalid conflict strategy. Must be one of: {[s.value for s in ConflictStrategy]}") # Extract username from user (which could be string or dict with token) username = user if isinstance(user, str) else user.get("username", "unknown") @@ -8317,15 +8264,15 @@ async def admin_import_configuration(request: Request, db: Session = Depends(get return JSONResponse(content=status.to_dict()) except ImportServiceError as e: - LOGGER.error(f"Admin import failed for user {user}: {str(e)}") + if logger: logger.error(f"Admin import failed for user {user}: {str(e)}") raise HTTPException(status_code=400, detail=str(e)) except Exception as e: - LOGGER.error(f"Unexpected admin import error for user {user}: {str(e)}") + if logger: logger.error(f"Unexpected admin import error for user {user}: {str(e)}") raise HTTPException(status_code=500, detail=f"Import failed: {str(e)}") @admin_router.get("/import/status/{import_id}") -async def admin_get_import_status(import_id: str, user=Depends(get_current_user_with_permissions)): +async def admin_get_import_status(import_id: str, user: str = Depends(require_auth)): """Get import status via Admin UI. Args: @@ -8338,7 +8285,7 @@ async def admin_get_import_status(import_id: str, user=Depends(get_current_user_ Raises: HTTPException: If import not found """ - LOGGER.debug(f"Admin user {user} requested import status for {import_id}") + if logger: logger.debug(f"Admin user {user} requested import status for {import_id}") status = import_service.get_import_status(import_id) if not status: @@ -8348,7 +8295,7 @@ async def admin_get_import_status(import_id: str, user=Depends(get_current_user_ @admin_router.get("/import/status") -async def admin_list_import_statuses(user=Depends(get_current_user_with_permissions)): +async def admin_list_import_statuses(user: str = Depends(require_auth)): """List all import statuses via Admin UI. Args: @@ -8357,7 +8304,7 @@ async def admin_list_import_statuses(user=Depends(get_current_user_with_permissi Returns: JSON response with list of import statuses """ - LOGGER.debug(f"Admin user {user} requested all import statuses") + if logger: logger.debug(f"Admin user {user} requested all import statuses") statuses = import_service.list_import_statuses() return JSONResponse(content=[status.to_dict() for status in statuses]) @@ -8373,7 +8320,7 @@ async def admin_list_a2a_agents( include_inactive: bool = False, tags: Optional[str] = None, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), + user: str = Depends(require_auth), ) -> HTMLResponse: """List A2A agents for admin UI. @@ -8396,7 +8343,7 @@ async def admin_list_a2a_agents( if tags: tags_list = [tag.strip() for tag in tags.split(",") if tag.strip()] - LOGGER.debug(f"Admin user {user} requested A2A agent list with tags={tags_list}") + if logger: logger.debug(f"Admin user {user} requested A2A agent list with tags={tags_list}") agents = await a2a_service.list_agents(db, include_inactive=include_inactive, tags=tags_list) # Convert to template format @@ -8432,7 +8379,7 @@ async def admin_list_a2a_agents( # Generate tags HTML separately tags_html = "" if agent["tags"]: - tag_spans: List[Any] = [] + tag_spans = [] for tag in agent["tags"]: tag_spans.append(f'{tag}') tags_html = f'
{" ".join(tag_spans)}
' @@ -8506,8 +8453,8 @@ async def admin_list_a2a_agents( async def admin_add_a2a_agent( request: Request, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), -) -> Response: + user: str = Depends(require_auth), +) -> RedirectResponse: """Add a new A2A agent via admin UI. Args: @@ -8516,24 +8463,23 @@ async def admin_add_a2a_agent( user: Authenticated user Returns: - Response with success/error status + HTML response with success/error status Raises: HTTPException: If A2A features are disabled """ - LOGGER.info(f"A2A agent creation request from user {user}") + if logger: logger.info(f"A2A agent creation request from user {user}") if not a2a_service or not settings.mcpgateway_a2a_enabled: - LOGGER.warning("A2A agent creation attempted but A2A features are disabled") + if logger: logger.warning("A2A agent creation attempted but A2A features are disabled") return HTMLResponse(content='
A2A features are disabled
', status_code=403) try: form = await request.form() - LOGGER.info(f"A2A agent creation form data: {dict(form)}") + if logger: logger.info(f"A2A agent creation form data: {dict(form)}") # Process tags - ts_val = form.get("tags", "") - tags_str = ts_val if isinstance(ts_val, str) else "" + tags_str = form.get("tags", "") tags = [tag.strip() for tag in tags_str.split(",") if tag.strip()] if tags_str else [] agent_data = A2AAgentCreate( @@ -8546,7 +8492,7 @@ async def admin_add_a2a_agent( tags=tags, ) - LOGGER.info(f"Creating A2A agent: {agent_data.name} at {agent_data.endpoint_url}") + if logger: logger.info(f"Creating A2A agent: {agent_data.name} at {agent_data.endpoint_url}") # Extract metadata from request metadata = MetadataCapture.extract_creation_metadata(request, user) @@ -8567,19 +8513,19 @@ async def admin_add_a2a_agent( return RedirectResponse(f"{root_path}/admin#a2a-agents", status_code=303) except A2AAgentNameConflictError as e: - LOGGER.error(f"A2A agent name conflict: {e}") + if logger: logger.error(f"A2A agent name conflict: {e}") root_path = request.scope.get("root_path", "") return RedirectResponse(f"{root_path}/admin#a2a-agents", status_code=303) except A2AAgentError as e: - LOGGER.error(f"A2A agent error: {e}") + if logger: logger.error(f"A2A agent error: {e}") root_path = request.scope.get("root_path", "") return RedirectResponse(f"{root_path}/admin#a2a-agents", status_code=303) except ValidationError as e: - LOGGER.error(f"Validation error while creating A2A agent: {e}") + if logger: logger.error(f"Validation error while creating A2A agent: {e}") root_path = request.scope.get("root_path", "") return RedirectResponse(f"{root_path}/admin#a2a-agents", status_code=303) except Exception as e: - LOGGER.error(f"Error creating A2A agent: {e}") + if logger: logger.error(f"Error creating A2A agent: {e}") root_path = request.scope.get("root_path", "") return RedirectResponse(f"{root_path}/admin#a2a-agents", status_code=303) @@ -8589,7 +8535,7 @@ async def admin_toggle_a2a_agent( agent_id: str, request: Request, db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), # pylint: disable=unused-argument + user: str = Depends(require_auth), # pylint: disable=unused-argument ) -> RedirectResponse: """Toggle A2A agent status via admin UI. @@ -8611,19 +8557,18 @@ async def admin_toggle_a2a_agent( try: form = await request.form() - act_val = form.get("activate", "false") - activate = act_val.lower() == "true" if isinstance(act_val, str) else False + activate = form.get("activate", "false").lower() == "true" await a2a_service.toggle_agent_status(db, agent_id, activate) root_path = request.scope.get("root_path", "") return RedirectResponse(f"{root_path}/admin#a2a-agents", status_code=303) except A2AAgentNotFoundError as e: - LOGGER.error(f"A2A agent toggle failed - not found: {e}") + if logger: logger.error(f"A2A agent toggle failed - not found: {e}") root_path = request.scope.get("root_path", "") return RedirectResponse(f"{root_path}/admin#a2a-agents", status_code=303) except Exception as e: - LOGGER.error(f"Error toggling A2A agent: {e}") + if logger: logger.error(f"Error toggling A2A agent: {e}") root_path = request.scope.get("root_path", "") return RedirectResponse(f"{root_path}/admin#a2a-agents", status_code=303) @@ -8633,7 +8578,7 @@ async def admin_delete_a2a_agent( agent_id: str, request: Request, # pylint: disable=unused-argument db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), # pylint: disable=unused-argument + user: str = Depends(require_auth), # pylint: disable=unused-argument ) -> RedirectResponse: """Delete A2A agent via admin UI. @@ -8659,11 +8604,11 @@ async def admin_delete_a2a_agent( return RedirectResponse(f"{root_path}/admin#a2a-agents", status_code=303) except A2AAgentNotFoundError as e: - LOGGER.error(f"A2A agent delete failed - not found: {e}") + if logger: logger.error(f"A2A agent delete failed - not found: {e}") root_path = request.scope.get("root_path", "") return RedirectResponse(f"{root_path}/admin#a2a-agents", status_code=303) except Exception as e: - LOGGER.error(f"Error deleting A2A agent: {e}") + if logger: logger.error(f"Error deleting A2A agent: {e}") root_path = request.scope.get("root_path", "") return RedirectResponse(f"{root_path}/admin#a2a-agents", status_code=303) @@ -8673,7 +8618,7 @@ async def admin_test_a2a_agent( agent_id: str, request: Request, # pylint: disable=unused-argument db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), # pylint: disable=unused-argument + user: str = Depends(require_auth), # pylint: disable=unused-argument ) -> JSONResponse: """Test A2A agent via admin UI. @@ -8713,280 +8658,5 @@ async def admin_test_a2a_agent( return JSONResponse(content={"success": True, "result": result, "agent_name": agent.name, "test_timestamp": time.time()}) except Exception as e: - LOGGER.error(f"Error testing A2A agent {agent_id}: {e}") + if logger: logger.error(f"Error testing A2A agent {agent_id}: {e}") return JSONResponse(content={"success": False, "error": str(e), "agent_id": agent_id}, status_code=500) - - -# Team-scoped resource section endpoints -@admin_router.get("/sections/tools") -@require_permission("admin") -async def get_tools_section( - team_id: Optional[str] = None, - db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), -): - """Get tools data filtered by team. - - Args: - team_id: Optional team ID to filter by - db: Database session - user: Current authenticated user context - - Returns: - JSONResponse: Tools data with team filtering applied - """ - try: - local_tool_service = ToolService() - user_email = get_user_email(user) - - # Get team-filtered tools - tools_list = await local_tool_service.list_tools_for_user(db, user_email, team_id=team_id, include_inactive=True) - - # Convert to JSON-serializable format - tools = [] - for tool in tools_list: - tool_dict = ( - tool.model_dump(by_alias=True) - if hasattr(tool, "model_dump") - else { - "id": tool.id, - "name": tool.name, - "description": tool.description, - "tags": tool.tags or [], - "isActive": getattr(tool, "enabled", False), - "team_id": getattr(tool, "team_id", None), - "visibility": getattr(tool, "visibility", "private"), - } - ) - tools.append(tool_dict) - - return JSONResponse(content=jsonable_encoder({"tools": tools, "team_id": team_id})) - - except Exception as e: - LOGGER.error(f"Error loading tools section: {e}") - return JSONResponse(content={"error": str(e)}, status_code=500) - - -@admin_router.get("/sections/resources") -@require_permission("admin") -async def get_resources_section( - team_id: Optional[str] = None, - db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), -): - """Get resources data filtered by team. - - Args: - team_id: Optional team ID to filter by - db: Database session - user: Current authenticated user context - - Returns: - JSONResponse: Resources data with team filtering applied - """ - try: - local_resource_service = ResourceService() - user_email = get_user_email(user) - LOGGER.debug(f"User {user_email} requesting resources section with team_id={team_id}") - - # Get all resources and filter by team - resources_list = await local_resource_service.list_resources(db, include_inactive=True) - - # Apply team filtering if specified - if team_id: - resources_list = [r for r in resources_list if getattr(r, "team_id", None) == team_id] - - # Convert to JSON-serializable format - resources = [] - for resource in resources_list: - resource_dict = ( - resource.model_dump(by_alias=True) - if hasattr(resource, "model_dump") - else { - "id": resource.id, - "name": resource.name, - "description": resource.description, - "uri": resource.uri, - "tags": resource.tags or [], - "isActive": resource.is_active, - "team_id": getattr(resource, "team_id", None), - "visibility": getattr(resource, "visibility", "private"), - } - ) - resources.append(resource_dict) - - return JSONResponse(content={"resources": resources, "team_id": team_id}) - - except Exception as e: - LOGGER.error(f"Error loading resources section: {e}") - return JSONResponse(content={"error": str(e)}, status_code=500) - - -@admin_router.get("/sections/prompts") -@require_permission("admin") -async def get_prompts_section( - team_id: Optional[str] = None, - db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), -): - """Get prompts data filtered by team. - - Args: - team_id: Optional team ID to filter by - db: Database session - user: Current authenticated user context - - Returns: - JSONResponse: Prompts data with team filtering applied - """ - try: - local_prompt_service = PromptService() - user_email = get_user_email(user) - LOGGER.debug(f"User {user_email} requesting prompts section with team_id={team_id}") - - # Get all prompts and filter by team - prompts_list = await local_prompt_service.list_prompts(db, include_inactive=True) - - # Apply team filtering if specified - if team_id: - prompts_list = [p for p in prompts_list if getattr(p, "team_id", None) == team_id] - - # Convert to JSON-serializable format - prompts = [] - for prompt in prompts_list: - prompt_dict = ( - prompt.model_dump(by_alias=True) - if hasattr(prompt, "model_dump") - else { - "id": prompt.id, - "name": prompt.name, - "description": prompt.description, - "arguments": prompt.arguments or [], - "tags": prompt.tags or [], - "isActive": prompt.is_active, - "team_id": getattr(prompt, "team_id", None), - "visibility": getattr(prompt, "visibility", "private"), - } - ) - prompts.append(prompt_dict) - - return JSONResponse(content={"prompts": prompts, "team_id": team_id}) - - except Exception as e: - LOGGER.error(f"Error loading prompts section: {e}") - return JSONResponse(content={"error": str(e)}, status_code=500) - - -@admin_router.get("/sections/servers") -@require_permission("admin") -async def get_servers_section( - team_id: Optional[str] = None, - db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), -): - """Get servers data filtered by team. - - Args: - team_id: Optional team ID to filter by - db: Database session - user: Current authenticated user context - - Returns: - JSONResponse: Servers data with team filtering applied - """ - try: - local_server_service = ServerService() - user_email = get_user_email(user) - LOGGER.debug(f"User {user_email} requesting servers section with team_id={team_id}") - - # Get all servers and filter by team - servers_list = await local_server_service.list_servers(db, include_inactive=True) - - # Apply team filtering if specified - if team_id: - servers_list = [s for s in servers_list if getattr(s, "team_id", None) == team_id] - - # Convert to JSON-serializable format - servers = [] - for server in servers_list: - server_dict = ( - server.model_dump(by_alias=True) - if hasattr(server, "model_dump") - else { - "id": server.id, - "name": server.name, - "description": server.description, - "tags": server.tags or [], - "isActive": server.is_active, - "team_id": getattr(server, "team_id", None), - "visibility": getattr(server, "visibility", "private"), - } - ) - servers.append(server_dict) - - return JSONResponse(content={"servers": servers, "team_id": team_id}) - - except Exception as e: - LOGGER.error(f"Error loading servers section: {e}") - return JSONResponse(content={"error": str(e)}, status_code=500) - - -@admin_router.get("/sections/gateways") -@require_permission("admin") -async def get_gateways_section( - team_id: Optional[str] = None, - db: Session = Depends(get_db), - user=Depends(get_current_user_with_permissions), -): - """Get gateways data filtered by team. - - Args: - team_id: Optional team ID to filter by - db: Database session - user: Current authenticated user context - - Returns: - JSONResponse: Gateways data with team filtering applied - """ - try: - local_gateway_service = GatewayService() - get_user_email(user) - - # Get all gateways and filter by team - gateways_list = await local_gateway_service.list_gateways(db, include_inactive=True) - - # Apply team filtering if specified - if team_id: - gateways_list = [g for g in gateways_list if g.team_id == team_id] - - # Convert to JSON-serializable format - gateways = [] - for gateway in gateways_list: - if hasattr(gateway, "model_dump"): - # Get dict and serialize datetime objects - gateway_dict = gateway.model_dump(by_alias=True) - # Convert datetime objects to strings - for key, value in gateway_dict.items(): - gateway_dict[key] = serialize_datetime(value) - else: - # Parse URL to extract host and port - parsed_url = urllib.parse.urlparse(gateway.url) if gateway.url else None - gateway_dict = { - "id": gateway.id, - "name": gateway.name, - "host": parsed_url.hostname if parsed_url else "", - "port": parsed_url.port if parsed_url else 80, - "tags": gateway.tags or [], - "isActive": getattr(gateway, "enabled", False), - "team_id": getattr(gateway, "team_id", None), - "visibility": getattr(gateway, "visibility", "private"), - "created_at": serialize_datetime(getattr(gateway, "created_at", None)), - "updated_at": serialize_datetime(getattr(gateway, "updated_at", None)), - } - gateways.append(gateway_dict) - - return JSONResponse(content={"gateways": gateways, "team_id": team_id}) - - except Exception as e: - LOGGER.error(f"Error loading gateways section: {e}") - return JSONResponse(content={"error": str(e)}, status_code=500) diff --git a/mcpgateway/federation/forward.py b/mcpgateway/federation/forward.py index 22344a83b..aad6f77d4 100644 --- a/mcpgateway/federation/forward.py +++ b/mcpgateway/federation/forward.py @@ -27,6 +27,7 @@ # Standard import asyncio from datetime import datetime, timezone +import time from typing import Any, Dict, List, Optional, Set, Tuple, Union # Third-Party @@ -37,6 +38,7 @@ # First-Party from mcpgateway.config import settings from mcpgateway.db import Gateway as DbGateway +from mcpgateway.db import ServerMetric from mcpgateway.db import Tool as DbTool from mcpgateway.models import ToolResult from mcpgateway.services.logging_service import LoggingService @@ -435,6 +437,10 @@ async def _forward_to_gateway( if not self._check_rate_limit(gateway.url): raise ForwardingError("Rate limit exceeded") + start_time = time.monotonic() + success = False + error_message = None + try: # Build request request = {"jsonrpc": "2.0", "id": 1, "method": method} @@ -462,16 +468,24 @@ async def _forward_to_gateway( # Handle response if "error" in result: - raise ForwardingError(f"Gateway error: {result['error'].get('message')}") + error_message = result['error'].get('message') + raise ForwardingError(f"Gateway error: {error_message}") + + success = True return result.get("result") - except httpx.TimeoutException: + except httpx.TimeoutException as e: + error_message = f"Timeout on attempt {attempt + 1}: {str(e)}" if attempt == settings.max_tool_retries - 1: raise await asyncio.sleep(1 * (attempt + 1)) except Exception as e: + error_message = str(e) raise ForwardingError(f"Failed to forward to {gateway.name}: {str(e)}") + finally: + # Always record server metrics + await self._record_server_metric(db, gateway, start_time, success, error_message) async def _forward_to_all(self, db: Session, method: str, params: Optional[Dict[str, Any]] = None, request_headers: Optional[Dict[str, str]] = None) -> List[Any]: """Forward request to all active gateways. @@ -738,3 +752,33 @@ def _get_auth_headers(self) -> Dict[str, str]: """ api_key = f"{settings.basic_auth_user}:{settings.basic_auth_password}" return {"Authorization": f"Basic {api_key}", "X-API-Key": api_key} + + async def _record_server_metric(self, db: Session, gateway: DbGateway, start_time: float, success: bool, error_message: Optional[str]) -> None: + """ + Records a metric for a server interaction. + + This function calculates the response time using the provided start time and records + the metric details (including whether the interaction was successful and any error message) + into the database. The metric is then committed to the database. + + Args: + db (Session): The SQLAlchemy database session. + gateway (DbGateway): The gateway that was accessed. + start_time (float): The monotonic start time of the interaction. + success (bool): True if the interaction succeeded; otherwise, False. + error_message (Optional[str]): The error message if the interaction failed, otherwise None. + """ + end_time = time.monotonic() + response_time = end_time - start_time + metric = ServerMetric( + server_id=gateway.id, + response_time=response_time, + is_success=success, + error_message=error_message, + ) + db.add(metric) + db.commit() + try: # pragma: no cover + db.expire(gateway, ["metrics"]) + except Exception: # noqa: BLE001 + pass diff --git a/mcpgateway/main.py b/mcpgateway/main.py index 53668d924..e12e636de 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -51,8 +51,6 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from starlette.middleware.base import BaseHTTPMiddleware -from starlette.requests import Request as starletteRequest -from starlette.responses import Response as starletteResponse from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware # First-Party @@ -71,7 +69,7 @@ from mcpgateway.middleware.token_scoping import token_scoping_middleware from mcpgateway.models import InitializeResult, ListResourceTemplatesResult, LogLevel, Root from mcpgateway.observability import init_telemetry -from mcpgateway.plugins.framework import PluginError, PluginManager, PluginViolationError +from mcpgateway.plugins.framework import PluginManager, PluginViolationError from mcpgateway.routers.well_known import router as well_known_router from mcpgateway.schemas import ( A2AAgentCreate, @@ -574,81 +572,6 @@ async def database_exception_handler(_request: Request, exc: IntegrityError): return JSONResponse(status_code=409, content=ErrorFormatter.format_database_error(exc)) -@app.exception_handler(PluginViolationError) -async def plugin_violation_exception_handler(_request: Request, exc: PluginViolationError): - """Handle plugins violations globally. - - Intercepts PluginViolationError exceptions (e.g., OPA policy violation) and returns a properly formatted JSON error response. - This provides consistent error handling for plugin violation across the entire application. - - Args: - _request: The FastAPI request object that triggered the database error. - (Unused but required by FastAPI's exception handler interface) - exc: The PluginViolationError exception containing constraint - violation details. - - Returns: - JSONResponse: A 403 response with access forbidden. - - Examples: - >>> from mcpgateway.plugins.framework import PluginViolationError - >>> from mcpgateway.plugins.framework.models import PluginViolation - >>> from fastapi import Request - >>> import asyncio - >>> - >>> # Create a mock integrity error - >>> mock_error = PluginViolationError(message="plugin violation",violation = PluginViolation( - ... reason="Invalid input", - ... description="The input contains prohibited content", - ... code="PROHIBITED_CONTENT", - ... details={"field": "message", "value": "test"} - ... )) - >>> result = asyncio.run(plugin_violation_exception_handler(None, mock_error)) - >>> result.status_code - 403 - """ - policy_violation = exc.violation.model_dump() if exc.violation else {} - policy_violation["message"] = exc.message - return JSONResponse(status_code=403, content=policy_violation) - - -@app.exception_handler(PluginError) -async def plugin_exception_handler(_request: Request, exc: PluginError): - """Handle plugins errors globally. - - Intercepts PluginError exceptions and returns a properly formatted JSON error response. - This provides consistent error handling for plugin error across the entire application. - - Args: - _request: The FastAPI request object that triggered the database error. - (Unused but required by FastAPI's exception handler interface) - exc: The PluginError exception containing constraint - violation details. - - Returns: - JSONResponse: A 500 response with internal server error. - - Examples: - >>> from mcpgateway.plugins.framework import PluginViolationError - >>> from mcpgateway.plugins.framework.models import PluginErrorModel - >>> from fastapi import Request - >>> import asyncio - >>> - >>> # Create a mock integrity error - >>> mock_error = PluginError(error = PluginErrorModel( - ... message="plugin error", - ... code="timeout", - ... plugin_name="abc", - ... details={"field": "message", "value": "test"} - ... )) - >>> result = asyncio.run(plugin_exception_handler(None, mock_error)) - >>> result.status_code - 500 - """ - error_obj = exc.error.model_dump() if exc.error else {} - return JSONResponse(status_code=500, content=error_obj) - - class DocsAuthMiddleware(BaseHTTPMiddleware): """ Middleware to protect FastAPI's auto-generated documentation routes @@ -714,36 +637,22 @@ async def dispatch(self, request: Request, call_next): class MCPPathRewriteMiddleware: """ - Middleware that rewrites paths ending with '/mcp' to '/mcp', after performing authentication. + Supports requests like '/servers//mcp' by rewriting the path to '/mcp'. - - Rewrites paths like '/servers//mcp' to '/mcp'. - - Only paths ending with '/mcp' (but not exactly '/mcp') are rewritten. - - Authentication is performed before any path rewriting. - - If authentication fails, the request is not processed further. + - Only rewrites paths ending with '/mcp' but not exactly '/mcp'. + - Performs authentication before rewriting. + - Passes rewritten requests to `streamable_http_session`. - All other requests are passed through without change. - - Attributes: - application (Callable): The next ASGI application to process the request. """ - def __init__(self, application, dispatch=None): + def __init__(self, application): """ Initialize the middleware with the ASGI application. Args: - application (Callable): The next ASGI application to handle the request. - dispatch (Callable, optional): An optional dispatch function for additional middleware processing. - - Example: - >>> import asyncio - >>> from unittest.mock import AsyncMock, patch - >>> app_mock = AsyncMock() - >>> middleware = MCPPathRewriteMiddleware(app_mock) - >>> isinstance(middleware.application, AsyncMock) - True + application (Callable): The next ASGI application in the middleware stack. """ self.application = application - self.dispatch = dispatch # this can be TokenScopingMiddleware async def __call__(self, scope, receive, send): """ @@ -757,86 +666,39 @@ async def __call__(self, scope, receive, send): Examples: >>> import asyncio >>> from unittest.mock import AsyncMock, patch + >>> + >>> # Test non-HTTP request passthrough >>> app_mock = AsyncMock() >>> middleware = MCPPathRewriteMiddleware(app_mock) - - >>> # Test path rewriting for /servers/123/mcp with headers in scope - >>> scope = { "type": "http", "path": "/servers/123/mcp", "headers": [(b"host", b"example.com")] } + >>> scope = {"type": "websocket", "path": "/ws"} >>> receive = AsyncMock() >>> send = AsyncMock() + >>> + >>> asyncio.run(middleware(scope, receive, send)) + >>> app_mock.assert_called_once_with(scope, receive, send) + >>> + >>> # Test path rewriting for /servers/123/mcp + >>> app_mock.reset_mock() + >>> scope = {"type": "http", "path": "/servers/123/mcp"} >>> with patch('mcpgateway.main.streamable_http_auth', return_value=True): ... with patch.object(streamable_http_session, 'handle_streamable_http') as mock_handler: ... asyncio.run(middleware(scope, receive, send)) ... scope["path"] '/mcp' - + >>> >>> # Test regular path (no rewrite) - >>> scope = { "type": "http","path": "/tools","headers": [(b"host", b"example.com")] } + >>> scope = {"type": "http", "path": "/tools"} >>> with patch('mcpgateway.main.streamable_http_auth', return_value=True): ... asyncio.run(middleware(scope, receive, send)) ... scope["path"] '/tools' """ + # Only handle HTTP requests, HTTPS uses scope["type"] == "http" in ASGI if scope["type"] != "http": await self.application(scope, receive, send) return - # If a dispatch (request middleware) is provided, adapt it - if self.dispatch is not None: - request = starletteRequest(scope, receive=receive) - - async def call_next(_req: starletteRequest) -> starletteResponse: - """ - Handles the next request in the middleware chain by calling a streamable HTTP response. - - Args: - _req (starletteRequest): The incoming request to be processed. - - Returns: - starletteResponse: A response generated from the streamable HTTP call. - """ - return await self._call_streamable_http(scope, receive, send) - - response = await self.dispatch(request, call_next) - - if response is None: - # Either the dispatch handled the response itself, - # or it blocked the request. Just return. - return - - await response(scope, receive, send) - return - - # Otherwise, just continue as normal - await self._call_streamable_http(scope, receive, send) - - async def _call_streamable_http(self, scope, receive, send): - """ - Handles the streamable HTTP request after authentication and path rewriting. - - - If authentication is successful and the path is rewritten, this method processes the request - using the `streamable_http_session` handler. - - Args: - scope (dict): The ASGI connection scope containing request metadata. - receive (Callable): The function to receive events from the client. - send (Callable): The function to send events to the client. - - Example: - >>> import asyncio - >>> from unittest.mock import AsyncMock, patch - >>> app_mock = AsyncMock() - >>> middleware = MCPPathRewriteMiddleware(app_mock) - >>> scope = {"type": "http", "path": "/servers/123/mcp"} - >>> receive = AsyncMock() - >>> send = AsyncMock() - >>> with patch('mcpgateway.main.streamable_http_auth', return_value=True): - ... with patch.object(streamable_http_session, 'handle_streamable_http') as mock_handler: - ... asyncio.run(middleware._call_streamable_http(scope, receive, send)) - >>> mock_handler.assert_called_once_with(scope, receive, send) - >>> # The streamable HTTP session handler was called after path rewriting. - """ - # Auth check first + # Call auth check first auth_ok = await streamable_http_auth(scope, receive, send) if not auth_ok: return @@ -875,15 +737,13 @@ async def _call_streamable_http(self, scope, receive, send): # Add token scoping middleware (only when email auth is enabled) if settings.email_auth_enabled: app.add_middleware(BaseHTTPMiddleware, dispatch=token_scoping_middleware) - # Add streamable HTTP middleware for /mcp routes with token scoping - app.add_middleware(MCPPathRewriteMiddleware, dispatch=token_scoping_middleware) -else: - # Add streamable HTTP middleware for /mcp routes - app.add_middleware(MCPPathRewriteMiddleware) # Add custom DocsAuthMiddleware app.add_middleware(DocsAuthMiddleware) +# Add streamable HTTP middleware for /mcp routes +app.add_middleware(MCPPathRewriteMiddleware) + # Trust all proxies (or lock down with a list of host patterns) app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*") @@ -3369,10 +3229,6 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user=Depen Returns: Response with the RPC result or error. - - Raises: - PluginError: If encounters issue with plugin - PluginViolationError: If plugin violated the request. Example - In case of OPA plugin, if the request is denied by policy. """ try: # Extract user identifier from either RBAC user object or JWT payload @@ -3483,14 +3339,13 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user=Depen else: # Backward compatibility: Try to invoke as a tool directly # This allows both old format (method=tool_name) and new format (method=tools/call) - # Standard headers = {k.lower(): v for k, v in request.headers.items()} try: result = await tool_service.invoke_tool(db=db, name=method, arguments=params, request_headers=headers) if hasattr(result, "model_dump"): result = result.model_dump(by_alias=True, exclude_none=True) - except (PluginError, PluginViolationError): - raise + except PluginViolationError: + return JSONResponse(status_code=403, content={"detail": "policy_deny"}) except (ValueError, Exception): # If not a tool, try forwarding to gateway try: @@ -3503,8 +3358,6 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user=Depen return {"jsonrpc": "2.0", "result": result, "id": req_id} - except (PluginError, PluginViolationError): - raise except JSONRPCError as e: error = e.to_dict() return {"jsonrpc": "2.0", "error": error["error"], "id": req_id} @@ -4370,7 +4223,7 @@ async def cleanup_import_statuses(max_age_hours: int = 24, user=Depends(get_curr try: # Create a sub-application for static files that will respect root_path static_app = StaticFiles(directory=str(settings.static_dir)) - STATIC_PATH = "/static" + STATIC_PATH = f"{settings.app_root_path}/static" if settings.app_root_path else "/static" app.mount( STATIC_PATH, diff --git a/mcpgateway/middleware/token_scoping.py b/mcpgateway/middleware/token_scoping.py index 8809d1727..9cfe02906 100644 --- a/mcpgateway/middleware/token_scoping.py +++ b/mcpgateway/middleware/token_scoping.py @@ -18,21 +18,15 @@ # Third-Party from fastapi import HTTPException, Request, status -from fastapi.responses import JSONResponse from fastapi.security import HTTPBearer # First-Party from mcpgateway.db import Permissions -from mcpgateway.services.logging_service import LoggingService from mcpgateway.utils.verify_credentials import verify_jwt_token # Security scheme bearer_scheme = HTTPBearer(auto_error=False) -# Initialize logging service first -logging_service = LoggingService() -logger = logging_service.get_logger(__name__) - class TokenScopingMiddleware: """Middleware to enforce token scoping restrictions. @@ -349,68 +343,47 @@ async def __call__(self, request: Request, call_next): Raises: HTTPException: If token scoping restrictions are violated """ - try: - # Skip scoping for certain paths (truly public endpoints only) - skip_paths = [ - "/health", - "/metrics", - "/openapi.json", - "/docs", - "/redoc", - "/auth/email/login", - "/auth/email/register", - "/.well-known/", - ] - - # Check exact root path separately - if request.url.path == "/": - return await call_next(request) - - if any(request.url.path.startswith(path) for path in skip_paths): - return await call_next(request) - - # Extract token scopes - scopes = await self._extract_token_scopes(request) - - # If no scopes, continue (regular auth will handle this) - if not scopes: - return await call_next(request) - - # Check server ID restriction - server_id = scopes.get("server_id") - if not self._check_server_restriction(request.url.path, server_id): - logger.warning(f"Token not authorized for this server. Required: {server_id}") - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Token not authorized for this server. Required: {server_id}") - - # Check IP restrictions - ip_restrictions = scopes.get("ip_restrictions", []) - if ip_restrictions: - client_ip = self._get_client_ip(request) - if not self._check_ip_restrictions(client_ip, ip_restrictions): - logger.warning(f"Request from IP {client_ip} not allowed by token restrictions") - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Request from IP {client_ip} not allowed by token restrictions") - - # Check time restrictions - time_restrictions = scopes.get("time_restrictions", {}) - if not self._check_time_restrictions(time_restrictions): - logger.warning("Request not allowed at this time by token restrictions") - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Request not allowed at this time by token restrictions") - - # Check permission restrictions - permissions = scopes.get("permissions", []) - if not self._check_permission_restrictions(request.url.path, request.method, permissions): - logger.warning("Insufficient permissions for this operation") - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient permissions for this operation") - - # All scoping checks passed, continue + # Skip scoping for certain paths (truly public endpoints only) + skip_paths = ["/health", "/metrics", "/openapi.json", "/docs", "/redoc", "/auth/email/login", "/auth/email/register", "/.well-known/"] + + # Check exact root path separately + if request.url.path == "/": + return await call_next(request) + + if any(request.url.path.startswith(path) for path in skip_paths): + return await call_next(request) + + # Extract token scopes + scopes = await self._extract_token_scopes(request) + + # If no scopes, continue (regular auth will handle this) + if not scopes: return await call_next(request) - except HTTPException as exc: - # Return clean JSON response instead of traceback - return JSONResponse( - status_code=exc.status_code, - content={"detail": exc.detail}, - ) + # Check server ID restriction + server_id = scopes.get("server_id") + if not self._check_server_restriction(request.url.path, server_id): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Token not authorized for this server. Required: {server_id}") + + # Check IP restrictions + ip_restrictions = scopes.get("ip_restrictions", []) + if ip_restrictions: + client_ip = self._get_client_ip(request) + if not self._check_ip_restrictions(client_ip, ip_restrictions): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Request from IP {client_ip} not allowed by token restrictions") + + # Check time restrictions + time_restrictions = scopes.get("time_restrictions", {}) + if not self._check_time_restrictions(time_restrictions): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Request not allowed at this time by token restrictions") + + # Check permission restrictions + permissions = scopes.get("permissions", []) + if not self._check_permission_restrictions(request.url.path, request.method, permissions): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient permissions for this operation") + + # All scoping checks passed, continue to next handler + return await call_next(request) # Create middleware instance diff --git a/mcpgateway/models.py b/mcpgateway/models.py index 40ee3087c..82816e436 100644 --- a/mcpgateway/models.py +++ b/mcpgateway/models.py @@ -401,103 +401,37 @@ class PromptResult(BaseModel): description: Optional[str] = None -class CommonAttributes(BaseModel): - """Common attributes for tools and gateways. +# Tool types +class Tool(BaseModel): + """A tool that can be invoked. Attributes: name (str): The unique name of the tool. url (AnyHttpUrl): The URL of the tool. description (Optional[str]): A description of the tool. - created_at (Optional[datetime]): The time at which the tool was created. - update_at (Optional[datetime]): The time at which the tool was updated. - enabled (Optional[bool]): If the tool is enabled. - reachable (Optional[bool]): If the tool is currently reachable. - tags (Optional[list[str]]): A list of meta data tags describing the tool. - created_by (Optional[str]): The person that created the tool. - created_from_ip (Optional[str]): The client IP that created the tool. - created_via (Optional[str]): How the tool was created (e.g., ui). - created_user_agent (Optioanl[str]): The client user agent. - modified_by (Optional[str]): The person that modified the tool. - modified_from_ip (Optional[str]): The client IP that modified the tool. - modified_via (Optional[str]): How the tool was modified (e.g., ui). - modified_user_agent (Optioanl[str]): The client user agent. - import_batch_id (Optional[str]): The id of the batch file that imported the tool. - federation_source (Optional[str]): The federation source of the tool - version (Optional[int]): The version of the tool. - team_id (Optional[str]): The id of the team that created the tool. - owner_email (Optional[str]): Tool owner's email. - visibility (Optional[str]): Visibility of the tool (e.g., public, private). - """ - - name: str - url: AnyHttpUrl - description: Optional[str] = None - created_at: Optional[datetime] = None - updated_at: Optional[datetime] = None - enabled: Optional[bool] = None - reachable: Optional[bool] = None - auth_type: Optional[str] = None - tags: Optional[list[str]] = None - # Comprehensive metadata for audit tracking - created_by: Optional[str] = None - created_from_ip: Optional[str] = None - created_via: Optional[str] = None - created_user_agent: Optional[str] = None - - modified_by: Optional[str] = None - modified_from_ip: Optional[str] = None - modified_via: Optional[str] = None - modified_user_agent: Optional[str] = None - - import_batch_id: Optional[str] = None - federation_source: Optional[str] = None - version: Optional[int] = None - # Team scoping fields for resource organization - team_id: Optional[str] = None - owner_email: Optional[str] = None - visibility: Optional[str] = None - - -# Tool types -class Tool(CommonAttributes): - """A tool that can be invoked. - - Attributes: - original_name (str): The original supplied name of the tool before imported by the gateway. integrationType (str): The integration type of the tool (e.g. MCP or REST). requestType (str): The HTTP method used to invoke the tool (GET, POST, PUT, DELETE, SSE, STDIO). headers (Dict[str, Any]): A JSON object representing HTTP headers. input_schema (Dict[str, Any]): A JSON Schema for validating the tool's input. annotations (Optional[Dict[str, Any]]): Tool annotations for behavior hints. + auth_type (Optional[str]): The type of authentication used ("basic", "bearer", or None). auth_username (Optional[str]): The username for basic authentication. auth_password (Optional[str]): The password for basic authentication. auth_token (Optional[str]): The token for bearer authentication. - jsonpath_filter (Optional[str]): Filter the tool based on a JSON path expression. - custom_name (Optional[str]): Custom tool name. - custom_name_slug (Optional[str]): Alternative custom tool name. - display_name (Optional[str]): Display name. - gateway_id (Optional[str]): The gateway id on which the tool is hosted. """ - model_config = ConfigDict(from_attributes=True) - original_name: Optional[str] = None + name: str + url: AnyHttpUrl + description: Optional[str] = None integration_type: str = "MCP" request_type: str = "SSE" - headers: Optional[Dict[str, Any]] = Field(default_factory=dict) + headers: Dict[str, Any] = Field(default_factory=dict) input_schema: Dict[str, Any] = Field(default_factory=lambda: {"type": "object", "properties": {}}) annotations: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Tool annotations for behavior hints") + auth_type: Optional[str] = None auth_username: Optional[str] = None auth_password: Optional[str] = None auth_token: Optional[str] = None - jsonpath_filter: Optional[str] = None - - # custom_name,custom_name_slug, display_name - custom_name: Optional[str] = None - custom_name_slug: Optional[str] = None - display_name: Optional[str] = None - - # Federation relationship with a local gateway - gateway_id: Optional[str] = None class ToolResult(BaseModel): @@ -877,7 +811,7 @@ class FederatedPrompt(Prompt): gateway_name: str -class Gateway(CommonAttributes): +class Gateway(BaseModel): """A federated gateway peer. Attributes: @@ -888,17 +822,11 @@ class Gateway(CommonAttributes): last_seen (Optional[datetime]): Timestamp when the gateway was last seen. """ - model_config = ConfigDict(from_attributes=True) id: str + name: str + url: AnyHttpUrl capabilities: ServerCapabilities last_seen: Optional[datetime] = None - slug: str - transport: str - last_seen: Optional[datetime] - # Header passthrough configuration - passthrough_headers: Optional[list[str]] # Store list of strings as JSON array - # Request type and authentication fields - auth_value: Optional[str | dict] # ===== RBAC Models ===== diff --git a/mcpgateway/plugins/framework/__init__.py b/mcpgateway/plugins/framework/__init__.py index db61745c7..2d8c59f3c 100644 --- a/mcpgateway/plugins/framework/__init__.py +++ b/mcpgateway/plugins/framework/__init__.py @@ -22,8 +22,6 @@ from mcpgateway.plugins.framework.manager import PluginManager from mcpgateway.plugins.framework.models import ( GlobalContext, - HttpHeaderPayload, - HttpHeaderPayloadResult, HookType, PluginCondition, PluginConfig, @@ -52,8 +50,6 @@ "ExternalPluginServer", "GlobalContext", "HookType", - "HttpHeaderPayload", - "HttpHeaderPayloadResult", "Plugin", "PluginCondition", "PluginConfig", diff --git a/mcpgateway/plugins/framework/constants.py b/mcpgateway/plugins/framework/constants.py index 7b446624f..065b85d1f 100644 --- a/mcpgateway/plugins/framework/constants.py +++ b/mcpgateway/plugins/framework/constants.py @@ -27,8 +27,3 @@ ERROR = "error" GET_PLUGIN_CONFIG = "get_plugin_config" IGNORE_CONFIG_EXTERNAL = "ignore_config_external" - -# Global Context Metadata fields - -TOOL_METADATA = "tool" -GATEWAY_METADATA = "gateway" diff --git a/mcpgateway/plugins/framework/external/mcp/client.py b/mcpgateway/plugins/framework/external/mcp/client.py index 7facb160e..fbb24bf52 100644 --- a/mcpgateway/plugins/framework/external/mcp/client.py +++ b/mcpgateway/plugins/framework/external/mcp/client.py @@ -135,29 +135,54 @@ async def __connect_to_stdio_server(self, server_script_path: str) -> None: raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name)) async def __connect_to_http_server(self, uri: str) -> None: - """Connect to an MCP plugin server via streamable http. + """Connect to an MCP plugin server via streamable http with retry logic. Args: uri: the URI of the mcp plugin server. Raises: - PluginError: if there is an external connection error. + PluginError: if there is an external connection error after all retries. """ - - try: - http_transport = await self._exit_stack.enter_async_context(streamablehttp_client(uri)) - self._http, self._write, _ = http_transport - self._session = await self._exit_stack.enter_async_context(ClientSession(self._http, self._write)) - - await self._session.initialize() - - # List available tools - response = await self._session.list_tools() - tools = response.tools - logger.info("\nConnected to plugin MCP (http) server with tools: %s", " ".join([tool.name for tool in tools])) - except Exception as e: - logger.exception(e) - raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name)) + max_retries = 3 + base_delay = 1.0 + + for attempt in range(max_retries): + logger.info(f"Connecting to external plugin server: {uri} (attempt {attempt + 1}/{max_retries})") + + try: + # Create a fresh exit stack for each attempt + async with AsyncExitStack() as temp_stack: + http_transport = await temp_stack.enter_async_context(streamablehttp_client(uri)) + http_client, write_func, _ = http_transport + session = await temp_stack.enter_async_context(ClientSession(http_client, write_func)) + + await session.initialize() + + # List available tools + response = await session.list_tools() + tools = response.tools + logger.info("Successfully connected to plugin MCP server with tools: %s", " ".join([tool.name for tool in tools])) + + # Success! Now move to the main exit stack + self._http = await self._exit_stack.enter_async_context(streamablehttp_client(uri)) + self._http, self._write, _ = self._http + self._session = await self._exit_stack.enter_async_context(ClientSession(self._http, self._write)) + await self._session.initialize() + return + + except Exception as e: + logger.warning(f"Connection attempt {attempt + 1}/{max_retries} failed: {e}") + + if attempt == max_retries - 1: + # Final attempt failed + error_msg = f"External plugin '{self.name}' connection failed after {max_retries} attempts: {uri} is not reachable. Please ensure the MCP server is running." + logger.error(error_msg) + raise PluginError(error=PluginErrorModel(message=error_msg, plugin_name=self.name)) + + # Wait before retry + delay = base_delay * (2**attempt) + logger.info(f"Retrying in {delay}s...") + await asyncio.sleep(delay) async def __invoke_hook(self, payload_result_model: Type[P], hook_type: HookType, payload: BaseModel, context: PluginContext) -> P: """Invoke an external plugin hook using the MCP protocol. diff --git a/mcpgateway/plugins/framework/manager.py b/mcpgateway/plugins/framework/manager.py index d373bda6c..2c29de2e7 100644 --- a/mcpgateway/plugins/framework/manager.py +++ b/mcpgateway/plugins/framework/manager.py @@ -35,7 +35,7 @@ # First-Party from mcpgateway.plugins.framework.base import Plugin, PluginRef -from mcpgateway.plugins.framework.errors import convert_exception_to_error, PluginError, PluginViolationError +from mcpgateway.plugins.framework.errors import convert_exception_to_error, PluginError from mcpgateway.plugins.framework.loader.config import ConfigLoader from mcpgateway.plugins.framework.loader.plugin import PluginLoader from mcpgateway.plugins.framework.models import ( @@ -74,16 +74,7 @@ # Use standard logging to avoid circular imports (plugins -> services -> plugins) logger = logging.getLogger(__name__) -T = TypeVar( - "T", - PromptPosthookPayload, - PromptPrehookPayload, - ResourcePostFetchPayload, - ResourcePreFetchPayload, - ToolPostInvokePayload, - ToolPreInvokePayload, -) - +T = TypeVar("T") # Configuration constants DEFAULT_PLUGIN_TIMEOUT = 30 # seconds @@ -140,7 +131,6 @@ async def execute( plugin_run: Callable[[PluginRef, T, PluginContext], Coroutine[Any, Any, PluginResult[T]]], compare: Callable[[T, list[PluginCondition], GlobalContext], bool], local_contexts: Optional[PluginContextTable] = None, - violations_as_exceptions: bool = False, ) -> tuple[PluginResult[T], PluginContextTable | None]: """Execute plugins in priority order with timeout protection. @@ -151,7 +141,6 @@ async def execute( plugin_run: Async function to execute a specific plugin hook. compare: Function to check if plugin conditions match the current context. local_contexts: Optional existing contexts from previous hook executions. - violations_as_exceptions: Raise violations as exceptions rather than as returns. Returns: A tuple containing: @@ -161,7 +150,6 @@ async def execute( Raises: PayloadSizeError: If the payload exceeds MAX_PAYLOAD_SIZE. PluginError: If there is an error inside a plugin. - PluginViolationError: If a violation occurs and violation_as_exceptions is set. Examples: >>> # Execute plugins with timeout protection @@ -200,7 +188,6 @@ async def execute( tenant_id=global_context.tenant_id, server_id=global_context.server_id, state={} if not global_context.state else deepcopy(global_context.state), - metadata={} if not global_context.metadata else deepcopy(global_context.metadata), ) # Get or create local context for this plugin local_context_key = global_context.request_id + pluginref.uuid @@ -233,16 +220,6 @@ async def execute( if not result.continue_processing: if pluginref.plugin.mode == PluginMode.ENFORCE: logger.warning(f"Plugin {pluginref.plugin.name} blocked request in enforce mode") - if violations_as_exceptions: - if result.violation: - plugin_name = result.violation.plugin_name - violation_reason = result.violation.reason - violation_desc = result.violation.description - violation_code = result.violation.code - raise PluginViolationError( - f"{plugin_run.__name__} blocked by plugin {plugin_name}: {violation_code} - {violation_reason} ({violation_desc})", violation=result.violation - ) - raise PluginViolationError(f"{plugin_run.__name__} blocked by plugin") return (PluginResult[T](continue_processing=False, modified_payload=current_payload, violation=result.violation, metadata=combined_metadata), res_local_contexts) if pluginref.plugin.mode == PluginMode.PERMISSIVE: logger.warning(f"Plugin {pluginref.plugin.name} would block (permissive mode): {result.violation.description if result.violation else 'No description'}") @@ -253,8 +230,7 @@ async def execute( raise PluginError(error=PluginErrorModel(message=f"Plugin {pluginref.name} exceeded {self.timeout}s timeout", plugin_name=pluginref.name)) # In permissive or enforce_ignore_error mode, continue with next plugin continue - except PluginViolationError: - raise + except PluginError as pe: logger.error(f"Plugin {pluginref.name} failed with error: {str(pe)}", exc_info=True) if self.config.plugin_settings.fail_on_plugin_error or pluginref.plugin.mode == PluginMode.ENFORCE: @@ -680,7 +656,10 @@ async def _cleanup_old_contexts(self) -> None: self._last_cleanup = current_time async def prompt_pre_fetch( - self, payload: PromptPrehookPayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False + self, + payload: PromptPrehookPayload, + global_context: GlobalContext, + local_contexts: Optional[PluginContextTable] = None, ) -> tuple[PromptPrehookResult, PluginContextTable | None]: """Execute pre-fetch hooks before a prompt is retrieved and rendered. @@ -688,7 +667,6 @@ async def prompt_pre_fetch( payload: The prompt payload containing name and arguments. global_context: Shared context for all plugins with request metadata. local_contexts: Optional existing contexts from previous executions. - violations_as_exceptions: Raise violations as exceptions rather than as returns. Returns: A tuple containing: @@ -726,7 +704,7 @@ async def prompt_pre_fetch( plugins = self._registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) # Execute plugins - result = await self._pre_prompt_executor.execute(plugins, payload, global_context, pre_prompt_fetch, pre_prompt_matches, local_contexts, violations_as_exceptions) + result = await self._pre_prompt_executor.execute(plugins, payload, global_context, pre_prompt_fetch, pre_prompt_matches, local_contexts) # Store contexts for potential reuse if result[1]: @@ -735,7 +713,7 @@ async def prompt_pre_fetch( return result async def prompt_post_fetch( - self, payload: PromptPosthookPayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False + self, payload: PromptPosthookPayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None ) -> tuple[PromptPosthookResult, PluginContextTable | None]: """Execute post-fetch hooks after a prompt is rendered. @@ -743,7 +721,6 @@ async def prompt_post_fetch( payload: The prompt result payload containing rendered messages. global_context: Shared context for all plugins with request metadata. local_contexts: Optional contexts from pre-fetch hook execution. - violations_as_exceptions: Raise violations as exceptions rather than as returns. Returns: A tuple containing: @@ -787,7 +764,7 @@ async def prompt_post_fetch( plugins = self._registry.get_plugins_for_hook(HookType.PROMPT_POST_FETCH) # Execute plugins - result = await self._post_prompt_executor.execute(plugins, payload, global_context, post_prompt_fetch, post_prompt_matches, local_contexts, violations_as_exceptions) + result = await self._post_prompt_executor.execute(plugins, payload, global_context, post_prompt_fetch, post_prompt_matches, local_contexts) # Clean up stored context after post-fetch if global_context.request_id in self._context_store: @@ -796,7 +773,10 @@ async def prompt_post_fetch( return result async def tool_pre_invoke( - self, payload: ToolPreInvokePayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False + self, + payload: ToolPreInvokePayload, + global_context: GlobalContext, + local_contexts: Optional[PluginContextTable] = None, ) -> tuple[ToolPreInvokeResult, PluginContextTable | None]: """Execute pre-invoke hooks before a tool is invoked. @@ -804,7 +784,6 @@ async def tool_pre_invoke( payload: The tool payload containing name and arguments. global_context: Shared context for all plugins with request metadata. local_contexts: Optional existing contexts from previous executions. - violations_as_exceptions: Raise violations as exceptions rather than as returns. Returns: A tuple containing: @@ -842,7 +821,7 @@ async def tool_pre_invoke( plugins = self._registry.get_plugins_for_hook(HookType.TOOL_PRE_INVOKE) # Execute plugins - result = await self._pre_tool_executor.execute(plugins, payload, global_context, pre_tool_invoke, pre_tool_matches, local_contexts, violations_as_exceptions) + result = await self._pre_tool_executor.execute(plugins, payload, global_context, pre_tool_invoke, pre_tool_matches, local_contexts) # Store contexts for potential reuse if result[1]: @@ -851,7 +830,7 @@ async def tool_pre_invoke( return result async def tool_post_invoke( - self, payload: ToolPostInvokePayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False + self, payload: ToolPostInvokePayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None ) -> tuple[ToolPostInvokeResult, PluginContextTable | None]: """Execute post-invoke hooks after a tool is invoked. @@ -859,7 +838,6 @@ async def tool_post_invoke( payload: The tool result payload containing invocation results. global_context: Shared context for all plugins with request metadata. local_contexts: Optional contexts from pre-invoke hook execution. - violations_as_exceptions: Raise violations as exceptions rather than as returns. Returns: A tuple containing: @@ -895,7 +873,7 @@ async def tool_post_invoke( plugins = self._registry.get_plugins_for_hook(HookType.TOOL_POST_INVOKE) # Execute plugins - result = await self._post_tool_executor.execute(plugins, payload, global_context, post_tool_invoke, post_tool_matches, local_contexts, violations_as_exceptions) + result = await self._post_tool_executor.execute(plugins, payload, global_context, post_tool_invoke, post_tool_matches, local_contexts) # Clean up stored context after post-invoke if global_context.request_id in self._context_store: @@ -904,7 +882,10 @@ async def tool_post_invoke( return result async def resource_pre_fetch( - self, payload: ResourcePreFetchPayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False + self, + payload: ResourcePreFetchPayload, + global_context: GlobalContext, + local_contexts: Optional[PluginContextTable] = None, ) -> tuple[ResourcePreFetchResult, PluginContextTable | None]: """Execute pre-fetch hooks before a resource is fetched. @@ -912,7 +893,6 @@ async def resource_pre_fetch( payload: The resource payload containing URI and metadata. global_context: Shared context for all plugins with request metadata. local_contexts: Optional existing contexts from previous hook executions. - violations_as_exceptions: Raise violations as exceptions rather than as returns. Returns: A tuple containing: @@ -934,7 +914,7 @@ async def resource_pre_fetch( plugins = self._registry.get_plugins_for_hook(HookType.RESOURCE_PRE_FETCH) # Execute plugins - result = await self._resource_pre_executor.execute(plugins, payload, global_context, pre_resource_fetch, pre_resource_matches, local_contexts, violations_as_exceptions) + result = await self._resource_pre_executor.execute(plugins, payload, global_context, pre_resource_fetch, pre_resource_matches, local_contexts) # Store context for potential post-fetch if result[1]: @@ -946,7 +926,7 @@ async def resource_pre_fetch( return result async def resource_post_fetch( - self, payload: ResourcePostFetchPayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False + self, payload: ResourcePostFetchPayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None ) -> tuple[ResourcePostFetchResult, PluginContextTable | None]: """Execute post-fetch hooks after a resource is fetched. @@ -954,7 +934,6 @@ async def resource_post_fetch( payload: The resource content payload containing fetched data. global_context: Shared context for all plugins with request metadata. local_contexts: Optional contexts from pre-fetch hook execution. - violations_as_exceptions: Raise violations as exceptions rather than as returns. Returns: A tuple containing: @@ -979,7 +958,7 @@ async def resource_post_fetch( plugins = self._registry.get_plugins_for_hook(HookType.RESOURCE_POST_FETCH) # Execute plugins - result = await self._resource_post_executor.execute(plugins, payload, global_context, post_resource_fetch, post_resource_matches, local_contexts, violations_as_exceptions) + result = await self._resource_post_executor.execute(plugins, payload, global_context, post_resource_fetch, post_resource_matches, local_contexts) # Clean up stored context after post-fetch if global_context.request_id in self._context_store: diff --git a/mcpgateway/plugins/framework/models.py b/mcpgateway/plugins/framework/models.py index a55a38b1c..5260106d9 100644 --- a/mcpgateway/plugins/framework/models.py +++ b/mcpgateway/plugins/framework/models.py @@ -12,10 +12,15 @@ # Standard from enum import Enum from pathlib import Path -from typing import Any, Generic, Optional, Self, TypeVar +from typing import Any, Generic, Optional, TypeVar + +try: + from typing import Self # Python 3.11+ +except ImportError: + from typing_extensions import Self # Python 3.10 # Third-Party -from pydantic import BaseModel, Field, field_serializer, field_validator, model_validator, PrivateAttr, RootModel, ValidationInfo +from pydantic import BaseModel, Field, field_serializer, field_validator, model_validator, PrivateAttr, ValidationInfo # First-Party from mcpgateway.models import PromptResult @@ -603,56 +608,12 @@ class PluginResult(BaseModel, Generic[T]): PromptPosthookResult = PluginResult[PromptPosthookPayload] -class HttpHeaderPayload(RootModel[dict[str, str]]): - """An HTTP dictionary of headers used in the pre/post HTTP forwarding hooks.""" - - def __iter__(self): - """Custom iterator function to override root attribute. - - Returns: - A custom iterator for header dictionary. - """ - return iter(self.root) - - def __getitem__(self, item: str) -> str: - """Custom getitem function to override root attribute. - - Args: - item: The http header key. - - Returns: - A custom accesser for the header dictionary. - """ - return self.root[item] - - def __setitem__(self, key: str, value: str) -> None: - """Custom setitem function to override root attribute. - - Args: - key: The http header key. - value: The http header value to be set. - """ - self.root[key] = value - - def __len__(self): - """Custom len function to override root attribute. - - Returns: - The len of the header dictionary. - """ - return len(self.root) - - -HttpHeaderPayloadResult = PluginResult[HttpHeaderPayload] - - class ToolPreInvokePayload(BaseModel): """A tool payload for a tool pre-invoke hook. Args: name: The tool name. args: The tool arguments for invocation. - headers: The http pass through headers. Examples: >>> payload = ToolPreInvokePayload(name="test_tool", args={"input": "data"}) @@ -673,7 +634,6 @@ class ToolPreInvokePayload(BaseModel): name: str args: Optional[dict[str, Any]] = Field(default_factory=dict) - headers: Optional[HttpHeaderPayload] = None class ToolPostInvokePayload(BaseModel): diff --git a/mcpgateway/routers/email_auth.py b/mcpgateway/routers/email_auth.py index a2ff53653..7e615db8e 100644 --- a/mcpgateway/routers/email_auth.py +++ b/mcpgateway/routers/email_auth.py @@ -19,7 +19,7 @@ # Standard from datetime import datetime, timedelta, UTC -from typing import Optional +from typing import Any, Dict, Optional # Third-Party from fastapi import APIRouter, Depends, HTTPException, Request, status @@ -104,7 +104,7 @@ def get_user_agent(request: Request) -> str: return request.headers.get("User-Agent", "unknown") -async def create_access_token(user: EmailUser, token_scopes: Optional[dict] = None, jti: Optional[str] = None) -> tuple[str, int]: +async def create_access_token(user: EmailUser, token_scopes: Optional[Dict[str, Any]] = None, jti: Optional[str] = None) -> tuple[str, int]: """Create JWT access token for user with enhanced scoping. Args: @@ -584,7 +584,6 @@ async def update_user(user_email: str, user_request: EmailRegistrationRequest, c new_password=user_request.password, ip_address="admin_update", user_agent="admin_panel", - skip_old_password_check=True, ) db.commit() diff --git a/mcpgateway/schemas.py b/mcpgateway/schemas.py index b2f3b6aaa..e8d8610e8 100644 --- a/mcpgateway/schemas.py +++ b/mcpgateway/schemas.py @@ -26,7 +26,13 @@ import json import logging import re -from typing import Any, Dict, List, Literal, Optional, Self, Union +from typing import Any, Dict, List, Literal, Optional, Union + +try: + from typing import Self +except ImportError: + # Python < 3.11 compatibility + from typing_extensions import Self # Third-Party from pydantic import AnyHttpUrl, BaseModel, ConfigDict, EmailStr, Field, field_serializer, field_validator, model_validator, ValidationInfo diff --git a/mcpgateway/services/gateway_service.py b/mcpgateway/services/gateway_service.py index 675c5b498..452680008 100644 --- a/mcpgateway/services/gateway_service.py +++ b/mcpgateway/services/gateway_service.py @@ -44,7 +44,7 @@ import os import tempfile import time -from typing import Any, AsyncGenerator, cast, Dict, List, Optional, Set, TYPE_CHECKING +from typing import Any, AsyncGenerator, cast, Dict, List, Optional, Set from urllib.parse import urlparse, urlunparse import uuid @@ -61,10 +61,10 @@ try: # Third-Party import redis - - REDIS_AVAILABLE = True + redis_available = True except ImportError: - REDIS_AVAILABLE = False + redis = None # type: ignore + redis_available = False logging.info("Redis is not utilized in this environment.") # First-Party @@ -72,6 +72,7 @@ from mcpgateway.db import Gateway as DbGateway from mcpgateway.db import Prompt as DbPrompt from mcpgateway.db import Resource as DbResource +from mcpgateway.db import ServerMetric from mcpgateway.db import SessionLocal from mcpgateway.db import Tool as DbTool from mcpgateway.observability import create_span @@ -141,7 +142,7 @@ class GatewayNameConflictError(GatewayError): >>> error.gateway_id is None True - >>> error_inactive = GatewayNameConflictError("inactive_gw", enabled=False, gateway_id=123) + >>> error_inactive = GatewayNameConflictError("inactive_gw", enabled=False, gateway_id="123") >>> str(error_inactive) 'Public Gateway already exists with name: inactive_gw (currently inactive, ID: 123)' >>> error_inactive.enabled @@ -150,7 +151,7 @@ class GatewayNameConflictError(GatewayError): 123 """ - def __init__(self, name: str, enabled: bool = True, gateway_id: Optional[int] = None, visibility: Optional[str] = "public"): + def __init__(self, name: str, enabled: bool = True, gateway_id: Optional[str] = None, visibility: Optional[str] = "public"): """Initialize the error with gateway information. Args: @@ -192,7 +193,7 @@ class GatewayUrlConflictError(GatewayError): >>> error.gateway_id is None True - >>> error_inactive = GatewayUrlConflictError("http://inactive.com/gw", enabled=False, gateway_id=123) + >>> error_inactive = GatewayUrlConflictError("http://inactive.com/gw", enabled=False, gateway_id="123") >>> str(error_inactive) 'Public Gateway already exists with URL: http://inactive.com/gw (currently inactive, ID: 123)' >>> error_inactive.enabled @@ -201,7 +202,7 @@ class GatewayUrlConflictError(GatewayError): 123 """ - def __init__(self, url: str, enabled: bool = True, gateway_id: Optional[int] = None, visibility: Optional[str] = "public"): + def __init__(self, url: str, enabled: bool = True, gateway_id: Optional[str] = None, visibility: Optional[str] = "public"): """Initialize the error with gateway information. Args: @@ -282,10 +283,10 @@ def __init__(self) -> None: >>> hasattr(service, '_instance_id') or True # May not exist if no Redis True """ - self._event_subscribers: List[asyncio.Queue] = [] + self._event_subscribers: List[asyncio.Queue[Dict[str, Any]]] = [] self._http_client = ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": not settings.skip_ssl_verify}) self._health_check_interval = GW_HEALTH_CHECK_INTERVAL - self._health_check_task: Optional[asyncio.Task] = None + self._health_check_task: Optional[asyncio.Task[None]] = None self._active_gateways: Set[str] = set() # Track active gateway URLs self._stream_response = None self._pending_responses = {} @@ -299,8 +300,8 @@ def __init__(self) -> None: # Initialize optional Redis client holder self._redis_client: Optional[Any] = None - if self.redis_url and REDIS_AVAILABLE: - self._redis_client = redis.from_url(self.redis_url) + if self.redis_url and redis_available: + self._redis_client = redis.from_url(self.redis_url) # type: ignore self._instance_id = str(uuid.uuid4()) # Unique ID for this process self._leader_key = "gateway_service_leader" self._leader_ttl = 40 # seconds @@ -352,7 +353,7 @@ def normalize_url(url: str) -> str: # For all other URLs, preserve the domain name return url - async def _validate_gateway_url(self, url: str, headers: dict, transport_type: str, timeout: Optional[int] = None): + async def _validate_gateway_url(self, url: str, headers: Dict[str, str], transport_type: str, timeout: Optional[int] = None) -> bool: """ Validate if the given URL is a live Server-Sent Events (SSE) endpoint. @@ -554,7 +555,7 @@ async def register_gateway( authentication_headers = {str(k): str(v) for k, v in header_dict.items()} elif isinstance(auth_value, str) and auth_value: # Decode persisted auth for initialization - decoded = decode_auth(auth_value) + decoded = cast(Dict[str, str], decode_auth(auth_value)) authentication_headers = {str(k): str(v) for k, v in decoded.items()} else: authentication_headers = None @@ -679,41 +680,24 @@ async def register_gateway( await self._notify_gateway_added(db_gateway) return GatewayRead.model_validate(db_gateway).masked() - except* GatewayConnectionError as ge: # pragma: no mutate - if TYPE_CHECKING: - ge: ExceptionGroup[GatewayConnectionError] - logger.error(f"GatewayConnectionError in group: {ge.exceptions}") - raise ge.exceptions[0] - except* GatewayNameConflictError as gnce: # pragma: no mutate - if TYPE_CHECKING: - gnce: ExceptionGroup[GatewayNameConflictError] - logger.error(f"GatewayNameConflictError in group: {gnce.exceptions}") - raise gnce.exceptions[0] - except* GatewayUrlConflictError as guce: # pragma: no mutate - if TYPE_CHECKING: - guce: ExceptionGroup[GatewayUrlConflictError] - logger.error(f"GatewayUrlConflictError in group: {guce.exceptions}") - raise guce.exceptions[0] - except* ValueError as ve: # pragma: no mutate - if TYPE_CHECKING: - ve: ExceptionGroup[ValueError] - logger.error(f"ValueErrors in group: {ve.exceptions}") - raise ve.exceptions[0] - except* RuntimeError as re: # pragma: no mutate - if TYPE_CHECKING: - re: ExceptionGroup[RuntimeError] - logger.error(f"RuntimeErrors in group: {re.exceptions}") - raise re.exceptions[0] - except* IntegrityError as ie: # pragma: no mutate - if TYPE_CHECKING: - ie: ExceptionGroup[IntegrityError] - logger.error(f"IntegrityErrors in group: {ie.exceptions}") - raise ie.exceptions[0] - except* BaseException as other: # catches every other sub-exception # pragma: no mutate - if TYPE_CHECKING: - other: ExceptionGroup[Exception] - logger.error(f"Other grouped errors: {other.exceptions}") - raise other.exceptions[0] + except GatewayConnectionError as ge: # pragma: no mutate + logger.error(f"GatewayConnectionError: {ge}") + raise ge + except GatewayNameConflictError as gnce: # pragma: no mutate + logger.error(f"GatewayNameConflictError: {gnce}") + raise gnce + except ValueError as ve: # pragma: no mutate + logger.error(f"ValueError: {ve}") + raise ve + except RuntimeError as re: # pragma: no mutate + logger.error(f"RuntimeError: {re}") + raise re + except IntegrityError as ie: # pragma: no mutate + logger.error(f"IntegrityError: {ie}") + raise ie + except BaseException as other: # catches every other sub-exception # pragma: no mutate + logger.error(f"Other error: {other}") + raise other async def fetch_tools_after_oauth(self, db: Session, gateway_id: str) -> Dict[str, Any]: """Fetch tools from MCP server after OAuth completion for Authorization Code flow. @@ -767,10 +751,6 @@ async def fetch_tools_after_oauth(self, db: Session, gateway_id: str) -> Dict[st # Filter out any None tools and create DbTool objects tools_to_add = [] for tool in tools: - if tool is None: - logger.warning("Skipping None tool in tools list") - continue - try: db_tool = self._create_db_tool( tool=tool, @@ -937,7 +917,7 @@ async def update_gateway( modified_via: Optional[str] = None, modified_user_agent: Optional[str] = None, include_inactive: bool = True, - ) -> GatewayRead: + ) -> Optional[GatewayRead]: """Update a gateway. Args: @@ -1040,8 +1020,7 @@ async def update_gateway( gateway.url = self.normalize_url(str(gateway_update.url)) if gateway_update.description is not None: gateway.description = gateway_update.description - if gateway_update.transport is not None: - gateway.transport = gateway_update.transport + gateway.transport = gateway_update.transport if gateway_update.tags is not None: gateway.tags = gateway_update.tags if gateway_update.visibility is not None: @@ -1049,16 +1028,8 @@ async def update_gateway( if gateway_update.visibility is not None: gateway.visibility = gateway_update.visibility if gateway_update.passthrough_headers is not None: - if isinstance(gateway_update.passthrough_headers, list): - gateway.passthrough_headers = gateway_update.passthrough_headers - else: - if isinstance(gateway_update.passthrough_headers, str): - parsed: List[str] = [h.strip() for h in gateway_update.passthrough_headers.split(",") if h.strip()] - gateway.passthrough_headers = parsed - else: - raise GatewayError("Invalid passthrough_headers format: must be list[str] or comma-separated string") - - logger.info("Updated passthrough_headers for gateway {gateway.id}: {gateway.passthrough_headers}") + gateway.passthrough_headers = gateway_update.passthrough_headers + logger.info(f"Updated passthrough_headers for gateway {gateway.id}: {gateway.passthrough_headers}") if getattr(gateway, "auth_type", None) is not None: gateway.auth_type = gateway_update.auth_type @@ -1084,7 +1055,7 @@ async def update_gateway( elif settings.masked_auth_value not in (token, password, header_value): # Check if values differ from existing ones if gateway.auth_value != gateway_update.auth_value: - gateway.auth_value = decode_auth(gateway_update.auth_value) if isinstance(gateway_update.auth_value, str) else gateway_update.auth_value + gateway.auth_value = cast(Dict[str, str], decode_auth(gateway_update.auth_value)) if isinstance(gateway_update.auth_value, str) else gateway_update.auth_value # Try to reinitialize connection if URL changed if gateway_update.url is not None: @@ -1197,7 +1168,7 @@ async def update_gateway( gateway.modified_via = modified_via if modified_user_agent: gateway.modified_user_agent = modified_user_agent - if hasattr(gateway, "version") and gateway.version is not None: + if hasattr(gateway, "version"): gateway.version = gateway.version + 1 else: gateway.version = 1 @@ -1476,7 +1447,7 @@ async def delete_gateway(self, db: Session, gateway_id: str) -> None: db.rollback() raise GatewayError(f"Failed to delete gateway: {str(e)}") - async def forward_request(self, gateway_or_db, method: str, params: Optional[Dict[str, Any]] = None) -> Any: # noqa: F811 # pylint: disable=function-redefined + async def forward_request(self, gateway_or_db: Any, method: str, params: Optional[Dict[str, Any]] = None) -> Any: # noqa: F811 # pylint: disable=function-redefined """ Forward a request to a gateway or multiple gateways. @@ -1576,7 +1547,7 @@ async def _forward_request_to_gateway(self, gateway: DbGateway, method: str, par # Handle non-OAuth authentication (existing logic) auth_data = gateway.auth_value or {} if isinstance(auth_data, str): - headers = decode_auth(auth_data) if auth_data else self._get_auth_headers() + headers = cast(Dict[str, str], decode_auth(auth_data)) if auth_data else self._get_auth_headers() elif isinstance(auth_data, dict): headers = {str(k): str(v) for k, v in auth_data.items()} else: @@ -1666,7 +1637,7 @@ async def _forward_request_to_all(self, db: Session, method: str, params: Option # Handle non-OAuth authentication auth_data = gateway.auth_value or {} if isinstance(auth_data, str): - headers = decode_auth(auth_data) + headers = cast(Dict[str, str], decode_auth(auth_data)) elif isinstance(auth_data, dict): headers = {str(k): str(v) for k, v in auth_data.items()} else: @@ -1844,7 +1815,7 @@ async def check_health_of_gateways(self, gateways: List[DbGateway]) -> bool: # Handle non-OAuth authentication (existing logic) auth_data = gateway.auth_value or {} if isinstance(auth_data, str): - headers = decode_auth(auth_data) + headers = cast(Dict[str, str], decode_auth(auth_data)) elif isinstance(auth_data, dict): headers = {str(k): str(v) for k, v in auth_data.items()} else: @@ -1999,7 +1970,7 @@ async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]: >>> asyncio.run(test_event()) 'test' """ - queue: asyncio.Queue = asyncio.Queue() + queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue() self._event_subscribers.append(queue) try: while True: @@ -2087,7 +2058,7 @@ async def _initialize_gateway( resources = [] prompts = [] if auth_type in ("basic", "bearer", "headers") and isinstance(authentication, str): - authentication = decode_auth(authentication) + authentication = cast(Dict[str, str], decode_auth(authentication)) if transport.lower() == "sse": capabilities, tools, resources, prompts = await self.connect_to_sse_server(url, authentication) elif transport.lower() == "streamablehttp": @@ -2457,6 +2428,8 @@ async def connect_to_sse_server(self, server_url: str, authentication: Optional[ mime_type=resource_data.get("mime_type"), template=resource_data.get("template"), content="", + team_id=None, + owner_email=None, ) ) logger.info(f"Fetched {len(resources)} resources from gateway") @@ -2484,6 +2457,8 @@ async def connect_to_sse_server(self, server_url: str, authentication: Optional[ name=prompt_data.get("name", ""), description=prompt_data.get("description"), template=prompt_data.get("template", ""), + team_id=None, + owner_email=None, ) ) logger.info(f"Fetched {len(prompts)} prompts from gateway") @@ -2566,3 +2541,33 @@ async def connect_to_streamablehttp_server(self, server_url: str, authentication logger.warning(f"Failed to fetch prompts: {e}") return capabilities, tools, resources, prompts + + async def _record_server_metric(self, db: Session, server: DbGateway, start_time: float, success: bool, error_message: Optional[str]) -> None: + """ + Records a metric for a server interaction. + + This function calculates the response time using the provided start time and records + the metric details (including whether the interaction was successful and any error message) + into the database. The metric is then committed to the database. + + Args: + db (Session): The SQLAlchemy database session. + server (DbGateway): The server/gateway that was accessed. + start_time (float): The monotonic start time of the interaction. + success (bool): True if the interaction succeeded; otherwise, False. + error_message (Optional[str]): The error message if the interaction failed, otherwise None. + """ + end_time = time.monotonic() + response_time = end_time - start_time + metric = ServerMetric( + server_id=server.id, + response_time=response_time, + is_success=success, + error_message=error_message, + ) + db.add(metric) + db.commit() + try: + db.expire(server, ["metrics"]) + except Exception: + pass diff --git a/mcpgateway/services/prompt_service.py b/mcpgateway/services/prompt_service.py index 2bae69493..6f5ef90fe 100644 --- a/mcpgateway/services/prompt_service.py +++ b/mcpgateway/services/prompt_service.py @@ -35,7 +35,7 @@ from mcpgateway.db import PromptMetric, server_prompt_association from mcpgateway.models import Message, PromptResult, Role, TextContent from mcpgateway.observability import create_span -from mcpgateway.plugins.framework import GlobalContext, PluginManager, PromptPosthookPayload, PromptPrehookPayload +from mcpgateway.plugins.framework import GlobalContext, PluginManager, PluginViolationError, PromptPosthookPayload, PromptPrehookPayload from mcpgateway.schemas import PromptCreate, PromptRead, PromptUpdate, TopPerformer from mcpgateway.services.logging_service import LoggingService from mcpgateway.utils.metrics_common import build_top_performers @@ -152,7 +152,7 @@ async def shutdown(self) -> None: self._event_subscribers.clear() logger.info("Prompt service shutdown complete") - async def get_top_prompts(self, db: Session, limit: int = 5) -> List[TopPerformer]: + async def get_top_prompts(self, db: Session, limit: Optional[int] = 5) -> List[TopPerformer]: """Retrieve the top-performing prompts based on execution count. Queries the database to get prompts with their metrics, ordered by the number of executions @@ -161,7 +161,8 @@ async def get_top_prompts(self, db: Session, limit: int = 5) -> List[TopPerforme Args: db (Session): Database session for querying prompt metrics. - limit (int): Maximum number of prompts to return. Defaults to 5. + limit (Optional[int]): Maximum number of prompts to return. Defaults to 5. + If None, returns all prompts. Returns: List[TopPerformer]: A list of TopPerformer objects, each containing: @@ -172,7 +173,7 @@ async def get_top_prompts(self, db: Session, limit: int = 5) -> List[TopPerforme - success_rate: Success rate percentage, or None if no metrics. - last_execution: Timestamp of the last execution, or None if no metrics. """ - results = ( + query = ( db.query( DbPrompt.id, DbPrompt.name, @@ -190,12 +191,47 @@ async def get_top_prompts(self, db: Session, limit: int = 5) -> List[TopPerforme .outerjoin(PromptMetric) .group_by(DbPrompt.id, DbPrompt.name) .order_by(desc("execution_count")) - .limit(limit) - .all() ) + + # If a limit is provided (default 5), cap the number of rows returned + if limit is not None: + query = query.limit(limit) + + results = query.all() return build_top_performers(results) + async def _record_prompt_metric(self, db: Session, prompt: DbPrompt, start_time: float, success: bool, error_message: Optional[str]) -> None: + """ + Records a metric for a prompt invocation. + + This function calculates the response time using the provided start time and records + the metric details (including whether the invocation was successful and any error message) + into the database. The metric is then committed to the database. + + Args: + db (Session): The SQLAlchemy database session. + prompt (DbPrompt): The prompt that was invoked. + start_time (float): The monotonic start time of the invocation. + success (bool): True if the invocation succeeded; otherwise, False. + error_message (Optional[str]): The error message if the invocation failed, otherwise None. + """ + end_time = time.monotonic() + response_time = end_time - start_time + metric = PromptMetric( + prompt_id=prompt.id, + response_time=response_time, + is_success=success, + error_message=error_message, + ) + db.add(metric) + db.commit() + # Expire metrics relationship for accurate immediate aggregation + try: # pragma: no cover + db.expire(prompt, ["metrics"]) + except Exception: # noqa: BLE001 + pass + def _convert_db_prompt(self, db_prompt: DbPrompt) -> Dict[str, Any]: """ Convert a DbPrompt instance to a dictionary matching the PromptRead schema, @@ -565,7 +601,6 @@ async def get_prompt( PluginViolationError: If prompt violates a plugin policy PromptNotFoundError: If prompt not found PromptError: For other prompt errors - PluginError: If encounters issue with plugin Examples: >>> from mcpgateway.services.prompt_service import PromptService @@ -581,6 +616,9 @@ async def get_prompt( """ start_time = time.monotonic() + success = False + error_message = None + prompt = None # Create a trace span for prompt rendering with create_span( @@ -594,67 +632,113 @@ async def get_prompt( "request_id": request_id or "none", }, ) as span: - if self._plugin_manager: - if not request_id: - request_id = uuid.uuid4().hex - global_context = GlobalContext(request_id=request_id, user=user, server_id=server_id, tenant_id=tenant_id) - pre_result, context_table = await self._plugin_manager.prompt_pre_fetch( - payload=PromptPrehookPayload(name=name, args=arguments), global_context=global_context, local_contexts=None, violations_as_exceptions=True - ) - - # Use modified payload if provided - if pre_result.modified_payload: - payload = pre_result.modified_payload - name = payload.name - arguments = payload.args - - # Find prompt - prompt = db.execute(select(DbPrompt).where(DbPrompt.name == name).where(DbPrompt.is_active)).scalar_one_or_none() - - if not prompt: - inactive_prompt = db.execute(select(DbPrompt).where(DbPrompt.name == name).where(not_(DbPrompt.is_active))).scalar_one_or_none() - if inactive_prompt: - raise PromptNotFoundError(f"Prompt '{name}' exists but is inactive") - - raise PromptNotFoundError(f"Prompt not found: {name}") - - if not arguments: - result = PromptResult( - messages=[ - Message( - role=Role.USER, - content=TextContent(type="text", text=prompt.template), - ) - ], - description=prompt.description, - ) - else: - try: - prompt.validate_arguments(arguments) - rendered = self._render_template(prompt.template, arguments) - messages = self._parse_messages(rendered) - result = PromptResult(messages=messages, description=prompt.description) - except Exception as e: - if span: - span.set_attribute("error", True) - span.set_attribute("error.message", str(e)) - raise PromptError(f"Failed to process prompt: {str(e)}") - - if self._plugin_manager: - post_result, _ = await self._plugin_manager.prompt_post_fetch( - payload=PromptPosthookPayload(name=name, result=result), global_context=global_context, local_contexts=context_table, violations_as_exceptions=True - ) - # Use modified payload if provided - return post_result.modified_payload.result if post_result.modified_payload else result - - # Set success attributes on span - if span: - span.set_attribute("success", True) - span.set_attribute("duration.ms", (time.monotonic() - start_time) * 1000) - if result and hasattr(result, "messages"): - span.set_attribute("messages.count", len(result.messages)) - - return result + try: + if self._plugin_manager: + if not request_id: + request_id = uuid.uuid4().hex + global_context = GlobalContext(request_id=request_id, user=user, server_id=server_id, tenant_id=tenant_id) + try: + pre_result, context_table = await self._plugin_manager.prompt_pre_fetch(payload=PromptPrehookPayload(name=name, args=arguments), global_context=global_context, local_contexts=None) + + if not pre_result.continue_processing: + # Plugin blocked the request + if pre_result.violation: + plugin_name = pre_result.violation.plugin_name + violation_reason = pre_result.violation.reason + violation_desc = pre_result.violation.description + violation_code = pre_result.violation.code + raise PluginViolationError(f"Pre prompting fetch blocked by plugin {plugin_name}: {violation_code} - {violation_reason} ({violation_desc})", pre_result.violation) + raise PluginViolationError("Pre prompting fetch blocked by plugin") + + # Use modified payload if provided + if pre_result.modified_payload: + payload = pre_result.modified_payload + name = payload.name + arguments = payload.args + except PluginViolationError: + raise + except Exception as e: + logger.error(f"Error in pre-prompt fetch plugin hook: {e}") + # Only fail if configured to do so + if self._plugin_manager.config and self._plugin_manager.config.plugin_settings.fail_on_plugin_error: + raise + + # Find prompt + prompt = db.execute(select(DbPrompt).where(DbPrompt.name == name).where(DbPrompt.is_active)).scalar_one_or_none() + + if not prompt: + inactive_prompt = db.execute(select(DbPrompt).where(DbPrompt.name == name).where(not_(DbPrompt.is_active))).scalar_one_or_none() + if inactive_prompt: + raise PromptNotFoundError(f"Prompt '{name}' exists but is inactive") + + raise PromptNotFoundError(f"Prompt not found: {name}") + + if not arguments: + result = PromptResult( + messages=[ + Message( + role=Role.USER, + content=TextContent(type="text", text=prompt.template), + ) + ], + description=prompt.description, + ) + else: + try: + prompt.validate_arguments(arguments) + rendered = self._render_template(prompt.template, arguments) + messages = self._parse_messages(rendered) + result = PromptResult(messages=messages, description=prompt.description) + except Exception as e: + if span: + span.set_attribute("error", True) + span.set_attribute("error.message", str(e)) + raise PromptError(f"Failed to process prompt: {str(e)}") + + if self._plugin_manager: + try: + post_result, _ = await self._plugin_manager.prompt_post_fetch(payload=PromptPosthookPayload(name=name, result=result), global_context=global_context, local_contexts=context_table) + if not post_result.continue_processing: + # Plugin blocked the request + if post_result.violation: + plugin_name = post_result.violation.plugin_name + violation_reason = post_result.violation.reason + violation_desc = post_result.violation.description + violation_code = post_result.violation.code + raise PluginViolationError(f"Post prompting fetch blocked by plugin {plugin_name}: {violation_code} - {violation_reason} ({violation_desc})", post_result.violation) + raise PluginViolationError("Post prompting fetch blocked by plugin") + # Use modified payload if provided + if post_result.modified_payload: + result = post_result.modified_payload.result + except PluginViolationError: + raise + except Exception as e: + logger.error(f"Error in post-prompt fetch plugin hook: {e}") + # Only fail if configured to do so + if self._plugin_manager.config and self._plugin_manager.config.plugin_settings.fail_on_plugin_error: + raise + + # Set success attributes on span + if span: + span.set_attribute("success", True) + span.set_attribute("duration.ms", (time.monotonic() - start_time) * 1000) + if result and hasattr(result, "messages"): + span.set_attribute("messages.count", len(result.messages)) + + # Mark as successful only after all operations complete successfully + success = True + return result + except Exception as e: + error_message = str(e) + # Set span error status + if span: + span.set_attribute("error", True) + span.set_attribute("error.message", str(e)) + raise + finally: + # Record metric regardless of success or failure + if prompt: + await self._record_prompt_metric(db, prompt, start_time, success, error_message) async def update_prompt( self, diff --git a/mcpgateway/services/resource_service.py b/mcpgateway/services/resource_service.py index 8cac7d735..dbd4daddc 100644 --- a/mcpgateway/services/resource_service.py +++ b/mcpgateway/services/resource_service.py @@ -151,16 +151,16 @@ async def shutdown(self) -> None: self._event_subscribers.clear() logger.info("Resource service shutdown complete") - async def get_top_resources(self, db: Session, limit: int = 5) -> List[TopPerformer]: - """Retrieve the top-performing resources based on execution count. + async def get_top_resources(self, db: Session, limit: Optional[int] = 5) -> List[TopPerformer]: + """Retrieve top-performing resources by execution count. - Queries the database to get resources with their metrics, ordered by the number of executions - in descending order. Uses the resource URI as the name field for TopPerformer objects. - Returns a list of TopPerformer objects containing resource details and performance metrics. + Aggregates resource metrics (execution count, average response time, success rate, last execution) + ordered by execution count descending. Returns TopPerformer objects, using resource URI as the name. Args: db (Session): Database session for querying resource metrics. - limit (int): Maximum number of resources to return. Defaults to 5. + limit (Optional[int]): Maximum number of resources to return. Defaults to 5. + If None, returns all resources. Returns: List[TopPerformer]: A list of TopPerformer objects, each containing: @@ -171,7 +171,7 @@ async def get_top_resources(self, db: Session, limit: int = 5) -> List[TopPerfor - success_rate: Success rate percentage, or None if no metrics. - last_execution: Timestamp of the last execution, or None if no metrics. """ - results = ( + query = ( db.query( DbResource.id, DbResource.uri.label("name"), # Using URI as the name field for TopPerformer @@ -189,12 +189,46 @@ async def get_top_resources(self, db: Session, limit: int = 5) -> List[TopPerfor .outerjoin(ResourceMetric) .group_by(DbResource.id, DbResource.uri) .order_by(desc("execution_count")) - .limit(limit) - .all() ) + + if limit is not None: + query = query.limit(limit) + + results = query.all() return build_top_performers(results) + async def _record_resource_metric(self, db: Session, resource: DbResource, start_time: float, success: bool, error_message: Optional[str]) -> None: + """ + Records a metric for a resource access. + + This function calculates the response time using the provided start time and records + the metric details (including whether the access was successful and any error message) + into the database. The metric is then committed to the database. + + Args: + db (Session): The SQLAlchemy database session. + resource (DbResource): The resource that was accessed. + start_time (float): The monotonic start time of the access. + success (bool): True if the access succeeded; otherwise, False. + error_message (Optional[str]): The error message if the access failed, otherwise None. + """ + end_time = time.monotonic() + response_time = end_time - start_time + metric = ResourceMetric( + resource_id=resource.id, + response_time=response_time, + is_success=success, + error_message=error_message, + ) + db.add(metric) + db.commit() + # Expire metrics relationship so subsequent accesses re-query fresh data + try: # pragma: no cover + db.expire(resource, ["metrics"]) + except Exception: # noqa: BLE001 + pass + def _convert_resource_to_read(self, resource: DbResource) -> ResourceRead: """ Converts a DbResource instance into a ResourceRead model, including aggregated metrics. @@ -573,8 +607,6 @@ async def read_resource(self, db: Session, uri: str, request_id: Optional[str] = Raises: ResourceNotFoundError: If resource not found ResourceError: If blocked by plugin - PluginError: If encounters issue with plugin - PluginViolationError: If plugin violated the request. Example - In case of OPA plugin, if the request is denied by policy. Examples: >>> from mcpgateway.services.resource_service import ResourceService @@ -601,6 +633,9 @@ async def read_resource(self, db: Session, uri: str, request_id: Optional[str] = True """ start_time = time.monotonic() + success = False + error_message = None + resource = None # Create trace span for resource reading with create_span( @@ -614,105 +649,124 @@ async def read_resource(self, db: Session, uri: str, request_id: Optional[str] = "resource.type": "template" if ("{" in uri and "}" in uri) else "static", }, ) as span: - # Generate request ID if not provided - if not request_id: - request_id = str(uuid.uuid4()) - - original_uri = uri - contexts = None - - # Call pre-fetch hooks if plugin manager is available - plugin_eligible = bool(self._plugin_manager and PLUGINS_AVAILABLE and ("://" in uri)) - if plugin_eligible: - # Initialize plugin manager if needed - # pylint: disable=protected-access - if not self._plugin_manager._initialized: - await self._plugin_manager.initialize() - # pylint: enable=protected-access - - # Create plugin context - # Normalize user to an identifier string if provided - user_id = None - if user is not None: - if isinstance(user, dict) and "email" in user: - user_id = user.get("email") - elif isinstance(user, str): - user_id = user - else: - # Attempt to fallback to attribute access - user_id = getattr(user, "email", None) - - global_context = GlobalContext(request_id=request_id, user=user_id, server_id=server_id) - - # Create pre-fetch payload - pre_payload = ResourcePreFetchPayload(uri=uri, metadata={}) - - # Execute pre-fetch hooks - pre_result, contexts = await self._plugin_manager.resource_pre_fetch(pre_payload, global_context, violations_as_exceptions=True) - # Use modified URI if plugin changed it - if pre_result.modified_payload: - uri = pre_result.modified_payload.uri - logger.debug(f"Resource URI modified by plugin: {original_uri} -> {uri}") - - # Original resource fetching logic - # Check for template - if "{" in uri and "}" in uri: - content = await self._read_template_resource(uri) - else: - # Find resource - resource = db.execute(select(DbResource).where(DbResource.uri == uri).where(DbResource.is_active)).scalar_one_or_none() - - if not resource: - # Check if inactive resource exists - inactive_resource = db.execute(select(DbResource).where(DbResource.uri == uri).where(not_(DbResource.is_active))).scalar_one_or_none() - - if inactive_resource: - raise ResourceNotFoundError(f"Resource '{uri}' exists but is inactive") - - raise ResourceNotFoundError(f"Resource not found: {uri}") - - content = resource.content - - # Call post-fetch hooks if plugin manager is available - if plugin_eligible: - # Create post-fetch payload - post_payload = ResourcePostFetchPayload(uri=original_uri, content=content) - - # Execute post-fetch hooks - post_result, _ = await self._plugin_manager.resource_post_fetch(post_payload, global_context, contexts, violations_as_exceptions=True) # Pass contexts from pre-fetch - - # Use modified content if plugin changed it - if post_result.modified_payload: - content = post_result.modified_payload.content - logger.debug(f"Resource content modified by plugin for URI: {original_uri}") - - # Set success attributes on span - if span: - span.set_attribute("success", True) - span.set_attribute("duration.ms", (time.monotonic() - start_time) * 1000) - if content: - span.set_attribute("content.size", len(str(content))) - - # Return standardized content without breaking callers that expect passthrough - # Prefer returning first-class content models or objects with content-like attributes. - # ResourceContent and TextContent already imported at top level - - # If content is already a Pydantic content model, return as-is - if isinstance(content, (ResourceContent, TextContent)): - return content - - # If content is any object that quacks like content (e.g., MagicMock with .text/.blob), return as-is - if hasattr(content, "text") or hasattr(content, "blob"): + try: + # Generate request ID if not provided + if not request_id: + request_id = str(uuid.uuid4()) + + original_uri = uri + contexts = None + global_context = None + + # Call pre-fetch hooks if plugin manager is available + if self._plugin_manager and PLUGINS_AVAILABLE: + # Initialize plugin manager if needed + # pylint: disable=protected-access + if not self._plugin_manager._initialized: + await self._plugin_manager.initialize() + # pylint: enable=protected-access + + # Create plugin context + global_context = GlobalContext(request_id=request_id, user=user, server_id=server_id) + + # Create pre-fetch payload + pre_payload = ResourcePreFetchPayload(uri=uri, metadata={}) + + # Execute pre-fetch hooks + try: + pre_result, contexts = await self._plugin_manager.resource_pre_fetch(pre_payload, global_context) + + # Check if we should continue + if not pre_result.continue_processing: + # Plugin blocked the resource fetch + if pre_result.violation: + logger.warning(f"Resource blocked by plugin: {pre_result.violation.reason} (URI: {uri})") + raise ResourceError(f"Resource blocked: {pre_result.violation.reason}") + raise ResourceError("Resource fetch blocked by plugin") + + # Use modified URI if plugin changed it + if pre_result.modified_payload: + uri = pre_result.modified_payload.uri + logger.debug(f"Resource URI modified by plugin: {original_uri} -> {uri}") + except ResourceError: + raise + except Exception as e: + logger.error(f"Error in resource pre-fetch hooks: {e}") + # Continue without plugin processing if there's an error + + # Original resource fetching logic + # Check for template + if "{" in uri and "}" in uri: + content = await self._read_template_resource(uri) + else: + # Find resource + resource = db.execute(select(DbResource).where(DbResource.uri == uri).where(DbResource.is_active)).scalar_one_or_none() + + if not resource: + # Check if inactive resource exists + inactive_resource = db.execute(select(DbResource).where(DbResource.uri == uri).where(not_(DbResource.is_active))).scalar_one_or_none() + + if inactive_resource: + raise ResourceNotFoundError(f"Resource '{uri}' exists but is inactive") + + raise ResourceNotFoundError(f"Resource not found: {uri}") + + content = resource.content + + # Call post-fetch hooks if plugin manager is available + if self._plugin_manager and PLUGINS_AVAILABLE and global_context: + # Create post-fetch payload + post_payload = ResourcePostFetchPayload(uri=original_uri, content=content) + + # Execute post-fetch hooks + try: + post_result, _ = await self._plugin_manager.resource_post_fetch( + post_payload, + global_context, + contexts, # Pass contexts from pre-fetch + ) + + # Check if we should continue + if not post_result.continue_processing: + # Plugin blocked the resource after fetching + if post_result.violation: + logger.warning(f"Resource content blocked by plugin: {post_result.violation.reason} (URI: {original_uri})") + raise ResourceError(f"Resource content blocked: {post_result.violation.reason}") + raise ResourceError("Resource content blocked by plugin") + + # Use modified content if plugin changed it + if post_result.modified_payload: + content = post_result.modified_payload.content + logger.debug(f"Resource content modified by plugin for URI: {original_uri}") + except ResourceError: + raise + except Exception as e: + logger.error(f"Error in resource post-fetch hooks: {e}") + # Continue with unmodified content if there's an error + + # Set success attributes on span + if span: + span.set_attribute("success", True) + span.set_attribute("duration.ms", (time.monotonic() - start_time) * 1000) + if content: + span.set_attribute("content.size", len(str(content))) + + # Mark as successful only after all operations complete successfully + success = True + + # Return content return content - - # Normalize primitive types to ResourceContent - if isinstance(content, bytes): - return ResourceContent(type="resource", uri=original_uri, blob=content) - if isinstance(content, str): - return ResourceContent(type="resource", uri=original_uri, text=content) - - # Fallback to stringified content - return ResourceContent(type="resource", uri=original_uri, text=str(content)) + except Exception as e: + error_message = str(e) + # Set span error status + if span: + span.set_attribute("error", True) + span.set_attribute("error.message", str(e)) + raise + finally: + # Record metric regardless of success or failure, but only if we have a resource + if resource: + await self._record_resource_metric(db, resource, start_time, success, error_message) async def toggle_resource_status(self, db: Session, resource_id: int, activate: bool) -> ResourceRead: """ diff --git a/mcpgateway/services/server_service.py b/mcpgateway/services/server_service.py index 19779d756..0581317a7 100644 --- a/mcpgateway/services/server_service.py +++ b/mcpgateway/services/server_service.py @@ -141,7 +141,7 @@ async def shutdown(self) -> None: logger.info("Server service shutdown complete") # get_top_server - async def get_top_servers(self, db: Session, limit: int = 5) -> List[TopPerformer]: + async def get_top_servers(self, db: Session, limit: Optional[int] = 5) -> List[TopPerformer]: """Retrieve the top-performing servers based on execution count. Queries the database to get servers with their metrics, ordered by the number of executions @@ -150,7 +150,8 @@ async def get_top_servers(self, db: Session, limit: int = 5) -> List[TopPerforme Args: db (Session): Database session for querying server metrics. - limit (int): Maximum number of servers to return. Defaults to 5. + limit (Optional[int]): Maximum number of servers to return. Defaults to 5. + If None, returns all servers. Returns: List[TopPerformer]: A list of TopPerformer objects, each containing: @@ -161,7 +162,7 @@ async def get_top_servers(self, db: Session, limit: int = 5) -> List[TopPerforme - success_rate: Success rate percentage, or None if no metrics. - last_execution: Timestamp of the last execution, or None if no metrics. """ - results = ( + query = ( db.query( DbServer.id, DbServer.name, @@ -179,12 +180,16 @@ async def get_top_servers(self, db: Session, limit: int = 5) -> List[TopPerforme .outerjoin(ServerMetric) .group_by(DbServer.id, DbServer.name) .order_by(desc("execution_count")) - .limit(limit) - .all() ) + + if limit is not None: + query = query.limit(limit) + + results = query.all() return build_top_performers(results) + def _convert_server_to_read(self, server: DbServer) -> ServerRead: """ Converts a DbServer instance into a ServerRead model, including aggregated metrics. diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 4159e9337..7e427c7f8 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -42,13 +42,9 @@ from mcpgateway.db import server_tool_association from mcpgateway.db import Tool as DbTool from mcpgateway.db import ToolMetric -from mcpgateway.models import Gateway as PydanticGateway -from mcpgateway.models import TextContent -from mcpgateway.models import Tool as PydanticTool -from mcpgateway.models import ToolResult +from mcpgateway.models import TextContent, ToolResult from mcpgateway.observability import create_span -from mcpgateway.plugins.framework import GlobalContext, HttpHeaderPayload, PluginError, PluginManager, PluginViolationError, ToolPostInvokePayload, ToolPreInvokePayload -from mcpgateway.plugins.framework.constants import GATEWAY_METADATA, TOOL_METADATA +from mcpgateway.plugins.framework import GlobalContext, PluginManager, PluginViolationError, ToolPostInvokePayload, ToolPreInvokePayload from mcpgateway.schemas import ToolCreate, ToolRead, ToolUpdate, TopPerformer from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.oauth_manager import OAuthManager @@ -225,7 +221,7 @@ async def shutdown(self) -> None: await self._http_client.aclose() logger.info("Tool service shutdown complete") - async def get_top_tools(self, db: Session, limit: int = 5) -> List[TopPerformer]: + async def get_top_tools(self, db: Session, limit: Optional[int] = 5) -> List[TopPerformer]: """Retrieve the top-performing tools based on execution count. Queries the database to get tools with their metrics, ordered by the number of executions @@ -234,7 +230,8 @@ async def get_top_tools(self, db: Session, limit: int = 5) -> List[TopPerformer] Args: db (Session): Database session for querying tool metrics. - limit (int): Maximum number of tools to return. Defaults to 5. + limit (Optional[int]): Maximum number of tools to return. Defaults to 5. + If None, returns all tools. Returns: List[TopPerformer]: A list of TopPerformer objects, each containing: @@ -245,7 +242,8 @@ async def get_top_tools(self, db: Session, limit: int = 5) -> List[TopPerformer] - success_rate: Success rate percentage, or None if no metrics. - last_execution: Timestamp of the last execution, or None if no metrics. """ - results = ( + # Build query to aggregate tool metrics and rank by execution count + query = ( db.query( DbTool.id, DbTool.name, @@ -263,9 +261,12 @@ async def get_top_tools(self, db: Session, limit: int = 5) -> List[TopPerformer] .outerjoin(ToolMetric) .group_by(DbTool.id, DbTool.name) .order_by(desc("execution_count")) - .limit(limit) - .all() ) + + if limit is not None: + query = query.limit(limit) + + results = query.all() return build_top_performers(results) @@ -346,6 +347,13 @@ async def _record_tool_metric(self, db: Session, tool: DbTool, start_time: float ) db.add(metric) db.commit() + # Ensure the in-memory relationship collection is expired so that + # subsequent accesses to tool.metrics / derived properties (execution_count, + # metrics_summary, last_execution_time, etc.) reflect the newly added metric. + try: # pragma: no cover - defensive; expire won't raise in normal conditions + db.expire(tool, ["metrics"]) + except Exception: # noqa: BLE001 + pass async def register_tool( self, @@ -787,7 +795,6 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], r ToolNotFoundError: If tool not found. ToolInvocationError: If invocation fails. PluginViolationError: If plugin blocks tool invocation. - PluginError: If encounters issue with plugin Examples: >>> from mcpgateway.services.tool_service import ToolService @@ -828,6 +835,33 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], r server_id = gateway_id if isinstance(gateway_id, str) else "unknown" global_context = GlobalContext(request_id=request_id, server_id=server_id, tenant_id=None) + if self._plugin_manager: + try: + pre_result, context_table = await self._plugin_manager.tool_pre_invoke(payload=ToolPreInvokePayload(name=name, args=arguments), global_context=global_context, local_contexts=None) + + if not pre_result.continue_processing: + # Plugin blocked the request + if pre_result.violation: + plugin_name = pre_result.violation.plugin_name + violation_reason = pre_result.violation.reason + violation_desc = pre_result.violation.description + violation_code = pre_result.violation.code + raise PluginViolationError(f"Tool invocation blocked by plugin {plugin_name}: {violation_code} - {violation_reason} ({violation_desc})", pre_result.violation) + raise PluginViolationError("Tool invocation blocked by plugin") + + # Use modified payload if provided + if pre_result.modified_payload: + payload = pre_result.modified_payload + name = payload.name + arguments = payload.args + except PluginViolationError: + raise + except Exception as e: + logger.error(f"Error in pre-tool invoke plugin hook: {e}") + # Only fail if configured to do so + if self._plugin_manager.config and self._plugin_manager.config.plugin_settings.fail_on_plugin_error: + raise + start_time = time.monotonic() success = False error_message = None @@ -867,22 +901,6 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], r if request_headers: headers = get_passthrough_headers(request_headers, headers, db) - if self._plugin_manager: - tool_metadata = PydanticTool.model_validate(tool) - global_context.metadata[TOOL_METADATA] = tool_metadata - pre_result, context_table = await self._plugin_manager.tool_pre_invoke( - payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(headers)), - global_context=global_context, - local_contexts=None, - violations_as_exceptions=True, - ) - if pre_result.modified_payload: - payload = pre_result.modified_payload - name = payload.name - arguments = payload.args - if payload.headers is not None: - headers = payload.headers.model_dump() - # Build the payload based on integration type payload = arguments.copy() @@ -1011,25 +1029,6 @@ async def connect_to_streamablehttp_server(server_url: str): tool_gateway_id = tool.gateway_id tool_gateway = db.execute(select(DbGateway).where(DbGateway.id == tool_gateway_id).where(DbGateway.enabled)).scalar_one_or_none() - if self._plugin_manager: - tool_metadata = PydanticTool.model_validate(tool) - global_context.metadata[TOOL_METADATA] = tool_metadata - if tool_gateway: - gateway_metadata = PydanticGateway.model_validate(tool_gateway) - global_context.metadata[GATEWAY_METADATA] = gateway_metadata - pre_result, context_table = await self._plugin_manager.tool_pre_invoke( - payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(headers)), - global_context=global_context, - local_contexts=None, - violations_as_exceptions=True, - ) - if pre_result.modified_payload: - payload = pre_result.modified_payload - name = payload.name - arguments = payload.args - if payload.headers is not None: - headers = payload.headers.model_dump() - tool_call_result = ToolResult(content=[TextContent(text="", type="text")]) if transport == "sse": tool_call_result = await connect_to_sse_server(tool_gateway.url) @@ -1046,25 +1045,39 @@ async def connect_to_streamablehttp_server(server_url: str): # Plugin hook: tool post-invoke if self._plugin_manager: - post_result, _ = await self._plugin_manager.tool_post_invoke( - payload=ToolPostInvokePayload(name=name, result=tool_result.model_dump(by_alias=True)), - global_context=global_context, - local_contexts=context_table, - violations_as_exceptions=True, - ) - # Use modified payload if provided - if post_result.modified_payload: - # Reconstruct ToolResult from modified result - modified_result = post_result.modified_payload.result - if isinstance(modified_result, dict) and "content" in modified_result: - tool_result = ToolResult(content=modified_result["content"]) - else: - # If result is not in expected format, convert it to text content - tool_result = ToolResult(content=[TextContent(type="text", text=str(modified_result))]) + try: + post_result, _ = await self._plugin_manager.tool_post_invoke( + payload=ToolPostInvokePayload(name=name, result=tool_result.model_dump(by_alias=True)), global_context=global_context, local_contexts=context_table + ) + if not post_result.continue_processing: + # Plugin blocked the request + if post_result.violation: + plugin_name = post_result.violation.plugin_name + violation_reason = post_result.violation.reason + violation_desc = post_result.violation.description + violation_code = post_result.violation.code + raise PluginViolationError(f"Tool result blocked by plugin {plugin_name}: {violation_code} - {violation_reason} ({violation_desc})", post_result.violation) + raise PluginViolationError("Tool result blocked by plugin") + + # Use modified payload if provided + if post_result.modified_payload: + # Reconstruct ToolResult from modified result + modified_result = post_result.modified_payload.result + if isinstance(modified_result, dict) and "content" in modified_result: + tool_result = ToolResult(content=modified_result["content"]) + else: + # If result is not in expected format, convert it to text content + tool_result = ToolResult(content=[TextContent(type="text", text=str(modified_result))]) + + except PluginViolationError: + raise + except Exception as e: + logger.error(f"Error in post-tool invoke plugin hook: {e}") + # Only fail if configured to do so + if self._plugin_manager.config and self._plugin_manager.config.plugin_settings.fail_on_plugin_error: + raise return tool_result - except (PluginError, PluginViolationError): - raise except Exception as e: error_message = str(e) # Set span error status @@ -1585,6 +1598,10 @@ async def _invoke_a2a_tool(self, db: Session, tool: DbTool, arguments: Dict[str, ) db.add(metric) db.commit() + try: # Ensure subsequent accesses see fresh metrics + db.expire(tool, ["metrics"]) + except Exception: # noqa: BLE001 + pass return result diff --git a/mcpgateway/static/admin.js b/mcpgateway/static/admin.js index 2877667b7..19dd00db0 100644 --- a/mcpgateway/static/admin.js +++ b/mcpgateway/static/admin.js @@ -865,6 +865,26 @@ function retryLoadMetrics() { // Make retry function available globally immediately window.retryLoadMetrics = retryLoadMetrics; +// --------------------------------------------------------------- +// Auto-refresh aggregated metrics every 10 seconds (when visible) +// --------------------------------------------------------------- +let metricsAutoRefreshTimer = null; +function startMetricsAutoRefresh() { + if (metricsAutoRefreshTimer) return; + metricsAutoRefreshTimer = setInterval(() => { + const panel = safeGetElement("metrics-panel"); + if (!panel || panel.closest(".tab-panel.hidden")) return; // only refresh if visible + loadAggregatedMetrics(); + }, 10000); +} +function stopMetricsAutoRefresh() { + if (metricsAutoRefreshTimer) { + clearInterval(metricsAutoRefreshTimer); + metricsAutoRefreshTimer = null; + } +} +startMetricsAutoRefresh(); + function showMetricsPlaceholder() { const metricsPanel = safeGetElement("metrics-panel"); if (metricsPanel) { @@ -889,6 +909,28 @@ function displayMetrics(data) { } try { + // Normalize snake_case metrics keys to camelCase for consistent downstream processing + const normalizeCategory = (obj) => { + if (!obj || typeof obj !== "object") return obj; + const out = { ...obj }; + const map = [ + ["total_executions", "totalExecutions"], + ["successful_executions", "successfulExecutions"], + ["failed_executions", "failedExecutions"], + ["failure_rate", "failureRate"], + ["avg_response_time", "avgResponseTime"], + ["min_response_time", "minResponseTime"], + ["max_response_time", "maxResponseTime"], + ["last_execution_time", "lastExecutionTime"], + ]; + map.forEach(([snake, camel]) => { + if (out[camel] === undefined && out[snake] !== undefined) out[camel] = out[snake]; + }); + return out; + }; + ["tools", "resources", "prompts", "servers", "gateways", "a2a_agents"].forEach((k) => { + if (data[k]) data[k] = normalizeCategory(data[k]); + }); // FIX: Handle completely empty data if (!data || Object.keys(data).length === 0) { const emptyStateDiv = document.createElement("div"); @@ -1185,29 +1227,32 @@ function extractKPIData(data) { let totalFailed = 0; const responseTimes = []; - // Process each category safely - const categories = [ - "tools", - "resources", - "prompts", - "gateways", - "servers", - ]; + // Helper to safely resolve camelCase or snake_case keys. + const metricVal = (obj, camel) => { + if (!obj) return undefined; + if (camel in obj) return obj[camel]; + // convert camelCase to snake_case + const snake = camel + .replace(/([A-Z])/g, "_$1") + .replace(/__/g, "_") + .toLowerCase(); + return obj[snake]; + }; + + // Process each category safely (added a2a_agents; gateways kept for future parity) + const categories = ["tools", "resources", "prompts", "servers", "gateways", "a2a_agents"]; categories.forEach((category) => { - if (data[category]) { - const categoryData = data[category]; - totalExecutions += Number(categoryData.totalExecutions || 0); - totalSuccessful += Number( - categoryData.successfulExecutions || 0, - ); - totalFailed += Number(categoryData.failedExecutions || 0); + const categoryData = data[category]; + if (!categoryData) return; - if ( - categoryData.avgResponseTime && - categoryData.avgResponseTime !== "N/A" - ) { - responseTimes.push(Number(categoryData.avgResponseTime)); - } + totalExecutions += Number(metricVal(categoryData, "totalExecutions") || 0); + totalSuccessful += Number(metricVal(categoryData, "successfulExecutions") || 0); + totalFailed += Number(metricVal(categoryData, "failedExecutions") || 0); + + const avgRt = metricVal(categoryData, "avgResponseTime"); + if (avgRt !== undefined && avgRt !== null && avgRt !== "N/A") { + const n = Number(avgRt); + if (!Number.isNaN(n)) responseTimes.push(n); } }); @@ -9024,8 +9069,8 @@ function generateConfig(server, configType) { command: "python", args: ["-m", "mcpgateway.wrapper"], env: { - MCP_AUTH: "Bearer ", - MCP_SERVER_URL: `${baseUrl}/servers/${server.id}`, + MCP_AUTH_TOKEN: "your-token-here", + MCP_SERVER_CATALOG_URLS: `${baseUrl}/servers/${server.id}`, MCP_TOOL_CALL_TIMEOUT: "120", }, }, diff --git a/mcpgateway/utils/metrics_common.py b/mcpgateway/utils/metrics_common.py index 104eb8094..6e08f4102 100644 --- a/mcpgateway/utils/metrics_common.py +++ b/mcpgateway/utils/metrics_common.py @@ -8,12 +8,75 @@ """ # Standard -from typing import List +from typing import List, Optional, Union # First-Party from mcpgateway.schemas import TopPerformer +def calculate_success_rate(successful: Union[int, float], total: Union[int, float]) -> Optional[float]: + """ + Calculate success rate as a percentage. + + This function handles division by zero and ensures the result is always a valid + percentage or None if the calculation is not possible. + + Args: + successful: Number of successful operations + total: Total number of operations + + Returns: + Optional[float]: Success rate as a percentage (0-100) or None if total is zero + + Examples: + >>> calculate_success_rate(75, 100) + 75.0 + >>> calculate_success_rate(0, 0) + None + >>> calculate_success_rate(0, 10) + 0.0 + >>> calculate_success_rate(5, 0) + None + """ + if total is None or successful is None: + return None + + try: + total_float = float(total) + if total_float <= 0: + return None + return (float(successful) / total_float) * 100.0 + except (ValueError, TypeError, ZeroDivisionError): + return None + + +def format_response_time(response_time: Optional[float]) -> Optional[str]: + """ + Format response time to display with 3 decimal places. + + Args: + response_time: Response time in seconds + + Returns: + Optional[str]: Formatted response time with 3 decimal places or None + + Examples: + >>> format_response_time(1.2345) + '1.235' + >>> format_response_time(None) + None + >>> format_response_time(0) + '0.000' + """ + if response_time is None: + return None + + try: + return f"{float(response_time):.3f}" + except (ValueError, TypeError): + return None + + def build_top_performers(results: List) -> List[TopPerformer]: """ Convert database query results to TopPerformer objects. diff --git a/mcpgateway/utils/verify_credentials.py b/mcpgateway/utils/verify_credentials.py index 7a33f4c26..b7a4d2e21 100644 --- a/mcpgateway/utils/verify_credentials.py +++ b/mcpgateway/utils/verify_credentials.py @@ -15,7 +15,6 @@ ... jwt_algorithm = 'HS256' ... jwt_audience = 'mcpgateway-api' ... jwt_issuer = 'mcpgateway' - ... jwt_audience_verification = True ... basic_auth_user = 'user' ... basic_auth_password = 'pass' ... auth_required = True @@ -93,18 +92,20 @@ async def verify_jwt_token(token: str) -> dict: unverified = jwt.decode(token, options={"verify_signature": False}) # Check for expiration claim + if "exp" not in unverified and settings.require_token_expiration: + raise jwt.MissingRequiredClaimError("exp") + + # Log warning for non-expiring tokens if "exp" not in unverified: logger.warning(f"JWT token without expiration accepted. Consider enabling REQUIRE_TOKEN_EXPIRATION for better security. Token sub: {unverified.get('sub', 'unknown')}") - if settings.require_token_expiration: - raise jwt.MissingRequiredClaimError("exp") + # Full validation options = {} - if settings.require_token_expiration: options["require"] = ["exp"] + options["verify_aud"] = settings.jwt_audience_verification - options["verify_aud"] = settings.jwt_audience_verification - + # Use configured audience and issuer for validation (security fix) decode_kwargs = { "key": get_jwt_public_key_or_secret(), "algorithms": [settings.jwt_algorithm], @@ -158,7 +159,6 @@ async def verify_credentials(token: str) -> dict: ... jwt_algorithm = 'HS256' ... jwt_audience = 'mcpgateway-api' ... jwt_issuer = 'mcpgateway' - ... jwt_audience_verification = True ... basic_auth_user = 'user' ... basic_auth_password = 'pass' ... auth_required = True @@ -207,7 +207,6 @@ async def require_auth(request: Request, credentials: Optional[HTTPAuthorization ... jwt_algorithm = 'HS256' ... jwt_audience = 'mcpgateway-api' ... jwt_issuer = 'mcpgateway' - ... jwt_audience_verification = True ... basic_auth_user = 'user' ... basic_auth_password = 'pass' ... auth_required = True @@ -311,7 +310,6 @@ async def verify_basic_credentials(credentials: HTTPBasicCredentials) -> str: ... jwt_algorithm = 'HS256' ... jwt_audience = 'mcpgateway-api' ... jwt_issuer = 'mcpgateway' - ... jwt_audience_verification = True ... basic_auth_user = 'user' ... basic_auth_password = 'pass' ... auth_required = True @@ -364,7 +362,6 @@ async def require_basic_auth(credentials: HTTPBasicCredentials = Depends(basic_s ... jwt_algorithm = 'HS256' ... jwt_audience = 'mcpgateway-api' ... jwt_issuer = 'mcpgateway' - ... jwt_audience_verification = True ... basic_auth_user = 'user' ... basic_auth_password = 'pass' ... auth_required = True @@ -425,7 +422,6 @@ async def require_docs_basic_auth(auth_header: str) -> str: ... jwt_algorithm = 'HS256' ... jwt_audience = 'mcpgateway-api' ... jwt_issuer = 'mcpgateway' - ... jwt_audience_verification = True ... basic_auth_user = 'user' ... basic_auth_password = 'pass' ... auth_required = True @@ -558,7 +554,6 @@ async def require_docs_auth_override( ... jwt_algorithm = 'HS256' ... jwt_audience = 'mcpgateway-api' ... jwt_issuer = 'mcpgateway' - ... jwt_audience_verification = True ... docs_allow_basic_auth = False ... require_token_expiration = False >>> vc.settings = DummySettings() @@ -638,7 +633,6 @@ async def require_auth_override( ... jwt_algorithm = 'HS256' ... jwt_audience = 'mcpgateway-api' ... jwt_issuer = 'mcpgateway' - ... jwt_audience_verification = True ... basic_auth_user = 'user' ... basic_auth_password = 'pass' ... auth_required = True diff --git a/plugins/external/opa/opapluginfilter/plugin.py b/plugins/external/opa/opapluginfilter/plugin.py index 004d67155..c77e36485 100644 --- a/plugins/external/opa/opapluginfilter/plugin.py +++ b/plugins/external/opa/opapluginfilter/plugin.py @@ -153,7 +153,7 @@ async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginCo if not decision: violation = PluginViolation( reason="tool invocation not allowed", - description="OPA policy denied for tool preinvocation", + description="OPA policy failed on tool preinvocation", code="deny", details=decision_context,) return ToolPreInvokeResult(modified_payload=payload, violation=violation, continue_processing=False) diff --git a/plugins/external/opa/tests/server/opa_server.py b/plugins/external/opa/tests/server/opa_server.py new file mode 100644 index 000000000..5f969a321 --- /dev/null +++ b/plugins/external/opa/tests/server/opa_server.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +"""Test cases for OPA plugin + +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Shriti Priya + +This module mocks up an opa server for testing. +""" + + +# Standard +import json +import threading + +# Third-Party +from http.server import BaseHTTPRequestHandler, HTTPServer + + +# This class mocks up the post request for OPA server to evaluate policies. +class MockOPAHandler(BaseHTTPRequestHandler): + def do_POST(self): + if self.path == "/v1/data/example/allow": + content_length = int(self.headers.get('Content-Length', 0)) + post_body = self.rfile.read(content_length).decode('utf-8') + try: + data = json.loads(post_body) + if "IBM" in data["input"]["payload"]["args"]["repo_path"]: + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(b'{"result": true}') + else: + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(b'{"result": false}') + # Process data dictionary... + except json.JSONDecodeError: + # Handle invalid JSON + self.send_response(400) + self.end_headers() + self.wfile.write(b"Invalid JSON") + return + +# This creates a mock up server for OPA at port 8181 +def run_mock_opa(): + server = HTTPServer(('localhost', 8181), MockOPAHandler) + threading.Thread(target=server.serve_forever, daemon=True).start() + return server diff --git a/tests/integration/test_metrics_export.py b/tests/integration/test_metrics_export.py new file mode 100644 index 000000000..30687b02c --- /dev/null +++ b/tests/integration/test_metrics_export.py @@ -0,0 +1,142 @@ +# -*- coding: utf-8 -*- +""" +Integration tests for metrics export endpoints. +This test ensures the /admin/metrics/export endpoint produces valid CSV output + +""" + +# Standard +import csv +from io import StringIO +from typing import Dict, List, Any + +# Third-party +import pytest +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session + +# First-party +from mcpgateway.db import get_db, Tool as DbTool, ToolMetric +from mcpgateway.main import app +from mcpgateway.utils.metrics_common import format_response_time + +# Tests +from tests.conftest import MockAuthMiddleware + + +@pytest.fixture +def mock_db_with_metrics(db_session): + """Create a database session with tool metrics for testing.""" + # Create test tools + tools = [] + for i in range(10): # Create 10 tools to ensure we test beyond the default limit of 5 + tool = DbTool( + name=f"test_tool_{i}", + url=f"http://example.com/tool_{i}", + integration_type="REST", + enabled=True + ) + db_session.add(tool) + tools.append(tool) + db_session.commit() + + # Add metrics for each tool + import datetime + from datetime import timedelta + + now = datetime.datetime.now(datetime.timezone.utc) + + for i, tool in enumerate(tools): + # Add successful metrics + successful_count = i + 1 + # Add different numbers of metrics to each tool to test sorting + for j in range(successful_count): + metric = ToolMetric( + tool_id=tool.id, + is_success=True, + response_time=1.0 + (j * 0.1), # Different response times + timestamp=now - timedelta(minutes=j) + ) + db_session.add(metric) + + # Add failed metrics for some tools to test success rate + if i % 2 == 0: # Even numbered tools have some failures + failed_count = i // 2 + for j in range(failed_count): + metric = ToolMetric( + tool_id=tool.id, + is_success=False, + response_time=2.0 + (j * 0.1), # Different response times + timestamp=now - timedelta(minutes=j + successful_count) + ) + db_session.add(metric) + + db_session.commit() + return db_session + + +@pytest.mark.asyncio +async def test_export_metrics_csv(mock_db_with_metrics): + """Test exporting metrics to CSV format.""" + # Override the get_db dependency + app.dependency_overrides[get_db] = lambda: mock_db_with_metrics + + # Apply auth middleware for testing + app.middleware_stack = MockAuthMiddleware(app) + + # Create test client + client = TestClient(app) + + # Test export for tools + response = client.get("/admin/metrics/export?entity_type=tools") + + # Check response status and headers + assert response.status_code == 200 + assert response.headers["Content-Type"] == "text/csv" + assert "attachment; filename=tools_metrics.csv" in response.headers["Content-Disposition"] + + # Parse CSV content + content = response.content.decode("utf-8") + reader = csv.reader(StringIO(content)) + rows = list(reader) + + # Check headers + headers = rows[0] + assert headers == ["ID", "Name", "Execution Count", "Average Response Time (s)", "Success Rate (%)", "Last Execution"] + + # Check data rows + data_rows = rows[1:] + + # Should export all rows, not just top 5 + assert len(data_rows) > 5 + + # Check first row (should be the tool with highest execution count) + tool_data = data_rows[0] + assert tool_data[1].startswith("test_tool_") # Name + assert int(tool_data[2]) > 0 # Execution Count + + # Verify response time format (x.xxx) + for row in data_rows: + if row[3] != "N/A": + assert len(row[3].split(".")[-1]) == 3 # 3 decimal places + + # Test with explicit limit + limited_response = client.get("/admin/metrics/export?entity_type=tools&limit=3") + limited_content = limited_response.content.decode("utf-8") + limited_reader = csv.reader(StringIO(limited_content)) + limited_rows = list(limited_reader) + assert len(limited_rows) == 4 # header + 3 data rows + + # Test with no data + # First clear the metrics + mock_db_with_metrics.query(ToolMetric).delete() + mock_db_with_metrics.commit() + + empty_response = client.get("/admin/metrics/export?entity_type=tools") + empty_content = empty_response.content.decode("utf-8") + empty_reader = csv.reader(StringIO(empty_content)) + empty_rows = list(empty_reader) + assert len(empty_rows) == 1 # just header + + # Clean up + app.dependency_overrides.clear() diff --git a/tests/integration/test_resource_plugin_integration.py b/tests/integration/test_resource_plugin_integration.py index da34737cd..12ac033d6 100644 --- a/tests/integration/test_resource_plugin_integration.py +++ b/tests/integration/test_resource_plugin_integration.py @@ -154,7 +154,7 @@ async def initialize(self): def initialized(self) -> bool: return self._initialized - async def resource_pre_fetch(self, payload, global_context, violations_as_exceptions): + async def resource_pre_fetch(self, payload, global_context): # Allow test:// protocol if payload.uri.startswith("test://"): return ( @@ -167,17 +167,21 @@ async def resource_pre_fetch(self, payload, global_context, violations_as_except else: # First-Party from mcpgateway.plugins.framework.models import PluginViolation - raise PluginViolationError( - message="Protocol not allowed", + + return ( + ResourcePreFetchResult( + continue_processing=False, violation=PluginViolation( reason="Protocol not allowed", description="Protocol is not in the allowed list", code="PROTOCOL_BLOCKED", details={"protocol": payload.uri.split(":")[0], "uri": payload.uri} ), + ), + None, ) - async def resource_post_fetch(self, payload, global_context, contexts, violations_as_exceptions): + async def resource_post_fetch(self, payload, global_context, contexts): # Filter sensitive content if payload.content and payload.content.text: filtered_text = payload.content.text.replace( @@ -226,7 +230,7 @@ async def resource_post_fetch(self, payload, global_context, contexts, violation # Try to read a blocked protocol # First-Party - from mcpgateway.plugins.framework import PluginViolationError + from mcpgateway.services.resource_service import ResourceError blocked_resource = ResourceCreate( uri="file:///etc/passwd", @@ -236,7 +240,7 @@ async def resource_post_fetch(self, payload, global_context, contexts, violation ) await service.register_resource(test_db, blocked_resource) - with pytest.raises(PluginViolationError) as exc_info: + with pytest.raises(ResourceError) as exc_info: await service.read_resource(test_db, "file:///etc/passwd") assert "Protocol not allowed" in str(exc_info.value) @@ -248,17 +252,17 @@ async def test_plugin_context_flow(self, test_db, resource_service_with_mock_plu # Track context flow contexts_from_pre = {"plugin_data": "test_value", "validated": True} - async def pre_fetch_side_effect(payload, global_context, violations_as_exceptions): + def pre_fetch_side_effect(payload, global_context): # Verify global context assert global_context.request_id == "integration-test-123" assert global_context.user == "integration-user" assert global_context.server_id == "server-123" return ( - MagicMock(continue_processing=True, modified_payload=None), + MagicMock(continue_processing=True), contexts_from_pre, ) - async def post_fetch_side_effect(payload, global_context, contexts, violations_as_exceptions): + def post_fetch_side_effect(payload, global_context, contexts): # Verify contexts from pre-fetch assert contexts == contexts_from_pre assert contexts["plugin_data"] == "test_value" diff --git a/tests/setup_test_data.py b/tests/setup_test_data.py new file mode 100644 index 000000000..4ab195f8c --- /dev/null +++ b/tests/setup_test_data.py @@ -0,0 +1,142 @@ +""" +Simple test script for testing metrics functionality (issue #699) +""" +import sqlite3 +import datetime +import uuid +import sys + +# Database path +db_path = "mcp-context-forge/mcp.db" + +def create_test_data(): + """Create test data for metrics functionality""" + print("Creating test data in the database...") + + # Connect to SQLite database + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + try: + # 1. Create test tools + print("Adding test tools...") + current_time = datetime.datetime.now() + time_suffix = current_time.strftime("%Y%m%d%H%M%S") + test_tools = [ + (str(uuid.uuid4()), f"Test Tool 1 - {time_suffix}", f"Test_Tool_1_{time_suffix}", "http://example.com/tool1", "REST", "GET", + "{}", "{}", "{}", datetime.datetime.now().isoformat(), datetime.datetime.now().isoformat(), + 1, 1, "$", "{}", "test", "127.0.0.1", "test-script", "test-agent", + "test", "127.0.0.1", "test-script", "test-agent", None, None, 1, + None, None, f"Test Tool 1 - {time_suffix}", f"test-tool-1-{time_suffix}", None, f"Test Tool 1 - {time_suffix}"), + + (str(uuid.uuid4()), f"Test Tool 2 - {time_suffix}", f"Test_Tool_2_{time_suffix}", "http://example.com/tool2", "REST", "GET", + "{}", "{}", "{}", datetime.datetime.now().isoformat(), datetime.datetime.now().isoformat(), + 1, 1, "$", "{}", "test", "127.0.0.1", "test-script", "test-agent", + "test", "127.0.0.1", "test-script", "test-agent", None, None, 1, + None, None, f"Test Tool 2 - {time_suffix}", f"test-tool-2-{time_suffix}", None, f"Test Tool 2 - {time_suffix}"), + + (str(uuid.uuid4()), f"Test Tool 3 - {time_suffix}", f"Test_Tool_3_{time_suffix}", "http://example.com/tool3", "REST", "GET", + "{}", "{}", "{}", datetime.datetime.now().isoformat(), datetime.datetime.now().isoformat(), + 1, 1, "$", "{}", "test", "127.0.0.1", "test-script", "test-agent", + "test", "127.0.0.1", "test-script", "test-agent", None, None, 1, + None, None, f"Test Tool 3 - {time_suffix}", f"test-tool-3-{time_suffix}", None, f"Test Tool 3 - {time_suffix}") + ] + + cursor.executemany( + """INSERT OR REPLACE INTO tools ( + id, original_name, custom_name, url, integration_type, request_type, + headers, input_schema, annotations, created_at, updated_at, + enabled, reachable, jsonpath_filter, tags, created_by, created_from_ip, + created_via, created_user_agent, modified_by, modified_from_ip, + modified_via, modified_user_agent, import_batch_id, federation_source, + version, auth_type, auth_value, custom_name, custom_name_slug, + gateway_id, name + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + test_tools + ) + + # Store the tool IDs for reference + tool_ids = [tool[0] for tool in test_tools] + print(f"Added tools with IDs: {tool_ids}") + + # 2. Clean up any existing metrics + print("Cleaning up existing metrics...") + cursor.execute("DELETE FROM tool_metrics WHERE tool_id IN (?, ?, ?)", (tool_ids[0], tool_ids[1], tool_ids[2])) + + # 3. Add tool metrics with different patterns: + # - Tool 1: 80% success rate (8 success, 2 failure) + # - Tool 2: 100% success rate (5 success, 0 failure) + # - Tool 3: 0% success rate (0 success, 3 failure) + + now = datetime.datetime.now() + + # Tool 1: 80% success rate + print(f"Adding metrics for Tool 1: 80% success rate... (ID: {tool_ids[0]})") + for i in range(8): # 8 successful calls + cursor.execute( + "INSERT INTO tool_metrics (tool_id, timestamp, response_time, is_success, error_message) VALUES (?, ?, ?, 1, NULL)", + (tool_ids[0], (now - datetime.timedelta(minutes=i)).isoformat(), 1.23456) + ) + + for i in range(2): # 2 failed calls + cursor.execute( + "INSERT INTO tool_metrics (tool_id, timestamp, response_time, is_success, error_message) VALUES (?, ?, ?, 0, ?)", + (tool_ids[0], (now - datetime.timedelta(minutes=i+8)).isoformat(), 2.34567, "Test error") + ) + + # Tool 2: 100% success rate + print(f"Adding metrics for Tool 2: 100% success rate... (ID: {tool_ids[1]})") + for i in range(5): # 5 successful calls + cursor.execute( + "INSERT INTO tool_metrics (tool_id, timestamp, response_time, is_success, error_message) VALUES (?, ?, ?, 1, NULL)", + (tool_ids[1], (now - datetime.timedelta(minutes=i)).isoformat(), 0.9876) + ) + + # Tool 3: 0% success rate + print(f"Adding metrics for Tool 3: 0% success rate... (ID: {tool_ids[2]})") + for i in range(3): # 3 failed calls + cursor.execute( + "INSERT INTO tool_metrics (tool_id, timestamp, response_time, is_success, error_message) VALUES (?, ?, ?, 0, ?)", + (tool_ids[2], (now - datetime.timedelta(minutes=i)).isoformat(), 3.45678, "Test error") + ) + + # Commit the changes + conn.commit() + print("Test data created successfully!") + + # Save the tool IDs for reference during testing + with open("test_tool_ids.txt", "w") as f: + for i, tool_id in enumerate(tool_ids): + f.write(f"Tool {i+1} ID: {tool_id}\n") + print("Tool IDs saved to test_tool_ids.txt") + + except Exception as e: + print(f"Error creating test data: {str(e)}") + conn.rollback() + finally: + conn.close() + +if __name__ == "__main__": + create_test_data() + print("\nTest data added to the database.") + print("\nManual Testing Instructions:") + print("1. Make sure the MCP Gateway is running with admin features enabled:") + print(" - $env:MCPGATEWAY_ADMIN_API_ENABLED=\"true\"") + print(" - $env:MCPGATEWAY_UI_ENABLED=\"true\"") + print(" - python -m uvicorn mcpgateway.main:app --host 0.0.0.0 --port 8008 --reload") + print("\n2. Access the admin UI at: http://localhost:8008/admin") + print(" - Login with: admin / changeme") + print("\n3. Navigate to the Metrics tab and verify:") + print(" - Test Tool 1 shows 80% success rate with 10 executions") + print(" - Test Tool 2 shows 100% success rate with 5 executions") + print(" - Test Tool 3 shows 0% success rate with 3 executions") + print(" - Response times are formatted with 3 decimal places (x.xxx)") + print("\n4. Test CSV export:") + print(" - Click Export Metrics button and verify all rows are included") + print(" - Or access: http://localhost:8008/admin/metrics/export?entity_type=tools") + print(" - Verify the CSV includes ALL rows, not just top 5") + print(" - Verify response times have 3 decimal places") + print("\n5. Test empty state:") + print(" - Delete all metrics from the database for a specific tool") + print(" - Verify the UI and export handle empty state gracefully") + print("\nDone!") diff --git a/tests/test_metrics_functions.py b/tests/test_metrics_functions.py new file mode 100644 index 000000000..f218f7eb1 --- /dev/null +++ b/tests/test_metrics_functions.py @@ -0,0 +1,67 @@ +""" +Simple test script to directly test the metrics calculation functions in issue #699. +""" + +import sys +import os + +# Add the mcp-context-forge directory to the path so we can import from it +sys.path.append(os.path.join(os.path.dirname(__file__), "mcp-context-forge")) +try: + from mcpgateway.utils.metrics_common import calculate_success_rate, format_response_time +except ImportError: + print("❌ Could not import metrics functions. Make sure you're in the correct directory.") + sys.exit(1) + +def test_calculate_success_rate(): + """Test the calculate_success_rate function.""" + print("\n--- Testing calculate_success_rate function ---") + + test_cases = [ + # (successes, total, expected_result) + (8, 10, 80.0), # 80% success rate + (5, 5, 100.0), # 100% success rate + (0, 3, 0.0), # 0% success rate + (0, 0, None), # No data (should return None) + (None, None, None), # None inputs (should return None) + ] + + for i, (successes, total, expected) in enumerate(test_cases): + result = calculate_success_rate(successes, total) + if result == expected: + print(f"✅ Test case {i+1}: calculate_success_rate({successes}, {total}) = {result} (Expected: {expected})") + else: + print(f"❌ Test case {i+1}: calculate_success_rate({successes}, {total}) = {result} (Expected: {expected})") + +def test_format_response_time(): + """Test the format_response_time function.""" + print("\n--- Testing format_response_time function ---") + + test_cases = [ + # (response_time, expected_result) + (1.23456, "1.235"), # Standard case, rounds to 3 decimal places + (0.12, "0.120"), # Adds trailing zeros if needed + (1, "1.000"), # Integer input + (None, None), # None input - returns None (admin.py converts to "N/A") + (0, "0.000"), # Zero input + ] + + for i, (response_time, expected) in enumerate(test_cases): + result = format_response_time(response_time) + if result == expected: + print(f"✅ Test case {i+1}: format_response_time({response_time}) = '{result}' (Expected: '{expected}')") + else: + print(f"❌ Test case {i+1}: format_response_time({response_time}) = '{result}' (Expected: '{expected}')") + +def main(): + """Run all tests.""" + print("Starting tests for issue #699 metrics calculation functions...") + + # Test core metrics functions + test_calculate_success_rate() + test_format_response_time() + + print("\nAll tests completed!") + +if __name__ == "__main__": + main() diff --git a/tests/test_setup_data.py b/tests/test_setup_data.py new file mode 100644 index 000000000..120d05d8a --- /dev/null +++ b/tests/test_setup_data.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- +""" +Simple test script for testing metrics functionality (issue #699) +""" +import sqlite3 +import datetime +import uuid +import sys + +# Database path +db_path = "mcp-context-forge/mcp.db" + +def create_test_data(): + """Create test data for metrics functionality""" + print("Creating test data in the database...") + + # Connect to SQLite database + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + try: + # 1. Create test tools + print("Adding test tools...") + current_time = datetime.datetime.now() + time_suffix = current_time.strftime("%Y%m%d%H%M%S") + test_tools = [ + (str(uuid.uuid4()), f"Test Tool 1 - {time_suffix}", f"Test_Tool_1_{time_suffix}", "http://example.com/tool1", "REST", "GET", + "{}", "{}", "{}", datetime.datetime.now().isoformat(), datetime.datetime.now().isoformat(), + 1, 1, "$", "{}", "test", "127.0.0.1", "test-script", "test-agent", + "test", "127.0.0.1", "test-script", "test-agent", None, None, 1, + None, None, f"Test Tool 1 - {time_suffix}", f"test-tool-1-{time_suffix}", None, f"Test Tool 1 - {time_suffix}"), + + (str(uuid.uuid4()), f"Test Tool 2 - {time_suffix}", f"Test_Tool_2_{time_suffix}", "http://example.com/tool2", "REST", "GET", + "{}", "{}", "{}", datetime.datetime.now().isoformat(), datetime.datetime.now().isoformat(), + 1, 1, "$", "{}", "test", "127.0.0.1", "test-script", "test-agent", + "test", "127.0.0.1", "test-script", "test-agent", None, None, 1, + None, None, f"Test Tool 2 - {time_suffix}", f"test-tool-2-{time_suffix}", None, f"Test Tool 2 - {time_suffix}"), + + (str(uuid.uuid4()), f"Test Tool 3 - {time_suffix}", f"Test_Tool_3_{time_suffix}", "http://example.com/tool3", "REST", "GET", + "{}", "{}", "{}", datetime.datetime.now().isoformat(), datetime.datetime.now().isoformat(), + 1, 1, "$", "{}", "test", "127.0.0.1", "test-script", "test-agent", + "test", "127.0.0.1", "test-script", "test-agent", None, None, 1, + None, None, f"Test Tool 3 - {time_suffix}", f"test-tool-3-{time_suffix}", None, f"Test Tool 3 - {time_suffix}") + ] + + cursor.executemany( + """INSERT OR REPLACE INTO tools ( + id, original_name, custom_name, url, integration_type, request_type, + headers, input_schema, annotations, created_at, updated_at, + enabled, reachable, jsonpath_filter, tags, created_by, created_from_ip, + created_via, created_user_agent, modified_by, modified_from_ip, + modified_via, modified_user_agent, import_batch_id, federation_source, + version, auth_type, auth_value, custom_name, custom_name_slug, + gateway_id, name + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + test_tools + ) + + # Store the tool IDs for reference + tool_ids = [tool[0] for tool in test_tools] + print(f"Added tools with IDs: {tool_ids}") + + # 2. Clean up any existing metrics + print("Cleaning up existing metrics...") + cursor.execute("DELETE FROM tool_metrics WHERE tool_id IN (?, ?, ?)", (tool_ids[0], tool_ids[1], tool_ids[2])) + + # 3. Add tool metrics with different patterns: + # - Tool 1: 80% success rate (8 success, 2 failure) + # - Tool 2: 100% success rate (5 success, 0 failure) + # - Tool 3: 0% success rate (0 success, 3 failure) + + now = datetime.datetime.now() + + # Tool 1: 80% success rate + print(f"Adding metrics for Tool 1: 80% success rate... (ID: {tool_ids[0]})") + for i in range(8): # 8 successful calls + cursor.execute( + "INSERT INTO tool_metrics (tool_id, timestamp, response_time, is_success, error_message) VALUES (?, ?, ?, 1, NULL)", + (tool_ids[0], (now - datetime.timedelta(minutes=i)).isoformat(), 1.23456) + ) + + for i in range(2): # 2 failed calls + cursor.execute( + "INSERT INTO tool_metrics (tool_id, timestamp, response_time, is_success, error_message) VALUES (?, ?, ?, 0, ?)", + (tool_ids[0], (now - datetime.timedelta(minutes=i+8)).isoformat(), 2.34567, "Test error") + ) + + # Tool 2: 100% success rate + print(f"Adding metrics for Tool 2: 100% success rate... (ID: {tool_ids[1]})") + for i in range(5): # 5 successful calls + cursor.execute( + "INSERT INTO tool_metrics (tool_id, timestamp, response_time, is_success, error_message) VALUES (?, ?, ?, 1, NULL)", + (tool_ids[1], (now - datetime.timedelta(minutes=i)).isoformat(), 0.9876) + ) + + # Tool 3: 0% success rate + print(f"Adding metrics for Tool 3: 0% success rate... (ID: {tool_ids[2]})") + for i in range(3): # 3 failed calls + cursor.execute( + "INSERT INTO tool_metrics (tool_id, timestamp, response_time, is_success, error_message) VALUES (?, ?, ?, 0, ?)", + (tool_ids[2], (now - datetime.timedelta(minutes=i)).isoformat(), 3.45678, "Test error") + ) + + # Commit the changes + conn.commit() + print("Test data created successfully!") + + # Save the tool IDs for reference during testing + with open("test_tool_ids.txt", "w") as f: + for i, tool_id in enumerate(tool_ids): + f.write(f"Tool {i+1} ID: {tool_id}\n") + print("Tool IDs saved to test_tool_ids.txt") + + except Exception as e: + print(f"Error creating test data: {str(e)}") + conn.rollback() + finally: + conn.close() + +if __name__ == "__main__": + create_test_data() + print("\nTest data added to the database.") + print("\nManual Testing Instructions:") + print("1. Make sure the MCP Gateway is running with admin features enabled:") + print(" - $env:MCPGATEWAY_ADMIN_API_ENABLED=\"true\"") + print(" - $env:MCPGATEWAY_UI_ENABLED=\"true\"") + print(" - python -m uvicorn mcpgateway.main:app --host 0.0.0.0 --port 8008 --reload") + print("\n2. Access the admin UI at: http://localhost:8008/admin") + print(" - Login with: admin / changeme") + print("\n3. Navigate to the Metrics tab and verify:") + print(" - Test Tool 1 shows 80% success rate with 10 executions") + print(" - Test Tool 2 shows 100% success rate with 5 executions") + print(" - Test Tool 3 shows 0% success rate with 3 executions") + print(" - Response times are formatted with 3 decimal places (x.xxx)") + print("\n4. Test CSV export:") + print(" - Click Export Metrics button and verify all rows are included") + print(" - Or access: http://localhost:8008/admin/metrics/export?entity_type=tools") + print(" - Verify the CSV includes ALL rows, not just top 5") + print(" - Verify response times have 3 decimal places") + print("\n5. Test empty state:") + print(" - Delete all metrics from the database for a specific tool") + print(" - Verify the UI and export handle empty state gracefully") + print("\nDone!") diff --git a/tests/unit/mcpgateway/middleware/test_token_scoping.py b/tests/unit/mcpgateway/middleware/test_token_scoping.py index 8781916a1..f7d1da632 100644 --- a/tests/unit/mcpgateway/middleware/test_token_scoping.py +++ b/tests/unit/mcpgateway/middleware/test_token_scoping.py @@ -7,7 +7,6 @@ """ # Standard -import json from unittest.mock import AsyncMock, MagicMock, patch # Third-Party @@ -84,8 +83,6 @@ async def test_admin_permissions_use_canonical_constants(self, middleware): result = middleware._check_permission_restrictions("/admin", "GET", ["admin.read"]) assert result == False, "Should reject non-canonical 'admin.read' permission" - - @pytest.mark.asyncio async def test_server_scoped_token_blocked_from_admin(self, middleware, mock_request): """Test that server-scoped tokens are blocked from admin endpoints (security fix).""" @@ -97,19 +94,16 @@ async def test_server_scoped_token_blocked_from_admin(self, middleware, mock_req with patch.object(middleware, '_extract_token_scopes') as mock_extract: mock_extract.return_value = {"server_id": "specific-server"} - # Mock call_next (the next middleware or request handler) + # Create mock call_next call_next = AsyncMock() - # Perform the request, which should return a JSONResponse instead of raising HTTPException - response = await middleware(mock_request, call_next) - - # Ensure response is a JSONResponse and parse its content - content = json.loads(response.body) # Parse response content to dictionary + # Should raise HTTPException due to server restriction + with pytest.raises(HTTPException) as exc_info: + await middleware(mock_request, call_next) - # Check that the response is a JSONResponse with status 403 and the correct detail - assert response.status_code == status.HTTP_403_FORBIDDEN - assert "not authorized for this server" in content.get("detail") - call_next.assert_not_called() # Ensure the next handler is not called + assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN + assert "not authorized for this server" in exc_info.value.detail + call_next.assert_not_called() @pytest.mark.asyncio async def test_permission_restricted_token_blocked_from_admin(self, middleware, mock_request): @@ -122,21 +116,15 @@ async def test_permission_restricted_token_blocked_from_admin(self, middleware, with patch.object(middleware, '_extract_token_scopes') as mock_extract: mock_extract.return_value = {"permissions": [Permissions.TOOLS_READ]} - # Mock call_next (the next middleware or request handler) call_next = AsyncMock() - # Perform the request, which should return a JSONResponse instead of raising HTTPException - response = await middleware(mock_request, call_next) - - # Ensure response is a JSONResponse and parse its content - content = json.loads(response.body) # Parse response content to dictionary - - # Check that the response is a JSONResponse with status 403 and the correct detail - assert response.status_code == status.HTTP_403_FORBIDDEN - assert "Insufficient permissions for this operation" in content.get("detail") - call_next.assert_not_called() # Ensure the next handler is not called - + # Should raise HTTPException due to insufficient permissions + with pytest.raises(HTTPException) as exc_info: + await middleware(mock_request, call_next) + assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN + assert "Insufficient permissions" in exc_info.value.detail + call_next.assert_not_called() @pytest.mark.asyncio async def test_admin_token_allowed_to_admin_endpoints(self, middleware, mock_request): diff --git a/tests/unit/mcpgateway/plugins/framework/test_manager.py b/tests/unit/mcpgateway/plugins/framework/test_manager.py index 22efe7c42..81296cc93 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_manager.py +++ b/tests/unit/mcpgateway/plugins/framework/test_manager.py @@ -11,7 +11,8 @@ # First-Party from mcpgateway.models import Message, PromptResult, Role, TextContent -from mcpgateway.plugins.framework import GlobalContext, HttpHeaderPayload, PluginManager, PluginViolationError, PromptPosthookPayload, PromptPrehookPayload, ToolPostInvokePayload, ToolPreInvokePayload +from mcpgateway.plugins.framework.manager import PluginManager +from mcpgateway.plugins.framework.models import GlobalContext, PromptPosthookPayload, PromptPrehookPayload, ToolPostInvokePayload, ToolPreInvokePayload from plugins.regex_filter.search_replace import SearchReplaceConfig @@ -120,11 +121,6 @@ async def test_manager_filter_plugins(): result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) assert not result.continue_processing assert result.violation - - with pytest.raises(PluginViolationError) as ve: - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context, violations_as_exceptions=True) - assert ve.value.violation - assert ve.value.violation.reason == "Prompt not allowed" await manager.shutdown() @@ -138,9 +134,6 @@ async def test_manager_multi_filter_plugins(): result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) assert not result.continue_processing assert result.violation - with pytest.raises(PluginViolationError) as ve: - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context, violations_as_exceptions=True) - assert ve.value.violation await manager.shutdown() @@ -238,46 +231,3 @@ async def test_manager_tool_hooks_with_actual_plugin(): assert result.violation is None await manager.shutdown() - - -@pytest.mark.asyncio -async def test_manager_tool_hooks_with_header_mods(): - """Test tool hooks with a real plugin configured for tool processing.""" - manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/tool_headers_plugin.yaml") - await manager.initialize() - assert manager.initialized - - # Test tool pre-invoke with transformation - use correct tool name from config - tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is bad data", "quality": "wrong"}, headers=None) - global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.tool_pre_invoke(tool_payload, global_context=global_context) - - # Should continue processing with transformations applied - assert result.continue_processing - assert result.modified_payload is not None - assert result.modified_payload.name == "test_tool" - assert result.modified_payload.args["input"] == "This is bad data" # bad -> good - assert result.modified_payload.args["quality"] == "wrong" # wrong -> right - assert result.violation is None - assert result.modified_payload.headers - assert result.modified_payload.headers["User-Agent"] == "Mozilla/5.0" - assert result.modified_payload.headers["Connection"] == "keep-alive" - - # Test tool pre-invoke with transformation - use correct tool name from config - tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is bad data", "quality": "wrong"}, headers=HttpHeaderPayload({'Content-Type': 'application/json'})) - global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.tool_pre_invoke(tool_payload, global_context=global_context) - - # Should continue processing with transformations applied - assert result.continue_processing - assert result.modified_payload is not None - assert result.modified_payload.name == "test_tool" - assert result.modified_payload.args["input"] == "This is bad data" # bad -> good - assert result.modified_payload.args["quality"] == "wrong" # wrong -> right - assert result.violation is None - assert result.modified_payload.headers - assert result.modified_payload.headers["User-Agent"] == "Mozilla/5.0" - assert result.modified_payload.headers["Connection"] == "keep-alive" - assert result.modified_payload.headers['Content-Type'] == 'application/json' - - await manager.shutdown() diff --git a/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py b/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py index 9abf9f88d..3a6f78417 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py +++ b/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py @@ -29,7 +29,6 @@ PluginMode, PluginResult, PluginViolation, - PluginViolationError, PromptPosthookPayload, PromptPrehookPayload, ToolPostInvokePayload, @@ -176,17 +175,6 @@ async def prompt_pre_fetch(self, payload, context): assert result.continue_processing assert result.violation is None - plugin_config.mode = PluginMode.ENFORCE_IGNORE_ERROR - with patch.object(manager._registry, 'get_plugins_for_hook') as mock_get: - plugin_ref = PluginRef(error_plugin) - mock_get.return_value = [plugin_ref] - - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) - - # Should continue in enforce_ignore_error mode - assert result.continue_processing - assert result.violation is None - await manager.shutdown() @@ -417,12 +405,6 @@ async def prompt_pre_fetch(self, payload, context): assert result.violation.code == "CONTENT_BLOCKED" assert result.violation.plugin_name == "BlockingPlugin" - with pytest.raises(PluginViolationError) as pve: - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context, violations_as_exceptions=True) - assert pve.value.violation - assert pve.value.message - assert pve.value.violation.code == "CONTENT_BLOCKED" - assert pve.value.violation.plugin_name == "BlockingPlugin" await manager.shutdown() diff --git a/tests/unit/mcpgateway/services/test_resource_service_plugins.py b/tests/unit/mcpgateway/services/test_resource_service_plugins.py index ba840333d..b5c43bf9e 100644 --- a/tests/unit/mcpgateway/services/test_resource_service_plugins.py +++ b/tests/unit/mcpgateway/services/test_resource_service_plugins.py @@ -17,9 +17,10 @@ # First-Party from mcpgateway.models import ResourceContent -from mcpgateway.services.resource_service import ResourceNotFoundError, ResourceService -from mcpgateway.plugins.framework import PluginError, PluginErrorModel, PluginViolation, PluginViolationError - +from mcpgateway.plugins.framework.models import ( + PluginViolation, +) +from mcpgateway.services.resource_service import ResourceError, ResourceNotFoundError, ResourceService class TestResourceServicePluginIntegration: @@ -130,7 +131,9 @@ async def test_read_resource_blocked_by_plugin(self, resource_service_with_plugi # Setup pre-fetch hook to block mock_manager.resource_pre_fetch = AsyncMock( - side_effect=PluginViolationError(message="Protocol not allowed", + return_value=( + MagicMock( + continue_processing=False, violation=PluginViolation( reason="Protocol not allowed", code="PROTOCOL_BLOCKED", @@ -138,12 +141,14 @@ async def test_read_resource_blocked_by_plugin(self, resource_service_with_plugi details={"protocol": "file", "uri": "file:///etc/passwd"} ), ), + None, + ) ) - with pytest.raises(PluginViolationError) as exc_info: + with pytest.raises(ResourceError) as exc_info: await service.read_resource(mock_db, "file:///etc/passwd") - assert "Protocol not allowed" in str(exc_info.value) + assert "Resource blocked: Protocol not allowed" in str(exc_info.value) mock_manager.resource_pre_fetch.assert_called_once() # Post-fetch should not be called if pre-fetch blocks mock_manager.resource_post_fetch.assert_not_called() @@ -259,11 +264,12 @@ async def test_read_resource_plugin_error_handling(self, resource_service_with_p mock_db.execute.return_value.scalar_one_or_none.return_value = mock_resource # Setup pre-fetch hook to raise an error - mock_manager.resource_pre_fetch = AsyncMock(side_effect=PluginError(error=PluginErrorModel(message="Plugin error", plugin_name="mock_plugin"))) + mock_manager.resource_pre_fetch = AsyncMock(side_effect=ValueError("Plugin error")) - with pytest.raises(PluginError): - result = await service.read_resource(mock_db, "test://resource") + # Should continue without plugin processing on error + result = await service.read_resource(mock_db, "test://resource") + assert result == mock_resource.content mock_manager.resource_pre_fetch.assert_called_once() @pytest.mark.asyncio @@ -291,19 +297,24 @@ async def test_read_resource_post_fetch_blocking(self, resource_service_with_plu # Setup post-fetch hook to block mock_manager.resource_post_fetch = AsyncMock( - side_effect=PluginViolationError(message="Content contains sensitive data", - violation=PluginViolation( + return_value=( + MagicMock( + continue_processing=False, + violation=PluginViolation( reason="Content contains sensitive data", description="The resource content was flagged as containing sensitive information", code="SENSITIVE_CONTENT", details={"uri": "test://resource"} - )) + ), + ), + None, + ) ) - with pytest.raises(PluginViolationError) as exc_info: + with pytest.raises(ResourceError) as exc_info: await service.read_resource(mock_db, "test://resource") - assert "Content contains sensitive data" in str(exc_info.value) + assert "Resource content blocked: Content contains sensitive data" in str(exc_info.value) @pytest.mark.asyncio async def test_read_resource_with_template(self, resource_service_with_plugins, mock_db): diff --git a/tests/unit/mcpgateway/services/test_tool_service.py b/tests/unit/mcpgateway/services/test_tool_service.py index 9a6f4850c..6d9d24b91 100644 --- a/tests/unit/mcpgateway/services/test_tool_service.py +++ b/tests/unit/mcpgateway/services/test_tool_service.py @@ -22,7 +22,7 @@ from mcpgateway.db import A2AAgent as DbA2AAgent from mcpgateway.db import Gateway as DbGateway from mcpgateway.db import Tool as DbTool -from mcpgateway.plugins.framework import PluginError, PluginErrorModel, PluginViolationError, PluginManager +from mcpgateway.plugins.framework import PluginViolationError from mcpgateway.schemas import AuthenticationValues, ToolCreate, ToolRead, ToolUpdate from mcpgateway.services.tool_service import ( TextContent, @@ -48,7 +48,7 @@ def tool_service(): def mock_gateway(): """Create a mock gateway model.""" gw = MagicMock(spec=DbGateway) - gw.id = "1" + gw.id = 1 gw.name = "test_gateway" gw.slug = "test-gateway" gw.url = "http://example.com/gateway" @@ -56,11 +56,6 @@ def mock_gateway(): gw.transport = "SSE" gw.capabilities = {"prompts": {"listChanged": True}, "resources": {"listChanged": True}, "tools": {"listChanged": True}} gw.created_at = gw.updated_at = gw.last_seen = "2025-01-01T00:00:00Z" - gw.modified_by = gw.created_by = "Someone" - gw.modified_via = gw.created_via = "ui" - gw.modified_from_ip = gw.created_from_ip = "127.0.0.1" - gw.modified_user_agent = gw.created_user_agent = "Chrome" - gw.import_batch_id = gw.federation_source = gw.team_id = gw.visibility = gw.owner_email = None # one dummy tool hanging off the gateway tool = MagicMock(spec=DbTool, id=101, name="dummy_tool") @@ -90,19 +85,6 @@ def mock_tool(): tool.jsonpath_filter = "" tool.created_at = "2023-01-01T00:00:00" tool.updated_at = "2023-01-01T00:00:00" - tool.created_by = "MCP Gateway team" - tool.created_from_ip = "1.2.3.4" - tool.created_via = "ui" - tool.created_user_agent = "Chrome" - tool.modified_by = "No one" - tool.modified_from_ip = "1.2.3.4" - tool.modified_via = "ui" - tool.modified_user_agent = "Chrome" - tool.import_batch_id = "2" - tool.federation_source = "federation_source" - tool.team_id = "5" - tool.visibility = "private" - tool.owner_email = "admin@admin.org" tool.enabled = True tool.reachable = True tool.auth_type = None @@ -1419,9 +1401,6 @@ async def test_invoke_tool_mcp_streamablehttp(self, tool_service, mock_tool, tes reachable=True, auth_type="bearer", # ←← attribute your error complained about auth_value="Bearer abc123", - capabilities = {"prompts": {"listChanged": True}, "resources": {"listChanged": True}, "tools": {"listChanged": True}}, - transport = "STREAMABLEHTTP", - passthrough_headers = [], ) # Configure tool as REST mock_tool.integration_type = "MCP" @@ -1522,9 +1501,6 @@ async def test_invoke_tool_mcp_non_standard(self, tool_service, mock_tool, test_ reachable=True, auth_type="bearer", # ←← attribute your error complained about auth_value="Bearer abc123", - capabilities = {"prompts": {"listChanged": True}, "resources": {"listChanged": True}, "tools": {"listChanged": True}}, - transport = "STREAMABLEHTTP", - passthrough_headers = [], ) # Configure tool as REST mock_tool.integration_type = "MCP" @@ -2407,9 +2383,8 @@ async def test_invoke_tool_with_plugin_post_invoke_error_fail_on_error(self, too assert "Plugin error" in str(exc_info.value) - - async def test_invoke_tool_with_plugin_metadata_rest(self, tool_service, mock_tool, test_db): - """Test invoking tool with plugin post-invoke hook error when fail_on_plugin_error is True.""" + async def test_invoke_tool_with_plugin_post_invoke_error_continue_on_error(self, tool_service, mock_tool, test_db): + """Test invoking tool with plugin post-invoke hook error when fail_on_plugin_error is False.""" # Configure tool as REST mock_tool.integration_type = "REST" mock_tool.request_type = "POST" @@ -2428,10 +2403,16 @@ async def test_invoke_tool_with_plugin_metadata_rest(self, tool_service, mock_to tool_service._http_client.request.return_value = mock_response # Mock plugin manager and post-invoke hook with error - tool_service._plugin_manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/tool_headers_metadata_plugin.yaml") - await tool_service._plugin_manager.initialize() - # Mock metrics recording - tool_service._record_tool_metric = AsyncMock() + tool_service._plugin_manager = Mock() + tool_service._plugin_manager.tool_pre_invoke = AsyncMock(return_value=(Mock(continue_processing=True, violation=None, modified_payload=None), None)) + tool_service._plugin_manager.tool_post_invoke = AsyncMock(side_effect=Exception("Plugin error")) + + # Mock plugin config to continue on errors + mock_plugin_settings = Mock() + mock_plugin_settings.fail_on_plugin_error = False + mock_config = Mock() + mock_config.plugin_settings = mock_plugin_settings + tool_service._plugin_manager.config = mock_config with ( patch("mcpgateway.services.tool_service.decode_auth", return_value={}), @@ -2492,4 +2473,5 @@ async def test_invoke_tool_with_plugin_metadata_sse(self, tool_service, mock_too ): await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=None) - await tool_service._plugin_manager.shutdown() + # Verify result still succeeded despite plugin error + assert result.content[0].text == '{\n "result": "original response"\n}' diff --git a/tests/unit/mcpgateway/utils/test_metrics_common.py b/tests/unit/mcpgateway/utils/test_metrics_common.py new file mode 100644 index 000000000..38f57cba3 --- /dev/null +++ b/tests/unit/mcpgateway/utils/test_metrics_common.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +""" +Unit tests for metrics_common.py utility functions. +""" + +# Standard +import unittest +from unittest.mock import MagicMock + +# Third-party +import pytest + +# First-party +from mcpgateway.utils.metrics_common import build_top_performers, calculate_success_rate, format_response_time + + +class TestMetricsCommon(unittest.TestCase): + """Test suite for metrics_common.py utility functions.""" + + def test_calculate_success_rate_normal(self): + """Test success rate calculation with normal inputs.""" + # Test with integer inputs + self.assertEqual(calculate_success_rate(75, 100), 75.0) + # Test with float inputs + self.assertEqual(calculate_success_rate(7.5, 10), 75.0) + # Test with 100% success rate + self.assertEqual(calculate_success_rate(10, 10), 100.0) + # Test with 0% success rate + self.assertEqual(calculate_success_rate(0, 10), 0.0) + + def test_calculate_success_rate_edge_cases(self): + """Test success rate calculation with edge cases.""" + # Test with zero total + self.assertIsNone(calculate_success_rate(0, 0)) + # Test with negative total + self.assertIsNone(calculate_success_rate(5, -10)) + # Test with None inputs + self.assertIsNone(calculate_success_rate(None, 10)) + self.assertIsNone(calculate_success_rate(5, None)) + self.assertIsNone(calculate_success_rate(None, None)) + # Test with successful > total (should still calculate but might not make logical sense) + self.assertEqual(calculate_success_rate(15, 10), 150.0) + + def test_format_response_time_normal(self): + """Test response time formatting with normal inputs.""" + # Test with integer + self.assertEqual(format_response_time(1), "1.000") + # Test with float, no rounding + self.assertEqual(format_response_time(1.234), "1.234") + # Test with float, rounding up + self.assertEqual(format_response_time(1.2345), "1.235") + # Test with float, rounding down + self.assertEqual(format_response_time(1.2344), "1.234") + # Test with zero + self.assertEqual(format_response_time(0), "0.000") + + def test_format_response_time_edge_cases(self): + """Test response time formatting with edge cases.""" + # Test with None + self.assertIsNone(format_response_time(None)) + # Test with negative value + self.assertEqual(format_response_time(-1.234), "-1.234") + # Test with string that can be converted to float + self.assertEqual(format_response_time("1.234"), "1.234") + # Test with string that cannot be converted to float + with pytest.raises(ValueError): + format_response_time("not a number") + + def test_build_top_performers(self): + """Test building TopPerformer objects from database results.""" + # Create mock results + result1 = MagicMock() + result1.id = 1 + result1.name = "test1" + result1.execution_count = 10 + result1.avg_response_time = 1.5 + result1.success_rate = 85.0 + result1.last_execution = None + + result2 = MagicMock() + result2.id = 2 + result2.name = "test2" + result2.execution_count = 20 + result2.avg_response_time = None + result2.success_rate = None + result2.last_execution = None + + # Test with a list of results + performers = build_top_performers([result1, result2]) + + # Verify the results + self.assertEqual(len(performers), 2) + self.assertEqual(performers[0].id, 1) + self.assertEqual(performers[0].name, "test1") + self.assertEqual(performers[0].execution_count, 10) + self.assertEqual(performers[0].avg_response_time, 1.5) + self.assertEqual(performers[0].success_rate, 85.0) + self.assertIsNone(performers[0].last_execution) + + self.assertEqual(performers[1].id, 2) + self.assertEqual(performers[1].name, "test2") + self.assertEqual(performers[1].execution_count, 20) + self.assertIsNone(performers[1].avg_response_time) + self.assertIsNone(performers[1].success_rate) + self.assertIsNone(performers[1].last_execution) + + # Test with empty list + empty_performers = build_top_performers([]) + self.assertEqual(len(empty_performers), 0) + + +if __name__ == "__main__": + unittest.main()