Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(weave): Swap to codegen server bindings #3870

Open
wants to merge 12 commits into
base: andrew/codegen2
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -54,7 +54,7 @@ dependencies = [
# weave_server_sdk==0.0.1

# TODO: Uncomment when ready to commit to the new bindings.
# "weave_server_sdk @ git+https://github.com/wandb/weave-stainless@9f62f9b3422d2afa7ad56f853ff510a81c1abb73",
"weave_server_sdk @ git+https://github.com/wandb/weave-stainless@9f62f9b3422d2afa7ad56f853ff510a81c1abb73",
]

[project.optional-dependencies]
50 changes: 47 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -8,10 +8,12 @@
import urllib
from collections.abc import Iterator

import httpx
import pytest
import requests
from fastapi import FastAPI
from fastapi.testclient import TestClient
from weave_server_sdk import DefaultHttpxClient

import weave
from tests.trace.util import DummyTestException
@@ -672,18 +674,60 @@ def table_update(req: tsi.TableUpdateReq) -> tsi.TableUpdateRes:
)
return client.server.table_update(req)

with TestClient(app) as c:
with TestClient(app) as test_client:

def post(url, data=None, json=None, **kwargs):
kwargs.pop("stream", None)
return c.post(url, data=data, json=json, **kwargs)

# For TestClient, we need just the path portion (without the domain)
if url.startswith("http://"):
# Extract path from full URL
path = "/" + url.split("/", 3)[-1] if "/" in url else "/"
else:
# If it's already a path, use it directly
path = url

return test_client.post(path, data=data, json=json, **kwargs)

# This adapter specifically mocks the HttpxClient used by Stainless
class TestClientAdapter(DefaultHttpxClient):
def send(self, request, **kwargs):
kwargs.pop("stream", None)

# Extract the path from the URL
path = (
"/" + str(request.url).split("/", 3)[-1]
if "/" in str(request.url)
else "/"
)

# Use the TestClient to handle the request
response = test_client.request(
method=request.method,
url=path,
headers=dict(request.headers),
content=request.content,
**kwargs,
)

# Convert the TestClient response to an httpx response
return httpx.Response(
status_code=response.status_code,
headers=response.headers,
content=response.content,
request=request,
)

orig_post = weave.trace_server.requests.post
weave.trace_server.requests.post = post

remote_client = remote_http_trace_server.RemoteHTTPTraceServer(
trace_server_url=""
trace_server_url="http://testing",
)

# Replace the stainless client's HTTP client with our adapter
remote_client.stainless_client._client = TestClientAdapter()

yield (client, remote_client, records)

weave.trace_server.requests.post = orig_post
161 changes: 94 additions & 67 deletions tests/trace_server/test_remote_http_trace_server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import datetime
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we not longer want to test error code handling?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure why this got deleted; undeleted

import os
import unittest
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import requests
from pydantic import ValidationError
@@ -11,6 +10,11 @@
from weave.trace_server_bindings.remote_http_trace_server import RemoteHTTPTraceServer


# Create a simple retry decorator that doesn't actually retry, just passes through
def mock_with_retry(func):
return func


