@@ -301,6 +301,7 @@ def test_with_expected_output(some_input: str, expected_output: str):
301
301
client = kwargs .pop ("client" , None ),
302
302
test_suite_name = kwargs .pop ("test_suite_name" , None ),
303
303
cache = ls_utils .get_cache_dir (kwargs .pop ("cache" , None )),
304
+ metadata = kwargs .pop ("metadata" , None ),
304
305
)
305
306
if kwargs :
306
307
warnings .warn (f"Unexpected keyword arguments: { kwargs .keys ()} " )
@@ -648,6 +649,7 @@ def end_run(
648
649
example_id ,
649
650
outputs ,
650
651
reference_outputs ,
652
+ metadata ,
651
653
pytest_plugin = None ,
652
654
pytest_nodeid = None ,
653
655
) -> Future :
@@ -657,6 +659,7 @@ def end_run(
657
659
example_id = example_id ,
658
660
outputs = outputs ,
659
661
reference_outputs = reference_outputs ,
662
+ metadata = metadata ,
660
663
pytest_plugin = pytest_plugin ,
661
664
pytest_nodeid = pytest_nodeid ,
662
665
)
@@ -667,12 +670,18 @@ def _end_run(
667
670
example_id ,
668
671
outputs ,
669
672
reference_outputs ,
673
+ metadata ,
670
674
pytest_plugin ,
671
675
pytest_nodeid ,
672
676
) -> None :
673
677
# TODO: remove this hack so that run durations are correct
674
678
# Ensure example is fully updated
675
- self .sync_example (example_id , inputs = run_tree .inputs , outputs = reference_outputs )
679
+ self .sync_example (
680
+ example_id ,
681
+ inputs = run_tree .inputs ,
682
+ outputs = reference_outputs ,
683
+ metadata = metadata ,
684
+ )
676
685
run_tree .end (outputs = outputs )
677
686
run_tree .patch ()
678
687
@@ -683,6 +692,7 @@ def __init__(
683
692
test_suite : _LangSmithTestSuite ,
684
693
example_id : uuid .UUID ,
685
694
run_id : uuid .UUID ,
695
+ metadata : Optional [dict ] = None ,
686
696
pytest_plugin : Any = None ,
687
697
pytest_nodeid : Any = None ,
688
698
inputs : Optional [dict ] = None ,
@@ -691,6 +701,7 @@ def __init__(
691
701
self .test_suite = test_suite
692
702
self .example_id = example_id
693
703
self .run_id = run_id
704
+ self .metadata = metadata
694
705
self .pytest_plugin = pytest_plugin
695
706
self .pytest_nodeid = pytest_nodeid
696
707
self .inputs = inputs
@@ -714,6 +725,7 @@ def sync_example(
714
725
self .example_id ,
715
726
inputs = inputs ,
716
727
outputs = outputs ,
728
+ metadata = self .metadata ,
717
729
pytest_plugin = self .pytest_plugin ,
718
730
pytest_nodeid = self .pytest_nodeid ,
719
731
)
@@ -783,6 +795,7 @@ def end_run(self, run_tree, outputs: Any) -> None:
783
795
self .example_id ,
784
796
outputs ,
785
797
reference_outputs = self ._logged_reference_outputs ,
798
+ metadata = self .metadata ,
786
799
pytest_plugin = self .pytest_plugin ,
787
800
pytest_nodeid = self .pytest_nodeid ,
788
801
)
@@ -797,14 +810,7 @@ class _UTExtra(TypedDict, total=False):
797
810
output_keys : Optional [Sequence [str ]]
798
811
test_suite_name : Optional [str ]
799
812
cache : Optional [str ]
800
-
801
-
802
- def _get_test_repr (func : Callable , sig : inspect .Signature ) -> str :
803
- name = getattr (func , "__name__" , None ) or ""
804
- description = getattr (func , "__doc__" , None ) or ""
805
- if description :
806
- description = f" - { description .strip ()} "
807
- return f"{ name } { sig } { description } "
813
+ metadata : Optional [dict ]
808
814
809
815
810
816
def _create_test_case (
@@ -816,6 +822,7 @@ def _create_test_case(
816
822
) -> _TestCase :
817
823
client = langtest_extra ["client" ] or rt .get_cached_client ()
818
824
output_keys = langtest_extra ["output_keys" ]
825
+ metadata = langtest_extra ["metadata" ]
819
826
signature = inspect .signature (func )
820
827
inputs = rh ._get_inputs_safe (signature , * args , ** kwargs ) or None
821
828
outputs = None
@@ -850,6 +857,7 @@ def _create_test_case(
850
857
test_suite ,
851
858
example_id ,
852
859
run_id = uuid .uuid4 (),
860
+ metadata = metadata ,
853
861
inputs = inputs ,
854
862
reference_outputs = outputs ,
855
863
pytest_plugin = pytest_plugin ,
@@ -881,6 +889,14 @@ def _test():
881
889
run_id = test_case .run_id ,
882
890
reference_example_id = test_case .example_id ,
883
891
inputs = test_case .inputs ,
892
+ metadata = {
893
+ # Experiment run metadata is prefixed with "ls_example_" in
894
+ # the ingest backend, but we must reproduce this behavior here
895
+ # because the example may not have been created before the trace
896
+ # starts.
897
+ f"ls_example_{ k } " : v
898
+ for k , v in (test_case .metadata or {}).items ()
899
+ },
884
900
project_name = test_case .test_suite .name ,
885
901
exceptions_to_handle = (SkipException ,),
886
902
_end_on_exit = False ,
@@ -950,6 +966,14 @@ async def _test():
950
966
run_id = test_case .run_id ,
951
967
reference_example_id = test_case .example_id ,
952
968
inputs = test_case .inputs ,
969
+ metadata = {
970
+ # Experiment run metadata is prefixed with "ls_example_" in
971
+ # the ingest backend, but we must reproduce this behavior here
972
+ # because the example may not have been created before the trace
973
+ # starts.
974
+ f"ls_example_{ k } " : v
975
+ for k , v in (test_case .metadata or {}).items ()
976
+ },
953
977
project_name = test_case .test_suite .name ,
954
978
exceptions_to_handle = (SkipException ,),
955
979
_end_on_exit = False ,
0 commit comments