diff --git a/crates/potato_type/src/anthropic/v1/request.rs b/crates/potato_type/src/anthropic/v1/request.rs index 8c08c0bd..c9095b7f 100644 --- a/crates/potato_type/src/anthropic/v1/request.rs +++ b/crates/potato_type/src/anthropic/v1/request.rs @@ -1225,6 +1225,7 @@ impl MessageParam { } // Return the text content from the first content part that is text + #[getter] pub fn text(&self) -> String { self.content .iter() diff --git a/crates/potato_type/src/google/v1/generate/request.rs b/crates/potato_type/src/google/v1/generate/request.rs index b40c8549..49aa5e24 100644 --- a/crates/potato_type/src/google/v1/generate/request.rs +++ b/crates/potato_type/src/google/v1/generate/request.rs @@ -1382,6 +1382,7 @@ impl GeminiContent { } // helper method for returning firs text part content + #[getter] pub fn text(&self) -> String { self.parts .iter() diff --git a/crates/potato_type/src/openai/v1/chat/request.rs b/crates/potato_type/src/openai/v1/chat/request.rs index a1934add..a5f0f386 100644 --- a/crates/potato_type/src/openai/v1/chat/request.rs +++ b/crates/potato_type/src/openai/v1/chat/request.rs @@ -328,6 +328,7 @@ impl ChatMessage { } // Return the text content from the first content part that is text + #[getter] pub fn text(&self) -> String { self.content .iter() diff --git a/py-potato/examples/anthropic/structured_output.py b/py-potato/examples/anthropic/structured_output.py index 7fbc58e0..10f75a03 100644 --- a/py-potato/examples/anthropic/structured_output.py +++ b/py-potato/examples/anthropic/structured_output.py @@ -38,7 +38,7 @@ class StructuredTaskOutput(BaseModel): agent = Agent(Provider.Anthropic) if __name__ == "__main__": - result: StructuredTaskOutput = agent.execute_prompt( + result = agent.execute_prompt( prompt=prompt, output_type=StructuredTaskOutput, ).structured_output diff --git a/py-potato/examples/google/query_reformulation.py b/py-potato/examples/google/query_reformulation.py index d3c0ac56..219b785c 100644 --- a/py-potato/examples/google/query_reformulation.py +++ b/py-potato/examples/google/query_reformulation.py @@ -5,8 +5,13 @@ # 2. Evaluates the quality of the reformulated query. # 3. Returns a score and reason for how well the reformulation improves the query. -from potato_head import Agent, Prompt, Provider, Score, Task, Workflow -from potato_head.google import GeminiSettings, GenerationConfig, GeminiThinkingConfig +from potato_head import Agent, Prompt, Provider, Score, Task, Workflow, AgentResponse +from potato_head.google import ( + GeminiSettings, + GenerationConfig, + GeminiThinkingConfig, + GenerateContentResponse, +) from potato_head.logging import LoggingConfig, LogLevel, RustyLogger RustyLogger.setup_logging(LoggingConfig(log_level=LogLevel.Debug)) @@ -112,7 +117,9 @@ def create_workflow(): agent = Agent(Provider.Gemini) user_query = "How do I find good post-hardcore bands?" - response = agent.execute_prompt(prompt=prompt.bind(user_query=user_query)) + response: AgentResponse[str, GenerateContentResponse] = agent.execute_prompt( + prompt=prompt.bind(user_query=user_query) + ) workflow = create_workflow() diff --git a/py-potato/examples/openai/structured_output.py b/py-potato/examples/openai/structured_output.py index c8e979e0..423980de 100644 --- a/py-potato/examples/openai/structured_output.py +++ b/py-potato/examples/openai/structured_output.py @@ -35,7 +35,7 @@ class StructuredTaskOutput(BaseModel): agent = Agent(Provider.OpenAI) if __name__ == "__main__": - result: StructuredTaskOutput = agent.execute_prompt( + result = agent.execute_prompt( prompt=prompt, output_type=StructuredTaskOutput, ).structured_output diff --git a/py-potato/python/potato_head/_potato_head.pyi b/py-potato/python/potato_head/_potato_head.pyi index 8d3e7236..634c7924 100644 --- a/py-potato/python/potato_head/_potato_head.pyi +++ b/py-potato/python/potato_head/_potato_head.pyi @@ -714,8 +714,9 @@ OutT = TypeVar( "OutT", default=str, ) +RespT = TypeVar("RespT", default=_ResponseType) -class AgentResponse(Generic[OutT]): +class AgentResponse(Generic[OutT, RespT]): """Agent response generic over OutputDataT. The structured_output property returns OutputDataT type. @@ -731,7 +732,7 @@ class AgentResponse(Generic[OutT]): """The ID of the agent response.""" @property - def response(self) -> _ResponseType: + def response(self) -> RespT: """The response of the agent.""" @property @@ -847,7 +848,7 @@ class Agent: self, task: Task, output_type: type[OutT] | None = None, - ) -> AgentResponse[OutT]: + ) -> AgentResponse[OutT, _ResponseType]: """Execute a task. Args: @@ -857,15 +858,16 @@ class Agent: The output type to use for the task. Returns: - AgentResponse[OutT]: - The response from the agent after executing the task. + AgentResponse[OutT, _ResponseType]: + The response from the agent. For type-safe response access, + annotate the return value with the specific response type. """ def execute_prompt( self, prompt: Prompt, output_type: type[OutT] | None = None, - ) -> AgentResponse[OutT]: + ) -> AgentResponse[OutT, _ResponseType]: """Execute a prompt. Args: @@ -875,8 +877,9 @@ class Agent: The output type to use for the task. Returns: - AgentResponse: - The response from the agent after executing the task. + AgentResponse[OutT, _ResponseType]: + The response from the agent. For type-safe response access, + annotate the return value with the specific response type. """ @property @@ -2540,6 +2543,7 @@ class ChatMessage: TypeError: If content format is invalid """ + @property def text(self) -> str: """Get the text content of the first part, if available. Returns an empty string if the first part is not text. @@ -5322,6 +5326,7 @@ class GeminiContent: Role of the message sender (e.g., "user", "model", "function") """ + @property def text(self) -> str: """Get the text content of the first part, if available. Returns an empty string if the first part is not text. @@ -8648,6 +8653,7 @@ class MessageParam: Message role ("user" or "assistant") """ + @property def text(self) -> str: """Get the text content of the first part, if available. Returns an empty string if the first part is not text. diff --git a/py-potato/tests/prompt/test_anthropic.py b/py-potato/tests/prompt/test_anthropic.py index dca60218..11f2e56f 100644 --- a/py-potato/tests/prompt/test_anthropic.py +++ b/py-potato/tests/prompt/test_anthropic.py @@ -30,8 +30,8 @@ def test_prompt(): system_instructions="system_prompt", ) - assert prompt.anthropic_message.text() == "My prompt" - assert prompt.system_instructions[0].text() == "system_prompt" + assert prompt.anthropic_message.text == "My prompt" + assert prompt.system_instructions[0].text == "system_prompt" # test string message prompt = Prompt( @@ -76,8 +76,8 @@ def test_bind_prompt(): ) bound_prompt = prompt.bind("variable1", "world").bind("variable2", "Foo") - assert bound_prompt.messages[0].content[0].text == "Hello world" - assert bound_prompt.messages[1].content[0].text == "This is Foo" + assert bound_prompt.messages[0].text == "Hello world" + assert bound_prompt.messages[1].text == "This is Foo" def test_prompt_structured_output(): diff --git a/py-potato/tests/prompt/test_gemini.py b/py-potato/tests/prompt/test_gemini.py index bb104097..0be6f4ea 100644 --- a/py-potato/tests/prompt/test_gemini.py +++ b/py-potato/tests/prompt/test_gemini.py @@ -79,8 +79,8 @@ def test_prompt(): def test_bind_prompt(): prompt = Prompt( - model="gemini-3.0-flash", provider="gemini", + model="gemini-3.0-flash", messages=[ "Hello ${variable1}", "This is ${variable2}", @@ -90,7 +90,8 @@ def test_bind_prompt(): ) bound_prompt = prompt.bind("variable1", "world").bind("variable2", "Foo") - assert bound_prompt.messages[0].parts[0].data == "Hello world" + messages = bound_prompt.messages + assert messages[0].parts[0].data == "Hello world" assert bound_prompt.messages[1].parts[0].data == "This is Foo" diff --git a/py-potato/tests/prompt/test_openai.py b/py-potato/tests/prompt/test_openai.py index 50998387..bf1ade7a 100644 --- a/py-potato/tests/prompt/test_openai.py +++ b/py-potato/tests/prompt/test_openai.py @@ -98,9 +98,11 @@ def test_bind_prompt(): ], system_instructions="system_prompt", ) - bound_prompt = prompt.bind("variable1", "world").bind("variable2", "Foo") - assert bound_prompt.openai_messages[0].content[0].text == "Hello world" - assert bound_prompt.openai_messages[1].content[0].text == "This is Foo" + kwargs = {"variable1": "world", "variable2": "Foo"} + bound_prompt = prompt.bind(**kwargs) + + assert bound_prompt.openai_messages[0].text == "Hello world" + assert bound_prompt.openai_messages[1].text == "This is Foo" # testing binding with kwargs bound_prompt = prompt.bind(variable1="world") diff --git a/py-potato/tests/test_workflow.py b/py-potato/tests/test_workflow.py index ff05faee..ecfe3c24 100644 --- a/py-potato/tests/test_workflow.py +++ b/py-potato/tests/test_workflow.py @@ -33,23 +33,23 @@ class Prompts: def test_simple_workflow(prompt_step1: Prompt): agent = PydanticAgent( prompt_step1.model_identifier, - system_prompt=prompt_step1.system_instructions[0].text(), + system_prompt=prompt_step1.system_instructions[0].text, ) with agent.override(model=TestModel()): - agent.run_sync(prompt_step1.message.text()) + agent.run_sync(prompt_step1.message.text) def test_simple_dep_workflow(prompt_step1: Prompt, prompt_step2: Prompt): agent = PydanticAgent( prompt_step1.model_identifier, - system_prompt=prompt_step1.system_instructions[0].text(), + system_prompt=prompt_step1.system_instructions[0].text, deps_type=Prompts, ) @agent.system_prompt def get_system_instruction(ctx: RunContext[Prompts]) -> str: - return ctx.deps.prompt_step1.system_instructions[0].text() + return ctx.deps.prompt_step1.system_instructions[0].text with agent.override(model=TestModel()): agent.run_sync( @@ -64,7 +64,7 @@ def get_system_instruction(ctx: RunContext[Prompts]) -> str: def test_binding_workflow(prompt_step1: Prompt, prompt_step2: Prompt): agent = PydanticAgent( "openai:gpt-4o", - system_prompt=prompt_step1.system_instructions[0].text(), + system_prompt=prompt_step1.system_instructions[0].text, deps_type=Prompts, )