diff --git a/src/promptflow/promptflow/executor/_line_execution_process_pool.py b/src/promptflow/promptflow/executor/_line_execution_process_pool.py index 0d6405a91b5a..6e3bd3b083b1 100644 --- a/src/promptflow/promptflow/executor/_line_execution_process_pool.py +++ b/src/promptflow/promptflow/executor/_line_execution_process_pool.py @@ -406,20 +406,35 @@ def _generate_thread_status_messages(self, pool: ThreadPool, total_count: int): def _determine_worker_count(self): worker_count = get_int_env_var("PF_WORKER_COUNT") + estimated_available_worker_count = None - # Starting a new process in non-fork mode requires to allocate memory. Calculate the maximum number of processes - # based on available memory to avoid memory bursting. - estimated_available_worker_count = get_available_max_worker_count() if not self._use_fork else None + if not self._use_fork: + # Starting a new process in non-fork mode requires to allocate memory. Calculate the maximum number of + # processes based on available memory to avoid memory bursting. + estimated_available_worker_count = get_available_max_worker_count() # If the environment variable PF_WORKER_COUNT exists and valid, use the value as the worker_count. if worker_count is not None and worker_count > 0: self._log_set_worker_count(worker_count, estimated_available_worker_count) return worker_count - if not self._use_fork: - return self._calculate_worker_count_without_fork(estimated_available_worker_count) - else: - return self._calculate_worker_count_with_fork() + # If the environment variable PF_WORKER_COUNT is not set or invalid, take the minimum value among the + # factors: default_worker_count, row_count and estimated_worker_count_based_on_memory_usage + factors = { + "default_worker_count": self._DEFAULT_WORKER_COUNT, + "row_count": self._nlines, + } + + if estimated_available_worker_count is not None and estimated_available_worker_count > 0: + factors.update({ + "estimated_worker_count_based_on_memory_usage": estimated_available_worker_count, + }) + + # Take the minimum value as the result + worker_count = min(factors.values()) + bulk_logger.info( + f"Set process count to {worker_count} by taking the minimum value among the factors of {factors}.") + return worker_count def _log_set_worker_count(self, worker_count, estimated_available_worker_count): bulk_logger.info( @@ -430,26 +445,6 @@ def _log_set_worker_count(self, worker_count, estimated_available_worker_count): f"({estimated_available_worker_count}) that estimated by system available memory. This may " f"cause memory exhaustion") - def _calculate_worker_count_with_fork(self): - worker_count = min(self._DEFAULT_WORKER_COUNT, self._nlines) - bulk_logger.info("Using fork to create new process.") - bulk_logger.info( - f"Set process count to {worker_count} by taking the minimum value among the " - f"default worker_count ({self._DEFAULT_WORKER_COUNT}) and the row count ({self._nlines}).") - return worker_count - - def _calculate_worker_count_without_fork(self, estimated_available_worker_count): - worker_count = min(self._DEFAULT_WORKER_COUNT, self._nlines, estimated_available_worker_count) - bulk_logger.info("Not using fork to create new process.") - bulk_logger.info( - "The environment variable PF_WORKER_COUNT is not set or invalid. Calculate the worker count based " - "on the currently memory usage.") - bulk_logger.info( - f"Set process count to {worker_count} by taking the minimum value among estimated process " - f"count ({estimated_available_worker_count}), the row count ({self._nlines}) and the default " - f"worker count ({self._DEFAULT_WORKER_COUNT}).") - return worker_count - def _exec_line( executor: FlowExecutor, diff --git a/src/promptflow/tests/executor/unittests/processpool/test_line_execution_process_pool.py b/src/promptflow/tests/executor/unittests/processpool/test_line_execution_process_pool.py index 37626d02514d..313e130b9531 100644 --- a/src/promptflow/tests/executor/unittests/processpool/test_line_execution_process_pool.py +++ b/src/promptflow/tests/executor/unittests/processpool/test_line_execution_process_pool.py @@ -85,10 +85,13 @@ def test_fork_mode_parallelism_in_subprocess( f"Set process count to {pf_worker_count} with the environment " f"variable 'PF_WORKER_COUNT'.") else: - mock_logger.info.assert_any_call("Using fork to create new process.") + factors = { + "default_worker_count": pool._DEFAULT_WORKER_COUNT, + "row_count": pool._nlines, + } mock_logger.info.assert_any_call( f"Set process count to {n_process} by taking the minimum value among the " - f"default worker_count ({pool._DEFAULT_WORKER_COUNT}) and the row count ({nlines})." + f"factors of {factors}." ) @@ -126,6 +129,7 @@ def test_spawn_mode_parallelism_in_subprocess( False, None, ) as pool: + assert pool._n_process == n_process if is_set_environ_pf_worker_count and is_calculation_smaller_than_set: mock_logger.info.assert_any_call( @@ -140,15 +144,13 @@ def test_spawn_mode_parallelism_in_subprocess( f"Set process count to {pf_worker_count} with the environment " f"variable 'PF_WORKER_COUNT'.") elif not is_set_environ_pf_worker_count: - mock_logger.info.assert_any_call("Not using fork to create new process.") - mock_logger.info.assert_any_call( - "The environment variable PF_WORKER_COUNT is not set or invalid. Calculate the worker " - "count based on the currently memory usage." - ) + factors = { + "default_worker_count": pool._DEFAULT_WORKER_COUNT, + "row_count": pool._nlines, + "estimated_worker_count_based_on_memory_usage": estimated_available_worker_count + } mock_logger.info.assert_any_call( - f"Set process count to {n_process} by taking the minimum value among estimated " - f"process count ({estimated_available_worker_count}), the row count ({nlines}) and the " - f"default worker count ({pool._DEFAULT_WORKER_COUNT})." + f"Set process count to {n_process} by taking the minimum value among the factors of {factors}." )