diff --git a/src/promptflow/promptflow/_utils/utils.py b/src/promptflow/promptflow/_utils/utils.py index fe7b3543906..5ae9dcd5f1c 100644 --- a/src/promptflow/promptflow/_utils/utils.py +++ b/src/promptflow/promptflow/_utils/utils.py @@ -167,7 +167,7 @@ def extract_user_frame_summaries(frame_summaries: List[traceback.FrameSummary]): # If the current frame is in tool.py and the next frame is not in _core folder # then we can say that the next frame is in user code. if cur_file == tool_file and not next_file.startswith(core_folder): - return frame_summaries[i + 1 :] + return frame_summaries[i + 1:] return frame_summaries @@ -247,3 +247,19 @@ def parse_ua_to_dict(ua): key, value = item.split("/") ua_dict[key] = value return ua_dict + + +def get_int_env_var(env_var_name, default_value=None): + """ + The function `get_int_env_var` retrieves an integer environment variable value, with an optional + default value if the variable is not set or cannot be converted to an integer. + + :param env_var_name: The name of the environment variable you want to retrieve the value of + :param default_value: The default value is the value that will be returned if the environment + variable is not found or if it cannot be converted to an integer + :return: an integer value. + """ + try: + return int(os.environ.get(env_var_name, default_value)) + except Exception: + return default_value diff --git a/src/promptflow/promptflow/executor/_line_execution_process_pool.py b/src/promptflow/promptflow/executor/_line_execution_process_pool.py index 1da653b4a49..3e8f50f5e68 100644 --- a/src/promptflow/promptflow/executor/_line_execution_process_pool.py +++ b/src/promptflow/promptflow/executor/_line_execution_process_pool.py @@ -1,5 +1,4 @@ import contextvars -import math import multiprocessing import os import queue @@ -21,7 +20,7 @@ from promptflow._utils.logger_utils import LogContext, bulk_logger from promptflow._utils.multimedia_utils import _process_recursively, persist_multimedia_data from promptflow._utils.thread_utils import RepeatLogTimer -from promptflow._utils.utils import log_progress, set_context +from promptflow._utils.utils import log_progress, set_context, get_int_env_var from promptflow.contracts.multimedia import Image from promptflow.contracts.run_info import FlowRunInfo from promptflow.contracts.run_info import RunInfo as NodeRunInfo @@ -136,6 +135,8 @@ def format_current_process(self, line_number: int, is_completed=False): class LineExecutionProcessPool: + _DEFAULT_WORKER_COUNT = 16 + def __init__( self, flow_executor: FlowExecutor, @@ -149,7 +150,6 @@ def __init__( self._run_id = run_id self._variant_id = variant_id self._validate_inputs = validate_inputs - self._worker_count = flow_executor._worker_count multiprocessing_start_method = os.environ.get("PF_BATCH_METHOD") sys_start_methods = multiprocessing.get_all_start_methods() if multiprocessing_start_method and multiprocessing_start_method not in sys_start_methods: @@ -194,15 +194,8 @@ def __enter__(self): self._completed_idx = manager.dict() self._inputs_queue = Queue() - # Starting a new process in non-fork mode requires to allocate memory. Determine the maximum number of processes - # based on available memory to avoid memory bursting. - if not self._use_fork: - available_max_worker_count = get_available_max_worker_count() - self._n_process = min(self._worker_count, self._nlines, available_max_worker_count) - bulk_logger.info(f"Not using fork, process count: {self._n_process}") - else: - self._n_process = min(self._worker_count, self._nlines) - bulk_logger.info(f"Using fork, process count: {self._n_process}") + self._n_process = self._determine_worker_count() + pool = ThreadPool(self._n_process, initializer=set_context, initargs=(contextvars.copy_context(),)) self._pool = pool @@ -411,6 +404,43 @@ def _generate_thread_status_messages(self, pool: ThreadPool, total_count: int): msgs.append("Processing Lines: " + ", ".join(lines) + ".") return msgs + def _determine_worker_count(self): + worker_count = get_int_env_var("PF_WORKER_COUNT") + + # Starting a new process in non-fork mode requires to allocate memory. Calculate the maximum number of processes + # based on available memory to avoid memory bursting. + estimated_available_worker_count = get_available_max_worker_count() if not self._use_fork else None + + # If the environment variable PF_WORKER_COUNT exists and valid, use the value as the worker_count. + if worker_count is not None and worker_count > 0: + self._log_set_worker_count(worker_count, estimated_available_worker_count) + return worker_count + + # If the environment variable PF_WORKER_COUNT is not set or invalid, take the minimum value among the + # factors: default_worker_count, row_count and estimated_worker_count_based_on_memory_usage + factors = { + "default_worker_count": self._DEFAULT_WORKER_COUNT, + "row_count": self._nlines, + "estimated_worker_count_based_on_memory_usage": estimated_available_worker_count, + } + + valid_factors = {k: v for k, v in factors.items() if v is not None and v > 0} + + # Take the minimum value as the result + worker_count = min(valid_factors.values()) + bulk_logger.info( + f"Set process count to {worker_count} by taking the minimum value among the factors of {valid_factors}.") + return worker_count + + def _log_set_worker_count(self, worker_count, estimated_available_worker_count): + bulk_logger.info( + f"Set process count to {worker_count} with the environment variable 'PF_WORKER_COUNT'.") + if estimated_available_worker_count is not None and estimated_available_worker_count < worker_count: + bulk_logger.warning( + f"The current process count ({worker_count}) is larger than recommended process count " + f"({estimated_available_worker_count}) that estimated by system available memory. This may " + f"cause memory exhaustion") + def _exec_line( executor: FlowExecutor, @@ -481,7 +511,6 @@ def create_executor_fork(*, flow_executor: FlowExecutor, storage: AbstractRunSto run_tracker=run_tracker, cache_manager=flow_executor._cache_manager, loaded_tools=flow_executor._loaded_tools, - worker_count=flow_executor._worker_count, raise_ex=False, line_timeout_sec=flow_executor._line_timeout_sec, ) @@ -540,31 +569,28 @@ def create_executor_legacy(*, flow, connections, loaded_tools, cache_manager, st def get_available_max_worker_count(): pid = os.getpid() mem_info = psutil.virtual_memory() - total_memory = mem_info.total / (1024 * 1024) # in MB - total_memory_in_use = mem_info.used / (1024 * 1024) # in MB available_memory = mem_info.available / (1024 * 1024) # in MB process = psutil.Process(pid) process_memory_info = process.memory_info() process_memory = process_memory_info.rss / (1024 * 1024) # in MB - # To ensure system stability, reserve memory for system usage. - available_max_worker_count = math.floor((available_memory - 0.3 * total_memory) / process_memory) - if available_max_worker_count < 1: - # For the case of vector db, at most 1/3 of the memory will be used, which is 33% of the memory - # In this scenario, the "available_max_worker_count" may be 0, which will cause an error - # "Number of processes must be at least 1" when creating ThreadPool - # So set "available_max_worker_count" to 1 if it's less than 1 + estimated_available_worker_count = int(available_memory // process_memory) + if estimated_available_worker_count < 1: # TODO: For the case of vector db, Optimize execution logic # 1. Let the main process not consume memory because it does not actually invoke # 2. When the degree of parallelism is 1, main process executes the task directly and not # create the child process - bulk_logger.warning(f"Available max worker count {available_max_worker_count} is less than 1, set it to 1.") - available_max_worker_count = 1 - bulk_logger.info( - f"""Process {pid} uses {process_memory}, - total memory {total_memory}, total memory in use: {total_memory_in_use}, - available memory: {available_memory}, available max worker count: {available_max_worker_count}""" - ) - return available_max_worker_count + bulk_logger.warning( + f"Current system's available memory is {available_memory}MB, less than the memory " + f"{process_memory}MB required by the process. The maximum available worker count is 1.") + estimated_available_worker_count = 1 + else: + bulk_logger.info( + f"Current system's available memory is {available_memory}MB, " + f"memory consumption of current process is {process_memory}MB, " + f"estimated available worker count is {available_memory}/{process_memory} " + f"= {estimated_available_worker_count}" + ) + return estimated_available_worker_count def get_multiprocessing_context(multiprocessing_start_method=None): diff --git a/src/promptflow/promptflow/executor/flow_executor.py b/src/promptflow/promptflow/executor/flow_executor.py index 4160058ce16..375f7bf8842 100644 --- a/src/promptflow/promptflow/executor/flow_executor.py +++ b/src/promptflow/promptflow/executor/flow_executor.py @@ -5,7 +5,6 @@ import copy import functools import inspect -import os import uuid from pathlib import Path from threading import current_thread @@ -88,7 +87,6 @@ def __init__( cache_manager: AbstractCacheManager, loaded_tools: Mapping[str, Callable], *, - worker_count=None, raise_ex: bool = False, working_dir=None, line_timeout_sec=LINE_TIMEOUT_SEC, @@ -106,8 +104,6 @@ def __init__( :type cache_manager: ~promptflow._core.cache_manager.AbstractCacheManager :param loaded_tools: A mapping of tool names to their corresponding functions. :type loaded_tools: Mapping[str, Callable] - :param worker_count: The number of workers to use for parallel execution of the Flow. - :type worker_count: int or None :param raise_ex: Whether to raise an exception if an error occurs during execution. :type raise_ex: bool :param working_dir: The working directory to use for execution. @@ -126,16 +122,6 @@ def __init__( self._connections = connections self._aggregation_inputs_references = get_aggregation_inputs_properties(flow) self._aggregation_nodes = {node.name for node in self._flow.nodes if node.aggregation} - if worker_count is not None: - self._worker_count = worker_count - else: - try: - worker_count = int(os.environ.get("PF_WORKER_COUNT", self._DEFAULT_WORKER_COUNT)) - self._worker_count = worker_count - except Exception: - self._worker_count = self._DEFAULT_WORKER_COUNT - if self._worker_count <= 0: - self._worker_count = self._DEFAULT_WORKER_COUNT self._run_tracker = run_tracker self._cache_manager = cache_manager self._loaded_tools = loaded_tools diff --git a/src/promptflow/tests/executor/unittests/_utils/test_utils.py b/src/promptflow/tests/executor/unittests/_utils/test_utils.py index d85ff4ff800..617d9eec5ef 100644 --- a/src/promptflow/tests/executor/unittests/_utils/test_utils.py +++ b/src/promptflow/tests/executor/unittests/_utils/test_utils.py @@ -1,6 +1,8 @@ import pytest +import os +from unittest.mock import patch -from promptflow._utils.utils import is_json_serializable +from promptflow._utils.utils import is_json_serializable, get_int_env_var class MyObj: @@ -12,3 +14,27 @@ class TestUtils: @pytest.mark.parametrize("value, expected_res", [(None, True), (1, True), ("", True), (MyObj(), False)]) def test_is_json_serializable(self, value, expected_res): assert is_json_serializable(value) == expected_res + + @pytest.mark.parametrize( + "env_var, env_value, default_value, expected_result", + [ + ("TEST_VAR", "10", None, 10), # Valid integer string + ("TEST_VAR", "invalid", None, None), # Invalid integer strings + ("TEST_VAR", None, 5, 5), # Environment variable does not exist + ("TEST_VAR", "10", 5, 10), # Valid integer string with a default value + ("TEST_VAR", "invalid", 5, 5), # Invalid integer string with a default value + ]) + def test_get_int_env_var(self, env_var, env_value, default_value, expected_result): + with patch.dict(os.environ, {env_var: env_value} if env_value is not None else {}): + assert get_int_env_var(env_var, default_value) == expected_result + + @pytest.mark.parametrize( + "env_var, env_value, expected_result", + [ + ("TEST_VAR", "10", 10), # Valid integer string + ("TEST_VAR", "invalid", None), # Invalid integer strings + ("TEST_VAR", None, None), # Environment variable does not exist + ]) + def test_get_int_env_var_without_default_vaue(self, env_var, env_value, expected_result): + with patch.dict(os.environ, {env_var: env_value} if env_value is not None else {}): + assert get_int_env_var(env_var) == expected_result diff --git a/src/promptflow/tests/executor/unittests/executor/test_flow_executor.py b/src/promptflow/tests/executor/unittests/executor/test_flow_executor.py index 268a37dd20d..784ca44e6ef 100644 --- a/src/promptflow/tests/executor/unittests/executor/test_flow_executor.py +++ b/src/promptflow/tests/executor/unittests/executor/test_flow_executor.py @@ -1,11 +1,11 @@ -from unittest.mock import Mock, patch +from unittest.mock import Mock import pytest from promptflow import tool from promptflow.contracts.flow import FlowInputDefinition from promptflow.contracts.tool import ValueType -from promptflow.executor._line_execution_process_pool import get_available_max_worker_count + from promptflow.executor.flow_executor import ( FlowExecutor, _ensure_node_result_is_serializable, @@ -178,31 +178,3 @@ def test_streaming_tool_should_be_consumed_and_merged(self): def test_non_streaming_tool_should_not_be_affected(self): func = _ensure_node_result_is_serializable(non_streaming_tool) assert func() == 1 - - -class TestGetAvailableMaxWorkerCount: - @pytest.mark.parametrize( - "total_memory, available_memory, process_memory, expected_max_worker_count, actual_calculate_worker_count", - [ - (1024.0, 128.0, 64.0, 1, -3), # available_memory - 0.3 * total_memory < 0 - (1024.0, 307.20, 64.0, 1, 0), # available_memory - 0.3 * total_memory = 0 - (1024.0, 768.0, 64.0, 7, 7), # available_memory - 0.3 * total_memory > 0 - ], - ) - def test_get_available_max_worker_count( - self, total_memory, available_memory, process_memory, expected_max_worker_count, actual_calculate_worker_count - ): - with patch("psutil.virtual_memory") as mock_mem: - mock_mem.return_value.total = total_memory * 1024 * 1024 - mock_mem.return_value.available = available_memory * 1024 * 1024 - with patch("psutil.Process") as mock_process: - mock_process.return_value.memory_info.return_value.rss = process_memory * 1024 * 1024 - with patch("promptflow.executor._line_execution_process_pool.bulk_logger") as mock_logger: - mock_logger.warning.return_value = None - max_worker_count = get_available_max_worker_count() - assert max_worker_count == expected_max_worker_count - if actual_calculate_worker_count < 1: - mock_logger.warning.assert_called_with( - f"Available max worker count {actual_calculate_worker_count} is less than 1, " - "set it to 1." - ) diff --git a/src/promptflow/tests/executor/unittests/processpool/test_line_execution_process_pool.py b/src/promptflow/tests/executor/unittests/processpool/test_line_execution_process_pool.py index 1fed0c670c5..05607ed9b51 100644 --- a/src/promptflow/tests/executor/unittests/processpool/test_line_execution_process_pool.py +++ b/src/promptflow/tests/executor/unittests/processpool/test_line_execution_process_pool.py @@ -17,6 +17,7 @@ LineExecutionProcessPool, _exec_line, get_multiprocessing_context, + get_available_max_worker_count ) from promptflow.executor._result import LineResult @@ -25,32 +26,137 @@ SAMPLE_FLOW = "web_classification_no_variants" -@pytest.mark.unittest -class TestLineExecutionProcessPool: - def get_line_inputs(self, flow_folder=""): - if flow_folder: - inputs = self.get_bulk_inputs(flow_folder) - return inputs[0] - return { - "url": "https://www.microsoft.com/en-us/windows/", - "text": "some_text", - } +def get_line_inputs(flow_folder=""): + if flow_folder: + inputs = get_bulk_inputs(flow_folder) + return inputs[0] + return { + "url": "https://www.microsoft.com/en-us/windows/", + "text": "some_text", + } + - def get_bulk_inputs(self, nlinee=4, flow_folder="", sample_inputs_file="", return_dict=False): - if flow_folder: - if not sample_inputs_file: - sample_inputs_file = "samples.json" - inputs = get_flow_sample_inputs(flow_folder, sample_inputs_file=sample_inputs_file) - if isinstance(inputs, list) and len(inputs) > 0: +def get_bulk_inputs(nlinee=4, flow_folder="", sample_inputs_file="", return_dict=False): + if flow_folder: + if not sample_inputs_file: + sample_inputs_file = "samples.json" + inputs = get_flow_sample_inputs(flow_folder, sample_inputs_file=sample_inputs_file) + if isinstance(inputs, list) and len(inputs) > 0: + return inputs + elif isinstance(inputs, dict): + if return_dict: return inputs - elif isinstance(inputs, dict): - if return_dict: - return inputs - return [inputs] + return [inputs] + else: + raise Exception(f"Invalid type of bulk input: {inputs}") + return [get_line_inputs() for _ in range(nlinee)] + + +@pytest.mark.skip("This is a subprocess function used for testing and cannot be tested alone.") +def test_fork_mode_parallelism_in_subprocess( + dev_connections, + flow_folder, + is_set_environ_pf_worker_count, + pf_worker_count, + n_process): + + if is_set_environ_pf_worker_count: + os.environ["PF_WORKER_COUNT"] = pf_worker_count + executor = FlowExecutor.create( + get_yaml_file(flow_folder), + dev_connections, + ) + run_id = str(uuid.uuid4()) + bulk_inputs = get_bulk_inputs() + nlines = len(bulk_inputs) + + with patch("promptflow.executor._line_execution_process_pool.bulk_logger") as mock_logger: + with LineExecutionProcessPool( + executor, + nlines, + run_id, + "", + False, + None, + ) as pool: + assert pool._n_process == n_process + if is_set_environ_pf_worker_count: + mock_logger.info.assert_any_call( + f"Set process count to {pf_worker_count} with the environment " + f"variable 'PF_WORKER_COUNT'.") else: - raise Exception(f"Invalid type of bulk input: {inputs}") - return [self.get_line_inputs() for _ in range(nlinee)] + factors = { + "default_worker_count": pool._DEFAULT_WORKER_COUNT, + "row_count": pool._nlines, + } + mock_logger.info.assert_any_call( + f"Set process count to {n_process} by taking the minimum value among the " + f"factors of {factors}." + ) + + +@pytest.mark.skip("This is a subprocess function used for testing and cannot be tested alone.") +def test_spawn_mode_parallelism_in_subprocess( + dev_connections, + flow_folder, + is_set_environ_pf_worker_count, + is_calculation_smaller_than_set, + pf_worker_count, + estimated_available_worker_count, + n_process +): + os.environ["PF_BATCH_METHOD"] = "spawn" + if is_set_environ_pf_worker_count: + os.environ["PF_WORKER_COUNT"] = pf_worker_count + executor = FlowExecutor.create( + get_yaml_file(flow_folder), + dev_connections, + ) + run_id = str(uuid.uuid4()) + bulk_inputs = get_bulk_inputs() + nlines = len(bulk_inputs) + + with patch("psutil.virtual_memory") as mock_mem: + mock_mem.return_value.available = 128.0 * 1024 * 1024 + with patch("psutil.Process") as mock_process: + mock_process.return_value.memory_info.return_value.rss = 64 * 1024 * 1024 + with patch("promptflow.executor._line_execution_process_pool.bulk_logger") as mock_logger: + with LineExecutionProcessPool( + executor, + nlines, + run_id, + "", + False, + None, + ) as pool: + + assert pool._n_process == n_process + if is_set_environ_pf_worker_count and is_calculation_smaller_than_set: + mock_logger.info.assert_any_call( + f"Set process count to {pf_worker_count} with the environment " + f"variable 'PF_WORKER_COUNT'.") + mock_logger.warning.assert_any_call( + f"The current process count ({pf_worker_count}) is larger than recommended process count " + f"({estimated_available_worker_count}) that estimated by system available memory. This may " + f"cause memory exhaustion") + elif is_set_environ_pf_worker_count and not is_calculation_smaller_than_set: + mock_logger.info.assert_any_call( + f"Set process count to {pf_worker_count} with the environment " + f"variable 'PF_WORKER_COUNT'.") + elif not is_set_environ_pf_worker_count: + factors = { + "default_worker_count": pool._DEFAULT_WORKER_COUNT, + "row_count": pool._nlines, + "estimated_worker_count_based_on_memory_usage": estimated_available_worker_count + } + mock_logger.info.assert_any_call( + f"Set process count to {n_process} by taking the minimum value among the factors " + f"of {factors}." + ) + +@pytest.mark.unittest +class TestLineExecutionProcessPool: def create_line_execution_process_pool(self, dev_connections): executor = FlowExecutor.create( get_yaml_file(SAMPLE_FLOW), @@ -58,7 +164,7 @@ def create_line_execution_process_pool(self, dev_connections): line_timeout_sec=1, ) run_id = str(uuid.uuid4()) - bulk_inputs = self.get_bulk_inputs() + bulk_inputs = get_bulk_inputs() nlines = len(bulk_inputs) line_execution_process_pool = LineExecutionProcessPool( executor, @@ -84,7 +190,7 @@ def test_line_execution_process_pool(self, flow_folder, dev_connections): executor = FlowExecutor.create(get_yaml_file(flow_folder), dev_connections) executor._log_interval = 1 run_id = str(uuid.uuid4()) - bulk_inputs = self.get_bulk_inputs() + bulk_inputs = get_bulk_inputs() nlines = len(bulk_inputs) run_id = run_id or str(uuid.uuid4()) with LineExecutionProcessPool( @@ -114,7 +220,7 @@ def test_line_execution_not_completed(self, flow_folder, dev_connections): line_timeout_sec=1, ) run_id = str(uuid.uuid4()) - bulk_inputs = self.get_bulk_inputs() + bulk_inputs = get_bulk_inputs() nlines = len(bulk_inputs) with LineExecutionProcessPool( executor, @@ -147,7 +253,7 @@ def test_exec_line(self, flow_folder, dev_connections, mocker: MockFixture): line_timeout_sec=1, ) run_id = str(uuid.uuid4()) - line_inputs = self.get_line_inputs() + line_inputs = get_line_inputs() line_result = _exec_line( executor=executor, output_queue=output_queue, @@ -178,7 +284,7 @@ def test_exec_line_failed_when_line_execution_not_start(self, flow_folder, dev_c message=test_error_msg, target=ErrorTarget.AZURE_RUN_STORAGE ) run_id = str(uuid.uuid4()) - line_inputs = self.get_line_inputs() + line_inputs = get_line_inputs() line_result = _exec_line( executor=executor, output_queue=output_queue, @@ -243,7 +349,7 @@ def test_process_pool_run_with_exception(self, flow_folder, dev_connections, moc dev_connections, ) run_id = str(uuid.uuid4()) - bulk_inputs = self.get_bulk_inputs() + bulk_inputs = get_bulk_inputs() nlines = len(bulk_inputs) with LineExecutionProcessPool( executor, @@ -258,3 +364,104 @@ def test_process_pool_run_with_exception(self, flow_folder, dev_connections, moc assert e.value.message == test_error_msg assert e.value.target == ErrorTarget.AZURE_RUN_STORAGE assert e.value.error_codes[0] == "UserError" + + @pytest.mark.parametrize( + ( + "flow_folder", + "is_set_environ_pf_worker_count", + "pf_worker_count", + "n_process" + ), + [ + (SAMPLE_FLOW, True, "3", 3), + (SAMPLE_FLOW, False, None, 4) + ], + ) + def test_process_pool_parallelism_in_fork_mode( + self, + dev_connections, + flow_folder, + is_set_environ_pf_worker_count, + pf_worker_count, + n_process): + p = multiprocessing.Process( + target=test_fork_mode_parallelism_in_subprocess, + args=(dev_connections, + flow_folder, + is_set_environ_pf_worker_count, + pf_worker_count, + n_process)) + p.start() + p.join() + assert p.exitcode == 0 + + @pytest.mark.parametrize( + ( + "flow_folder", + "is_set_environ_pf_worker_count", + "is_calculation_smaller_than_set", + "pf_worker_count", + "estimated_available_worker_count", + "n_process" + ), + [ + (SAMPLE_FLOW, True, False, "2", 4, 2), + (SAMPLE_FLOW, True, True, "6", 2, 6), + (SAMPLE_FLOW, False, True, None, 2, 2) + ], + ) + def test_process_pool_parallelism_in_non_spawn_mode( + self, + dev_connections, + flow_folder, + is_set_environ_pf_worker_count, + is_calculation_smaller_than_set, + pf_worker_count, + estimated_available_worker_count, + n_process + ): + p = multiprocessing.Process( + target=test_spawn_mode_parallelism_in_subprocess, + args=(dev_connections, + flow_folder, + is_set_environ_pf_worker_count, + is_calculation_smaller_than_set, + pf_worker_count, + estimated_available_worker_count, + n_process)) + p.start() + p.join() + assert p.exitcode == 0 + + +class TestGetAvailableMaxWorkerCount: + @pytest.mark.parametrize( + "available_memory, process_memory, expected_max_worker_count, actual_calculate_worker_count", + [ + (128.0, 64.0, 2, 2), # available_memory/process_memory > 1 + (63.0, 64.0, 1, 0), # available_memory/process_memory < 1 + ], + ) + def test_get_available_max_worker_count( + self, available_memory, process_memory, expected_max_worker_count, actual_calculate_worker_count + ): + with patch("psutil.virtual_memory") as mock_mem: + mock_mem.return_value.available = available_memory * 1024 * 1024 + with patch("psutil.Process") as mock_process: + mock_process.return_value.memory_info.return_value.rss = process_memory * 1024 * 1024 + with patch("promptflow.executor._line_execution_process_pool.bulk_logger") as mock_logger: + mock_logger.warning.return_value = None + estimated_available_worker_count = get_available_max_worker_count() + assert estimated_available_worker_count == expected_max_worker_count + if actual_calculate_worker_count < 1: + mock_logger.warning.assert_called_with( + f"Current system's available memory is {available_memory}MB, less than the memory " + f"{process_memory}MB required by the process. The maximum available worker count is 1." + ) + else: + mock_logger.info.assert_called_with( + f"Current system's available memory is {available_memory}MB, " + f"memory consumption of current process is {process_memory}MB, " + f"estimated available worker count is {available_memory}/{process_memory} " + f"= {actual_calculate_worker_count}" + )