diff --git a/src/promptflow-core/promptflow/_utils/async_utils.py b/src/promptflow-core/promptflow/_utils/async_utils.py index da346c7ccc7..dd019233bf8 100644 --- a/src/promptflow-core/promptflow/_utils/async_utils.py +++ b/src/promptflow-core/promptflow/_utils/async_utils.py @@ -42,14 +42,14 @@ def async_run_allowing_running_loop(async_func, *args, **kwargs): return asyncio.run(async_func(*args, **kwargs)) -def run_async_function_sync(func): +def async_to_sync(func): def wrapper(*args, **kwargs): return async_run_allowing_running_loop(func, *args, **kwargs) return wrapper -def run_sync_function_async(func): +def sync_to_async(func): async def wrapper(*args, **kwargs): with ThreadPoolExecutorWithContext() as executor: partial_func = functools.partial(func, *args, **kwargs) diff --git a/src/promptflow-core/promptflow/executor/_script_executor.py b/src/promptflow-core/promptflow/executor/_script_executor.py index 29db8191b79..fbf2dc6ae86 100644 --- a/src/promptflow-core/promptflow/executor/_script_executor.py +++ b/src/promptflow-core/promptflow/executor/_script_executor.py @@ -13,7 +13,7 @@ from promptflow._core.log_manager import NodeLogManager from promptflow._core.run_tracker import RunTracker from promptflow._core.tool_meta_generator import PythonLoadError -from promptflow._utils.async_utils import run_async_function_sync, run_sync_function_async +from promptflow._utils.async_utils import async_to_sync, sync_to_async from promptflow._utils.dataclass_serializer import convert_eager_flow_output_to_dict from promptflow._utils.exception_utils import ExceptionPresenter from promptflow._utils.logger_utils import logger @@ -432,11 +432,11 @@ def _initialize_function(self): inputs, _, _, _ = function_to_interface(func) self._inputs = {k: v.to_flow_input_definition() for k, v in inputs.items()} if inspect.iscoroutinefunction(func): - self._func = run_async_function_sync(func) + self._func = async_to_sync(func) self._func_async = func else: self._func = func - self._func_async = run_sync_function_async(func) + self._func_async = sync_to_async(func) return func def _initialize_aggr_function(self, flow_obj: object): @@ -450,11 +450,11 @@ def _initialize_aggr_function(self, flow_obj: object): if not hasattr(aggr_func, "__original_function"): aggr_func = _traced(aggr_func) if inspect.iscoroutinefunction(aggr_func): - self._aggr_func = run_async_function_sync(aggr_func) + self._aggr_func = async_to_sync(aggr_func) self._aggr_func_async = aggr_func else: self._aggr_func = aggr_func - self._aggr_func_async = run_sync_function_async(aggr_func) + self._aggr_func_async = sync_to_async(aggr_func) self._aggr_input_name = list(sign.parameters.keys())[0] def _parse_flow_file(self):