diff --git a/src/promptflow/promptflow/executor/_line_execution_process_pool.py b/src/promptflow/promptflow/executor/_line_execution_process_pool.py index 45c18af4e8b3..35675428e25a 100644 --- a/src/promptflow/promptflow/executor/_line_execution_process_pool.py +++ b/src/promptflow/promptflow/executor/_line_execution_process_pool.py @@ -1,6 +1,5 @@ import contextvars import math -import multiprocessing import os import queue import signal @@ -30,12 +29,9 @@ from promptflow.executor._errors import LineExecutionTimeoutError from promptflow.executor._result import LineResult from promptflow.executor.flow_executor import DEFAULT_CONCURRENCY_BULK, FlowExecutor -from ._preloaded_resolved_tools import preloaded_obj -from promptflow.storage import AbstractRunStorage +from . import shared_vars -flow_file = None -connections = None -working_dir = None +from promptflow.storage import AbstractRunStorage def signal_handler(signum, frame): @@ -156,6 +152,7 @@ def __init__( self._validate_inputs = validate_inputs self._worker_count = flow_executor._worker_count multiprocessing_start_method = os.environ.get("PF_BATCH_METHOD") + import multiprocessing sys_start_methods = multiprocessing.get_all_start_methods() if multiprocessing_start_method and multiprocessing_start_method not in sys_start_methods: bulk_logger.warning( @@ -172,17 +169,16 @@ def __init__( multiprocessing_start_method = "forkserver" self.context = get_multiprocessing_context(multiprocessing_start_method) use_fork = self.context.get_start_method() == "fork" + use_forkserver = self.context.get_start_method() == "forkserver" # When using fork, we use this method to create the executor to avoid reloading the flow # which will introduce a lot more memory. if use_fork: self._executor_creation_func = partial(create_executor_fork, flow_executor=flow_executor) - elif flow_executor._flow_file: - global flow_file - flow_file = flow_executor._flow_file - global connections - connections = flow_executor._connections - global working_dir - working_dir = flow_executor._working_dir + elif use_forkserver: + shared_vars.flow_file = flow_executor._flow_file + shared_vars.connections = flow_executor._connections + shared_vars.working_dir = flow_executor._working_dir + from ._preloaded_resolved_tools import preloaded_obj self._executor_creation_func = partial( FlowExecutor.create, flow_file=flow_executor._flow_file, @@ -191,6 +187,15 @@ def __init__( loaded_tools=preloaded_obj.tools, raise_ex=False, ) + elif flow_executor._flow_file: + self._executor_creation_func = partial( + FlowExecutor.create, + flow_file=flow_executor._flow_file, + connections=flow_executor._connections, + working_dir=flow_executor._working_dir, + loaded_tools=flow_executor._loaded_tools, + raise_ex=False, + ) else: # Legacy flow executor, will be deprecated with the legacy pf portal. self._executor_creation_func = partial( create_executor_legacy, @@ -200,6 +205,7 @@ def __init__( cache_manager=flow_executor._cache_manager, ) self._use_fork = use_fork + self._use_forkserver = use_forkserver self._storage = flow_executor._run_tracker._storage self._flow_id = flow_executor._flow_id self._log_interval = flow_executor._log_interval @@ -214,7 +220,7 @@ def __enter__(self): self._inputs_queue = Queue() # Starting a new process in non-fork mode requires to allocate memory. Determine the maximum number of processes # based on available memory to avoid memory bursting. - if not self._use_fork: + if not self._use_fork and not self._use_forkserver: available_max_worker_count = get_available_max_worker_count() self._n_process = min(self._worker_count, self._nlines, available_max_worker_count) bulk_logger.info(f"Not using fork, process count: {self._n_process}") @@ -382,6 +388,7 @@ def run(self, batch_inputs): ): try: if self.context.get_start_method() == "forkserver": + from ._preloaded_resolved_tools import preloaded_obj self.context.set_forkserver_preload(['_preloaded_resolved_tools']) # The variable 'async_result' here is not the actual result of the batch run # but an AsyncResult object that can be used to check if the execution are finished @@ -589,6 +596,7 @@ def get_available_max_worker_count(): def get_multiprocessing_context(multiprocessing_start_method=None): + import multiprocessing if multiprocessing_start_method is not None: context = multiprocessing.get_context(multiprocessing_start_method) bulk_logger.info(f"Set start method to {multiprocessing_start_method}.") diff --git a/src/promptflow/promptflow/executor/_preloaded_resolved_tools.py b/src/promptflow/promptflow/executor/_preloaded_resolved_tools.py index c8be3d2ab75c..85f6ac98223f 100644 --- a/src/promptflow/promptflow/executor/_preloaded_resolved_tools.py +++ b/src/promptflow/promptflow/executor/_preloaded_resolved_tools.py @@ -1,7 +1,7 @@ from promptflow.contracts.flow import Flow from promptflow.executor._tool_resolver import ToolResolver from promptflow._utils.context_utils import _change_working_dir -from ._line_execution_process_pool import flow_file, connections, working_dir +from .shared_vars import flow_file, connections, working_dir class PreloadeResolvedTools: @@ -13,11 +13,11 @@ def __init__(self, flow_file, connections, working_dir): tool_resolver = ToolResolver(working_dir, connections, package_tool_keys) with _change_working_dir(working_dir): - self.tools = [tool_resolver.resolve_tool_by_node(node) for node in flow.nodes] + self._tools = [tool_resolver.resolve_tool_by_node(node) for node in flow.nodes] @property def tools(self): - return self.tools + return self._tools preloaded_obj = PreloadeResolvedTools(flow_file, connections, working_dir) diff --git a/src/promptflow/promptflow/executor/shared_vars.py b/src/promptflow/promptflow/executor/shared_vars.py new file mode 100644 index 000000000000..c8178b6b456d --- /dev/null +++ b/src/promptflow/promptflow/executor/shared_vars.py @@ -0,0 +1,3 @@ +flow_file = None +connections = None +working_dir = None