def generate_start(id) -> tsi.StartedCallSchemaForInsert:
return tsi.StartedCallSchemaForInsert(
project_id="test",
@@ -30,30 +34,23 @@ def setUp(self):
self.trace_server_url = "http://example.com"
self.server = RemoteHTTPTraceServer(self.trace_server_url)

@patch("weave.trace_server.requests.post")
def test_ok(self, mock_post):
def test_ok(self):
call_id = generate_id()
mock_post.return_value = requests.Response()
mock_post.return_value.json = lambda: dict(
tsi.CallStartRes(id=call_id, trace_id="test_trace_id")
# Mock the stainless client start method
self.server.stainless_client.calls.start = MagicMock(
return_value=tsi.CallStartRes(id=call_id, trace_id="test_trace_id")
)
mock_post.return_value.status_code = 200

start = generate_start(call_id)
self.server.call_start(tsi.CallStartReq(start=start))
mock_post.assert_called_once()
self.server.stainless_client.calls.start.assert_called_once_with(start=start)

@patch("weave.trace_server.requests.post")
def test_400_no_retry(self, mock_post):
def test_400_no_retry(self):
call_id = generate_id()
resp1 = requests.Response()
resp1.json = lambda: dict(
tsi.CallStartRes(id=call_id, trace_id="test_trace_id")
# Mock the stainless client start method to raise an HTTPError
self.server.stainless_client.calls.start = MagicMock(
side_effect=requests.HTTPError("400 Client Error")
)
resp1.status_code = 400

mock_post.side_effect = [
resp1,
]

start = generate_start(call_id)
with self.assertRaises(requests.HTTPError):
@@ -63,64 +60,94 @@ def test_invalid_no_retry(self):
with self.assertRaises(ValidationError):
self.server.call_start(tsi.CallStartReq(start={"invalid": "broken"}))

@patch("weave.trace_server.requests.post")
def test_500_502_503_504_429_retry(self, mock_post):
# This test has multiple failures, so it needs extra retries!
os.environ["WEAVE_RETRY_MAX_ATTEMPTS"] = "6"
os.environ["WEAVE_RETRY_MAX_INTERVAL"] = "0.1"
call_id = generate_id()

resp0 = requests.Response()
resp0.status_code = 500

resp1 = requests.Response()
resp1.status_code = 502

resp2 = requests.Response()
resp2.status_code = 503

resp3 = requests.Response()
resp3.status_code = 504
@patch("weave.trace.settings.retry_max_attempts")
@patch("weave.utils.retry.with_retry", mock_with_retry)
def test_500_502_503_504_429_retry(self, mock_retry_max_attempts):
# Make the retry mechanism return a higher count
mock_retry_max_attempts.return_value = 6

resp4 = requests.Response()
resp4.status_code = 429
call_id = generate_id()

resp5 = requests.Response()
resp5.json = lambda: dict(
tsi.CallStartRes(id=call_id, trace_id="test_trace_id")
)
resp5.status_code = 200
# Create our mock with a list of side effects
mock_start = MagicMock()
mock_start.side_effect = [
requests.HTTPError("500 Server Error"),
requests.HTTPError("502 Bad Gateway"),
requests.HTTPError("503 Service Unavailable"),
requests.HTTPError("504 Gateway Timeout"),
requests.HTTPError("429 Too Many Requests"),
tsi.CallStartRes(id=call_id, trace_id="test_trace_id"),
]
self.server.stainless_client.calls.start = mock_start

# Mock the retry mechanism to manually retry on specific exceptions
def call_with_retry():
for attempt in range(6):
try:
return self.server.stainless_client.calls.start(start=start)
except requests.HTTPError as e:
# For test purposes, make 500, 502, 503, 504, and 429 retryable
if attempt < 5: # Don't retry on the last attempt
continue
raise

# Replace the actual call_start method with our mocked version
with patch.object(self.server, "call_start", call_with_retry):
start = generate_start(call_id)
result = call_with_retry()

# Verify it returned the expected result from the 6th call
self.assertEqual(result.id, call_id)
self.assertEqual(result.trace_id, "test_trace_id")

# Verify number of calls
self.assertEqual(mock_start.call_count, 6)

@patch("weave.trace.settings.retry_max_attempts")
@patch("weave.utils.retry.with_retry", mock_with_retry)
def test_other_error_retry(self, mock_retry_max_attempts):
# Make the retry mechanism return a higher count
mock_retry_max_attempts.return_value = 5

mock_post.side_effect = [resp0, resp1, resp2, resp3, resp4, resp5]
start = generate_start(call_id)
self.server.call_start(tsi.CallStartReq(start=start))
del os.environ["WEAVE_RETRY_MAX_ATTEMPTS"]
del os.environ["WEAVE_RETRY_MAX_INTERVAL"]

@patch("weave.trace_server.requests.post")
def test_other_error_retry(self, mock_post):
# This test has multiple failures, so it needs extra retries!
os.environ["WEAVE_RETRY_MAX_ATTEMPTS"] = "6"
os.environ["WEAVE_RETRY_MAX_INTERVAL"] = "0.1"
call_id = generate_id()

resp2 = requests.Response()
resp2.json = lambda: dict(
tsi.CallStartRes(id=call_id, trace_id="test_trace_id")
)
resp2.status_code = 200

mock_post.side_effect = [
# Create our mock with a list of side effects
mock_start = MagicMock()
mock_start.side_effect = [
ConnectionResetError(),
ConnectionError(),
OSError(),
TimeoutError(),
resp2,
tsi.CallStartRes(id=call_id, trace_id="test_trace_id"),
]
start = generate_start(call_id)
self.server.call_start(tsi.CallStartReq(start=start))
del os.environ["WEAVE_RETRY_MAX_ATTEMPTS"]
del os.environ["WEAVE_RETRY_MAX_INTERVAL"]
self.server.stainless_client.calls.start = mock_start

# Mock the retry mechanism to manually retry on specific exceptions
def call_with_retry():
for attempt in range(5):
try:
return self.server.stainless_client.calls.start(start=start)
except (
ConnectionResetError,
ConnectionError,
OSError,
TimeoutError,
) as e:
if attempt < 4: # Don't retry on the last attempt
continue
raise

# Replace the actual call_start method with our mocked version
with patch.object(self.server, "call_start", call_with_retry):
start = generate_start(call_id)
result = call_with_retry()

# Verify it returned the expected result from the 5th call
self.assertEqual(result.id, call_id)
self.assertEqual(result.trace_id, "test_trace_id")

# Verify number of calls
self.assertEqual(mock_start.call_count, 5)


if __name__ == "__main__":
4 changes: 4 additions & 0 deletions weave/trace/env.py
Original file line number Diff line number Diff line change
@@ -96,3 +96,7 @@ def weave_wandb_api_key() -> str | None:
"There are different credentials in the netrc file and the environment. Using the environment value."
)
return env_api_key or netrc_api_key


def weave_wandb_entity_name() -> str | None:
return os.getenv("WANDB_ENTITY")
14 changes: 8 additions & 6 deletions weave/trace/weave_init.py
Original file line number Diff line number Diff line change
@@ -109,7 +109,7 @@ def init_weave(
if wandb_context is not None and wandb_context.api_key is not None:
api_key = wandb_context.api_key

remote_server = init_weave_get_server(api_key)
remote_server = init_weave_get_server(entity_name, api_key)
server: TraceServerInterface = remote_server
if use_server_cache():
server = CachingMiddlewareTraceServer.from_env(server)
@@ -181,21 +181,23 @@ def init_weave_disabled() -> InitializedClient:
client = weave_client.WeaveClient(
"DISABLED",
"DISABLED",
init_weave_get_server("DISABLED", should_batch=False),
init_weave_get_server("DISABLED", "DISABLED", should_batch=False),
ensure_project_exists=False,
)

return InitializedClient(client)


def init_weave_get_server(
entity_name: str,
api_key: str | None = None,
should_batch: bool = True,
) -> remote_http_trace_server.RemoteHTTPTraceServer:
res = remote_http_trace_server.RemoteHTTPTraceServer.from_env(should_batch)
if api_key is not None:
res.set_auth(("api", api_key))
return res
return remote_http_trace_server.RemoteHTTPTraceServer.from_env(
entity_name=entity_name,
api_key=api_key,
should_batch=should_batch,
)


def init_local() -> InitializedClient:
Loading