Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Support async callable classes in flat_map() #51180

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 67 additions & 4 deletions python/ray/data/_internal/planner/plan_udf_map_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,12 +510,75 @@ def transform_fn(rows: Iterable[Row], _: TaskContext) -> Iterable[Row]:

def _generate_transform_fn_for_flat_map(
fn: UserDefinedFunction,
) -> MapTransformCallable[Row, Row]:
if inspect.iscoroutinefunction(fn):
# UDF is a callable class with async generator `__call__` method.
transform_fn = _generate_transform_fn_for_async_flat_map(fn)

else:

def transform_fn(rows: Iterable[Row], _: TaskContext) -> Iterable[Row]:
for row in rows:
for out_row in fn(row):
_validate_row_output(out_row)
yield out_row

return transform_fn


def _generate_transform_fn_for_async_flat_map(
fn: UserDefinedFunction,
) -> MapTransformCallable[Row, Row]:
def transform_fn(rows: Iterable[Row], _: TaskContext) -> Iterable[Row]:
for row in rows:
for out_row in fn(row):
_validate_row_output(out_row)
yield out_row
# Use a queue to store outputs from async generator calls.
# We will put output rows into this queue from async
# generators, and in the main event loop, yield them from
# the queue as they become available.
output_row_queue = queue.Queue()
# Sentinel object to signal the end of the async generator.
sentinel = object()

async def process_row(row: Row):
try:
output_row_iterator = await fn(row)
# As soon as results become available from the async generator,
# put them into the result queue so they can be yielded.
async for output_row in output_row_iterator:
output_row_queue.put(output_row)
except Exception as e:
output_row_queue.put(
e
) # Put the exception into the queue to signal an error

async def process_all_rows():
try:
loop = ray.data._map_actor_context.udf_map_asyncio_loop
tasks = [loop.create_task(process_row(x)) for x in rows]

ctx = ray.data.DataContext.get_current()
if ctx.execution_options.preserve_order:
for task in tasks:
await task
else:
for task in asyncio.as_completed(tasks):
await task
finally:
output_row_queue.put(sentinel)

# Use the existing event loop to create and run Tasks to process each row
loop = ray.data._map_actor_context.udf_map_asyncio_loop
asyncio.run_coroutine_threadsafe(process_all_rows(), loop)

# Yield results as they become available.
while True:
out_row = output_row_queue.get()
if out_row is sentinel:
# Break out of the loop when the sentinel is received.
break
if isinstance(out_row, Exception):
raise out_row
_validate_row_output(out_row)
yield out_row

return transform_fn

Expand Down
26 changes: 26 additions & 0 deletions python/ray/data/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -1676,6 +1676,32 @@ async def __call__(self, batch):
)


def test_flat_map_async_generator(shutdown_only):
async def fetch_data(id):
return {"id": id}

class AsyncActor:
def __init__(self):
pass

async def __call__(self, row):
id = row["id"]
task1 = asyncio.create_task(fetch_data(id))
task2 = asyncio.create_task(fetch_data(id + 1))
print(f"yield task1: {id}")
yield await task1
print(f"sleep: {id}")
await asyncio.sleep(id % 5)
print(f"yield task2: {id}")
yield await task2

n = 10
ds = ray.data.from_items([{"id": i} for i in range(0, n, 2)])
ds = ds.flat_map(AsyncActor, concurrency=1, max_concurrency=2)
output = ds.take_all()
assert sorted(extract_values("id", output)) == list(range(0, n)), output


def test_map_batches_async_exception_propagation(shutdown_only):
ray.shutdown()
ray.init(num_cpus=2)
Expand Down