Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/potato_type/src/anthropic/v1/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1225,6 +1225,7 @@ impl MessageParam {
}

// Return the text content from the first content part that is text
#[getter]
pub fn text(&self) -> String {
self.content
.iter()
Expand Down
1 change: 1 addition & 0 deletions crates/potato_type/src/google/v1/generate/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1382,6 +1382,7 @@ impl GeminiContent {
}

// helper method for returning firs text part content
#[getter]
pub fn text(&self) -> String {
self.parts
.iter()
Expand Down
1 change: 1 addition & 0 deletions crates/potato_type/src/openai/v1/chat/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ impl ChatMessage {
}

// Return the text content from the first content part that is text
#[getter]
pub fn text(&self) -> String {
self.content
.iter()
Expand Down
2 changes: 1 addition & 1 deletion py-potato/examples/anthropic/structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class StructuredTaskOutput(BaseModel):
agent = Agent(Provider.Anthropic)

if __name__ == "__main__":
result: StructuredTaskOutput = agent.execute_prompt(
result = agent.execute_prompt(
prompt=prompt,
output_type=StructuredTaskOutput,
).structured_output
Expand Down
13 changes: 10 additions & 3 deletions py-potato/examples/google/query_reformulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@
# 2. Evaluates the quality of the reformulated query.
# 3. Returns a score and reason for how well the reformulation improves the query.

from potato_head import Agent, Prompt, Provider, Score, Task, Workflow
from potato_head.google import GeminiSettings, GenerationConfig, GeminiThinkingConfig
from potato_head import Agent, Prompt, Provider, Score, Task, Workflow, AgentResponse
from potato_head.google import (
GeminiSettings,
GenerationConfig,
GeminiThinkingConfig,
GenerateContentResponse,
)
from potato_head.logging import LoggingConfig, LogLevel, RustyLogger

RustyLogger.setup_logging(LoggingConfig(log_level=LogLevel.Debug))
Expand Down Expand Up @@ -112,7 +117,9 @@ def create_workflow():

agent = Agent(Provider.Gemini)
user_query = "How do I find good post-hardcore bands?"
response = agent.execute_prompt(prompt=prompt.bind(user_query=user_query))
response: AgentResponse[str, GenerateContentResponse] = agent.execute_prompt(
prompt=prompt.bind(user_query=user_query)
)

workflow = create_workflow()

Expand Down
2 changes: 1 addition & 1 deletion py-potato/examples/openai/structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class StructuredTaskOutput(BaseModel):
agent = Agent(Provider.OpenAI)

if __name__ == "__main__":
result: StructuredTaskOutput = agent.execute_prompt(
result = agent.execute_prompt(
prompt=prompt,
output_type=StructuredTaskOutput,
).structured_output
Expand Down
22 changes: 14 additions & 8 deletions py-potato/python/potato_head/_potato_head.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,9 @@ OutT = TypeVar(
"OutT",
default=str,
)
RespT = TypeVar("RespT", default=_ResponseType)

class AgentResponse(Generic[OutT]):
class AgentResponse(Generic[OutT, RespT]):
"""Agent response generic over OutputDataT.

The structured_output property returns OutputDataT type.
Expand All @@ -731,7 +732,7 @@ class AgentResponse(Generic[OutT]):
"""The ID of the agent response."""

@property
def response(self) -> _ResponseType:
def response(self) -> RespT:
"""The response of the agent."""

@property
Expand Down Expand Up @@ -847,7 +848,7 @@ class Agent:
self,
task: Task,
output_type: type[OutT] | None = None,
) -> AgentResponse[OutT]:
) -> AgentResponse[OutT, _ResponseType]:
"""Execute a task.

Args:
Expand All @@ -857,15 +858,16 @@ class Agent:
The output type to use for the task.

Returns:
AgentResponse[OutT]:
The response from the agent after executing the task.
AgentResponse[OutT, _ResponseType]:
The response from the agent. For type-safe response access,
annotate the return value with the specific response type.
"""

def execute_prompt(
self,
prompt: Prompt,
output_type: type[OutT] | None = None,
) -> AgentResponse[OutT]:
) -> AgentResponse[OutT, _ResponseType]:
"""Execute a prompt.

Args:
Expand All @@ -875,8 +877,9 @@ class Agent:
The output type to use for the task.

Returns:
AgentResponse:
The response from the agent after executing the task.
AgentResponse[OutT, _ResponseType]:
The response from the agent. For type-safe response access,
annotate the return value with the specific response type.
"""

@property
Expand Down Expand Up @@ -2540,6 +2543,7 @@ class ChatMessage:
TypeError: If content format is invalid
"""

@property
def text(self) -> str:
"""Get the text content of the first part, if available. Returns
an empty string if the first part is not text.
Expand Down Expand Up @@ -5322,6 +5326,7 @@ class GeminiContent:
Role of the message sender (e.g., "user", "model", "function")
"""

@property
def text(self) -> str:
"""Get the text content of the first part, if available. Returns
an empty string if the first part is not text.
Expand Down Expand Up @@ -8648,6 +8653,7 @@ class MessageParam:
Message role ("user" or "assistant")
"""

@property
def text(self) -> str:
"""Get the text content of the first part, if available. Returns
an empty string if the first part is not text.
Expand Down
8 changes: 4 additions & 4 deletions py-potato/tests/prompt/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def test_prompt():
system_instructions="system_prompt",
)

assert prompt.anthropic_message.text() == "My prompt"
assert prompt.system_instructions[0].text() == "system_prompt"
assert prompt.anthropic_message.text == "My prompt"
assert prompt.system_instructions[0].text == "system_prompt"

# test string message
prompt = Prompt(
Expand Down Expand Up @@ -76,8 +76,8 @@ def test_bind_prompt():
)

bound_prompt = prompt.bind("variable1", "world").bind("variable2", "Foo")
assert bound_prompt.messages[0].content[0].text == "Hello world"
assert bound_prompt.messages[1].content[0].text == "This is Foo"
assert bound_prompt.messages[0].text == "Hello world"
assert bound_prompt.messages[1].text == "This is Foo"


def test_prompt_structured_output():
Expand Down
5 changes: 3 additions & 2 deletions py-potato/tests/prompt/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def test_prompt():

def test_bind_prompt():
prompt = Prompt(
model="gemini-3.0-flash",
provider="gemini",
model="gemini-3.0-flash",
messages=[
"Hello ${variable1}",
"This is ${variable2}",
Expand All @@ -90,7 +90,8 @@ def test_bind_prompt():
)

bound_prompt = prompt.bind("variable1", "world").bind("variable2", "Foo")
assert bound_prompt.messages[0].parts[0].data == "Hello world"
messages = bound_prompt.messages
assert messages[0].parts[0].data == "Hello world"
assert bound_prompt.messages[1].parts[0].data == "This is Foo"


Expand Down
8 changes: 5 additions & 3 deletions py-potato/tests/prompt/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,11 @@ def test_bind_prompt():
],
system_instructions="system_prompt",
)
bound_prompt = prompt.bind("variable1", "world").bind("variable2", "Foo")
assert bound_prompt.openai_messages[0].content[0].text == "Hello world"
assert bound_prompt.openai_messages[1].content[0].text == "This is Foo"
kwargs = {"variable1": "world", "variable2": "Foo"}
bound_prompt = prompt.bind(**kwargs)

assert bound_prompt.openai_messages[0].text == "Hello world"
assert bound_prompt.openai_messages[1].text == "This is Foo"

# testing binding with kwargs
bound_prompt = prompt.bind(variable1="world")
Expand Down
10 changes: 5 additions & 5 deletions py-potato/tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,23 @@ class Prompts:
def test_simple_workflow(prompt_step1: Prompt):
agent = PydanticAgent(
prompt_step1.model_identifier,
system_prompt=prompt_step1.system_instructions[0].text(),
system_prompt=prompt_step1.system_instructions[0].text,
)

with agent.override(model=TestModel()):
agent.run_sync(prompt_step1.message.text())
agent.run_sync(prompt_step1.message.text)


def test_simple_dep_workflow(prompt_step1: Prompt, prompt_step2: Prompt):
agent = PydanticAgent(
prompt_step1.model_identifier,
system_prompt=prompt_step1.system_instructions[0].text(),
system_prompt=prompt_step1.system_instructions[0].text,
deps_type=Prompts,
)

@agent.system_prompt
def get_system_instruction(ctx: RunContext[Prompts]) -> str:
return ctx.deps.prompt_step1.system_instructions[0].text()
return ctx.deps.prompt_step1.system_instructions[0].text

with agent.override(model=TestModel()):
agent.run_sync(
Expand All @@ -64,7 +64,7 @@ def get_system_instruction(ctx: RunContext[Prompts]) -> str:
def test_binding_workflow(prompt_step1: Prompt, prompt_step2: Prompt):
agent = PydanticAgent(
"openai:gpt-4o",
system_prompt=prompt_step1.system_instructions[0].text(),
system_prompt=prompt_step1.system_instructions[0].text,
deps_type=Prompts,
)

Expand Down