Skip to content

Commit

Permalink
[PromptFlow] Optimize memory computation (#1328)
Browse files Browse the repository at this point in the history
# Description
**Optimize memory computation**

1. Optimize the calculation logic of the maximum number of processes
2. Prioritize using the environment variable "PF_WORKER_COUNT" as the
number of processes if users set.

The following table summarizes the logic for determining the number of
processes and corresponding log messages under different conditions
(based on whether fork is used and the setting of the environment
variable PF_WORKER_COUNT):

<div style="overflow-x:auto;">
  <table style="table-layout: auto;">
    <tr>
<th style="white-space: nowrap;">Scenario &nbsp; &nbsp; &nbsp;</th>
<th style="white-space: nowrap;">Fork_Used &nbsp; &nbsp;&nbsp;
&nbsp;&nbsp;&nbsp;</th>
      <th style="white-space: nowrap;">PF_WORKER_COUNT Set</th>
      <th style="white-space: nowrap;">Worker Count</th>
      <th style="white-space: nowrap;">NLines</th>
<th style="white-space: nowrap;">Estimated Available Worker Count</th>
      <th style="white-space: nowrap;">Calculation of Processes</th>
    </tr>
    <tr>
      <td>1</td>
      <td>No(spawn)</td>
      <td>No</td>
      <td>16</td>
      <td>4</td>
      <td><strong>2</strong></td>
      <td>Min(16, 4, 2) = <strong>2</strong></td>
    </tr>
    <tr>
      <td colspan="8" style="white-space: normal;">
          <strong style="font-size: larger;">Log Message:</strong><br>
Set process count to <strong>2</strong> by taking the minimum value
among the factors of {'default_worker_count': 16, 'row_count': 4,
'estimated_worker_count_based_on_memory_usage': 2}.</td>
    </tr>
    <tr>
      <td>2</td>
      <td>No(spawn)</td>
      <td>Yes (<strong>2</strong>)</td>
      <td>2</td>
      <td>4</td>
      <td>2</td>
      <td><strong>PF_WORKER_COUNT = 2</strong></td>
    </tr>
    <tr>
      <td colspan="8" style="white-space: normal;">
          <strong style="font-size: larger;">Log Message:</strong><br>
Set process count to <strong>2</strong> with the environment variable
'PF_WORKER_COUNT'.
        </td>
    </tr>
    <tr>
      <td>3</td>
      <td>No(spawn)</td>
      <td>Yes (<strong>6</strong>)</td>
      <td>6</td>
      <td>4</td>
      <td>2</td>
      <td><strong>PF_WORKER_COUNT = 6</strong></td>
    </tr>
    <tr>
      <td colspan="8" style="white-space: normal;">
          <strong style="font-size: larger;">Log Message:</strong><br>
Set process count to <strong>6</strong> with the environment variable
'PF_WORKER_COUNT'.<br>
<strong style="font-size: larger;">Warning Message:</strong><br>
The current process count (<strong>6</strong>) is larger than
recommended process count (2) that estimated by system available memory.
This may cause memory exhaustion.</td>
    </tr>
    <tr>
      <td>4</td>
      <td>Yes(fork)</td>
      <td>No</td>
      <td>16</td>
      <td><strong>4</strong></td>
      <td>N/A</td>
      <td>Min(16, 4) = <strong>4</strong></td>
    </tr>
    <tr>
      <td colspan="8" style="white-space: normal;">
          <strong style="font-size: larger;">Log Message:</strong><br>
Set process count to <strong>4</strong> by taking the minimum value
among the factors of {'default_worker_count': 16, 'row_count': 4}.</td>
    </tr>
    <tr>
      <td>5</td>
      <td>Yes(fork)</td>
      <td>Yes (<strong>3</strong>)</td>
      <td>3</td>
      <td>4</td>
      <td>N/A</td>
      <td><strong>PF_WORKER_COUNT = 3</strong></td>
    </tr>
    <tr>
      <td colspan="8" style="white-space: normal;">
          <strong style="font-size: larger;">Log Message:</strong><br>
Set process count to <strong>3</strong> with the environment variable
'PF_WORKER_COUNT'.</td>
    </tr>
  </table>
</div>

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.
  • Loading branch information
Hhhilulu authored Dec 7, 2023
1 parent e21b79f commit 0b18ee6
Show file tree
Hide file tree
Showing 6 changed files with 337 additions and 104 deletions.
18 changes: 17 additions & 1 deletion src/promptflow/promptflow/_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def extract_user_frame_summaries(frame_summaries: List[traceback.FrameSummary]):
# If the current frame is in tool.py and the next frame is not in _core folder
# then we can say that the next frame is in user code.
if cur_file == tool_file and not next_file.startswith(core_folder):
return frame_summaries[i + 1 :]
return frame_summaries[i + 1:]
return frame_summaries


Expand Down Expand Up @@ -247,3 +247,19 @@ def parse_ua_to_dict(ua):
key, value = item.split("/")
ua_dict[key] = value
return ua_dict


def get_int_env_var(env_var_name, default_value=None):
"""
The function `get_int_env_var` retrieves an integer environment variable value, with an optional
default value if the variable is not set or cannot be converted to an integer.
:param env_var_name: The name of the environment variable you want to retrieve the value of
:param default_value: The default value is the value that will be returned if the environment
variable is not found or if it cannot be converted to an integer
:return: an integer value.
"""
try:
return int(os.environ.get(env_var_name, default_value))
except Exception:
return default_value
86 changes: 56 additions & 30 deletions src/promptflow/promptflow/executor/_line_execution_process_pool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import contextvars
import math
import multiprocessing
import os
import queue
Expand All @@ -21,7 +20,7 @@
from promptflow._utils.logger_utils import LogContext, bulk_logger
from promptflow._utils.multimedia_utils import _process_recursively, persist_multimedia_data
from promptflow._utils.thread_utils import RepeatLogTimer
from promptflow._utils.utils import log_progress, set_context
from promptflow._utils.utils import log_progress, set_context, get_int_env_var
from promptflow.contracts.multimedia import Image
from promptflow.contracts.run_info import FlowRunInfo
from promptflow.contracts.run_info import RunInfo as NodeRunInfo
Expand Down Expand Up @@ -136,6 +135,8 @@ def format_current_process(self, line_number: int, is_completed=False):


class LineExecutionProcessPool:
_DEFAULT_WORKER_COUNT = 16

def __init__(
self,
flow_executor: FlowExecutor,
Expand All @@ -149,7 +150,6 @@ def __init__(
self._run_id = run_id
self._variant_id = variant_id
self._validate_inputs = validate_inputs
self._worker_count = flow_executor._worker_count
multiprocessing_start_method = os.environ.get("PF_BATCH_METHOD")
sys_start_methods = multiprocessing.get_all_start_methods()
if multiprocessing_start_method and multiprocessing_start_method not in sys_start_methods:
Expand Down Expand Up @@ -194,15 +194,8 @@ def __enter__(self):
self._completed_idx = manager.dict()

self._inputs_queue = Queue()
# Starting a new process in non-fork mode requires to allocate memory. Determine the maximum number of processes
# based on available memory to avoid memory bursting.
if not self._use_fork:
available_max_worker_count = get_available_max_worker_count()
self._n_process = min(self._worker_count, self._nlines, available_max_worker_count)
bulk_logger.info(f"Not using fork, process count: {self._n_process}")
else:
self._n_process = min(self._worker_count, self._nlines)
bulk_logger.info(f"Using fork, process count: {self._n_process}")
self._n_process = self._determine_worker_count()

pool = ThreadPool(self._n_process, initializer=set_context, initargs=(contextvars.copy_context(),))
self._pool = pool

Expand Down Expand Up @@ -411,6 +404,43 @@ def _generate_thread_status_messages(self, pool: ThreadPool, total_count: int):
msgs.append("Processing Lines: " + ", ".join(lines) + ".")
return msgs

def _determine_worker_count(self):
worker_count = get_int_env_var("PF_WORKER_COUNT")

# Starting a new process in non-fork mode requires to allocate memory. Calculate the maximum number of processes
# based on available memory to avoid memory bursting.
estimated_available_worker_count = get_available_max_worker_count() if not self._use_fork else None

# If the environment variable PF_WORKER_COUNT exists and valid, use the value as the worker_count.
if worker_count is not None and worker_count > 0:
self._log_set_worker_count(worker_count, estimated_available_worker_count)
return worker_count

# If the environment variable PF_WORKER_COUNT is not set or invalid, take the minimum value among the
# factors: default_worker_count, row_count and estimated_worker_count_based_on_memory_usage
factors = {
"default_worker_count": self._DEFAULT_WORKER_COUNT,
"row_count": self._nlines,
"estimated_worker_count_based_on_memory_usage": estimated_available_worker_count,
}

valid_factors = {k: v for k, v in factors.items() if v is not None and v > 0}

# Take the minimum value as the result
worker_count = min(valid_factors.values())
bulk_logger.info(
f"Set process count to {worker_count} by taking the minimum value among the factors of {valid_factors}.")
return worker_count

def _log_set_worker_count(self, worker_count, estimated_available_worker_count):
bulk_logger.info(
f"Set process count to {worker_count} with the environment variable 'PF_WORKER_COUNT'.")
if estimated_available_worker_count is not None and estimated_available_worker_count < worker_count:
bulk_logger.warning(
f"The current process count ({worker_count}) is larger than recommended process count "
f"({estimated_available_worker_count}) that estimated by system available memory. This may "
f"cause memory exhaustion")


def _exec_line(
executor: FlowExecutor,
Expand Down Expand Up @@ -481,7 +511,6 @@ def create_executor_fork(*, flow_executor: FlowExecutor, storage: AbstractRunSto
run_tracker=run_tracker,
cache_manager=flow_executor._cache_manager,
loaded_tools=flow_executor._loaded_tools,
worker_count=flow_executor._worker_count,
raise_ex=False,
line_timeout_sec=flow_executor._line_timeout_sec,
)
Expand Down Expand Up @@ -540,31 +569,28 @@ def create_executor_legacy(*, flow, connections, loaded_tools, cache_manager, st
def get_available_max_worker_count():
pid = os.getpid()
mem_info = psutil.virtual_memory()
total_memory = mem_info.total / (1024 * 1024) # in MB
total_memory_in_use = mem_info.used / (1024 * 1024) # in MB
available_memory = mem_info.available / (1024 * 1024) # in MB
process = psutil.Process(pid)
process_memory_info = process.memory_info()
process_memory = process_memory_info.rss / (1024 * 1024) # in MB
# To ensure system stability, reserve memory for system usage.
available_max_worker_count = math.floor((available_memory - 0.3 * total_memory) / process_memory)
if available_max_worker_count < 1:
# For the case of vector db, at most 1/3 of the memory will be used, which is 33% of the memory
# In this scenario, the "available_max_worker_count" may be 0, which will cause an error
# "Number of processes must be at least 1" when creating ThreadPool
# So set "available_max_worker_count" to 1 if it's less than 1
estimated_available_worker_count = int(available_memory // process_memory)
if estimated_available_worker_count < 1:
# TODO: For the case of vector db, Optimize execution logic
# 1. Let the main process not consume memory because it does not actually invoke
# 2. When the degree of parallelism is 1, main process executes the task directly and not
# create the child process
bulk_logger.warning(f"Available max worker count {available_max_worker_count} is less than 1, set it to 1.")
available_max_worker_count = 1
bulk_logger.info(
f"""Process {pid} uses {process_memory},
total memory {total_memory}, total memory in use: {total_memory_in_use},
available memory: {available_memory}, available max worker count: {available_max_worker_count}"""
)
return available_max_worker_count
bulk_logger.warning(
f"Current system's available memory is {available_memory}MB, less than the memory "
f"{process_memory}MB required by the process. The maximum available worker count is 1.")
estimated_available_worker_count = 1
else:
bulk_logger.info(
f"Current system's available memory is {available_memory}MB, "
f"memory consumption of current process is {process_memory}MB, "
f"estimated available worker count is {available_memory}/{process_memory} "
f"= {estimated_available_worker_count}"
)
return estimated_available_worker_count


def get_multiprocessing_context(multiprocessing_start_method=None):
Expand Down
14 changes: 0 additions & 14 deletions src/promptflow/promptflow/executor/flow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import copy
import functools
import inspect
import os
import uuid
from pathlib import Path
from threading import current_thread
Expand Down Expand Up @@ -88,7 +87,6 @@ def __init__(
cache_manager: AbstractCacheManager,
loaded_tools: Mapping[str, Callable],
*,
worker_count=None,
raise_ex: bool = False,
working_dir=None,
line_timeout_sec=LINE_TIMEOUT_SEC,
Expand All @@ -106,8 +104,6 @@ def __init__(
:type cache_manager: ~promptflow._core.cache_manager.AbstractCacheManager
:param loaded_tools: A mapping of tool names to their corresponding functions.
:type loaded_tools: Mapping[str, Callable]
:param worker_count: The number of workers to use for parallel execution of the Flow.
:type worker_count: int or None
:param raise_ex: Whether to raise an exception if an error occurs during execution.
:type raise_ex: bool
:param working_dir: The working directory to use for execution.
Expand All @@ -126,16 +122,6 @@ def __init__(
self._connections = connections
self._aggregation_inputs_references = get_aggregation_inputs_properties(flow)
self._aggregation_nodes = {node.name for node in self._flow.nodes if node.aggregation}
if worker_count is not None:
self._worker_count = worker_count
else:
try:
worker_count = int(os.environ.get("PF_WORKER_COUNT", self._DEFAULT_WORKER_COUNT))
self._worker_count = worker_count
except Exception:
self._worker_count = self._DEFAULT_WORKER_COUNT
if self._worker_count <= 0:
self._worker_count = self._DEFAULT_WORKER_COUNT
self._run_tracker = run_tracker
self._cache_manager = cache_manager
self._loaded_tools = loaded_tools
Expand Down
28 changes: 27 additions & 1 deletion src/promptflow/tests/executor/unittests/_utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
import os
from unittest.mock import patch

from promptflow._utils.utils import is_json_serializable
from promptflow._utils.utils import is_json_serializable, get_int_env_var


class MyObj:
Expand All @@ -12,3 +14,27 @@ class TestUtils:
@pytest.mark.parametrize("value, expected_res", [(None, True), (1, True), ("", True), (MyObj(), False)])
def test_is_json_serializable(self, value, expected_res):
assert is_json_serializable(value) == expected_res

@pytest.mark.parametrize(
"env_var, env_value, default_value, expected_result",
[
("TEST_VAR", "10", None, 10), # Valid integer string
("TEST_VAR", "invalid", None, None), # Invalid integer strings
("TEST_VAR", None, 5, 5), # Environment variable does not exist
("TEST_VAR", "10", 5, 10), # Valid integer string with a default value
("TEST_VAR", "invalid", 5, 5), # Invalid integer string with a default value
])
def test_get_int_env_var(self, env_var, env_value, default_value, expected_result):
with patch.dict(os.environ, {env_var: env_value} if env_value is not None else {}):
assert get_int_env_var(env_var, default_value) == expected_result

@pytest.mark.parametrize(
"env_var, env_value, expected_result",
[
("TEST_VAR", "10", 10), # Valid integer string
("TEST_VAR", "invalid", None), # Invalid integer strings
("TEST_VAR", None, None), # Environment variable does not exist
])
def test_get_int_env_var_without_default_vaue(self, env_var, env_value, expected_result):
with patch.dict(os.environ, {env_var: env_value} if env_value is not None else {}):
assert get_int_env_var(env_var) == expected_result
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from unittest.mock import Mock, patch
from unittest.mock import Mock

import pytest

from promptflow import tool
from promptflow.contracts.flow import FlowInputDefinition
from promptflow.contracts.tool import ValueType
from promptflow.executor._line_execution_process_pool import get_available_max_worker_count

from promptflow.executor.flow_executor import (
FlowExecutor,
_ensure_node_result_is_serializable,
Expand Down Expand Up @@ -178,31 +178,3 @@ def test_streaming_tool_should_be_consumed_and_merged(self):
def test_non_streaming_tool_should_not_be_affected(self):
func = _ensure_node_result_is_serializable(non_streaming_tool)
assert func() == 1


class TestGetAvailableMaxWorkerCount:
@pytest.mark.parametrize(
"total_memory, available_memory, process_memory, expected_max_worker_count, actual_calculate_worker_count",
[
(1024.0, 128.0, 64.0, 1, -3), # available_memory - 0.3 * total_memory < 0
(1024.0, 307.20, 64.0, 1, 0), # available_memory - 0.3 * total_memory = 0
(1024.0, 768.0, 64.0, 7, 7), # available_memory - 0.3 * total_memory > 0
],
)
def test_get_available_max_worker_count(
self, total_memory, available_memory, process_memory, expected_max_worker_count, actual_calculate_worker_count
):
with patch("psutil.virtual_memory") as mock_mem:
mock_mem.return_value.total = total_memory * 1024 * 1024
mock_mem.return_value.available = available_memory * 1024 * 1024
with patch("psutil.Process") as mock_process:
mock_process.return_value.memory_info.return_value.rss = process_memory * 1024 * 1024
with patch("promptflow.executor._line_execution_process_pool.bulk_logger") as mock_logger:
mock_logger.warning.return_value = None
max_worker_count = get_available_max_worker_count()
assert max_worker_count == expected_max_worker_count
if actual_calculate_worker_count < 1:
mock_logger.warning.assert_called_with(
f"Available max worker count {actual_calculate_worker_count} is less than 1, "
"set it to 1."
)
Loading

0 comments on commit 0b18ee6

Please sign in to comment.