Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(weave): Better organized op func methods #3898

Merged
merged 7 commits into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions tests/trace/test_op_return_forms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

import weave
from weave.trace.op_extensions.accumulator import add_accumulator
from weave.trace.op import _add_accumulator
from weave.trace.weave_client import get_ref
from weave.trace_server import trace_server_interface as tsi

Expand Down Expand Up @@ -110,7 +110,7 @@ def fn():
size -= 1
yield size

add_accumulator(fn, lambda inputs: simple_list_accumulator)
_add_accumulator(fn, lambda inputs: simple_list_accumulator)

for item in fn():
pass
Expand All @@ -137,7 +137,7 @@ async def fn():
size -= 1
yield size

add_accumulator(fn, lambda inputs: simple_list_accumulator)
_add_accumulator(fn, lambda inputs: simple_list_accumulator)

async for item in fn():
pass
Expand Down Expand Up @@ -172,7 +172,7 @@ def __next__(self):
def fn():
return MyIterator()

add_accumulator(fn, lambda inputs: simple_list_accumulator)
_add_accumulator(fn, lambda inputs: simple_list_accumulator)

for item in fn():
pass
Expand Down Expand Up @@ -208,7 +208,7 @@ async def __anext__(self):
def fn():
return MyAsyncIterator()

add_accumulator(fn, lambda inputs: simple_list_accumulator)
_add_accumulator(fn, lambda inputs: simple_list_accumulator)

async for item in fn():
pass
Expand All @@ -234,7 +234,7 @@ def fn():
size -= 1
yield size

add_accumulator(fn, lambda inputs: simple_list_accumulator)
_add_accumulator(fn, lambda inputs: simple_list_accumulator)

fn()

Expand All @@ -260,7 +260,7 @@ async def fn():
size -= 1
yield size

add_accumulator(fn, lambda inputs: simple_list_accumulator)
_add_accumulator(fn, lambda inputs: simple_list_accumulator)

fn()

Expand Down Expand Up @@ -294,7 +294,7 @@ def __next__(self):
def fn():
return MyIterator()

add_accumulator(fn, lambda inputs: simple_list_accumulator)
_add_accumulator(fn, lambda inputs: simple_list_accumulator)

fn()

Expand Down Expand Up @@ -329,7 +329,7 @@ async def __anext__(self):
def fn():
return MyAsyncIterator()

add_accumulator(fn, lambda inputs: simple_list_accumulator)
_add_accumulator(fn, lambda inputs: simple_list_accumulator)

fn()

Expand All @@ -354,7 +354,7 @@ def fn():
size -= 1
yield size

add_accumulator(fn, lambda inputs: simple_list_accumulator)
_add_accumulator(fn, lambda inputs: simple_list_accumulator)

for item in fn():
if item == 5:
Expand Down Expand Up @@ -382,7 +382,7 @@ async def fn():
size -= 1
yield size

add_accumulator(fn, lambda inputs: simple_list_accumulator)
_add_accumulator(fn, lambda inputs: simple_list_accumulator)

async for item in fn():
if item == 5:
Expand Down Expand Up @@ -418,7 +418,7 @@ def __next__(self):
def fn():
return MyIterator()

add_accumulator(fn, lambda inputs: simple_list_accumulator)
_add_accumulator(fn, lambda inputs: simple_list_accumulator)

for item in fn():
if item == 5:
Expand Down Expand Up @@ -455,7 +455,7 @@ async def __anext__(self):
def fn():
return MyAsyncIterator()

add_accumulator(fn, lambda inputs: simple_list_accumulator)
_add_accumulator(fn, lambda inputs: simple_list_accumulator)

async for item in fn():
if item == 5:
Expand Down Expand Up @@ -484,7 +484,7 @@ def fn():
if size == 5:
raise ValueError("test")

add_accumulator(fn, lambda inputs: simple_list_accumulator)
_add_accumulator(fn, lambda inputs: simple_list_accumulator)

try:
for item in fn():
Expand Down Expand Up @@ -517,7 +517,7 @@ async def fn():
if size == 5:
raise ValueError("test")

add_accumulator(fn, lambda inputs: simple_list_accumulator)
_add_accumulator(fn, lambda inputs: simple_list_accumulator)

try:
async for item in fn():
Expand Down Expand Up @@ -558,7 +558,7 @@ def __next__(self):
def fn():
return MyIterator()

