diff --git a/weave/trace/op.py b/weave/trace/op.py index 97c98d452ab6..902d683f6ba8 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -7,7 +7,7 @@ import random import sys import traceback -from collections.abc import Coroutine, Mapping +from collections.abc import Callable, Coroutine, Mapping from dataclasses import dataclass from functools import partial, wraps from types import MethodType @@ -16,8 +16,11 @@ Any, Callable, Optional, + ParamSpec, Protocol, TypedDict, + TypeGuard, + TypeVar, cast, overload, runtime_checkable, @@ -127,8 +130,14 @@ class WeaveKwargs(TypedDict): display_name: str | None +# Type variables for preserving function signatures +P = ParamSpec("P") +R = TypeVar("R") +T = TypeVar("T", bound=Callable[..., Any]) + + @runtime_checkable -class Op(Protocol): +class Op(Protocol[P, R]): """ The interface for Op-ified functions and methods. @@ -166,7 +175,7 @@ class Op(Protocol): _set_on_finish_handler: Callable[[OnFinishHandlerType], None] _on_finish_handler: OnFinishHandlerType | None - __call__: Callable[..., Any] + __call__: Callable[P, R] __self__: Any # `_tracing_enabled` is a runtime-only flag that can be used to disable @@ -181,19 +190,22 @@ class Op(Protocol): tracing_sample_rate: float -def _set_on_input_handler(func: Op, on_input: OnInputHandlerType) -> None: +def _set_on_input_handler(func: Op[P, R], on_input: OnInputHandlerType) -> None: + """Set the on_input handler for an op.""" if func._on_input_handler is not None: raise ValueError("Cannot set on_input_handler multiple times") func._on_input_handler = on_input -def _set_on_output_handler(func: Op, on_output: OnOutputHandlerType) -> None: +def _set_on_output_handler(func: Op[P, R], on_output: OnOutputHandlerType) -> None: + """Set the on_output handler for an op.""" if func._on_output_handler is not None: raise ValueError("Cannot set on_output_handler multiple times") func._on_output_handler = on_output -def _set_on_finish_handler(func: Op, on_finish: OnFinishHandlerType) -> None: +def _set_on_finish_handler(func: Op[P, R], on_finish: OnFinishHandlerType) -> None: + """Set the on_finish handler for an op.""" if func._on_finish_handler is not None: raise ValueError("Cannot set on_finish_handler multiple times") func._on_finish_handler = on_finish @@ -218,9 +230,23 @@ def _is_unbound_method(func: Callable) -> bool: class OpCallError(Exception): ... -def _default_on_input_handler(func: Op, args: tuple, kwargs: dict) -> ProcessedInputs: +def _default_on_input_handler( + func: Op[P, R], args: tuple, kwargs: dict +) -> ProcessedInputs: + """Default handler for processing inputs to an op.""" try: sig = inspect.signature(func) + except ValueError: + # This can happen for built-in functions + return ProcessedInputs( + original_args=args, + original_kwargs=kwargs, + args=args, + kwargs=kwargs, + inputs={}, + ) + + try: inputs = sig.bind(*args, **kwargs).arguments except TypeError as e: raise OpCallError(f"Error calling {func.name}: {e}") @@ -236,8 +262,20 @@ def _default_on_input_handler(func: Op, args: tuple, kwargs: dict) -> ProcessedI def _create_call( - func: Op, *args: Any, __weave: WeaveKwargs | None = None, **kwargs: Any + func: Op[P, R], *args: Any, __weave: WeaveKwargs | None = None, **kwargs: Any ) -> Call: + """ + Create a Call object for an op call. + + Args: + func: The op to call. + *args: Positional arguments to pass to the op. + __weave: Optional WeaveKwargs to pass to the op. + **kwargs: Keyword arguments to pass to the op. + + Returns: + A Call object. + """ client = weave_client_context.require_weave_client() pargs = None @@ -272,12 +310,25 @@ def _create_call( def _execute_op( - __op: Op, + __op: Op[P, R], __call: Call, *args: Any, __should_raise: bool = True, **kwargs: Any, ) -> tuple[Any, Call] | Coroutine[Any, Any, tuple[Any, Call]]: + """ + Execute an op and return the result and the Call object. + + Args: + __op: The op to call. + __call: The Call object to use. + *args: Positional arguments to pass to the op. + __should_raise: Whether to raise exceptions from the op. + **kwargs: Keyword arguments to pass to the op. + + Returns: + A tuple of (result, Call) or a coroutine that returns a tuple of (result, Call). + """ func = __op.resolve_fn client = weave_client_context.require_weave_client() has_finished = False @@ -348,44 +399,33 @@ async def _call_async() -> tuple[Any, Call]: def call( - op: Op, + op: Op[P, R], *args: Any, __weave: WeaveKwargs | None = None, __should_raise: bool = False, **kwargs: Any, ) -> tuple[Any, Call] | Coroutine[Any, Any, tuple[Any, Call]]: """ - Executes the op and returns both the result and a Call representing the execution. - - This function will never raise. Any errors are captured in the Call object. + Call an op and return the result and the Call object. - This method is automatically bound to any function decorated with `@weave.op`, - allowing for usage like: + This is a lower-level function that most users won't need to use directly. + It's useful if you want to get the Call object for an op call. - ```python - @weave.op - def add(a: int, b: int) -> int: - return a + b + Args: + op: The op to call. + *args: Positional arguments to pass to the op. + __weave: Optional WeaveKwargs to pass to the op. + __should_raise: Whether to raise exceptions from the op. + **kwargs: Keyword arguments to pass to the op. - result, call = add.call(1, 2) - ``` + Returns: + A tuple of (result, Call) or a coroutine that returns a tuple of (result, Call). """ if inspect.iscoroutinefunction(op.resolve_fn): return _do_call_async( - op, - *args, - __weave=__weave, - __should_raise=__should_raise, - **kwargs, - ) - else: - return _do_call( - op, - *args, - __weave=__weave, - __should_raise=__should_raise, - **kwargs, + op, *args, __weave=__weave, __should_raise=__should_raise, **kwargs ) + return _do_call(op, *args, __weave=__weave, __should_raise=__should_raise, **kwargs) def _placeholder_call() -> Call: @@ -402,12 +442,25 @@ def _placeholder_call() -> Call: def _do_call( - op: Op, + op: Op[P, R], *args: Any, __weave: WeaveKwargs | None = None, __should_raise: bool = False, **kwargs: Any, ) -> tuple[Any, Call]: + """ + Execute an op synchronously and return the result and the Call object. + + Args: + op: The op to call. + *args: Positional arguments to pass to the op. + __weave: Optional WeaveKwargs to pass to the op. + __should_raise: Whether to raise exceptions from the op. + **kwargs: Keyword arguments to pass to the op. + + Returns: + A tuple of (result, Call). + """ func = op.resolve_fn call = _placeholder_call() @@ -474,12 +527,25 @@ def _do_call( async def _do_call_async( - op: Op, + op: Op[P, R], *args: Any, __weave: WeaveKwargs | None = None, __should_raise: bool = False, **kwargs: Any, ) -> tuple[Any, Call]: + """ + Execute an op asynchronously and return the result and the Call object. + + Args: + op: The op to call. + *args: Positional arguments to pass to the op. + __weave: Optional WeaveKwargs to pass to the op. + __should_raise: Whether to raise exceptions from the op. + **kwargs: Keyword arguments to pass to the op. + + Returns: + A tuple of (result, Call). + """ func = op.resolve_fn call = _placeholder_call() @@ -536,20 +602,14 @@ async def _do_call_async( return res, call -def calls(op: Op) -> CallsIter: +def calls(op: Op[P, R]) -> CallsIter: """ - Get an iterator over all calls to this op. + Get an iterator over all calls to an op. - This method is automatically bound to any function decorated with `@weave.op`, - allowing for usage like: + Example usage: ```python - @weave.op - def add(a: int, b: int) -> int: - return a + b - - calls = add.calls() - for call in calls: + for call in my_op.calls(): print(call) ``` """ @@ -562,15 +622,22 @@ def add(a: int, b: int) -> int: PostprocessOutputFunc = Callable[..., Any] +# Type alias for the decorated function type +OpCallable = Callable[P, R] +# Type alias for the decorator type +OpDecorator = Callable[[Callable[P, R]], Callable[P, R]] + + @overload def op( - func: Callable, + func: Callable[P, R], *, name: str | None = None, call_display_name: str | CallDisplayNameFunc | None = None, postprocess_inputs: PostprocessInputsFunc | None = None, postprocess_output: PostprocessOutputFunc | None = None, -) -> Op: ... + tracing_sample_rate: float = 1.0, +) -> Op[P, R]: ... @overload @@ -580,18 +647,19 @@ def op( call_display_name: str | CallDisplayNameFunc | None = None, postprocess_inputs: PostprocessInputsFunc | None = None, postprocess_output: PostprocessOutputFunc | None = None, -) -> Callable[[Callable], Op]: ... + tracing_sample_rate: float = 1.0, +) -> Callable[[Callable[P, R]], Op[P, R]]: ... def op( - func: Callable | None = None, + func: Callable[P, R] | None = None, *, name: str | None = None, - call_display_name: str | CallDisplayNameFunc | None = None, - postprocess_inputs: PostprocessInputsFunc | None = None, - postprocess_output: PostprocessOutputFunc | None = None, + call_display_name: str | Callable[[Call], str] | None = None, + postprocess_inputs: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + postprocess_output: Callable[..., Any] | None = None, tracing_sample_rate: float = 1.0, -) -> Callable[[Callable], Op] | Op: +) -> Callable[[Callable[P, R]], Op[P, R]] | Op[P, R]: """ A decorator to weave op-ify a function or method. Works for both sync and async. @@ -650,28 +718,28 @@ async def extract(): if not 0 <= tracing_sample_rate <= 1: raise ValueError("tracing_sample_rate must be between 0 and 1") - def op_deco(func: Callable) -> Op: + def op_deco(func: Callable[P, R]) -> Op[P, R]: # Check function type is_method = _is_unbound_method(func) is_async = inspect.iscoroutinefunction(func) - def create_wrapper(func: Callable) -> Op: + def create_wrapper(func: Callable[P, R]) -> Op[P, R]: if is_async: @wraps(func) - async def wrapper(*args: Any, **kwargs: Any) -> Any: # pyright: ignore[reportRedeclaration] + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # pyright: ignore[reportRedeclaration] res, _ = await _do_call_async( - cast(Op, wrapper), *args, __should_raise=True, **kwargs + cast(Op[P, R], wrapper), *args, __should_raise=True, **kwargs ) - return res + return cast(R, res) else: @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: res, _ = _do_call( - cast(Op, wrapper), *args, __should_raise=True, **kwargs + cast(Op[P, R], wrapper), *args, __should_raise=True, **kwargs ) - return res + return cast(R, res) # Tack these helpers on to our wrapper wrapper.resolve_fn = func # type: ignore @@ -717,36 +785,33 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: ) wrapper.call_display_name = call_display_name # type: ignore - return cast(Op, wrapper) + return cast(Op[P, R], wrapper) return create_wrapper(func) if func is None: - return op_deco + return cast(Callable[[Callable[P, R]], Op[P, R]], op_deco) return op_deco(func) -def get_captured_code(op: Op) -> str: - """Get the captured code of the op. - - This only works when you get an op back from a ref. The pattern is: - - ref = weave.publish(func) - op = ref.get() - captured_code = op.get_captured_code() - """ +def get_captured_code(op: Op[P, R]) -> str: + """Get the source code of an op.""" try: - return op.art.path_contents["obj.py"].decode() # type: ignore - except Exception: - raise RuntimeError( - "Failed to get captured code for op (this only works when you get an op back from a ref)." - ) + return inspect.getsource(op.resolve_fn) + except (TypeError, OSError): + return "" def maybe_bind_method(func: Callable, self: Any = None) -> Callable | MethodType: - """Bind a function to any object (even if it's not a class) + """ + Bind a function to an object if it's an unbound method. - If self is None, return the function as is. + Args: + func: The function to bind. + self: The object to bind the function to. + + Returns: + The bound method or the original function. """ if (sig := inspect.signature(func)) and sig.parameters.get("self"): if inspect.ismethod(func) and id(func.__self__) != id(self): @@ -755,49 +820,59 @@ def maybe_bind_method(func: Callable, self: Any = None) -> Callable | MethodType return func -def maybe_unbind_method(oplike: Op | MethodType | partial) -> Op: - """Unbind an Op-like method or partial to a plain Op function. +def maybe_unbind_method(oplike: Op[P, R] | MethodType | partial) -> Op[P, R]: + """ + Unbind a method from its object if it's a bound method. + + Args: + oplike: The method to unbind. - For: - - methods, remove set `self` param - - partials, remove any preset params + Returns: + The unbound method or the original function. """ if isinstance(oplike, MethodType): + # For methods, we need to get the underlying function op = oplike.__func__ - elif isinstance(oplike, partial): # Handle cases op is defined as + elif isinstance(oplike, partial): + # For partials, we need to get the underlying function op = oplike.func else: op = oplike - return cast(Op, op) + return cast(Op[P, R], op) -def is_op(obj: Any) -> bool: - if sys.version_info < (3, 12): - return isinstance(obj, Op) +def is_op(obj: Any) -> TypeGuard[Op]: + """ + Check if an object is an op. - return all(hasattr(obj, attr) for attr in Op.__annotations__) + Args: + obj: The object to check. + Returns: + True if the object is an op, False otherwise. + """ + if sys.version_info < (3, 12): + return isinstance(obj, Op) + else: + # In Python 3.12, isinstance(obj, Protocol) is not supported + # See https://peps.python.org/pep-0544/#using-protocols + return hasattr(obj, "resolve_fn") and callable(obj) -def as_op(fn: Callable) -> Op: - """Given a @weave.op() decorated function, return its Op. - @weave.op() decorated functions are instances of Op already, so this - function should be a no-op at runtime. But you can use it to satisfy type checkers - if you need to access OpDef attributes in a typesafe way. +def as_op(fn: Callable[P, R]) -> Op[P, R]: + """ + Convert a function to an op if it's not already one. Args: - fn: A weave.op() decorated function. + fn: The function to convert. Returns: - The Op of the function. + The function as an op. """ - if not is_op(fn): - raise ValueError("fn must be a weave.op() decorated function") - - # The unbinding is necessary for methods because `MethodType` is applied after the - # func is decorated into an Op. - return maybe_unbind_method(cast(Op, fn)) + if is_op(fn): + return cast(Op[P, R], fn) + return op(fn) __docspec__ = [call, calls]