Skip to content

Commit 64ada16

Browse files
fix(py): serialize langchain objects in testing.log_outputs, fix rich display table [LS-3237] (#1609)
Co-authored-by: jacoblee93 <[email protected]>
1 parent 9aa4df7 commit 64ada16

File tree

6 files changed

+288
-250
lines changed

6 files changed

+288
-250
lines changed

python/langsmith/pytest_plugin.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import time
88
from collections import defaultdict
99
from threading import Lock
10+
from typing import Any
1011

1112
import pytest
1213

@@ -75,6 +76,7 @@ def pytest_runtest_call(item):
7576
request_obj = getattr(item, "_request", None)
7677
if request_obj is not None and "request" not in item.funcargs:
7778
item.funcargs["request"] = request_obj
79+
if request_obj is not None and "request" not in item._fixtureinfo.argnames:
7880
# Create a new FuncFixtureInfo instance with updated argnames
7981
item._fixtureinfo = type(item._fixtureinfo)(
8082
argnames=item._fixtureinfo.argnames + ("request",),
@@ -133,27 +135,11 @@ def update_process_status(self, process_id, status):
133135

134136
with self.status_lock:
135137
current_status = self.process_status.get(process_id, {})
136-
if status.get("feedback"):
137-
current_status["feedback"] = {
138-
**current_status.get("feedback", {}),
139-
**status.pop("feedback"),
140-
}
141-
if status.get("inputs"):
142-
current_status["inputs"] = {
143-
**current_status.get("inputs", {}),
144-
**status.pop("inputs"),
145-
}
146-
if status.get("reference_outputs"):
147-
current_status["reference_outputs"] = {
148-
**current_status.get("reference_outputs", {}),
149-
**status.pop("reference_outputs"),
150-
}
151-
if status.get("outputs"):
152-
current_status["outputs"] = {
153-
**current_status.get("outputs", {}),
154-
**status.pop("outputs"),
155-
}
156-
self.process_status[process_id] = {**current_status, **status}
138+
self.process_status[process_id] = _merge_statuses(
139+
status,
140+
current_status,
141+
unpack=["feedback", "inputs", "reference_outputs", "outputs"],
142+
)
157143
self.live.update(self.generate_tables())
158144

159145
def pytest_runtest_logstart(self, nodeid):
@@ -246,9 +232,11 @@ def _generate_table(self, suite_name: str):
246232
f"{_abbreviate(k, max_len=max_dynamic_col_width)}: {int(v) if isinstance(v, bool) else v}" # noqa: E501
247233
for k, v in status.get("feedback", {}).items()
248234
)
249-
inputs = json.dumps(status.get("inputs", {}))
250-
reference_outputs = json.dumps(status.get("reference_outputs", {}))
251-
outputs = json.dumps(status.get("outputs", {}))
235+
inputs = _dumps_with_fallback(status.get("inputs", {}))
236+
reference_outputs = _dumps_with_fallback(
237+
status.get("reference_outputs", {})
238+
)
239+
outputs = _dumps_with_fallback(status.get("outputs", {}))
252240
table.add_row(
253241
_abbreviate_test_name(str(pid), max_len=max_dynamic_col_width),
254242
_abbreviate(inputs, max_len=max_dynamic_col_width),
@@ -339,3 +327,21 @@ def _abbreviate_test_name(test_name: str, max_len: int) -> str:
339327
return "..." + file[-file_len:] + "::" + test
340328
else:
341329
return test_name
330+
331+
332+
def _merge_statuses(update: dict, current: dict, *, unpack: list[str]) -> dict:
333+
for path in unpack:
334+
if path_update := update.pop(path, None):
335+
path_current = current.get(path, {})
336+
if isinstance(path_update, dict) and isinstance(path_current, dict):
337+
current[path] = {**path_current, **path_update}
338+
else:
339+
current[path] = path_update
340+
return {**current, **update}
341+
342+
343+
def _dumps_with_fallback(obj: Any) -> str:
344+
try:
345+
return json.dumps(obj)
346+
except Exception:
347+
return "unserializable"

python/langsmith/testing/_internal.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,7 @@ def test_foo() -> None:
10761076
"LANGSMITH_TRACING environment variable to 'true')."
10771077
)
10781078
raise ValueError(msg)
1079+
outputs = _dumpd(outputs)
10791080
run_tree.add_outputs(outputs)
10801081
test_case.log_outputs(outputs)
10811082

@@ -1302,3 +1303,25 @@ def _stringify(x: Any) -> str:
13021303
return dumps_json(x).decode("utf-8", errors="surrogateescape")
13031304
except Exception:
13041305
return str(x)
1306+
1307+
1308+
def _dumpd(x: Any) -> Any:
1309+
"""Serialize LangChain Serializable objects."""
1310+
dumpd = _get_langchain_dumpd()
1311+
if not dumpd:
1312+
return x
1313+
try:
1314+
serialized = dumpd(x)
1315+
return serialized
1316+
except Exception:
1317+
return x
1318+
1319+
1320+
@functools.lru_cache
1321+
def _get_langchain_dumpd() -> Optional[Callable]:
1322+
try:
1323+
from langchain_core.load import dumpd
1324+
1325+
return dumpd
1326+
except ImportError:
1327+
return None

0 commit comments

Comments
 (0)