add_accumulator(fn, lambda inputs: simple_list_accumulator)
_add_accumulator(fn, lambda inputs: simple_list_accumulator)

try:
for item in fn():
Expand Down Expand Up @@ -600,7 +600,7 @@ async def __anext__(self):
def fn():
return MyAsyncIterator()

add_accumulator(fn, lambda inputs: simple_list_accumulator)
_add_accumulator(fn, lambda inputs: simple_list_accumulator)

try:
async for item in fn():
Expand Down
22 changes: 11 additions & 11 deletions tests/trace/test_tracing_resilience.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from tests.trace.util import DummyTestException
from weave.trace.context import call_context
from weave.trace.context.tests_context import raise_on_captured_errors
from weave.trace.op_extensions.accumulator import add_accumulator
from weave.trace.op import _add_accumulator


def assert_no_current_call():
Expand Down Expand Up @@ -144,7 +144,7 @@ def simple_op():
def make_accumulator(*args, **kwargs):
raise DummyTestException("FAILURE!")

add_accumulator(simple_op, make_accumulator=make_accumulator)
_add_accumulator(simple_op, make_accumulator=make_accumulator)

return simple_op()

Expand Down Expand Up @@ -179,7 +179,7 @@ async def simple_op():
def make_accumulator(*args, **kwargs):
raise DummyTestException("FAILURE!")

add_accumulator(simple_op, make_accumulator=make_accumulator)
_add_accumulator(simple_op, make_accumulator=make_accumulator)

return simple_op()

Expand Down Expand Up @@ -212,7 +212,7 @@ def accumulate(*args, **kwargs):

return accumulate

add_accumulator(simple_op, make_accumulator=make_accumulator)
_add_accumulator(simple_op, make_accumulator=make_accumulator)

return simple_op()

Expand Down Expand Up @@ -252,7 +252,7 @@ def accumulate(*args, **kwargs):

return accumulate

add_accumulator(simple_op, make_accumulator=make_accumulator)
_add_accumulator(simple_op, make_accumulator=make_accumulator)

return simple_op()

Expand Down Expand Up @@ -291,7 +291,7 @@ def accumulate(*args, **kwargs):
def should_accumulate(*args, **kwargs):
raise DummyTestException("FAILURE!")

