From 062e76916f21763030d47be13c1b6582b342add3 Mon Sep 17 00:00:00 2001 From: zbirenbaum Date: Fri, 7 Mar 2025 12:53:09 -0800 Subject: [PATCH 1/2] fix(weave): @weave.op decorator destroys type information --- weave/trace/op.py | 62 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 19 deletions(-) diff --git a/weave/trace/op.py b/weave/trace/op.py index 97c98d452ab6..fc07aa7777d0 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -7,7 +7,7 @@ import random import sys import traceback -from collections.abc import Coroutine, Mapping +from collections.abc import Callable, Coroutine, Mapping from dataclasses import dataclass from functools import partial, wraps from types import MethodType @@ -21,6 +21,11 @@ cast, overload, runtime_checkable, + TypeVar, + cast, + ParamSpec, + TypeGuard, + Generic, ) from weave.trace import box, settings @@ -562,15 +567,25 @@ def add(a: int, b: int) -> int: PostprocessOutputFunc = Callable[..., Any] +# Type variables for preserving function signatures +# Captures the function signature of the decorated function +P = ParamSpec("P") +# Captures the return type of the decorated function +R = TypeVar("R") +# Captures the type of the decorated function +T = TypeVar("T", bound=Callable[..., Any]) + + @overload def op( - func: Callable, + func: Callable[P, R], *, name: str | None = None, call_display_name: str | CallDisplayNameFunc | None = None, postprocess_inputs: PostprocessInputsFunc | None = None, postprocess_output: PostprocessOutputFunc | None = None, -) -> Op: ... + tracing_sample_rate: float = 1.0, +) -> Callable[P, R]: ... @overload @@ -580,18 +595,19 @@ def op( call_display_name: str | CallDisplayNameFunc | None = None, postprocess_inputs: PostprocessInputsFunc | None = None, postprocess_output: PostprocessOutputFunc | None = None, -) -> Callable[[Callable], Op]: ... + tracing_sample_rate: float = 1.0, +) -> Callable[[Callable[P, R]], Callable[P, R]]: ... def op( - func: Callable | None = None, + func: Callable[P, R] | None = None, *, name: str | None = None, - call_display_name: str | CallDisplayNameFunc | None = None, - postprocess_inputs: PostprocessInputsFunc | None = None, - postprocess_output: PostprocessOutputFunc | None = None, + call_display_name: str | Callable[["Call"], str] | None = None, + postprocess_inputs: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + postprocess_output: Callable[..., Any] | None = None, tracing_sample_rate: float = 1.0, -) -> Callable[[Callable], Op] | Op: +) -> Callable[[Callable[P, R]], Callable[P, R]] | Callable[P, R]: """ A decorator to weave op-ify a function or method. Works for both sync and async. @@ -619,8 +635,16 @@ def op( tracing_sample_rate (float): The sampling rate for tracing this function. Defaults to 1.0 (always trace). Returns: - Union[Callable[[Any], Op], Op]: If called without arguments, returns a decorator. - If called with a function, returns the decorated function as an Op. + Union[Callable[[Callable[P, R]], Callable[P, R]], Callable[P, R]]: + P is the ParamSpec that captures all the parameters of the original function + R is the TypeVar that captures the return type of the original function + Callable[[Callable[P, R]], Callable[P, R]]: + The decorator is directly applied to a function (without parentheses). + Immediately processes the function and returns the wrapped function with the same signature as the original. + + Callable[[Callable[P, R]], Callable[P, R]]: + The decorator is called with arguments. + The decorator returns another decorator function that will later be applied to the target function. Raises: ValueError: If the decorated object is not a function or method. @@ -650,28 +674,28 @@ async def extract(): if not 0 <= tracing_sample_rate <= 1: raise ValueError("tracing_sample_rate must be between 0 and 1") - def op_deco(func: Callable) -> Op: + def op_deco(func: Callable[P, R]) -> Callable[P, R]: # Check function type is_method = _is_unbound_method(func) is_async = inspect.iscoroutinefunction(func) - def create_wrapper(func: Callable) -> Op: + def create_wrapper(func: Callable[P, R]) -> Callable[P, R]: if is_async: @wraps(func) - async def wrapper(*args: Any, **kwargs: Any) -> Any: # pyright: ignore[reportRedeclaration] + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # pyright: ignore[reportRedeclaration] res, _ = await _do_call_async( cast(Op, wrapper), *args, __should_raise=True, **kwargs ) - return res + return cast(R, res) else: @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: res, _ = _do_call( cast(Op, wrapper), *args, __should_raise=True, **kwargs ) - return res + return cast(R, res) # Tack these helpers on to our wrapper wrapper.resolve_fn = func # type: ignore @@ -717,12 +741,12 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: ) wrapper.call_display_name = call_display_name # type: ignore - return cast(Op, wrapper) + return cast(Callable[P, R], wrapper) return create_wrapper(func) if func is None: - return op_deco + return cast(Callable[[Callable[P, R]], Callable[P, R]], op_deco) return op_deco(func) From 2406867a075fcdff120a742961bc7f0ce49f2691 Mon Sep 17 00:00:00 2001 From: zbirenbaum Date: Tue, 11 Mar 2025 11:27:28 -0700 Subject: [PATCH 2/2] refactor op to take parameters and return type --- weave/trace/op.py | 285 +++++++++++++++++++++++++++------------------- 1 file changed, 168 insertions(+), 117 deletions(-) diff --git a/weave/trace/op.py b/weave/trace/op.py index fc07aa7777d0..902d683f6ba8 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -16,16 +16,14 @@ Any, Callable, Optional, + ParamSpec, Protocol, TypedDict, + TypeGuard, + TypeVar, cast, overload, runtime_checkable, - TypeVar, - cast, - ParamSpec, - TypeGuard, - Generic, ) from weave.trace import box, settings @@ -132,8 +130,14 @@ class WeaveKwargs(TypedDict): display_name: str | None +# Type variables for preserving function signatures +P = ParamSpec("P") +R = TypeVar("R") +T = TypeVar("T", bound=Callable[..., Any]) + + @runtime_checkable -class Op(Protocol): +class Op(Protocol[P, R]): """ The interface for Op-ified functions and methods. @@ -171,7 +175,7 @@ class Op(Protocol): _set_on_finish_handler: Callable[[OnFinishHandlerType], None] _on_finish_handler: OnFinishHandlerType | None - __call__: Callable[..., Any] + __call__: Callable[P, R] __self__: Any # `_tracing_enabled` is a runtime-only flag that can be used to disable @@ -186,19 +190,22 @@ class Op(Protocol): tracing_sample_rate: float -def _set_on_input_handler(func: Op, on_input: OnInputHandlerType) -> None: +def _set_on_input_handler(func: Op[P, R], on_input: OnInputHandlerType) -> None: + """Set the on_input handler for an op.""" if func._on_input_handler is not None: raise ValueError("Cannot set on_input_handler multiple times") func._on_input_handler = on_input -def _set_on_output_handler(func: Op, on_output: OnOutputHandlerType) -> None: +def _set_on_output_handler(func: Op[P, R], on_output: OnOutputHandlerType) -> None: + """Set the on_output handler for an op.""" if func._on_output_handler is not None: raise ValueError("Cannot set on_output_handler multiple times") func._on_output_handler = on_output -def _set_on_finish_handler(func: Op, on_finish: OnFinishHandlerType) -> None: +def _set_on_finish_handler(func: Op[P, R], on_finish: OnFinishHandlerType) -> None: + """Set the on_finish handler for an op.""" if func._on_finish_handler is not None: raise ValueError("Cannot set on_finish_handler multiple times") func._on_finish_handler = on_finish @@ -223,9 +230,23 @@ def _is_unbound_method(func: Callable) -> bool: class OpCallError(Exception): ... -def _default_on_input_handler(func: Op, args: tuple, kwargs: dict) -> ProcessedInputs: +def _default_on_input_handler( + func: Op[P, R], args: tuple, kwargs: dict +) -> ProcessedInputs: + """Default handler for processing inputs to an op.""" try: sig = inspect.signature(func) + except ValueError: + # This can happen for built-in functions + return ProcessedInputs( + original_args=args, + original_kwargs=kwargs, + args=args, + kwargs=kwargs, + inputs={}, + ) + + try: inputs = sig.bind(*args, **kwargs).arguments except TypeError as e: raise OpCallError(f"Error calling {func.name}: {e}") @@ -241,8 +262,20 @@ def _default_on_input_handler(func: Op, args: tuple, kwargs: dict) -> ProcessedI def _create_call( - func: Op, *args: Any, __weave: WeaveKwargs | None = None, **kwargs: Any + func: Op[P, R], *args: Any, __weave: WeaveKwargs | None = None, **kwargs: Any ) -> Call: + """ + Create a Call object for an op call. + + Args: + func: The op to call. + *args: Positional arguments to pass to the op. + __weave: Optional WeaveKwargs to pass to the op. + **kwargs: Keyword arguments to pass to the op. + + Returns: + A Call object. + """ client = weave_client_context.require_weave_client() pargs = None @@ -277,12 +310,25 @@ def _create_call( def _execute_op( - __op: Op, + __op: Op[P, R], __call: Call, *args: Any, __should_raise: bool = True, **kwargs: Any, ) -> tuple[Any, Call] | Coroutine[Any, Any, tuple[Any, Call]]: + """ + Execute an op and return the result and the Call object. + + Args: + __op: The op to call. + __call: The Call object to use. + *args: Positional arguments to pass to the op. + __should_raise: Whether to raise exceptions from the op. + **kwargs: Keyword arguments to pass to the op. + + Returns: + A tuple of (result, Call) or a coroutine that returns a tuple of (result, Call). + """ func = __op.resolve_fn client = weave_client_context.require_weave_client() has_finished = False @@ -353,44 +399,33 @@ async def _call_async() -> tuple[Any, Call]: def call( - op: Op, + op: Op[P, R], *args: Any, __weave: WeaveKwargs | None = None, __should_raise: bool = False, **kwargs: Any, ) -> tuple[Any, Call] | Coroutine[Any, Any, tuple[Any, Call]]: """ - Executes the op and returns both the result and a Call representing the execution. + Call an op and return the result and the Call object. - This function will never raise. Any errors are captured in the Call object. + This is a lower-level function that most users won't need to use directly. + It's useful if you want to get the Call object for an op call. - This method is automatically bound to any function decorated with `@weave.op`, - allowing for usage like: - - ```python - @weave.op - def add(a: int, b: int) -> int: - return a + b + Args: + op: The op to call. + *args: Positional arguments to pass to the op. + __weave: Optional WeaveKwargs to pass to the op. + __should_raise: Whether to raise exceptions from the op. + **kwargs: Keyword arguments to pass to the op. - result, call = add.call(1, 2) - ``` + Returns: + A tuple of (result, Call) or a coroutine that returns a tuple of (result, Call). """ if inspect.iscoroutinefunction(op.resolve_fn): return _do_call_async( - op, - *args, - __weave=__weave, - __should_raise=__should_raise, - **kwargs, - ) - else: - return _do_call( - op, - *args, - __weave=__weave, - __should_raise=__should_raise, - **kwargs, + op, *args, __weave=__weave, __should_raise=__should_raise, **kwargs ) + return _do_call(op, *args, __weave=__weave, __should_raise=__should_raise, **kwargs) def _placeholder_call() -> Call: @@ -407,12 +442,25 @@ def _placeholder_call() -> Call: def _do_call( - op: Op, + op: Op[P, R], *args: Any, __weave: WeaveKwargs | None = None, __should_raise: bool = False, **kwargs: Any, ) -> tuple[Any, Call]: + """ + Execute an op synchronously and return the result and the Call object. + + Args: + op: The op to call. + *args: Positional arguments to pass to the op. + __weave: Optional WeaveKwargs to pass to the op. + __should_raise: Whether to raise exceptions from the op. + **kwargs: Keyword arguments to pass to the op. + + Returns: + A tuple of (result, Call). + """ func = op.resolve_fn call = _placeholder_call() @@ -479,12 +527,25 @@ def _do_call( async def _do_call_async( - op: Op, + op: Op[P, R], *args: Any, __weave: WeaveKwargs | None = None, __should_raise: bool = False, **kwargs: Any, ) -> tuple[Any, Call]: + """ + Execute an op asynchronously and return the result and the Call object. + + Args: + op: The op to call. + *args: Positional arguments to pass to the op. + __weave: Optional WeaveKwargs to pass to the op. + __should_raise: Whether to raise exceptions from the op. + **kwargs: Keyword arguments to pass to the op. + + Returns: + A tuple of (result, Call). + """ func = op.resolve_fn call = _placeholder_call() @@ -541,20 +602,14 @@ async def _do_call_async( return res, call -def calls(op: Op) -> CallsIter: +def calls(op: Op[P, R]) -> CallsIter: """ - Get an iterator over all calls to this op. + Get an iterator over all calls to an op. - This method is automatically bound to any function decorated with `@weave.op`, - allowing for usage like: + Example usage: ```python - @weave.op - def add(a: int, b: int) -> int: - return a + b - - calls = add.calls() - for call in calls: + for call in my_op.calls(): print(call) ``` """ @@ -567,13 +622,10 @@ def add(a: int, b: int) -> int: PostprocessOutputFunc = Callable[..., Any] -# Type variables for preserving function signatures -# Captures the function signature of the decorated function -P = ParamSpec("P") -# Captures the return type of the decorated function -R = TypeVar("R") -# Captures the type of the decorated function -T = TypeVar("T", bound=Callable[..., Any]) +# Type alias for the decorated function type +OpCallable = Callable[P, R] +# Type alias for the decorator type +OpDecorator = Callable[[Callable[P, R]], Callable[P, R]] @overload @@ -585,7 +637,7 @@ def op( postprocess_inputs: PostprocessInputsFunc | None = None, postprocess_output: PostprocessOutputFunc | None = None, tracing_sample_rate: float = 1.0, -) -> Callable[P, R]: ... +) -> Op[P, R]: ... @overload @@ -596,18 +648,18 @@ def op( postprocess_inputs: PostprocessInputsFunc | None = None, postprocess_output: PostprocessOutputFunc | None = None, tracing_sample_rate: float = 1.0, -) -> Callable[[Callable[P, R]], Callable[P, R]]: ... +) -> Callable[[Callable[P, R]], Op[P, R]]: ... def op( func: Callable[P, R] | None = None, *, name: str | None = None, - call_display_name: str | Callable[["Call"], str] | None = None, + call_display_name: str | Callable[[Call], str] | None = None, postprocess_inputs: Callable[[dict[str, Any]], dict[str, Any]] | None = None, postprocess_output: Callable[..., Any] | None = None, tracing_sample_rate: float = 1.0, -) -> Callable[[Callable[P, R]], Callable[P, R]] | Callable[P, R]: +) -> Callable[[Callable[P, R]], Op[P, R]] | Op[P, R]: """ A decorator to weave op-ify a function or method. Works for both sync and async. @@ -635,16 +687,8 @@ def op( tracing_sample_rate (float): The sampling rate for tracing this function. Defaults to 1.0 (always trace). Returns: - Union[Callable[[Callable[P, R]], Callable[P, R]], Callable[P, R]]: - P is the ParamSpec that captures all the parameters of the original function - R is the TypeVar that captures the return type of the original function - Callable[[Callable[P, R]], Callable[P, R]]: - The decorator is directly applied to a function (without parentheses). - Immediately processes the function and returns the wrapped function with the same signature as the original. - - Callable[[Callable[P, R]], Callable[P, R]]: - The decorator is called with arguments. - The decorator returns another decorator function that will later be applied to the target function. + Union[Callable[[Any], Op], Op]: If called without arguments, returns a decorator. + If called with a function, returns the decorated function as an Op. Raises: ValueError: If the decorated object is not a function or method. @@ -674,18 +718,18 @@ async def extract(): if not 0 <= tracing_sample_rate <= 1: raise ValueError("tracing_sample_rate must be between 0 and 1") - def op_deco(func: Callable[P, R]) -> Callable[P, R]: + def op_deco(func: Callable[P, R]) -> Op[P, R]: # Check function type is_method = _is_unbound_method(func) is_async = inspect.iscoroutinefunction(func) - def create_wrapper(func: Callable[P, R]) -> Callable[P, R]: + def create_wrapper(func: Callable[P, R]) -> Op[P, R]: if is_async: @wraps(func) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # pyright: ignore[reportRedeclaration] res, _ = await _do_call_async( - cast(Op, wrapper), *args, __should_raise=True, **kwargs + cast(Op[P, R], wrapper), *args, __should_raise=True, **kwargs ) return cast(R, res) else: @@ -693,7 +737,7 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # pyright: ignore[re @wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: res, _ = _do_call( - cast(Op, wrapper), *args, __should_raise=True, **kwargs + cast(Op[P, R], wrapper), *args, __should_raise=True, **kwargs ) return cast(R, res) @@ -741,36 +785,33 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: ) wrapper.call_display_name = call_display_name # type: ignore - return cast(Callable[P, R], wrapper) + return cast(Op[P, R], wrapper) return create_wrapper(func) if func is None: - return cast(Callable[[Callable[P, R]], Callable[P, R]], op_deco) + return cast(Callable[[Callable[P, R]], Op[P, R]], op_deco) return op_deco(func) -def get_captured_code(op: Op) -> str: - """Get the captured code of the op. - - This only works when you get an op back from a ref. The pattern is: - - ref = weave.publish(func) - op = ref.get() - captured_code = op.get_captured_code() - """ +def get_captured_code(op: Op[P, R]) -> str: + """Get the source code of an op.""" try: - return op.art.path_contents["obj.py"].decode() # type: ignore - except Exception: - raise RuntimeError( - "Failed to get captured code for op (this only works when you get an op back from a ref)." - ) + return inspect.getsource(op.resolve_fn) + except (TypeError, OSError): + return "" def maybe_bind_method(func: Callable, self: Any = None) -> Callable | MethodType: - """Bind a function to any object (even if it's not a class) + """ + Bind a function to an object if it's an unbound method. + + Args: + func: The function to bind. + self: The object to bind the function to. - If self is None, return the function as is. + Returns: + The bound method or the original function. """ if (sig := inspect.signature(func)) and sig.parameters.get("self"): if inspect.ismethod(func) and id(func.__self__) != id(self): @@ -779,49 +820,59 @@ def maybe_bind_method(func: Callable, self: Any = None) -> Callable | MethodType return func -def maybe_unbind_method(oplike: Op | MethodType | partial) -> Op: - """Unbind an Op-like method or partial to a plain Op function. +def maybe_unbind_method(oplike: Op[P, R] | MethodType | partial) -> Op[P, R]: + """ + Unbind a method from its object if it's a bound method. + + Args: + oplike: The method to unbind. - For: - - methods, remove set `self` param - - partials, remove any preset params + Returns: + The unbound method or the original function. """ if isinstance(oplike, MethodType): + # For methods, we need to get the underlying function op = oplike.__func__ - elif isinstance(oplike, partial): # Handle cases op is defined as + elif isinstance(oplike, partial): + # For partials, we need to get the underlying function op = oplike.func else: op = oplike - return cast(Op, op) + return cast(Op[P, R], op) -def is_op(obj: Any) -> bool: - if sys.version_info < (3, 12): - return isinstance(obj, Op) +def is_op(obj: Any) -> TypeGuard[Op]: + """ + Check if an object is an op. - return all(hasattr(obj, attr) for attr in Op.__annotations__) + Args: + obj: The object to check. + Returns: + True if the object is an op, False otherwise. + """ + if sys.version_info < (3, 12): + return isinstance(obj, Op) + else: + # In Python 3.12, isinstance(obj, Protocol) is not supported + # See https://peps.python.org/pep-0544/#using-protocols + return hasattr(obj, "resolve_fn") and callable(obj) -def as_op(fn: Callable) -> Op: - """Given a @weave.op() decorated function, return its Op. - @weave.op() decorated functions are instances of Op already, so this - function should be a no-op at runtime. But you can use it to satisfy type checkers - if you need to access OpDef attributes in a typesafe way. +def as_op(fn: Callable[P, R]) -> Op[P, R]: + """ + Convert a function to an op if it's not already one. Args: - fn: A weave.op() decorated function. + fn: The function to convert. Returns: - The Op of the function. + The function as an op. """ - if not is_op(fn): - raise ValueError("fn must be a weave.op() decorated function") - - # The unbinding is necessary for methods because `MethodType` is applied after the - # func is decorated into an Op. - return maybe_unbind_method(cast(Op, fn)) + if is_op(fn): + return cast(Op[P, R], fn) + return op(fn) __docspec__ = [call, calls]