diff --git a/README.md b/README.md index eeaf8016..2fb26af6 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ Strands Agents Tools is a community-driven project that provides a powerful set - 🐝 **Swarm Intelligence** - Coordinate multiple AI agents for parallel problem solving with shared memory - 🔄 **Multiple tools in Parallel** - Call multiple other tools at the same time in parallel with Batch Tool - 🔍 **Browser Tool** - Tool giving an agent access to perform automated actions on a browser (chromium) +- 🐘 **Query Postgres** – Query PostgreSQL with Natural Language ## 📦 Installation @@ -128,7 +129,7 @@ Below is a comprehensive table of all available tools, how to use them with an a | workflow | `agent.tool.workflow(action="create", name="data_pipeline", steps=[{"tool": "file_read"}, {"tool": "python_repl"}])` | Define, execute, and manage multi-step automated workflows | | batch| `agent.tool.batch(invocations=[{"name": "current_time", "arguments": {"timezone": "Europe/London"}}, {"name": "stop", "arguments": {}}])` | Call multiple other tools in parallel. | | browser | `browser = LocalChromiumBrowser(); agent = Agent(tools=[browser.browser])` | Web scraping, automated testing, form filling, web automation tasks | - +| query_postgres | `agent.tool.query_postgres(query="SELECT name FROM users WHERE active = true")` | Run secure, read-only PostgreSQL queries for insights | \* *These tools do not work on windows* ## 💻 Usage Examples @@ -450,6 +451,50 @@ response = agent("discover available agents and send a greeting message") # - send_message(message_text, target_agent_url) to communicate ``` +### Query Postgres +```python +import os +from strands import Agent +from strands.models import BedrockModel +from strands_tools.query_postgres import query_postgres + +# Show rich UI for tools in CLI +os.environ["STRANDS_TOOL_CONSOLE_MODE"] = "enabled" + +model = BedrockModel(model_id="apac.anthropic.claude-sonnet-4-20250514-v1:0") + +# Initialize the agent with tools, model, and configuration +agent = Agent( + tools=[query_postgres], + system_prompt=""" +You are a helpful business analysis tool that answers user questions by generating and executing SQL queries on a PostgreSQL database. +You only respond with SQL query results via the tool named `query_postgres`. + +### Database Schema +List you schema here for best results. + +### Instructions +- Use the `query_postgres` tool to run SQL queries. +- You do not need to handle connection details — they are automatically managed via environment variables: `PGHOST`, `PGDATABASE`, `PGUSER`, `PGPASSWORD` and `PGPORT`. +- Limit SELECT results to a maximum of 100 rows unless otherwise specified. +- Always be accurate with JOINs and field references based on the schema. +- If a query returns multiple results, structure the output for readability. +- Only return text content via `content: [{"text": "..."}]`. + +You are expected to translate user queries like: +- "What is the average product price?" +- "List top 5 customers by total order value" +- "Show the number of orders per customer" + +into valid SQL and return the results cleanly. +""", + model=model +) + +agent("What is the average price of the products") +``` + + ## 🌍 Environment Variables Configuration Agents Tools provides extensive customization through environment variables. This allows you to configure tool behavior without modifying code, making it ideal for different environments (development, testing, production). @@ -605,6 +650,15 @@ The Mem0 Memory Tool supports three different backend configurations: | STRANDS_BROWSER_WIDTH | Default width of the browser | 1280 | | STRANDS_BROWSER_HEIGHT | Default height of the browser | 800 | +### Query Postgres Tool +| Environment Variable | Description | Default | +|----------------------| ----------------------------------------------- |------------| +| PGHOST | Hostname or IP address of the PostgreSQL server | localhost | +| PGPORT | Port number to connect to PostgreSQL | 5432 | +| PGDATABASE | Name of the PostgreSQL database to connect to | *Required* | +| PGUSER | **Read-only** user for executing queries | *Required* | +| PGPASSWORD | Password for the PostgreSQL user | *Required* | + ## Contributing ❤️ diff --git a/pyproject.toml b/pyproject.toml index 2ceabb19..9eb42f79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,8 @@ agent_core_code_interpreter = [ "bedrock-agentcore==0.1.0" ] a2a_client = ["a2a-sdk[sql]>=0.2.11"] +# Optional dependency for PostgreSQL-related tools (e.g., query_postgres) +query_postgres = ["psycopg2-binary>=2.9.9,<3.0.0"] [tool.hatch.envs.hatch-static-analysis] features = ["mem0_memory", "local_chromium_browser", "agent_core_browser", "agent_core_code_interpreter", "a2a_client"] @@ -112,7 +114,7 @@ lint-check = [ lint-fix = ["ruff check --fix"] [tool.hatch.envs.hatch-test] -features = ["mem0_memory", "local_chromium_browser", "agent_core_browser", "agent_core_code_interpreter", "a2a_client"] +features = ["mem0_memory", "local_chromium_browser", "agent_core_browser", "agent_core_code_interpreter", "a2a_client", "query_postgres"] extra-dependencies = [ "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<9.0.0", diff --git a/src/strands_tools/query_postgres.py b/src/strands_tools/query_postgres.py new file mode 100644 index 00000000..8c0c0e64 --- /dev/null +++ b/src/strands_tools/query_postgres.py @@ -0,0 +1,95 @@ +import os + +from strands import tool + + +@tool +def query_postgres(tool_use_id: str, query: str, limit: int = 100) -> dict: + """ + Safely execute **read-only** SQL queries (e.g., SELECT, WITH) against a PostgreSQL database. + + 🔐 Security Guidelines: + - This tool **strictly blocks** any non-read query such as INSERT, UPDATE, DELETE, DROP, ALTER, etc. + - Use **read-only PostgreSQL credentials** (e.g., a user with SELECT-only permissions). + - All connections should be made using **environment-controlled credentials** to avoid exposure in code. + - Only use this tool for data exploration, reporting, and analytics — not transactional workloads. + + Parameters: + tool_use_id: Unique ID for tool invocation (provided by the agent runtime) + query: SQL SELECT/CTE query to execute + limit: Optional row limit for SELECT queries (defaults to 100) + + Returns: + A dict with toolUseId, status ("success" | "error"), and content (text response) + """ + try: + import psycopg2 + from psycopg2.extras import RealDictCursor + except ImportError: + return { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": "❌ psycopg2 not installed. Run: pip install psycopg2-binary"}], + } + + # Sanitize query + q_upper = query.strip().upper() + disallowed = ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE", "CREATE", "REPLACE", "GRANT", "REVOKE"] + if any(q_upper.startswith(k) for k in disallowed): + return { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": "🚫 Only SELECT/CTE queries are allowed. This tool is read-only."}], + } + + # Use env vars for connection + try: + conn = psycopg2.connect( + host=os.getenv("PGHOST", "localhost"), + port=os.getenv("PGPORT", "5432"), + dbname=os.getenv("PGDATABASE"), + user=os.getenv("PGUSER"), + password=os.getenv("PGPASSWORD"), + cursor_factory=RealDictCursor, + ) + except Exception as e: + return { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"❌ Could not connect to PostgreSQL: {str(e)}"}], + } + + try: + with conn: + with conn.cursor() as cur: + if q_upper.startswith("SELECT") and "LIMIT" not in q_upper: + query = f"{query.rstrip(';')} LIMIT {limit}" + + cur.execute(query) + rows = cur.fetchall() + cols = [desc[0] for desc in cur.description] if cur.description else [] + + lines = [f"📊 Query: `{query}`", f"🧮 Rows: {len(rows)}", f"🔠 Columns: {', '.join(cols)}"] + + for i, row in enumerate(rows[: min(10, len(rows))], start=1): + lines.append(f" • Row {i}: " + ", ".join(f"{c}={row[c]}" for c in cols)) + if len(rows) > 10: + lines.append(f"...and {len(rows) - 10} more rows.") + + return { + "toolUseId": tool_use_id, + "status": "success", + "content": [{"text": "\n".join(lines)}], + } + + except Exception as e: + return { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"❌ Query execution error: {str(e)}"}], + } + finally: + try: + conn.close() + except Exception: + pass diff --git a/tests/test_query_postgres.py b/tests/test_query_postgres.py new file mode 100644 index 00000000..c03b3579 --- /dev/null +++ b/tests/test_query_postgres.py @@ -0,0 +1,108 @@ +import os + +import pytest +from strands_tools.query_postgres import query_postgres # update to actual import path + + +# Utility to simulate tool invocation +def run_tool(query, env=None, limit=None): + # set tool_use_id arbitrarily + tool_use_id = "test-invocation" + if env: + for k, v in env.items(): + os.environ[k] = v + result = query_postgres(tool_use_id=tool_use_id, query=query, limit=limit or 100) + return result + + +@pytest.fixture(autouse=True) +def clear_env(): + # clear environment variables for isolation + for var in ("PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD"): + os.environ.pop(var, None) + yield + for var in ("PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD"): + os.environ.pop(var, None) + + +def test_missing_env_vars(): + res = run_tool("SELECT 1 as one;") + assert res["toolUseId"] == "test-invocation" + assert res["status"] == "error" + assert "Could not connect" in res["content"][0]["text"] + + +def test_disallowed_query(): + # Provide env so connection is attempted but the query is blocked first + env = { + "PGHOST": "localhost", + "PGPORT": "5432", + "PGDATABASE": "testdb", + "PGUSER": "user", + "PGPASSWORD": "pwd", + } + res = run_tool("DELETE FROM users;", env=env) + assert res["status"] == "error" + assert "🚫 Only SELECT/CTE queries are allowed. This tool is read-only." in res["content"][0]["text"] + + +def test_read_only_select(monkeypatch): + # --- Set env vars --- + monkeypatch.setenv("PGHOST", "localhost") + monkeypatch.setenv("PGPORT", "5432") + monkeypatch.setenv("PGDATABASE", "testdb") + monkeypatch.setenv("PGUSER", "readonly") + monkeypatch.setenv("PGPASSWORD", "pwd") + + # --- Mock psycopg2 --- + import sys + import types + + # Fake cursor that simulates fetch + class FakeCursor: + def execute(self, q): + self._rows = [{"col1": 123}, {"col1": 456}] + self.description = [("col1",)] + + def fetchall(self): + return self._rows + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + # Fake connection that returns fake cursor + class FakeConn: + def cursor(self): + return FakeCursor() + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def close(self): + pass + + # Create a fake psycopg2 module with connect + extras.RealDictCursor + fake_psycopg2 = types.ModuleType("psycopg2") + fake_psycopg2.connect = lambda **kwargs: FakeConn() + + fake_extras = types.SimpleNamespace(RealDictCursor=object) + fake_psycopg2.extras = fake_extras + + sys.modules["psycopg2"] = fake_psycopg2 + sys.modules["psycopg2.extras"] = fake_extras + + from strands_tools.query_postgres import query_postgres # Import after patching + + result = query_postgres("test-invoke", "SELECT col1 FROM test") + assert result["status"] == "success" + assert ( + "\U0001f4ca Query: `SELECT col1 FROM test LIMIT 100`\n" + "\U0001f9ee Rows: 2\n\U0001f520 Columns: col1\n \u2022 Row 1: col1=123\n \u2022 Row 2: col1=456" + in result["content"][0]["text"] + )