diff --git a/src/promptflow-core/promptflow/_utils/async_utils.py b/src/promptflow-core/promptflow/_utils/async_utils.py index afa648e8524..0bec611f936 100644 --- a/src/promptflow-core/promptflow/_utils/async_utils.py +++ b/src/promptflow-core/promptflow/_utils/async_utils.py @@ -4,9 +4,11 @@ import asyncio import contextvars +import functools from concurrent.futures import ThreadPoolExecutor from promptflow._utils.utils import set_context +from promptflow.tracing import ThreadPoolExecutorWithContext def _has_running_loop() -> bool: @@ -38,3 +40,21 @@ def async_run_allowing_running_loop(async_func, *args, **kwargs): return executor.submit(lambda: asyncio.run(async_func(*args, **kwargs))).result() else: return asyncio.run(async_func(*args, **kwargs)) + + +def run_async_function_sync(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return async_run_allowing_running_loop(func, *args, **kwargs) + + return wrapper + + +def run_sync_function_async(func): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + with ThreadPoolExecutorWithContext() as executor: + partial_func = functools.partial(func, *args, **kwargs) + return await asyncio.get_event_loop().run_in_executor(executor, partial_func) + + return wrapper diff --git a/src/promptflow-core/promptflow/executor/_script_executor.py b/src/promptflow-core/promptflow/executor/_script_executor.py index 744fa5f6834..bd42cdec1b0 100644 --- a/src/promptflow-core/promptflow/executor/_script_executor.py +++ b/src/promptflow-core/promptflow/executor/_script_executor.py @@ -1,4 +1,3 @@ -import asyncio import contextlib import dataclasses import importlib @@ -14,7 +13,7 @@ from promptflow._core.log_manager import NodeLogManager from promptflow._core.run_tracker import RunTracker from promptflow._core.tool_meta_generator import PythonLoadError -from promptflow._utils.async_utils import async_run_allowing_running_loop +from promptflow._utils.async_utils import run_async_function_sync, run_sync_function_async from promptflow._utils.dataclass_serializer import convert_eager_flow_output_to_dict from promptflow._utils.exception_utils import ExceptionPresenter from promptflow._utils.logger_utils import logger @@ -35,7 +34,6 @@ from promptflow.executor._result import AggregationResult, LineResult from promptflow.storage import AbstractRunStorage from promptflow.storage._run_storage import DefaultRunStorage -from promptflow.tracing import ThreadPoolExecutorWithContext from promptflow.tracing._trace import _traced from promptflow.tracing._tracer import Tracer from promptflow.tracing.contracts.trace import TraceType @@ -135,10 +133,7 @@ def _exec_line( line_run_id = run_info.run_id try: Tracer.start_tracing(line_run_id) - if self._is_async: - output = asyncio.run(self._func(**inputs)) - else: - output = self._func(**inputs) + output = self._func(**inputs) output = self._stringify_generator_output(output) if not allow_generator_output else output traces = Tracer.end_tracing(line_run_id) # Should convert output to dict before storing it to run info, since we will add key 'line_number' to it, @@ -262,11 +257,7 @@ async def _exec_line_async( line_run_id = run_info.run_id try: Tracer.start_tracing(line_run_id) - if self._is_async: - output = await self._func(**inputs) - else: - partial_func = partial(self._func, **inputs) - output = await asyncio.get_event_loop().run_in_executor(None, partial_func) + output = await self._func_async(**inputs) output = self._stringify_generator_output(output) if not allow_generator_output else output traces = Tracer.end_tracing(line_run_id) output_dict = convert_eager_flow_output_to_dict(output) @@ -429,14 +420,21 @@ def _initialize_function(self): if inspect.ismethod(func): # For class method, the original function is a function reference that not bound to any object, # so we need to pass the instance to it. + if func.__qualname__.endswith(".__call__"): + name = func.__qualname__[: -len(".__call__")] func = _traced( partial(getattr(func, "__original_function"), self=func.__self__), trace_type=TraceType.FLOW, - name=func.__qualname__, + name=name, ) else: func = _traced(getattr(func, "__original_function"), trace_type=TraceType.FLOW) - self._func = func + if inspect.iscoroutinefunction(func): + self._func = run_async_function_sync(func) + self._func_async = func + else: + self._func = func + self._func_async = run_sync_function_async(func) inputs, _, _, _ = function_to_interface(self._func) self._inputs = {k: v.to_flow_input_definition() for k, v in inputs.items()} self._is_async = inspect.iscoroutinefunction(self._func) @@ -453,21 +451,11 @@ def _initialize_aggr_function(self, flow_obj: object): if not hasattr(aggr_func, "__original_function"): aggr_func = _traced(aggr_func) if inspect.iscoroutinefunction(aggr_func): - - def run_async_function_sync(*args, **kwargs): - return async_run_allowing_running_loop(aggr_func, *args, **kwargs) - - self._aggr_func = run_async_function_sync + self._aggr_func = run_async_function_sync(aggr_func) self._aggr_func_async = aggr_func else: - - async def run_sync_function_async(*args, **kwargs): - with ThreadPoolExecutorWithContext() as executor: - partial_func = partial(aggr_func, *args, **kwargs) - return await asyncio.get_event_loop().run_in_executor(executor, partial_func) - self._aggr_func = aggr_func - self._aggr_func_async = run_sync_function_async + self._aggr_func_async = run_sync_function_async(aggr_func) self._aggr_input_name = list(sign.parameters.keys())[0] def _parse_flow_file(self):