Skip to content

Commit

Permalink
use forkserver
Browse files Browse the repository at this point in the history
  • Loading branch information
Hhhilulu committed Dec 6, 2023
1 parent 26934f4 commit eee60a4
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 17 deletions.
36 changes: 22 additions & 14 deletions src/promptflow/promptflow/executor/_line_execution_process_pool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import contextvars
import math
import multiprocessing
import os
import queue
import signal
Expand Down Expand Up @@ -30,12 +29,9 @@
from promptflow.executor._errors import LineExecutionTimeoutError
from promptflow.executor._result import LineResult
from promptflow.executor.flow_executor import DEFAULT_CONCURRENCY_BULK, FlowExecutor
from ._preloaded_resolved_tools import preloaded_obj
from promptflow.storage import AbstractRunStorage
from . import shared_vars

flow_file = None
connections = None
working_dir = None
from promptflow.storage import AbstractRunStorage


def signal_handler(signum, frame):
Expand Down Expand Up @@ -156,6 +152,7 @@ def __init__(
self._validate_inputs = validate_inputs
self._worker_count = flow_executor._worker_count
multiprocessing_start_method = os.environ.get("PF_BATCH_METHOD")
import multiprocessing
sys_start_methods = multiprocessing.get_all_start_methods()
if multiprocessing_start_method and multiprocessing_start_method not in sys_start_methods:
bulk_logger.warning(
Expand All @@ -172,17 +169,16 @@ def __init__(
multiprocessing_start_method = "forkserver"
self.context = get_multiprocessing_context(multiprocessing_start_method)
use_fork = self.context.get_start_method() == "fork"
use_forkserver = self.context.get_start_method() == "forkserver"
# When using fork, we use this method to create the executor to avoid reloading the flow
# which will introduce a lot more memory.
if use_fork:
self._executor_creation_func = partial(create_executor_fork, flow_executor=flow_executor)
elif flow_executor._flow_file:
global flow_file
flow_file = flow_executor._flow_file
global connections
connections = flow_executor._connections
global working_dir
working_dir = flow_executor._working_dir
elif use_forkserver:
shared_vars.flow_file = flow_executor._flow_file
shared_vars.connections = flow_executor._connections
shared_vars.working_dir = flow_executor._working_dir
from ._preloaded_resolved_tools import preloaded_obj
self._executor_creation_func = partial(
FlowExecutor.create,
flow_file=flow_executor._flow_file,
Expand All @@ -191,6 +187,15 @@ def __init__(
loaded_tools=preloaded_obj.tools,
raise_ex=False,
)
elif flow_executor._flow_file:
self._executor_creation_func = partial(
FlowExecutor.create,
flow_file=flow_executor._flow_file,
connections=flow_executor._connections,
working_dir=flow_executor._working_dir,
loaded_tools=flow_executor._loaded_tools,
raise_ex=False,
)
else: # Legacy flow executor, will be deprecated with the legacy pf portal.
self._executor_creation_func = partial(
create_executor_legacy,
Expand All @@ -200,6 +205,7 @@ def __init__(
cache_manager=flow_executor._cache_manager,
)
self._use_fork = use_fork
self._use_forkserver = use_forkserver
self._storage = flow_executor._run_tracker._storage
self._flow_id = flow_executor._flow_id
self._log_interval = flow_executor._log_interval
Expand All @@ -214,7 +220,7 @@ def __enter__(self):
self._inputs_queue = Queue()
# Starting a new process in non-fork mode requires to allocate memory. Determine the maximum number of processes
# based on available memory to avoid memory bursting.
if not self._use_fork:
if not self._use_fork and not self._use_forkserver:
available_max_worker_count = get_available_max_worker_count()
self._n_process = min(self._worker_count, self._nlines, available_max_worker_count)
bulk_logger.info(f"Not using fork, process count: {self._n_process}")
Expand Down Expand Up @@ -381,6 +387,7 @@ def run(self, batch_inputs):
):
try:
if self.context.get_start_method() == "forkserver":
from ._preloaded_resolved_tools import preloaded_obj
self.context.set_forkserver_preload(['_preloaded_resolved_tools'])
# The variable 'async_result' here is not the actual result of the batch run
# but an AsyncResult object that can be used to check if the execution are finished
Expand Down Expand Up @@ -588,6 +595,7 @@ def get_available_max_worker_count():


def get_multiprocessing_context(multiprocessing_start_method=None):
import multiprocessing
if multiprocessing_start_method is not None:
context = multiprocessing.get_context(multiprocessing_start_method)
bulk_logger.info(f"Set start method to {multiprocessing_start_method}.")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from promptflow.contracts.flow import Flow
from promptflow.executor._tool_resolver import ToolResolver
from promptflow._utils.context_utils import _change_working_dir
from ._line_execution_process_pool import flow_file, connections, working_dir
from .shared_vars import flow_file, connections, working_dir


class PreloadeResolvedTools:
Expand All @@ -13,11 +13,11 @@ def __init__(self, flow_file, connections, working_dir):
tool_resolver = ToolResolver(working_dir, connections, package_tool_keys)

with _change_working_dir(working_dir):
self.tools = [tool_resolver.resolve_tool_by_node(node) for node in flow.nodes]
self._tools = [tool_resolver.resolve_tool_by_node(node) for node in flow.nodes]

@property
def tools(self):
return self.tools
return self._tools


preloaded_obj = PreloadeResolvedTools(flow_file, connections, working_dir)
3 changes: 3 additions & 0 deletions src/promptflow/promptflow/executor/shared_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
flow_file = None
connections = None
working_dir = None

0 comments on commit eee60a4

Please sign in to comment.