Skip to content

Commit 39eec6e

Browse files
authored
chore(weave): Move accumulator exts into op (#3874)
1 parent de989d2 commit 39eec6e

File tree

18 files changed

+395
-446
lines changed

18 files changed

+395
-446
lines changed

tests/trace/test_op_return_forms.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22

33
import weave
4-
from weave.trace.op_extensions.accumulator import add_accumulator
4+
from weave.trace.op import _add_accumulator
55
from weave.trace.weave_client import get_ref
66
from weave.trace_server import trace_server_interface as tsi
77

@@ -110,7 +110,7 @@ def fn():
110110
size -= 1
111111
yield size
112112

113-
add_accumulator(fn, lambda inputs: simple_list_accumulator)
113+
_add_accumulator(fn, lambda inputs: simple_list_accumulator)
114114

115115
for item in fn():
116116
pass
@@ -137,7 +137,7 @@ async def fn():
137137
size -= 1
138138
yield size
139139

140-
add_accumulator(fn, lambda inputs: simple_list_accumulator)
140+
_add_accumulator(fn, lambda inputs: simple_list_accumulator)
141141

142142
async for item in fn():
143143
pass
@@ -172,7 +172,7 @@ def __next__(self):
172172
def fn():
173173
return MyIterator()
174174

175-
add_accumulator(fn, lambda inputs: simple_list_accumulator)
175+
_add_accumulator(fn, lambda inputs: simple_list_accumulator)
176176

177177
for item in fn():
178178
pass
@@ -208,7 +208,7 @@ async def __anext__(self):
208208
def fn():
209209
return MyAsyncIterator()
210210

211-
add_accumulator(fn, lambda inputs: simple_list_accumulator)
211+
_add_accumulator(fn, lambda inputs: simple_list_accumulator)
212212

213213
async for item in fn():
214214
pass
@@ -234,7 +234,7 @@ def fn():
234234
size -= 1
235235
yield size
236236

237-
add_accumulator(fn, lambda inputs: simple_list_accumulator)
237+
_add_accumulator(fn, lambda inputs: simple_list_accumulator)
238238

239239
fn()
240240

@@ -260,7 +260,7 @@ async def fn():
260260
size -= 1
261261
yield size
262262

263-
add_accumulator(fn, lambda inputs: simple_list_accumulator)
263+
_add_accumulator(fn, lambda inputs: simple_list_accumulator)
264264

265265
fn()
266266

@@ -294,7 +294,7 @@ def __next__(self):
294294
def fn():
295295
return MyIterator()
296296

297-
add_accumulator(fn, lambda inputs: simple_list_accumulator)
297+
_add_accumulator(fn, lambda inputs: simple_list_accumulator)
298298

299299
fn()
300300

@@ -329,7 +329,7 @@ async def __anext__(self):
329329
def fn():
330330
return MyAsyncIterator()
331331

332-
add_accumulator(fn, lambda inputs: simple_list_accumulator)
332+
_add_accumulator(fn, lambda inputs: simple_list_accumulator)
333333

334334
fn()
335335

@@ -354,7 +354,7 @@ def fn():
354354
size -= 1
355355
yield size
356356

357-
add_accumulator(fn, lambda inputs: simple_list_accumulator)
357+
_add_accumulator(fn, lambda inputs: simple_list_accumulator)
358358

359359
for item in fn():
360360
if item == 5:
@@ -382,7 +382,7 @@ async def fn():
382382
size -= 1
383383
yield size
384384

385-
add_accumulator(fn, lambda inputs: simple_list_accumulator)
385+
_add_accumulator(fn, lambda inputs: simple_list_accumulator)
386386

387387
async for item in fn():
388388
if item == 5:
@@ -418,7 +418,7 @@ def __next__(self):
418418
def fn():
419419
return MyIterator()
420420

421-
add_accumulator(fn, lambda inputs: simple_list_accumulator)
421+
_add_accumulator(fn, lambda inputs: simple_list_accumulator)
422422

423423
for item in fn():
424424
if item == 5:
@@ -455,7 +455,7 @@ async def __anext__(self):
455455
def fn():
456456
return MyAsyncIterator()
457457

458-
add_accumulator(fn, lambda inputs: simple_list_accumulator)
458+
_add_accumulator(fn, lambda inputs: simple_list_accumulator)
459459

460460
async for item in fn():
461461
if item == 5:
@@ -484,7 +484,7 @@ def fn():
484484
if size == 5:
485485
raise ValueError("test")
486486

487-
add_accumulator(fn, lambda inputs: simple_list_accumulator)
487+
_add_accumulator(fn, lambda inputs: simple_list_accumulator)
488488

489489
try:
490490
for item in fn():
@@ -517,7 +517,7 @@ async def fn():
517517
if size == 5:
518518
raise ValueError("test")
519519

520-
add_accumulator(fn, lambda inputs: simple_list_accumulator)
520+
_add_accumulator(fn, lambda inputs: simple_list_accumulator)
521521

522522
try:
523523
async for item in fn():
@@ -558,7 +558,7 @@ def __next__(self):
558558
def fn():
559559
return MyIterator()
560560

561-
add_accumulator(fn, lambda inputs: simple_list_accumulator)
561+
_add_accumulator(fn, lambda inputs: simple_list_accumulator)
562562

563563
try:
564564
for item in fn():
@@ -600,7 +600,7 @@ async def __anext__(self):
600600
def fn():
601601
return MyAsyncIterator()
602602

603-
add_accumulator(fn, lambda inputs: simple_list_accumulator)
603+
_add_accumulator(fn, lambda inputs: simple_list_accumulator)
604604

605605
try:
606606
async for item in fn():

tests/trace/test_tracing_resilience.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from tests.trace.util import DummyTestException
1515
from weave.trace.context import call_context
1616
from weave.trace.context.tests_context import raise_on_captured_errors
17-
from weave.trace.op_extensions.accumulator import add_accumulator
17+
from weave.trace.op import _add_accumulator
1818

1919

2020
def assert_no_current_call():
@@ -144,7 +144,7 @@ def simple_op():
144144
def make_accumulator(*args, **kwargs):
145145
raise DummyTestException("FAILURE!")
146146

147-
add_accumulator(simple_op, make_accumulator=make_accumulator)
147+
_add_accumulator(simple_op, make_accumulator=make_accumulator)
148148

149149
return simple_op()
150150

@@ -179,7 +179,7 @@ async def simple_op():
179179
def make_accumulator(*args, **kwargs):
180180
raise DummyTestException("FAILURE!")
181181

182-
add_accumulator(simple_op, make_accumulator=make_accumulator)
182+
_add_accumulator(simple_op, make_accumulator=make_accumulator)
183183

184184
return simple_op()
185185

@@ -212,7 +212,7 @@ def accumulate(*args, **kwargs):
212212

213213
return accumulate
214214

215-
add_accumulator(simple_op, make_accumulator=make_accumulator)
215+
_add_accumulator(simple_op, make_accumulator=make_accumulator)
216216

217217
return simple_op()
218218

@@ -252,7 +252,7 @@ def accumulate(*args, **kwargs):
252252

253253
return accumulate
254254

255-
add_accumulator(simple_op, make_accumulator=make_accumulator)
255+
_add_accumulator(simple_op, make_accumulator=make_accumulator)
256256

257257
return simple_op()
258258

@@ -291,7 +291,7 @@ def accumulate(*args, **kwargs):
291291
def should_accumulate(*args, **kwargs):
292292
raise DummyTestException("FAILURE!")
293293

294-
add_accumulator(
294+
_add_accumulator(
295295
simple_op,
296296
make_accumulator=make_accumulator,
297297
should_accumulate=should_accumulate,
@@ -336,7 +336,7 @@ def accumulate(*args, **kwargs):
336336
def should_accumulate(*args, **kwargs):
337337
raise DummyTestException("FAILURE!")
338338

339-
add_accumulator(
339+
_add_accumulator(
340340
simple_op,
341341
make_accumulator=make_accumulator,
342342
should_accumulate=should_accumulate,
@@ -383,7 +383,7 @@ def accumulate(*args, **kwargs):
383383
def on_finish_post_processor(*args, **kwargs):
384384
raise DummyTestException("FAILURE!")
385385

386-
add_accumulator(
386+
_add_accumulator(
387387
simple_op,
388388
make_accumulator=make_accumulator,
389389
on_finish_post_processor=on_finish_post_processor,
@@ -429,7 +429,7 @@ def accumulate(*args, **kwargs):
429429
def on_finish_post_processor(*args, **kwargs):
430430
raise DummyTestException("FAILURE!")
431431

432-
add_accumulator(
432+
_add_accumulator(
433433
simple_op,
434434
make_accumulator=make_accumulator,
435435
on_finish_post_processor=on_finish_post_processor,
@@ -468,7 +468,7 @@ def accumulate(*args, **kwargs):
468468

469469
return accumulate
470470

471-
add_accumulator(simple_op, make_accumulator=make_accumulator)
471+
_add_accumulator(simple_op, make_accumulator=make_accumulator)
472472

473473
return simple_op()
474474

@@ -498,7 +498,7 @@ def accumulate(*args, **kwargs):
498498

499499
return accumulate
500500

501-
add_accumulator(simple_op, make_accumulator=make_accumulator)
501+
_add_accumulator(simple_op, make_accumulator=make_accumulator)
502502

503503
return simple_op()
504504

weave/integrations/anthropic/anthropic_sdk.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import weave
99
from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
1010
from weave.trace.autopatch import IntegrationSettings, OpSettings
11-
from weave.trace.op_extensions.accumulator import _IteratorWrapper, add_accumulator
11+
from weave.trace.op import _add_accumulator, _IteratorWrapper
1212

1313
if TYPE_CHECKING:
1414
from anthropic.lib.streaming import MessageStream
@@ -77,7 +77,7 @@ def wrapper(fn: Callable) -> Callable:
7777
"We need to do this so we can check if `stream` is used"
7878
op_kwargs = settings.model_dump()
7979
op = weave.op(fn, **op_kwargs)
80-
return add_accumulator(
80+
return _add_accumulator(
8181
op, # type: ignore
8282
make_accumulator=lambda inputs: anthropic_accumulator,
8383
should_accumulate=should_use_accumulator,
@@ -101,7 +101,7 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:
101101
"We need to do this so we can check if `stream` is used"
102102
op_kwargs = settings.model_dump()
103103
op = weave.op(_fn_wrapper(fn), **op_kwargs)
104-
return add_accumulator(
104+
return _add_accumulator(
105105
op, # type: ignore
106106
make_accumulator=lambda inputs: anthropic_accumulator,
107107
should_accumulate=should_use_accumulator,
@@ -170,7 +170,7 @@ def create_stream_wrapper(settings: OpSettings) -> Callable[[Callable], Callable
170170
def wrapper(fn: Callable) -> Callable:
171171
op_kwargs = settings.model_dump()
172172
op = weave.op(fn, **op_kwargs)
173-
return add_accumulator(
173+
return _add_accumulator(
174174
op, # type: ignore
175175
make_accumulator=lambda _: anthropic_stream_accumulator,
176176
should_accumulate=lambda _: True,

weave/integrations/bedrock/bedrock_sdk.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import TYPE_CHECKING, Any, Callable, Optional
22

33
import weave
4-
from weave.trace.op_extensions.accumulator import _IteratorWrapper, add_accumulator
4+
from weave.trace.op import _add_accumulator, _IteratorWrapper
55
from weave.trace.weave_client import Call
66

77
if TYPE_CHECKING:
@@ -133,7 +133,7 @@ def get(self, key: str, default: Any = None) -> Any:
133133
)
134134
return self
135135

136-
return add_accumulator(
136+
return _add_accumulator(
137137
op,
138138
make_accumulator=lambda _: bedrock_stream_accumulator,
139139
should_accumulate=lambda _: True,

weave/integrations/cohere/cohere_sdk.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import weave
88
from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
99
from weave.trace.autopatch import IntegrationSettings, OpSettings
10-
from weave.trace.op_extensions.accumulator import add_accumulator
10+
from weave.trace.op import _add_accumulator
1111

1212
if TYPE_CHECKING:
1313
from cohere.types.non_streamed_chat_response import NonStreamedChatResponse
@@ -167,7 +167,7 @@ def cohere_stream_wrapper(settings: OpSettings) -> Callable:
167167
def wrapper(fn: Callable) -> Callable:
168168
op_kwargs = settings.model_dump()
169169
op = weave.op(fn, **op_kwargs)
170-
return add_accumulator(op, lambda inputs: cohere_accumulator)
170+
return _add_accumulator(op, lambda inputs: cohere_accumulator)
171171

172172
return wrapper
173173

@@ -176,7 +176,7 @@ def cohere_stream_wrapper_v2(settings: OpSettings) -> Callable:
176176
def wrapper(fn: Callable) -> Callable:
177177
op_kwargs = settings.model_dump()
178178
op = weave.op(fn, **op_kwargs)
179-
return add_accumulator(op, lambda inputs: cohere_accumulator_v2)
179+
return _add_accumulator(op, lambda inputs: cohere_accumulator_v2)
180180

181181
return wrapper
182182

weave/integrations/google_ai_studio/google_ai_studio_sdk.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import weave
88
from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
99
from weave.trace.autopatch import IntegrationSettings, OpSettings
10-
from weave.trace.op_extensions.accumulator import add_accumulator
10+
from weave.trace.op import _add_accumulator
1111
from weave.trace.serialization.serialize import dictify
1212
from weave.trace.weave_client import Call
1313

@@ -100,7 +100,7 @@ def wrapper(fn: Callable) -> Callable:
100100

101101
op = weave.op(fn, **op_kwargs)
102102
op._set_on_finish_handler(gemini_on_finish)
103-
return add_accumulator(
103+
return _add_accumulator(
104104
op, # type: ignore
105105
make_accumulator=lambda inputs: gemini_accumulator,
106106
should_accumulate=lambda inputs: isinstance(inputs, dict)
@@ -125,7 +125,7 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:
125125

126126
op = weave.op(_fn_wrapper(fn), **op_kwargs)
127127
op._set_on_finish_handler(gemini_on_finish)
128-
return add_accumulator(
128+
return _add_accumulator(
129129
op, # type: ignore
130130
make_accumulator=lambda inputs: gemini_accumulator,
131131
should_accumulate=lambda inputs: isinstance(inputs, dict)

weave/integrations/groq/groq_sdk.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import weave
77
from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
88
from weave.trace.autopatch import IntegrationSettings, OpSettings
9-
from weave.trace.op_extensions.accumulator import add_accumulator
9+
from weave.trace.op import _add_accumulator
1010

1111
if TYPE_CHECKING:
1212
from groq.types.chat import ChatCompletion, ChatCompletionChunk
@@ -93,7 +93,7 @@ def groq_wrapper(settings: OpSettings) -> Callable[[Callable], Callable]:
9393
def wrapper(fn: Callable) -> Callable:
9494
op_kwargs = settings.model_dump()
9595
op = weave.op(fn, **op_kwargs)
96-
return add_accumulator(
96+
return _add_accumulator(
9797
op, # type: ignore
9898
make_accumulator=lambda inputs: groq_accumulator,
9999
should_accumulate=should_use_accumulator,

0 commit comments

Comments
 (0)