Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions py-potato/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@ lints.ruff:
uv run ruff check ${SOURCE_OBJECTS}
lints.mypy:
uv run mypy ${SOURCE_OBJECTS}
lints.pylint:
uv run pylint ${SOURCE_OBJECTS}
lints: lints.ruff lints.pylint lints.mypy
lints.ci: lints.format_check lints.ruff lints.pylint lints.mypy
lints: lints.ruff lints.mypy
lints.ci: lints.format_check lints.ruff lints.mypy

setup.project:
uv sync --all-extras --group dev --group docs
Expand Down
1 change: 0 additions & 1 deletion py-potato/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ dev = [
"ruff >= 0.1.0, < 1.0.0",
"mypy >= 1.0.0, < 2.0.0",
"black >= 24.3.0, < 25.0.0",
"pylint >= 3.0.0, < 4.0.0",
"isort >= 5.13.2, < 6.0.0",
"pydantic>=2.10.5",
"pydantic-ai>=0.0.41",
Expand Down
124 changes: 71 additions & 53 deletions py-potato/python/potato_head/_potato_head.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,9 @@

import datetime
from pathlib import Path
from typing import (
Any,
Dict,
Generic,
List,
Optional,
TypeAlias,
TypeVar,
Union,
overload,
)
from typing import Any, Dict, Generic, List, Optional, TypeAlias, Union, overload

from typing_extensions import TypeVar

###### __potatohead__.main module ######

Expand Down Expand Up @@ -219,7 +211,9 @@ PromptMessage: TypeAlias = Union[
List[Union[str, "ChatMessage", "MessageParam", "GeminiContent"]],
]

class Prompt:
OutputType = TypeVar("OutputType", default=str)

class Prompt(Generic[OutputType]):
"""Prompt for interacting with an LLM API.

The Prompt class handles message parsing, provider-specific formatting, and
Expand All @@ -233,7 +227,7 @@ class Prompt:
provider: Provider | str,
system_instructions: Optional[PromptMessage] = None,
model_settings: Optional[ModelSettings | OpenAIChatSettings | GeminiSettings | AnthropicSettings] = None,
output_type: Optional[Any] = None,
output_type: Optional[OutputType] = None,
) -> None:
"""Initialize a Prompt object.

Expand All @@ -257,8 +251,9 @@ class Prompt:
model_settings (Optional[ModelSettings | OpenAIChatSettings | GeminiSettings | AnthropicSettings]):
Optional model-specific settings (temperature, max_tokens, etc.)
If None, provider default settings will be used
output_type (Optional[Pydantic BaseModel | Score]):
Optional structured output format.The provided format will be parsed into a JSON schema for structured outputs
output_type (Optional[OutputT]):
Optional structured output type.The provided format will be parsed into a JSON schema for structured outputs.
This is typically a pydantic BaseModel.

Raises:
TypeError: If message types are invalid or incompatible with the provider
Expand Down Expand Up @@ -715,30 +710,45 @@ _ResponseType: TypeAlias = Union[
"AnthropicMessageResponse",
]

class AgentResponse:
OutT = TypeVar(
"OutT",
default=str,
)

class AgentResponse(Generic[OutT]):
"""Agent response generic over OutputDataT.

The structured_output property returns OutputDataT type.

Examples:
>>> agent = Agent(provider=Provider.OpenAI)
>>> response: AgentResponse[WeatherData] = agent.execute_prompt(prompt, output_type=WeatherData)
>>> weather: WeatherData = response.structured_output
"""

@property
def id(self) -> str:
"""The ID of the agent response."""

@property
def response(self) -> _ResponseType:
"""The response of the agent. This can be an OpenAIChatResponse, GenerateContentResponse,
or AnthropicMessageResponse depending on the provider used.
"""
"""The response of the agent."""

@property
def token_usage(self) -> Any:
"""Returns the token usage of the agent response if supported"""

@property
def log_probs(self) -> ResponseLogProbs:
"""Returns the log probabilities of the agent response if supported.
This is primarily used for debugging and analysis purposes.
"""
"""Returns the log probabilities of the agent response if supported."""

@property
def structured_output(self) -> Any:
"""Returns the structured output of the agent response if supported."""
def structured_output(self) -> OutT:
"""Returns the structured output of the agent response.

The type is determined by the Agent's OutputType generic parameter
or the output_type argument passed to execute_task/execute_prompt.
"""

def response_text(self) -> str:
"""The response text from the agent if available, otherwise an empty string."""
Expand All @@ -747,7 +757,7 @@ class Task:
def __init__(
self,
agent_id: str,
prompt: Prompt,
prompt: Prompt[OutputType],
id: Optional[str] = None,
dependencies: List[str] = [],
max_retries: int = 3,
Expand All @@ -757,7 +767,7 @@ class Task:
Args:
agent_id (str):
The ID of the agent that will execute the task.
prompt (Prompt):
prompt (Prompt[OutputType]):
The prompt to use for the task.
id (Optional[str]):
The ID of the task. If None, a random uuid7 will be generated.
Expand Down Expand Up @@ -794,6 +804,27 @@ class TaskList:
"""Dictionary of tasks in the TaskList where keys are task IDs and values are Task objects."""

class Agent:
"""Create an Agent object.

Generic over OutputType which determines the structured output type.
By default, OutputType is str if no output_type is specified.

Examples:
>>> # Default agent (OutputType = str)
>>> agent = Agent(provider=Provider.OpenAI)
>>> response = agent.execute_prompt(prompt)
>>> text: str = response.structured_output

>>> # Typed agent with Pydantic model
>>> class WeatherData(BaseModel):
... temperature: float
... condition: str
>>>
>>> agent = Agent(provider=Provider.OpenAI)
>>> response = agent.execute_prompt(prompt, output_type=WeatherData)
>>> weather: WeatherData = response.structured_output
"""

def __init__(
self,
provider: Provider | str,
Expand All @@ -803,58 +834,45 @@ class Agent:

Args:
provider (Provider | str):
The provider to use for the agent. This can be a Provider enum or a string
representing the provider.
The provider to use for the agent.
system_instruction (Optional[PromptMessage]):
The system message to use for the agent. This can be a string, a list of strings,
a Message object, or a list of Message objects. If None, no system message will be used.
This is added to all tasks that the agent executes. If a given task contains it's own
system message, the agent's system message will be prepended to the task's system message.

Example:
```python
agent = Agent(
provider=Provider.OpenAI,
system_instructions="You are a helpful assistant.",
)
```
The system message to use for the agent.
"""

@property
def system_instruction(self) -> List[Any]:
"""The system message to use for the agent. This is a list of Message objects."""
"""The system message to use for the agent."""

def execute_task(
self,
task: Task,
output_type: Optional[Any] = None,
) -> AgentResponse:
output_type: type[OutT] | None = None,
) -> AgentResponse[OutT]:
"""Execute a task.

Args:
task (Task):
The task to execute.
output_type (Optional[Any]):
The output type to use for the task. This can either be a Pydantic `BaseModel` class
or a supported PotatoHead response type such as `Score`.
output_type (Optional[OutT]):
The output type to use for the task.

Returns:
AgentResponse:
AgentResponse[OutT]:
The response from the agent after executing the task.
"""

def execute_prompt(
self,
prompt: Prompt,
output_type: Optional[Any] = None,
) -> AgentResponse:
output_type: type[OutT] | None = None,
) -> AgentResponse[OutT]:
"""Execute a prompt.

Args:
prompt (Prompt):
The prompt to execute.
output_type (Optional[Any]):
The output type to use for the task. This can either be a Pydantic `BaseModel` class
or a supported potato_head response type such as `Score`.
output_type (Optional[OutT]):
The output type to use for the task.

Returns:
AgentResponse:
Expand All @@ -863,7 +881,7 @@ class Agent:

@property
def id(self) -> str:
"""The ID of the agent. This is a random uuid7 that is generated when the agent is created."""
"""The ID of the agent."""

class Workflow:
def __init__(self, name: str) -> None:
Expand Down
60 changes: 0 additions & 60 deletions py-potato/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.