From 0b18ee62484c0b90d8f967f428ef291843ba8de9 Mon Sep 17 00:00:00 2001 From: Hhhilulu <115983968+Hhhilulu@users.noreply.github.com> Date: Thu, 7 Dec 2023 14:08:05 +0800 Subject: [PATCH] [PromptFlow] Optimize memory computation (#1328) # Description **Optimize memory computation** 1. Optimize the calculation logic of the maximum number of processes 2. Prioritize using the environment variable "PF_WORKER_COUNT" as the number of processes if users set. The following table summarizes the logic for determining the number of processes and corresponding log messages under different conditions (based on whether fork is used and the setting of the environment variable PF_WORKER_COUNT): <div style="overflow-x:auto;"> <table style="table-layout: auto;"> <tr> <th style="white-space: nowrap;">Scenario </th> <th style="white-space: nowrap;">Fork_Used </th> <th style="white-space: nowrap;">PF_WORKER_COUNT Set</th> <th style="white-space: nowrap;">Worker Count</th> <th style="white-space: nowrap;">NLines</th> <th style="white-space: nowrap;">Estimated Available Worker Count</th> <th style="white-space: nowrap;">Calculation of Processes</th> </tr> <tr> <td>1</td> <td>No(spawn)</td> <td>No</td> <td>16</td> <td>4</td> <td><strong>2</strong></td> <td>Min(16, 4, 2) = <strong>2</strong></td> </tr> <tr> <td colspan="8" style="white-space: normal;"> <strong style="font-size: larger;">Log Message:</strong><br> Set process count to <strong>2</strong> by taking the minimum value among the factors of {'default_worker_count': 16, 'row_count': 4, 'estimated_worker_count_based_on_memory_usage': 2}.</td> </tr> <tr> <td>2</td> <td>No(spawn)</td> <td>Yes (<strong>2</strong>)</td> <td>2</td> <td>4</td> <td>2</td> <td><strong>PF_WORKER_COUNT = 2</strong></td> </tr> <tr> <td colspan="8" style="white-space: normal;"> <strong style="font-size: larger;">Log Message:</strong><br> Set process count to <strong>2</strong> with the environment variable 'PF_WORKER_COUNT'. </td> </tr> <tr> <td>3</td> <td>No(spawn)</td> <td>Yes (<strong>6</strong>)</td> <td>6</td> <td>4</td> <td>2</td> <td><strong>PF_WORKER_COUNT = 6</strong></td> </tr> <tr> <td colspan="8" style="white-space: normal;"> <strong style="font-size: larger;">Log Message:</strong><br> Set process count to <strong>6</strong> with the environment variable 'PF_WORKER_COUNT'.<br> <strong style="font-size: larger;">Warning Message:</strong><br> The current process count (<strong>6</strong>) is larger than recommended process count (2) that estimated by system available memory. This may cause memory exhaustion.</td> </tr> <tr> <td>4</td> <td>Yes(fork)</td> <td>No</td> <td>16</td> <td><strong>4</strong></td> <td>N/A</td> <td>Min(16, 4) = <strong>4</strong></td> </tr> <tr> <td colspan="8" style="white-space: normal;"> <strong style="font-size: larger;">Log Message:</strong><br> Set process count to <strong>4</strong> by taking the minimum value among the factors of {'default_worker_count': 16, 'row_count': 4}.</td> </tr> <tr> <td>5</td> <td>Yes(fork)</td> <td>Yes (<strong>3</strong>)</td> <td>3</td> <td>4</td> <td>N/A</td> <td><strong>PF_WORKER_COUNT = 3</strong></td> </tr> <tr> <td colspan="8" style="white-space: normal;"> <strong style="font-size: larger;">Log Message:</strong><br> Set process count to <strong>3</strong> with the environment variable 'PF_WORKER_COUNT'.</td> </tr> </table> </div> # All Promptflow Contribution checklist: - [ ] **The pull request does not introduce [breaking changes].** - [ ] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).** - [ ] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [ ] Title of the pull request is clear and informative. - [ ] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [ ] Pull request includes test coverage for the included changes. --- src/promptflow/promptflow/_utils/utils.py | 18 +- .../executor/_line_execution_process_pool.py | 86 ++++-- .../promptflow/executor/flow_executor.py | 14 - .../executor/unittests/_utils/test_utils.py | 28 +- .../unittests/executor/test_flow_executor.py | 32 +-- .../test_line_execution_process_pool.py | 263 ++++++++++++++++-- 6 files changed, 337 insertions(+), 104 deletions(-) diff --git a/src/promptflow/promptflow/_utils/utils.py b/src/promptflow/promptflow/_utils/utils.py index fe7b3543906..5ae9dcd5f1c 100644 --- a/src/promptflow/promptflow/_utils/utils.py +++ b/src/promptflow/promptflow/_utils/utils.py @@ -167,7 +167,7 @@ def extract_user_frame_summaries(frame_summaries: List[traceback.FrameSummary]): # If the current frame is in tool.py and the next frame is not in _core folder # then we can say that the next frame is in user code. if cur_file == tool_file and not next_file.startswith(core_folder): - return frame_summaries[i + 1 :] + return frame_summaries[i + 1:] return frame_summaries @@ -247,3 +247,19 @@ def parse_ua_to_dict(ua): key, value = item.split("/") ua_dict[key] = value return ua_dict + + +def get_int_env_var(env_var_name, default_value=None): + """ + The function `get_int_env_var` retrieves an integer environment variable value, with an optional + default value if the variable is not set or cannot be converted to an integer. + + :param env_var_name: The name of the environment variable you want to retrieve the value of + :param default_value: The default value is the value that will be returned if the environment + variable is not found or if it cannot be converted to an integer + :return: an integer value. + """ + try: + return int(os.environ.get(env_var_name, default_value)) + except Exception: + return default_value diff --git a/src/promptflow/promptflow/executor/_line_execution_process_pool.py b/src/promptflow/promptflow/executor/_line_execution_process_pool.py index 1da653b4a49..3e8f50f5e68 100644 --- a/src/promptflow/promptflow/executor/_line_execution_process_pool.py +++ b/src/promptflow/promptflow/executor/_line_execution_process_pool.py @@ -1,5 +1,4 @@ import contextvars -import math import multiprocessing import os import queue @@ -21,7 +20,7 @@ from promptflow._utils.logger_utils import LogContext, bulk_logger from promptflow._utils.multimedia_utils import _process_recursively, persist_multimedia_data from promptflow._utils.thread_utils import RepeatLogTimer -from promptflow._utils.utils import log_progress, set_context +from promptflow._utils.utils import log_progress, set_context, get_int_env_var from promptflow.contracts.multimedia import Image from promptflow.contracts.run_info import FlowRunInfo from promptflow.contracts.run_info import RunInfo as NodeRunInfo @@ -136,6 +135,8 @@ def format_current_process(self, line_number: int, is_completed=False): class LineExecutionProcessPool: + _DEFAULT_WORKER_COUNT = 16 + def __init__( self, flow_executor: FlowExecutor, @@ -149,7 +150,6 @@ def __init__( self._run_id = run_id self._variant_id = variant_id self._validate_inputs = validate_inputs - self._worker_count = flow_executor._worker_count multiprocessing_start_method = os.environ.get("PF_BATCH_METHOD") sys_start_methods = multiprocessing.get_all_start_methods() if multiprocessing_start_method and multiprocessing_start_method not in sys_start_methods: @@ -194,15 +194,8 @@ def __enter__(self): self._completed_idx = manager.dict() self._inputs_queue = Queue() - # Starting a new process in non-fork mode requires to allocate memory. Determine the maximum number of processes - # based on available memory to avoid memory bursting. - if not self._use_fork: - available_max_worker_count = get_available_max_worker_count() - self._n_process = min(self._worker_count, self._nlines, available_max_worker_count) - bulk_logger.info(f"Not using fork, process count: {self._n_process}") - else: - self._n_process = min(self._worker_count, self._nlines) - bulk_logger.info(f"Using fork, process count: {self._n_process}") + self._n_process = self._determine_worker_count() + pool = ThreadPool(self._n_process, initializer=set_context, initargs=(contextvars.copy_context(),)) self._pool = pool @@ -411,6 +404,43 @@ def _generate_thread_status_messages(self, pool: ThreadPool, total_count: int): msgs.append("Processing Lines: " + ", ".join(lines) + ".") return msgs + def _determine_worker_count(self): + worker_count = get_int_env_var("PF_WORKER_COUNT") + + # Starting a new process in non-fork mode requires to allocate memory. Calculate the maximum number of processes + # based on available memory to avoid memory bursting. + estimated_available_worker_count = get_available_max_worker_count() if not self._use_fork else None + + # If the environment variable PF_WORKER_COUNT exists and valid, use the value as the worker_count. + if worker_count is not None and worker_count > 0: + self._log_set_worker_count(worker_count, estimated_available_worker_count) + return worker_count + + # If the environment variable PF_WORKER_COUNT is not set or invalid, take the minimum value among the + # factors: default_worker_count, row_count and estimated_worker_count_based_on_memory_usage + factors = { + "default_worker_count": self._DEFAULT_WORKER_COUNT, + "row_count": self._nlines, + "estimated_worker_count_based_on_memory_usage": estimated_available_worker_count, + } + + valid_factors = {k: v for k, v in factors.items() if v is not None and v > 0} + + # Take the minimum value as the result + worker_count = min(valid_factors.values()) + bulk_logger.info( + f"Set process count to {worker_count} by taking the minimum value among the factors of {valid_factors}.") + return worker_count + + def _log_set_worker_count(self, worker_count, estimated_available_worker_count): + bulk_logger.info( + f"Set process count to {worker_count} with the environment variable 'PF_WORKER_COUNT'.") + if estimated_available_worker_count is not None and estimated_available_worker_count < worker_count: + bulk_logger.warning( + f"The current process count ({worker_count}) is larger than recommended process count " + f"({estimated_available_worker_count}) that estimated by system available memory. This may " + f"cause memory exhaustion") + def _exec_line( executor: FlowExecutor, @@ -481,7 +511,6 @@ def create_executor_fork(*, flow_executor: FlowExecutor, storage: AbstractRunSto run_tracker=run_tracker, cache_manager=flow_executor._cache_manager, loaded_tools=flow_executor._loaded_tools, - worker_count=flow_executor._worker_count, raise_ex=False, line_timeout_sec=flow_executor._line_timeout_sec, ) @@ -540,31 +569,28 @@ def create_executor_legacy(*, flow, connections, loaded_tools, cache_manager, st def get_available_max_worker_count(): pid = os.getpid() mem_info = psutil.virtual_memory() - total_memory = mem_info.total / (1024 * 1024) # in MB - total_memory_in_use = mem_info.used / (1024 * 1024) # in MB available_memory = mem_info.available / (1024 * 1024) # in MB process = psutil.Process(pid) process_memory_info = process.memory_info() process_memory = process_memory_info.rss / (1024 * 1024) # in MB - # To ensure system stability, reserve memory for system usage. - available_max_worker_count = math.floor((available_memory - 0.3 * total_memory) / process_memory) - if available_max_worker_count < 1: - # For the case of vector db, at most 1/3 of the memory will be used, which is 33% of the memory - # In this scenario, the "available_max_worker_count" may be 0, which will cause an error - # "Number of processes must be at least 1" when creating ThreadPool - # So set "available_max_worker_count" to 1 if it's less than 1 + estimated_available_worker_count = int(available_memory // process_memory) + if estimated_available_worker_count < 1: # TODO: For the case of vector db, Optimize execution logic # 1. Let the main process not consume memory because it does not actually invoke # 2. When the degree of parallelism is 1, main process executes the task directly and not # create the child process - bulk_logger.warning(f"Available max worker count {available_max_worker_count} is less than 1, set it to 1.") - available_max_worker_count = 1 - bulk_logger.info( - f"""Process {pid} uses {process_memory}, - total memory {total_memory}, total memory in use: {total_memory_in_use}, - available memory: {available_memory}, available max worker count: {available_max_worker_count}""" - ) - return available_max_worker_count + bulk_logger.warning( + f"Current system's available memory is {available_memory}MB, less than the memory " + f"{process_memory}MB required by the process. The maximum available worker count is 1.") + estimated_available_worker_count = 1 + else: + bulk_logger.info( + f"Current system's available memory is {available_memory}MB, " + f"memory consumption of current process is {process_memory}MB, " + f"estimated available worker count is {available_memory}/{process_memory} " + f"= {estimated_available_worker_count}" + ) + return estimated_available_worker_count def get_multiprocessing_context(multiprocessing_start_method=None): diff --git a/src/promptflow/promptflow/executor/flow_executor.py b/src/promptflow/promptflow/executor/flow_executor.py index 4160058ce16..375f7bf8842 100644 --- a/src/promptflow/promptflow/executor/flow_executor.py +++ b/src/promptflow/promptflow/executor/flow_executor.py @@ -5,7 +5,6 @@ import copy import functools import inspect -import os import uuid from pathlib import Path from threading import current_thread @@ -88,7 +87,6 @@ def __init__( cache_manager: AbstractCacheManager, loaded_tools: Mapping[str, Callable], *, - worker_count=None, raise_ex: bool = False, working_dir=None, line_timeout_sec=LINE_TIMEOUT_SEC, @@ -106,8 +104,6 @@ def __init__( :type cache_manager: ~promptflow._core.cache_manager.AbstractCacheManager :param loaded_tools: A mapping of tool names to their corresponding functions. :type loaded_tools: Mapping[str, Callable] - :param worker_count: The number of workers to use for parallel execution of the Flow. - :type worker_count: int or None :param raise_ex: Whether to raise an exception if an error occurs during execution. :type raise_ex: bool :param working_dir: The working directory to use for execution. @@ -126,16 +122,6 @@ def __init__( self._connections = connections self._aggregation_inputs_references = get_aggregation_inputs_properties(flow) self._aggregation_nodes = {node.name for node in self._flow.nodes if node.aggregation} - if worker_count is not None: - self._worker_count = worker_count - else: - try: - worker_count = int(os.environ.get("PF_WORKER_COUNT", self._DEFAULT_WORKER_COUNT)) - self._worker_count = worker_count - except Exception: - self._worker_count = self._DEFAULT_WORKER_COUNT - if self._worker_count <= 0: - self._worker_count = self._DEFAULT_WORKER_COUNT self._run_tracker = run_tracker self._cache_manager = cache_manager self._loaded_tools = loaded_tools diff --git a/src/promptflow/tests/executor/unittests/_utils/test_utils.py b/src/promptflow/tests/executor/unittests/_utils/test_utils.py index d85ff4ff800..617d9eec5ef 100644 --- a/src/promptflow/tests/executor/unittests/_utils/test_utils.py +++ b/src/promptflow/tests/executor/unittests/_utils/test_utils.py @@ -1,6 +1,8 @@ import pytest +import os +from unittest.mock import patch -from promptflow._utils.utils import is_json_serializable +from promptflow._utils.utils import is_json_serializable, get_int_env_var class MyObj: @@ -12,3 +14,27 @@ class TestUtils: @pytest.mark.parametrize("value, expected_res", [(None, True), (1, True), ("", True), (MyObj(), False)]) def test_is_json_serializable(self, value, expected_res): assert is_json_serializable(value) == expected_res + + @pytest.mark.parametrize( + "env_var, env_value, default_value, expected_result", + [ + ("TEST_VAR", "10", None, 10), # Valid integer string + ("TEST_VAR", "invalid", None, None), # Invalid integer strings + ("TEST_VAR", None, 5, 5), # Environment variable does not exist + ("TEST_VAR", "10", 5, 10), # Valid integer string with a default value + ("TEST_VAR", "invalid", 5, 5), # Invalid integer string with a default value + ]) + def test_get_int_env_var(self, env_var, env_value, default_value, expected_result): + with patch.dict(os.environ, {env_var: env_value} if env_value is not None else {}): + assert get_int_env_var(env_var, default_value) == expected_result + + @pytest.mark.parametrize( + "env_var, env_value, expected_result", + [ + ("TEST_VAR", "10", 10), # Valid integer string + ("TEST_VAR", "invalid", None), # Invalid integer strings + ("TEST_VAR", None, None), # Environment variable does not exist + ]) + def test_get_int_env_var_without_default_vaue(self, env_var, env_value, expected_result): + with patch.dict(os.environ, {env_var: env_value} if env_value is not None else {}): + assert get_int_env_var(env_var) == expected_result diff --git a/src/promptflow/tests/executor/unittests/executor/test_flow_executor.py b/src/promptflow/tests/executor/unittests/executor/test_flow_executor.py index 268a37dd20d..784ca44e6ef 100644 --- a/src/promptflow/tests/executor/unittests/executor/test_flow_executor.py +++ b/src/promptflow/tests/executor/unittests/executor/test_flow_executor.py @@ -1,11 +1,11 @@ -from unittest.mock import Mock, patch +from unittest.mock import Mock import pytest from promptflow import tool from promptflow.contracts.flow import FlowInputDefinition from promptflow.contracts.tool import ValueType -from promptflow.executor._line_execution_process_pool import get_available_max_worker_count + from promptflow.executor.flow_executor import ( FlowExecutor, _ensure_node_result_is_serializable, @@ -178,31 +178,3 @@ def test_streaming_tool_should_be_consumed_and_merged(self): def test_non_streaming_tool_should_not_be_affected(self): func = _ensure_node_result_is_serializable(non_streaming_tool) assert func() == 1 - - -class TestGetAvailableMaxWorkerCount: - @pytest.mark.parametrize( - "total_memory, available_memory, process_memory, expected_max_worker_count, actual_calculate_worker_count", - [ - (1024.0, 128.0, 64.0, 1, -3), # available_memory - 0.3 * total_memory < 0 - (1024.0, 307.20, 64.0, 1, 0), # available_memory - 0.3 * total_memory = 0 - (1024.0, 768.0, 64.0, 7, 7), # available_memory - 0.3 * total_memory > 0 - ], - ) - def test_get_available_max_worker_count( - self, total_memory, available_memory, process_memory, expected_max_worker_count, actual_calculate_worker_count - ): - with patch("psutil.virtual_memory") as mock_mem: - mock_mem.return_value.total = total_memory * 1024 * 1024 - mock_mem.return_value.available = available_memory * 1024 * 1024 - with patch("psutil.Process") as mock_process: - mock_process.return_value.memory_info.return_value.rss = process_memory * 1024 * 1024 - with patch("promptflow.executor._line_execution_process_pool.bulk_logger") as mock_logger: - mock_logger.warning.return_value = None - max_worker_count = get_available_max_worker_count() - assert max_worker_count == expected_max_worker_count - if actual_calculate_worker_count < 1: - mock_logger.warning.assert_called_with( - f"Available max worker count {actual_calculate_worker_count} is less than 1, " - "set it to 1." - ) diff --git a/src/promptflow/tests/executor/unittests/processpool/test_line_execution_process_pool.py b/src/promptflow/tests/executor/unittests/processpool/test_line_execution_process_pool.py index 1fed0c670c5..05607ed9b51 100644 --- a/src/promptflow/tests/executor/unittests/processpool/test_line_execution_process_pool.py +++ b/src/promptflow/tests/executor/unittests/processpool/test_line_execution_process_pool.py @@ -17,6 +17,7 @@ LineExecutionProcessPool, _exec_line, get_multiprocessing_context, + get_available_max_worker_count ) from promptflow.executor._result import LineResult @@ -25,32 +26,137 @@ SAMPLE_FLOW = "web_classification_no_variants" -@pytest.mark.unittest -class TestLineExecutionProcessPool: - def get_line_inputs(self, flow_folder=""): - if flow_folder: - inputs = self.get_bulk_inputs(flow_folder) - return inputs[0] - return { - "url": "https://www.microsoft.com/en-us/windows/", - "text": "some_text", - } +def get_line_inputs(flow_folder=""): + if flow_folder: + inputs = get_bulk_inputs(flow_folder) + return inputs[0] + return { + "url": "https://www.microsoft.com/en-us/windows/", + "text": "some_text", + } + - def get_bulk_inputs(self, nlinee=4, flow_folder="", sample_inputs_file="", return_dict=False): - if flow_folder: - if not sample_inputs_file: - sample_inputs_file = "samples.json" - inputs = get_flow_sample_inputs(flow_folder, sample_inputs_file=sample_inputs_file) - if isinstance(inputs, list) and len(inputs) > 0: +def get_bulk_inputs(nlinee=4, flow_folder="", sample_inputs_file="", return_dict=False): + if flow_folder: + if not sample_inputs_file: + sample_inputs_file = "samples.json" + inputs = get_flow_sample_inputs(flow_folder, sample_inputs_file=sample_inputs_file) + if isinstance(inputs, list) and len(inputs) > 0: + return inputs + elif isinstance(inputs, dict): + if return_dict: return inputs - elif isinstance(inputs, dict): - if return_dict: - return inputs - return [inputs] + return [inputs] + else: + raise Exception(f"Invalid type of bulk input: {inputs}") + return [get_line_inputs() for _ in range(nlinee)] + + +@pytest.mark.skip("This is a subprocess function used for testing and cannot be tested alone.") +def test_fork_mode_parallelism_in_subprocess( + dev_connections, + flow_folder, + is_set_environ_pf_worker_count, + pf_worker_count, + n_process): + + if is_set_environ_pf_worker_count: + os.environ["PF_WORKER_COUNT"] = pf_worker_count + executor = FlowExecutor.create( + get_yaml_file(flow_folder), + dev_connections, + ) + run_id = str(uuid.uuid4()) + bulk_inputs = get_bulk_inputs() + nlines = len(bulk_inputs) + + with patch("promptflow.executor._line_execution_process_pool.bulk_logger") as mock_logger: + with LineExecutionProcessPool( + executor, + nlines, + run_id, + "", + False, + None, + ) as pool: + assert pool._n_process == n_process + if is_set_environ_pf_worker_count: + mock_logger.info.assert_any_call( + f"Set process count to {pf_worker_count} with the environment " + f"variable 'PF_WORKER_COUNT'.") else: - raise Exception(f"Invalid type of bulk input: {inputs}") - return [self.get_line_inputs() for _ in range(nlinee)] + factors = { + "default_worker_count": pool._DEFAULT_WORKER_COUNT, + "row_count": pool._nlines, + } + mock_logger.info.assert_any_call( + f"Set process count to {n_process} by taking the minimum value among the " + f"factors of {factors}." + ) + + +@pytest.mark.skip("This is a subprocess function used for testing and cannot be tested alone.") +def test_spawn_mode_parallelism_in_subprocess( + dev_connections, + flow_folder, + is_set_environ_pf_worker_count, + is_calculation_smaller_than_set, + pf_worker_count, + estimated_available_worker_count, + n_process +): + os.environ["PF_BATCH_METHOD"] = "spawn" + if is_set_environ_pf_worker_count: + os.environ["PF_WORKER_COUNT"] = pf_worker_count + executor = FlowExecutor.create( + get_yaml_file(flow_folder), + dev_connections, + ) + run_id = str(uuid.uuid4()) + bulk_inputs = get_bulk_inputs() + nlines = len(bulk_inputs) + + with patch("psutil.virtual_memory") as mock_mem: + mock_mem.return_value.available = 128.0 * 1024 * 1024 + with patch("psutil.Process") as mock_process: + mock_process.return_value.memory_info.return_value.rss = 64 * 1024 * 1024 + with patch("promptflow.executor._line_execution_process_pool.bulk_logger") as mock_logger: + with LineExecutionProcessPool( + executor, + nlines, + run_id, + "", + False, + None, + ) as pool: + + assert pool._n_process == n_process + if is_set_environ_pf_worker_count and is_calculation_smaller_than_set: + mock_logger.info.assert_any_call( + f"Set process count to {pf_worker_count} with the environment " + f"variable 'PF_WORKER_COUNT'.") + mock_logger.warning.assert_any_call( + f"The current process count ({pf_worker_count}) is larger than recommended process count " + f"({estimated_available_worker_count}) that estimated by system available memory. This may " + f"cause memory exhaustion") + elif is_set_environ_pf_worker_count and not is_calculation_smaller_than_set: + mock_logger.info.assert_any_call( + f"Set process count to {pf_worker_count} with the environment " + f"variable 'PF_WORKER_COUNT'.") + elif not is_set_environ_pf_worker_count: + factors = { + "default_worker_count": pool._DEFAULT_WORKER_COUNT, + "row_count": pool._nlines, + "estimated_worker_count_based_on_memory_usage": estimated_available_worker_count + } + mock_logger.info.assert_any_call( + f"Set process count to {n_process} by taking the minimum value among the factors " + f"of {factors}." + ) + +@pytest.mark.unittest +class TestLineExecutionProcessPool: def create_line_execution_process_pool(self, dev_connections): executor = FlowExecutor.create( get_yaml_file(SAMPLE_FLOW), @@ -58,7 +164,7 @@ def create_line_execution_process_pool(self, dev_connections): line_timeout_sec=1, ) run_id = str(uuid.uuid4()) - bulk_inputs = self.get_bulk_inputs() + bulk_inputs = get_bulk_inputs() nlines = len(bulk_inputs) line_execution_process_pool = LineExecutionProcessPool( executor, @@ -84,7 +190,7 @@ def test_line_execution_process_pool(self, flow_folder, dev_connections): executor = FlowExecutor.create(get_yaml_file(flow_folder), dev_connections) executor._log_interval = 1 run_id = str(uuid.uuid4()) - bulk_inputs = self.get_bulk_inputs() + bulk_inputs = get_bulk_inputs() nlines = len(bulk_inputs) run_id = run_id or str(uuid.uuid4()) with LineExecutionProcessPool( @@ -114,7 +220,7 @@ def test_line_execution_not_completed(self, flow_folder, dev_connections): line_timeout_sec=1, ) run_id = str(uuid.uuid4()) - bulk_inputs = self.get_bulk_inputs() + bulk_inputs = get_bulk_inputs() nlines = len(bulk_inputs) with LineExecutionProcessPool( executor, @@ -147,7 +253,7 @@ def test_exec_line(self, flow_folder, dev_connections, mocker: MockFixture): line_timeout_sec=1, ) run_id = str(uuid.uuid4()) - line_inputs = self.get_line_inputs() + line_inputs = get_line_inputs() line_result = _exec_line( executor=executor, output_queue=output_queue, @@ -178,7 +284,7 @@ def test_exec_line_failed_when_line_execution_not_start(self, flow_folder, dev_c message=test_error_msg, target=ErrorTarget.AZURE_RUN_STORAGE ) run_id = str(uuid.uuid4()) - line_inputs = self.get_line_inputs() + line_inputs = get_line_inputs() line_result = _exec_line( executor=executor, output_queue=output_queue, @@ -243,7 +349,7 @@ def test_process_pool_run_with_exception(self, flow_folder, dev_connections, moc dev_connections, ) run_id = str(uuid.uuid4()) - bulk_inputs = self.get_bulk_inputs() + bulk_inputs = get_bulk_inputs() nlines = len(bulk_inputs) with LineExecutionProcessPool( executor, @@ -258,3 +364,104 @@ def test_process_pool_run_with_exception(self, flow_folder, dev_connections, moc assert e.value.message == test_error_msg assert e.value.target == ErrorTarget.AZURE_RUN_STORAGE assert e.value.error_codes[0] == "UserError" + + @pytest.mark.parametrize( + ( + "flow_folder", + "is_set_environ_pf_worker_count", + "pf_worker_count", + "n_process" + ), + [ + (SAMPLE_FLOW, True, "3", 3), + (SAMPLE_FLOW, False, None, 4) + ], + ) + def test_process_pool_parallelism_in_fork_mode( + self, + dev_connections, + flow_folder, + is_set_environ_pf_worker_count, + pf_worker_count, + n_process): + p = multiprocessing.Process( + target=test_fork_mode_parallelism_in_subprocess, + args=(dev_connections, + flow_folder, + is_set_environ_pf_worker_count, + pf_worker_count, + n_process)) + p.start() + p.join() + assert p.exitcode == 0 + + @pytest.mark.parametrize( + ( + "flow_folder", + "is_set_environ_pf_worker_count", + "is_calculation_smaller_than_set", + "pf_worker_count", + "estimated_available_worker_count", + "n_process" + ), + [ + (SAMPLE_FLOW, True, False, "2", 4, 2), + (SAMPLE_FLOW, True, True, "6", 2, 6), + (SAMPLE_FLOW, False, True, None, 2, 2) + ], + ) + def test_process_pool_parallelism_in_non_spawn_mode( + self, + dev_connections, + flow_folder, + is_set_environ_pf_worker_count, + is_calculation_smaller_than_set, + pf_worker_count, + estimated_available_worker_count, + n_process + ): + p = multiprocessing.Process( + target=test_spawn_mode_parallelism_in_subprocess, + args=(dev_connections, + flow_folder, + is_set_environ_pf_worker_count, + is_calculation_smaller_than_set, + pf_worker_count, + estimated_available_worker_count, + n_process)) + p.start() + p.join() + assert p.exitcode == 0 + + +class TestGetAvailableMaxWorkerCount: + @pytest.mark.parametrize( + "available_memory, process_memory, expected_max_worker_count, actual_calculate_worker_count", + [ + (128.0, 64.0, 2, 2), # available_memory/process_memory > 1 + (63.0, 64.0, 1, 0), # available_memory/process_memory < 1 + ], + ) + def test_get_available_max_worker_count( + self, available_memory, process_memory, expected_max_worker_count, actual_calculate_worker_count + ): + with patch("psutil.virtual_memory") as mock_mem: + mock_mem.return_value.available = available_memory * 1024 * 1024 + with patch("psutil.Process") as mock_process: + mock_process.return_value.memory_info.return_value.rss = process_memory * 1024 * 1024 + with patch("promptflow.executor._line_execution_process_pool.bulk_logger") as mock_logger: + mock_logger.warning.return_value = None + estimated_available_worker_count = get_available_max_worker_count() + assert estimated_available_worker_count == expected_max_worker_count + if actual_calculate_worker_count < 1: + mock_logger.warning.assert_called_with( + f"Current system's available memory is {available_memory}MB, less than the memory " + f"{process_memory}MB required by the process. The maximum available worker count is 1." + ) + else: + mock_logger.info.assert_called_with( + f"Current system's available memory is {available_memory}MB, " + f"memory consumption of current process is {process_memory}MB, " + f"estimated available worker count is {available_memory}/{process_memory} " + f"= {actual_calculate_worker_count}" + )