Skip to content

Commit

Permalink
Refine exec line in script executor
Browse files Browse the repository at this point in the history
  • Loading branch information
Lina Tang committed Apr 25, 2024
1 parent 97f059e commit d39d220
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 26 deletions.
20 changes: 20 additions & 0 deletions src/promptflow-core/promptflow/_utils/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

import asyncio
import contextvars
import functools
from concurrent.futures import ThreadPoolExecutor

from promptflow._utils.utils import set_context
from promptflow.tracing import ThreadPoolExecutorWithContext


def _has_running_loop() -> bool:
Expand Down Expand Up @@ -38,3 +40,21 @@ def async_run_allowing_running_loop(async_func, *args, **kwargs):
return executor.submit(lambda: asyncio.run(async_func(*args, **kwargs))).result()
else:
return asyncio.run(async_func(*args, **kwargs))


def run_async_function_sync(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return async_run_allowing_running_loop(func, *args, **kwargs)

return wrapper


def run_sync_function_async(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
with ThreadPoolExecutorWithContext() as executor:
partial_func = functools.partial(func, *args, **kwargs)
return await asyncio.get_event_loop().run_in_executor(executor, partial_func)

return wrapper
40 changes: 14 additions & 26 deletions src/promptflow-core/promptflow/executor/_script_executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import contextlib
import dataclasses
import importlib
Expand All @@ -14,7 +13,7 @@
from promptflow._core.log_manager import NodeLogManager
from promptflow._core.run_tracker import RunTracker
from promptflow._core.tool_meta_generator import PythonLoadError
from promptflow._utils.async_utils import async_run_allowing_running_loop
from promptflow._utils.async_utils import run_async_function_sync, run_sync_function_async
from promptflow._utils.dataclass_serializer import convert_eager_flow_output_to_dict
from promptflow._utils.exception_utils import ExceptionPresenter
from promptflow._utils.logger_utils import logger
Expand All @@ -35,7 +34,6 @@
from promptflow.executor._result import AggregationResult, LineResult
from promptflow.storage import AbstractRunStorage
from promptflow.storage._run_storage import DefaultRunStorage
from promptflow.tracing import ThreadPoolExecutorWithContext
from promptflow.tracing._trace import _traced
from promptflow.tracing._tracer import Tracer
from promptflow.tracing.contracts.trace import TraceType
Expand Down Expand Up @@ -135,10 +133,7 @@ def _exec_line(
line_run_id = run_info.run_id
try:
Tracer.start_tracing(line_run_id)
if self._is_async:
output = asyncio.run(self._func(**inputs))
else:
output = self._func(**inputs)
output = self._func(**inputs)
output = self._stringify_generator_output(output) if not allow_generator_output else output
traces = Tracer.end_tracing(line_run_id)
# Should convert output to dict before storing it to run info, since we will add key 'line_number' to it,
Expand Down Expand Up @@ -262,11 +257,7 @@ async def _exec_line_async(
line_run_id = run_info.run_id
try:
Tracer.start_tracing(line_run_id)
if self._is_async:
output = await self._func(**inputs)
else:
partial_func = partial(self._func, **inputs)
output = await asyncio.get_event_loop().run_in_executor(None, partial_func)
output = await self._func_async(**inputs)
output = self._stringify_generator_output(output) if not allow_generator_output else output
traces = Tracer.end_tracing(line_run_id)
output_dict = convert_eager_flow_output_to_dict(output)
Expand Down Expand Up @@ -429,14 +420,21 @@ def _initialize_function(self):
if inspect.ismethod(func):
# For class method, the original function is a function reference that not bound to any object,
# so we need to pass the instance to it.
if func.__qualname__.endswith(".__call__"):
name = func.__qualname__[: -len(".__call__")]
func = _traced(
partial(getattr(func, "__original_function"), self=func.__self__),
trace_type=TraceType.FLOW,
name=func.__qualname__,
name=name,
)
else:
func = _traced(getattr(func, "__original_function"), trace_type=TraceType.FLOW)
self._func = func
if inspect.iscoroutinefunction(func):
self._func = run_async_function_sync(func)
self._func_async = func
else:
self._func = func
self._func_async = run_sync_function_async(func)
inputs, _, _, _ = function_to_interface(self._func)
self._inputs = {k: v.to_flow_input_definition() for k, v in inputs.items()}
self._is_async = inspect.iscoroutinefunction(self._func)
Expand All @@ -453,21 +451,11 @@ def _initialize_aggr_function(self, flow_obj: object):
if not hasattr(aggr_func, "__original_function"):
aggr_func = _traced(aggr_func)
if inspect.iscoroutinefunction(aggr_func):

def run_async_function_sync(*args, **kwargs):
return async_run_allowing_running_loop(aggr_func, *args, **kwargs)

self._aggr_func = run_async_function_sync
self._aggr_func = run_async_function_sync(aggr_func)
self._aggr_func_async = aggr_func
else:

async def run_sync_function_async(*args, **kwargs):
with ThreadPoolExecutorWithContext() as executor:
partial_func = partial(aggr_func, *args, **kwargs)
return await asyncio.get_event_loop().run_in_executor(executor, partial_func)

self._aggr_func = aggr_func
self._aggr_func_async = run_sync_function_async
self._aggr_func_async = run_sync_function_async(aggr_func)
self._aggr_input_name = list(sign.parameters.keys())[0]

def _parse_flow_file(self):
Expand Down

0 comments on commit d39d220

Please sign in to comment.