diff --git a/pyproject.toml b/pyproject.toml index 040babe67..787c7e986 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ [project.optional-dependencies] anthropic = ["anthropic>=0.21.0,<1.0.0"] +langchain = ["langchain-core>=1.2.0,<2.0.0"] gemini = ["google-genai>=1.32.0,<2.0.0"] litellm = ["litellm>=1.75.9,<2.0.0", "openai>=1.68.0,<1.110.0"] llamaapi = ["llama-api-client>=0.1.0,<1.0.0"] @@ -79,7 +80,7 @@ bidi = [ bidi-gemini = ["google-genai>=1.32.0,<2.0.0"] bidi-openai = ["websockets>=15.0.0,<16.0.0"] -all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] +all = ["strands-agents[a2a,anthropic,docs,gemini,langchain,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] bidi-all = ["strands-agents[a2a,bidi,bidi-gemini,bidi-openai,docs,otel]"] dev = [ diff --git a/src/strands/experimental/tools/__init__.py b/src/strands/experimental/tools/__init__.py index ad693f8ac..9c3922325 100644 --- a/src/strands/experimental/tools/__init__.py +++ b/src/strands/experimental/tools/__init__.py @@ -1,5 +1,20 @@ """Experimental tools package.""" +from typing import Any + from .tool_provider import ToolProvider __all__ = ["ToolProvider"] + + +def __getattr__(name: str) -> Any: + """Lazy load optional dependencies. + + LangChainTool requires langchain-core which is an optional dependency. + """ + if name == "LangChainTool": + from .langchain_tool import LangChainTool + + return LangChainTool + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/strands/experimental/tools/langchain_tool.py b/src/strands/experimental/tools/langchain_tool.py new file mode 100644 index 000000000..434a271bc --- /dev/null +++ b/src/strands/experimental/tools/langchain_tool.py @@ -0,0 +1,224 @@ +"""LangChain tool wrapper for Strands Agents. + +This module provides a Strands AgentTool that wraps LangChain BaseTool instances, +enabling seamless use of LangChain tools with Strands Agents. + +All LangChain tools inherit from BaseTool, so this wrapper works with any LangChain tool: +tools created with the @tool decorator, StructuredTool instances, or custom BaseTool subclasses. + +See: https://python.langchain.com/docs/concepts/tools/ + +Example: + ```python + from langchain_core.tools import tool as langchain_tool + from strands import Agent + from strands.experimental.tools import LangChainTool + + @langchain_tool + def calculator(a: int, b: int) -> int: + '''Add two numbers.''' + return a + b + + agent = Agent(tools=[LangChainTool(calculator)]) + ``` +""" + +import logging +from typing import Any + +from langchain_core.tools import BaseTool as LangChainBaseTool +from typing_extensions import override + +from strands.types._events import ToolResultEvent +from strands.types.tools import AgentTool, ToolGenerator, ToolResultContent, ToolSpec, ToolUse + +logger = logging.getLogger(__name__) + + +class LangChainTool(AgentTool): + """A Strands AgentTool that wraps a LangChain BaseTool. + + This class allows LangChain tools to be used directly with Strands Agents + by wrapping them in the AgentTool interface. + + Example: + ```python + from langchain_core.tools import tool as langchain_tool + + @langchain_tool + def calculator(a: int, b: int) -> int: + '''Add two numbers.''' + return a + b + + # Wrap as Strands tool + strands_calculator = LangChainTool(calculator) + + # Use with Strands Agent + agent = Agent(tools=[strands_calculator]) + ``` + """ + + _langchain_tool: LangChainBaseTool + _tool_name: str + _tool_spec: ToolSpec + + def __init__( + self, + tool: LangChainBaseTool, + name: str | None = None, + description: str | None = None, + ) -> None: + """Initialize with a LangChain BaseTool. + + Args: + tool: A LangChain BaseTool instance. + name: Optional override for the tool name. + description: Optional override for the tool description. + """ + super().__init__() + + self._langchain_tool = tool + self._tool_name = name or tool.name + + tool_description = description or tool.description or f"Tool: {self._tool_name}" + + # Build tool spec + input_schema = self._build_input_schema(tool) + self._tool_spec: ToolSpec = { + "name": self._tool_name, + "description": tool_description, + "inputSchema": {"json": input_schema}, + } + + @staticmethod + def _build_input_schema(tool: LangChainBaseTool) -> dict[str, object]: + """Build JSON schema from a LangChain tool's args_schema. + + Args: + tool: A LangChain BaseTool instance. + + Returns: + A JSON schema dict suitable for Strands' inputSchema format. + """ + args_schema = tool.args_schema + + if args_schema is None: + return { + "type": "object", + "properties": {}, + "required": [], + } + + # args_schema is a Pydantic model class or a dict + # https://python.langchain.com/api_reference/core/tools/langchain_core.tools.base.BaseTool.html#langchain_core.tools.base.BaseTool.args_schema + if isinstance(args_schema, dict): + schema = args_schema.copy() + elif hasattr(args_schema, "model_json_schema"): + schema = args_schema.model_json_schema() + else: + return { + "type": "object", + "properties": {}, + "required": [], + } + + # Remove fields that aren't needed for tool input schemas: + # - title: Pydantic adds the class name, not useful for tool schemas + # - additionalProperties: validation constraint, not needed by model providers + schema.pop("title", None) + schema.pop("additionalProperties", None) + + # Ensure required fields exist + schema.setdefault("type", "object") + schema.setdefault("properties", {}) + schema.setdefault("required", []) + + return schema + + @property + def tool_name(self) -> str: + """Get the name of the tool. + + Returns: + The tool name. + """ + return self._tool_name + + @property + def tool_spec(self) -> ToolSpec: + """Get the tool specification. + + Returns: + The Strands-compatible tool specification. + """ + return self._tool_spec + + @property + def tool_type(self) -> str: + """Get the type of the tool. + + Returns: + 'langchain' to identify this as a wrapped LangChain tool. + """ + return "langchain" + + @property + def wrapped_tool(self) -> LangChainBaseTool: + """Access the underlying LangChain tool. + + Returns: + The original LangChain BaseTool instance. + """ + return self._langchain_tool + + @override + async def stream(self, tool_use: ToolUse, invocation_state: dict[str, object], **kwargs: object) -> ToolGenerator: + """Execute the LangChain tool and stream the result. + + Args: + tool_use: The tool use request containing input parameters. + invocation_state: Context for the tool invocation. + **kwargs: Additional keyword arguments. + + Yields: + ToolResultEvent containing the tool execution result. + """ + tool_use_id = tool_use.get("toolUseId", "unknown") + tool_input = tool_use.get("input", {}) + + result = await self._langchain_tool.ainvoke(tool_input) + content = self._convert_result_to_content(result) + + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "success", + "content": content, + } + ) + + def _convert_result_to_content(self, result: Any) -> list[ToolResultContent]: + """Convert a LangChain tool result to Strands content format. + + LangChain tools can return various content types defined in TOOL_MESSAGE_BLOCK_TYPES: + https://github.com/langchain-ai/langchain/blob/master/libs/core/langchain_core/tools/base.py + + Currently only string results are supported. Support for other types (text blocks, + image, json, document, etc.) will be added in future versions. + + Args: + result: The result from a LangChain tool invocation. + + Returns: + A list of content blocks in Strands format. + + Raises: + ValueError: If the result type is not supported. + """ + # TODO: Expand support for other LangChain content types (text blocks, image, json, etc.) + if isinstance(result, str): + return [{"text": result}] + + raise ValueError( + f"Unsupported LangChain result type: {type(result).__name__}. Only string results are currently supported." + ) diff --git a/tests/strands/tools/test_langchain_tool.py b/tests/strands/tools/test_langchain_tool.py new file mode 100644 index 000000000..3d2933c8e --- /dev/null +++ b/tests/strands/tools/test_langchain_tool.py @@ -0,0 +1,426 @@ +"""Tests for the LangChain tool wrapper.""" + +from typing import Optional, Type + +import pytest +from langchain_core.tools import BaseTool, StructuredTool +from langchain_core.tools import tool as langchain_tool +from pydantic import BaseModel, Field + +from strands.experimental.tools import LangChainTool +from strands.types.tools import ToolUse + + +class MockArgsSchema(BaseModel): + """Mock Pydantic schema for testing.""" + + query: str = Field(description="The search query") + max_results: int = Field(default=10, description="Maximum number of results") + + +class MockBaseTool(BaseTool): + """Mock LangChain BaseTool for testing.""" + + name: str = "mock_tool" + description: str = "A mock tool for testing" + args_schema: Optional[Type[BaseModel]] = None + return_value: str = "Mock result" + + def _run(self, **kwargs: object) -> str: + return self.return_value + + +class MockToolWithSchema(BaseTool): + """Mock LangChain tool with args_schema.""" + + name: str = "schema_tool" + description: str = "A tool with schema" + args_schema: Type[BaseModel] = MockArgsSchema + + def _run(self, query: str, max_results: int = 10) -> str: + return f"Searched: {query}, max: {max_results}" + + +# LangChain tool with explicit schema +class CalculatorInput(BaseModel): + """Input for calculator tool.""" + + a: int = Field(description="First number") + b: int = Field(description="Second number") + operation: str = Field(description="Operation: add, subtract, multiply, divide") + + +@langchain_tool(args_schema=CalculatorInput) +def calculator(a: int, b: int, operation: str) -> str: + """Perform basic arithmetic operations.""" + if operation == "add": + return f"Result: {a + b}" + elif operation == "subtract": + return f"Result: {a - b}" + elif operation == "multiply": + return f"Result: {a * b}" + elif operation == "divide": + return f"Result: {a / b}" if b != 0 else "Error: Division by zero" + return f"Unknown operation: {operation}" + + +# BaseTool subclass with schema +class GreetingInput(BaseModel): + """Input for greeting tool.""" + + person_name: str = Field(description="Name of the person to greet") + + +class GreetingTool(BaseTool): + """A tool that generates greetings.""" + + name: str = "greeting" + description: str = "Generate a greeting for a person" + args_schema: type[BaseModel] = GreetingInput + + def _run(self, person_name: str) -> str: + return f"Hello, {person_name}! Welcome!" + + +# StructuredTool.from_function() +def _reverse_string(text: str) -> str: + """Reverse a string.""" + return text[::-1] + + +reverse_tool = StructuredTool.from_function( + func=_reverse_string, + name="reverse_string", + description="Reverse the characters in a string", +) + + +# Tests for _build_input_schema + + +def test_build_input_schema_no_schema() -> None: + """Test schema building with no args_schema.""" + tool = MockBaseTool() + result = LangChainTool._build_input_schema(tool) + + assert result == { + "type": "object", + "properties": {}, + "required": [], + } + + +def test_build_input_schema_with_schema() -> None: + """Test schema building with args_schema.""" + tool = MockToolWithSchema() + result = LangChainTool._build_input_schema(tool) + + assert result["type"] == "object" + assert "properties" in result + assert "query" in result["properties"] + assert "max_results" in result["properties"] + + +# Tests for LangChainTool + + +def test_langchain_tool_init() -> None: + """Test LangChainTool initialization.""" + mock_tool = MockBaseTool() + tool = LangChainTool(mock_tool) + + assert tool.tool_name == "mock_tool" + assert tool.tool_type == "langchain" + assert tool.wrapped_tool is mock_tool + + +def test_langchain_tool_spec() -> None: + """Test tool spec generation.""" + mock_tool = MockToolWithSchema() + tool = LangChainTool(mock_tool) + + spec = tool.tool_spec + assert spec["name"] == "schema_tool" + assert spec["description"] == "A tool with schema" + assert "inputSchema" in spec + assert "json" in spec["inputSchema"] + + +# Tests for stream execution + + +@pytest.mark.asyncio +async def test_langchain_tool_stream_async() -> None: + """Test stream with async execution.""" + mock_tool = MockBaseTool() + tool = LangChainTool(mock_tool) + + tool_use: ToolUse = { + "toolUseId": "test-123", + "name": "mock_tool", + "input": {}, + } + + results = [] + async for event in tool.stream(tool_use, {}): + results.append(event) + + assert len(results) == 1 + result = results[0]["tool_result"] + assert result["toolUseId"] == "test-123" + assert result["status"] == "success" + assert "Mock result" in result["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_langchain_tool_stream_with_input() -> None: + """Test stream passes input correctly.""" + mock_tool = MockToolWithSchema() + tool = LangChainTool(mock_tool) + + tool_use: ToolUse = { + "toolUseId": "test-input", + "name": "schema_tool", + "input": {"query": "hello", "max_results": 5}, + } + + results = [] + async for event in tool.stream(tool_use, {}): + results.append(event) + + assert len(results) == 1 + result = results[0]["tool_result"] + assert result["status"] == "success" + + +# Test with @tool decorator + + +def test_langchain_tool_with_decorator() -> None: + """Test wrapping a tool created with @tool decorator.""" + + @langchain_tool + def search(query: str) -> str: + """Search for information.""" + return f"Results for: {query}" + + tool = LangChainTool(search) + + assert tool.tool_name == "search" + assert "Search for information" in tool.tool_spec["description"] + + +@pytest.mark.asyncio +async def test_langchain_tool_decorator_execution() -> None: + """Test executing a tool created with @tool decorator.""" + + @langchain_tool + def greet(name: str) -> str: + """Greet someone.""" + return f"Hello, {name}!" + + tool = LangChainTool(greet) + + tool_use: ToolUse = { + "toolUseId": "test-greet", + "name": "greet", + "input": {"name": "World"}, + } + + results = [] + async for event in tool.stream(tool_use, {}): + results.append(event) + + assert len(results) == 1 + result = results[0]["tool_result"] + assert result["status"] == "success" + assert "Hello, World!" in result["content"][0]["text"] + + +# Tests for _convert_result_to_content + + +def test_convert_result_string() -> None: + """Test converting a string result.""" + tool = LangChainTool(MockBaseTool()) + content = tool._convert_result_to_content("hello world") + + assert content == [{"text": "hello world"}] + + +def test_convert_result_non_string_raises() -> None: + """Test that non-string results raise ValueError.""" + tool = LangChainTool(MockBaseTool()) + + with pytest.raises(ValueError, match="Unsupported LangChain result type"): + tool._convert_result_to_content({"key": "value"}) + + with pytest.raises(ValueError, match="Unsupported LangChain result type"): + tool._convert_result_to_content(42) + + +# Tests for tool_spec with different LangChain tool types + + +def test_tool_spec_with_schema() -> None: + """Test tool_spec for a LangChain tool with explicit schema.""" + tool = LangChainTool(calculator) + + expected_spec = { + "name": "calculator", + "description": "Perform basic arithmetic operations.", + "inputSchema": { + "json": { + "description": "Input for calculator tool.", + "type": "object", + "properties": { + "a": { + "title": "A", + "type": "integer", + "description": "First number", + }, + "b": { + "title": "B", + "type": "integer", + "description": "Second number", + }, + "operation": { + "title": "Operation", + "type": "string", + "description": "Operation: add, subtract, multiply, divide", + }, + }, + "required": ["a", "b", "operation"], + } + }, + } + + assert tool.tool_spec == expected_spec + + +def test_tool_spec_basetool_subclass() -> None: + """Test tool_spec for a BaseTool subclass.""" + tool = LangChainTool(GreetingTool()) + + expected_spec = { + "name": "greeting", + "description": "Generate a greeting for a person", + "inputSchema": { + "json": { + "description": "Input for greeting tool.", + "type": "object", + "properties": { + "person_name": { + "title": "Person Name", + "type": "string", + "description": "Name of the person to greet", + }, + }, + "required": ["person_name"], + } + }, + } + + assert tool.tool_spec == expected_spec + + +def test_tool_spec_structured_tool() -> None: + """Test tool_spec for StructuredTool.from_function().""" + tool = LangChainTool(reverse_tool) + + assert tool.tool_name == "reverse_string" + assert tool.tool_spec["description"] == "Reverse the characters in a string" + assert "text" in tool.tool_spec["inputSchema"]["json"]["properties"] + + +# Tests for execution of different LangChain tool types + + +@pytest.mark.asyncio +async def test_stream_basetool_subclass() -> None: + """Test stream execution of a BaseTool subclass.""" + tool = LangChainTool(GreetingTool()) + + tool_use: ToolUse = { + "toolUseId": "test-greeting", + "name": "greeting", + "input": {"person_name": "Alice"}, + } + + results = [] + async for event in tool.stream(tool_use, {}): + results.append(event) + + assert len(results) == 1 + result = results[0]["tool_result"] + assert result["status"] == "success" + assert "Hello, Alice" in result["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_stream_structured_tool() -> None: + """Test stream execution of StructuredTool.from_function().""" + tool = LangChainTool(reverse_tool) + + tool_use: ToolUse = { + "toolUseId": "test-reverse", + "name": "reverse_string", + "input": {"text": "hello"}, + } + + results = [] + async for event in tool.stream(tool_use, {}): + results.append(event) + + assert len(results) == 1 + result = results[0]["tool_result"] + assert result["status"] == "success" + assert "olleh" in result["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_stream_tool_with_schema() -> None: + """Test stream execution of a tool with explicit schema.""" + tool = LangChainTool(calculator) + + tool_use: ToolUse = { + "toolUseId": "test-calc", + "name": "calculator", + "input": {"a": 10, "b": 5, "operation": "add"}, + } + + results = [] + async for event in tool.stream(tool_use, {}): + results.append(event) + + assert len(results) == 1 + result = results[0]["tool_result"] + assert result["status"] == "success" + assert "15" in result["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_stream_async_tool() -> None: + """Test stream execution of an async @langchain_tool.""" + + @langchain_tool + async def async_uppercase(text: str) -> str: + """Convert text to uppercase.""" + return text.upper() + + tool = LangChainTool(async_uppercase) + + tool_use: ToolUse = { + "toolUseId": "test-async", + "name": "async_uppercase", + "input": {"text": "hello"}, + } + + results = [] + async for event in tool.stream(tool_use, {}): + results.append(event) + + assert len(results) == 1 + result = results[0]["tool_result"] + assert result["status"] == "success" + assert "HELLO" in result["content"][0]["text"] diff --git a/tests_integ/tools/test_langchain_tool.py b/tests_integ/tools/test_langchain_tool.py new file mode 100644 index 000000000..ec8d5e2a4 --- /dev/null +++ b/tests_integ/tools/test_langchain_tool.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +"""Integration tests for LangChainTool with real agent interactions. + +These tests verify that LangChain tools work correctly when invoked by an agent +through natural language, which requires actual model inference. +""" + +from langchain_core.tools import tool as langchain_tool +from pydantic import BaseModel, Field + +from strands import Agent +from strands.experimental.tools import LangChainTool + + +@langchain_tool +def word_count(text: str) -> str: + """Count the number of words in text.""" + count = len(text.split()) + return f"The text contains {count} words." + + +class CalculatorInput(BaseModel): + """Input for calculator tool.""" + + a: int = Field(description="First number") + b: int = Field(description="Second number") + operation: str = Field(description="Operation: add, subtract, multiply, divide") + + +@langchain_tool(args_schema=CalculatorInput) +def calculator(a: int, b: int, operation: str) -> str: + """Perform basic arithmetic operations.""" + if operation == "add": + result = a + b + elif operation == "subtract": + result = a - b + elif operation == "multiply": + result = a * b + elif operation == "divide": + if b == 0: + return "Error: Division by zero" + result = a / b + else: + return f"Unknown operation: {operation}" + return f"Result: {result}" + + +def test_langchain_tool_natural_language(): + """Test LangChain tool invocation through natural language.""" + strands_tool = LangChainTool(word_count) + agent = Agent(tools=[strands_tool]) + + agent("Count the words in: 'The quick brown fox jumps over the lazy dog'") + + tool_results = [ + block["toolResult"] + for message in agent.messages + for block in message.get("content", []) + if "toolResult" in block + ] + assert len(tool_results) > 0 + assert tool_results[0]["status"] == "success" + + +def test_langchain_tool_calculator_natural_language(): + """Test calculator tool through natural language.""" + strands_tool = LangChainTool(calculator) + agent = Agent(tools=[strands_tool]) + + agent("What is 25 multiplied by 4? Use the calculator tool.") + + tool_results = [ + block["toolResult"] + for message in agent.messages + for block in message.get("content", []) + if "toolResult" in block + ] + assert len(tool_results) > 0 + assert tool_results[0]["status"] == "success"