From 54a29fba6fe9f8c316d8e9c4bd827573446859af Mon Sep 17 00:00:00 2001 From: Cheng Liu <51689021+liucheng-ms@users.noreply.github.com> Date: Thu, 21 Sep 2023 18:17:36 +0800 Subject: [PATCH] [Executor] Add iterator output support (#576) # Description This pull request refactors the output check logic to use `GeneratorProxy`. Previously, when the output was a `GeneratorType`, we would wrap a proxy class that holds the generator as output. However, this was an over condition, as we only need to wrap a proxy class on iterator types as output. Therefore, we did the following changes: - Replace `isinstance(output, GeneratorType)` with `isinstance(output, Iterator)` on output wrapper logic. This pull request fixes #561 # All Promptflow Contribution checklist: - [X] **The pull request does not introduce [breaking changes]** - [X] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [X] **I have read the [contribution guidelines](../CONTRIBUTING.md).** ## General Guidelines and Best Practices - [X] Title of the pull request is clear and informative. - [X] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [X] Pull request includes test coverage for the included changes. --- src/promptflow/promptflow/_core/tracer.py | 4 +- .../unittests/_core/test_generator_proxy.py | 40 +++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/src/promptflow/promptflow/_core/tracer.py b/src/promptflow/promptflow/_core/tracer.py index 658f6d46910..9a42a93b134 100644 --- a/src/promptflow/promptflow/_core/tracer.py +++ b/src/promptflow/promptflow/_core/tracer.py @@ -5,9 +5,9 @@ import inspect import json import logging +from collections.abc import Iterator from contextvars import ContextVar from datetime import datetime -from types import GeneratorType from typing import Optional from promptflow._core.generator_proxy import GeneratorProxy, generate_from_proxy @@ -106,7 +106,7 @@ def pop(cls, output=None, error: Optional[Exception] = None): def _pop(self, output=None, error: Optional[Exception] = None): last_trace = self._trace_stack[-1] - if isinstance(output, GeneratorType): + if isinstance(output, Iterator): output = GeneratorProxy(output) if output is not None: last_trace.output = self.to_serializable(output) diff --git a/src/promptflow/tests/executor/unittests/_core/test_generator_proxy.py b/src/promptflow/tests/executor/unittests/_core/test_generator_proxy.py index 7f8a2c74788..6ed5fe2e0c6 100644 --- a/src/promptflow/tests/executor/unittests/_core/test_generator_proxy.py +++ b/src/promptflow/tests/executor/unittests/_core/test_generator_proxy.py @@ -8,6 +8,10 @@ def generator(): yield i +def iterator(): + return iter([0, 1, 2]) + + @pytest.mark.unittest def test_generator_proxy_next(): proxy = GeneratorProxy(generator()) @@ -42,3 +46,39 @@ def test_generate_from_proxy(): assert i == next(original_generator) assert proxy.items == [0, 1, 2] + + +@pytest.mark.unittest +def test_iterator_proxy_next(): + proxy = GeneratorProxy(iterator()) + assert proxy.items == [] + assert next(proxy) == 0 + assert next(proxy) == 1 + assert next(proxy) == 2 + + with pytest.raises(StopIteration): + next(proxy) + + assert proxy.items == [0, 1, 2] + + +@pytest.mark.unittest +def test_iterator_proxy_iter(): + original_iterator = iterator() + proxy = GeneratorProxy(iterator()) + + for num in proxy: + assert num == next(original_iterator) + + assert proxy.items == [0, 1, 2] + + +@pytest.mark.unittest +def test_generate_from_iterator_proxy(): + proxy = GeneratorProxy(iterator()) + original_iterator = iterator() + + for i in generate_from_proxy(proxy): + assert i == next(original_iterator) + + assert proxy.items == [0, 1, 2]