add_accumulator(
_add_accumulator(
simple_op,
make_accumulator=make_accumulator,
should_accumulate=should_accumulate,
Expand Down Expand Up @@ -336,7 +336,7 @@ def accumulate(*args, **kwargs):
def should_accumulate(*args, **kwargs):
raise DummyTestException("FAILURE!")

add_accumulator(
_add_accumulator(
simple_op,
make_accumulator=make_accumulator,
should_accumulate=should_accumulate,
Expand Down Expand Up @@ -383,7 +383,7 @@ def accumulate(*args, **kwargs):
def on_finish_post_processor(*args, **kwargs):
raise DummyTestException("FAILURE!")

add_accumulator(
_add_accumulator(
simple_op,
make_accumulator=make_accumulator,
on_finish_post_processor=on_finish_post_processor,
Expand Down Expand Up @@ -429,7 +429,7 @@ def accumulate(*args, **kwargs):
def on_finish_post_processor(*args, **kwargs):
raise DummyTestException("FAILURE!")

add_accumulator(
_add_accumulator(
simple_op,
make_accumulator=make_accumulator,
on_finish_post_processor=on_finish_post_processor,
Expand Down Expand Up @@ -468,7 +468,7 @@ def accumulate(*args, **kwargs):

return accumulate

add_accumulator(simple_op, make_accumulator=make_accumulator)
_add_accumulator(simple_op, make_accumulator=make_accumulator)

return simple_op()

Expand Down Expand Up @@ -498,7 +498,7 @@ def accumulate(*args, **kwargs):

return accumulate

add_accumulator(simple_op, make_accumulator=make_accumulator)
_add_accumulator(simple_op, make_accumulator=make_accumulator)

return simple_op()

Expand Down
8 changes: 4 additions & 4 deletions weave/integrations/anthropic/anthropic_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import weave
from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import _IteratorWrapper, add_accumulator
from weave.trace.op import _add_accumulator, _IteratorWrapper

if TYPE_CHECKING:
from anthropic.lib.streaming import MessageStream
Expand Down Expand Up @@ -77,7 +77,7 @@ def wrapper(fn: Callable) -> Callable:
"We need to do this so we can check if `stream` is used"
op_kwargs = settings.model_dump()
op = weave.op(fn, **op_kwargs)
return add_accumulator(
return _add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: anthropic_accumulator,
should_accumulate=should_use_accumulator,
Expand All @@ -101,7 +101,7 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:
"We need to do this so we can check if `stream` is used"
op_kwargs = settings.model_dump()
op = weave.op(_fn_wrapper(fn), **op_kwargs)
return add_accumulator(
return _add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: anthropic_accumulator,
should_accumulate=should_use_accumulator,
Expand Down Expand Up @@ -170,7 +170,7 @@ def create_stream_wrapper(settings: OpSettings) -> Callable[[Callable], Callable
def wrapper(fn: Callable) -> Callable:
op_kwargs = settings.model_dump()
op = weave.op(fn, **op_kwargs)
return add_accumulator(
return _add_accumulator(
op, # type: ignore
make_accumulator=lambda _: anthropic_stream_accumulator,
should_accumulate=lambda _: True,
Expand Down
4 changes: 2 additions & 2 deletions weave/integrations/bedrock/bedrock_sdk.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import TYPE_CHECKING, Any, Callable, Optional

import weave
from weave.trace.op_extensions.accumulator import _IteratorWrapper, add_accumulator
from weave.trace.op import _add_accumulator, _IteratorWrapper
from weave.trace.weave_client import Call

if TYPE_CHECKING:
Expand Down Expand Up @@ -133,7 +133,7 @@ def get(self, key: str, default: Any = None) -> Any:
)
return self

return add_accumulator(
return _add_accumulator(
op,
make_accumulator=lambda _: bedrock_stream_accumulator,
should_accumulate=lambda _: True,
Expand Down
6 changes: 3 additions & 3 deletions weave/integrations/cohere/cohere_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import weave
from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator
from weave.trace.op import _add_accumulator

if TYPE_CHECKING:
from cohere.types.non_streamed_chat_response import NonStreamedChatResponse
Expand Down Expand Up @@ -167,7 +167,7 @@ def cohere_stream_wrapper(settings: OpSettings) -> Callable:
def wrapper(fn: Callable) -> Callable:
op_kwargs = settings.model_dump()
op = weave.op(fn, **op_kwargs)
return add_accumulator(op, lambda inputs: cohere_accumulator)
return _add_accumulator(op, lambda inputs: cohere_accumulator)

return wrapper

Expand All @@ -176,7 +176,7 @@ def cohere_stream_wrapper_v2(settings: OpSettings) -> Callable:
def wrapper(fn: Callable) -> Callable:
op_kwargs = settings.model_dump()
op = weave.op(fn, **op_kwargs)
return add_accumulator(op, lambda inputs: cohere_accumulator_v2)
return _add_accumulator(op, lambda inputs: cohere_accumulator_v2)

return wrapper

Expand Down
6 changes: 3 additions & 3 deletions weave/integrations/google_ai_studio/google_ai_studio_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import weave
from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator
from weave.trace.op import _add_accumulator
from weave.trace.serialization.serialize import dictify
from weave.trace.weave_client import Call

Expand Down Expand Up @@ -100,7 +100,7 @@ def wrapper(fn: Callable) -> Callable:

op = weave.op(fn, **op_kwargs)
op._set_on_finish_handler(gemini_on_finish)
return add_accumulator(
return _add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: gemini_accumulator,
should_accumulate=lambda inputs: isinstance(inputs, dict)
Expand All @@ -125,7 +125,7 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:

op = weave.op(_fn_wrapper(fn), **op_kwargs)
op._set_on_finish_handler(gemini_on_finish)
return add_accumulator(
return _add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: gemini_accumulator,
should_accumulate=lambda inputs: isinstance(inputs, dict)
Expand Down
4 changes: 2 additions & 2 deletions weave/integrations/groq/groq_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import weave
from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator
from weave.trace.op import _add_accumulator

if TYPE_CHECKING:
from groq.types.chat import ChatCompletion, ChatCompletionChunk
Expand Down Expand Up @@ -93,7 +93,7 @@ def groq_wrapper(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
op_kwargs = settings.model_dump()
op = weave.op(fn, **op_kwargs)
return add_accumulator(
return _add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: groq_accumulator,
should_accumulate=should_use_accumulator,
Expand Down
Loading