diff --git a/src/promptflow/promptflow/_core/tracer.py b/src/promptflow/promptflow/_core/tracer.py index 658f6d46910..9a42a93b134 100644 --- a/src/promptflow/promptflow/_core/tracer.py +++ b/src/promptflow/promptflow/_core/tracer.py @@ -5,9 +5,9 @@ import inspect import json import logging +from collections.abc import Iterator from contextvars import ContextVar from datetime import datetime -from types import GeneratorType from typing import Optional from promptflow._core.generator_proxy import GeneratorProxy, generate_from_proxy @@ -106,7 +106,7 @@ def pop(cls, output=None, error: Optional[Exception] = None): def _pop(self, output=None, error: Optional[Exception] = None): last_trace = self._trace_stack[-1] - if isinstance(output, GeneratorType): + if isinstance(output, Iterator): output = GeneratorProxy(output) if output is not None: last_trace.output = self.to_serializable(output) diff --git a/src/promptflow/tests/executor/unittests/_core/test_generator_proxy.py b/src/promptflow/tests/executor/unittests/_core/test_generator_proxy.py index 7f8a2c74788..6ed5fe2e0c6 100644 --- a/src/promptflow/tests/executor/unittests/_core/test_generator_proxy.py +++ b/src/promptflow/tests/executor/unittests/_core/test_generator_proxy.py @@ -8,6 +8,10 @@ def generator(): yield i +def iterator(): + return iter([0, 1, 2]) + + @pytest.mark.unittest def test_generator_proxy_next(): proxy = GeneratorProxy(generator()) @@ -42,3 +46,39 @@ def test_generate_from_proxy(): assert i == next(original_generator) assert proxy.items == [0, 1, 2] + + +@pytest.mark.unittest +def test_iterator_proxy_next(): + proxy = GeneratorProxy(iterator()) + assert proxy.items == [] + assert next(proxy) == 0 + assert next(proxy) == 1 + assert next(proxy) == 2 + + with pytest.raises(StopIteration): + next(proxy) + + assert proxy.items == [0, 1, 2] + + +@pytest.mark.unittest +def test_iterator_proxy_iter(): + original_iterator = iterator() + proxy = GeneratorProxy(iterator()) + + for num in proxy: + assert num == next(original_iterator) + + assert proxy.items == [0, 1, 2] + + +@pytest.mark.unittest +def test_generate_from_iterator_proxy(): + proxy = GeneratorProxy(iterator()) + original_iterator = iterator() + + for i in generate_from_proxy(proxy): + assert i == next(original_iterator) + + assert proxy.items == [0, 1, 2]