Skip to content

Commit

Permalink
[Executor] Add iterator output support (#576)
Browse files Browse the repository at this point in the history
# Description

This pull request refactors the output check logic to use
`GeneratorProxy`. Previously, when the output was a `GeneratorType`, we
would wrap a proxy class that holds the generator as output. However,
this was an over condition, as we only need to wrap a proxy class on
iterator types as output. Therefore, we did the following changes:

- Replace `isinstance(output, GeneratorType)` with `isinstance(output,
Iterator)` on output wrapper logic.

This pull request fixes #561 

# All Promptflow Contribution checklist:
- [X] **The pull request does not introduce [breaking changes]**
- [X] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [X] **I have read the [contribution guidelines](../CONTRIBUTING.md).**

## General Guidelines and Best Practices
- [X] Title of the pull request is clear and informative.
- [X] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [X] Pull request includes test coverage for the included changes.
  • Loading branch information
liucheng-ms authored Sep 21, 2023
1 parent 9c8cf8b commit 54a29fb
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/promptflow/promptflow/_core/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import inspect
import json
import logging
from collections.abc import Iterator
from contextvars import ContextVar
from datetime import datetime
from types import GeneratorType
from typing import Optional

from promptflow._core.generator_proxy import GeneratorProxy, generate_from_proxy
Expand Down Expand Up @@ -106,7 +106,7 @@ def pop(cls, output=None, error: Optional[Exception] = None):

def _pop(self, output=None, error: Optional[Exception] = None):
last_trace = self._trace_stack[-1]
if isinstance(output, GeneratorType):
if isinstance(output, Iterator):
output = GeneratorProxy(output)
if output is not None:
last_trace.output = self.to_serializable(output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ def generator():
yield i


def iterator():
return iter([0, 1, 2])


@pytest.mark.unittest
def test_generator_proxy_next():
proxy = GeneratorProxy(generator())
Expand Down Expand Up @@ -42,3 +46,39 @@ def test_generate_from_proxy():
assert i == next(original_generator)

assert proxy.items == [0, 1, 2]


@pytest.mark.unittest
def test_iterator_proxy_next():
proxy = GeneratorProxy(iterator())
assert proxy.items == []
assert next(proxy) == 0
assert next(proxy) == 1
assert next(proxy) == 2

with pytest.raises(StopIteration):
next(proxy)

assert proxy.items == [0, 1, 2]


@pytest.mark.unittest
def test_iterator_proxy_iter():
original_iterator = iterator()
proxy = GeneratorProxy(iterator())

for num in proxy:
assert num == next(original_iterator)

assert proxy.items == [0, 1, 2]


@pytest.mark.unittest
def test_generate_from_iterator_proxy():
proxy = GeneratorProxy(iterator())
original_iterator = iterator()

for i in generate_from_proxy(proxy):
assert i == next(original_iterator)

assert proxy.items == [0, 1, 2]

0 comments on commit 54a29fb

Please sign in to comment.