Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PromptFlow] Optimize memory computation #1328

Merged
merged 32 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
8cbf7b9
Optimize memory computation
Hhhilulu Nov 30, 2023
3aae7df
Optimize memory computation
Hhhilulu Nov 30, 2023
45f9c28
Optimize memory computation
Hhhilulu Nov 30, 2023
864443a
Add os.environ.get("PF_WORKER_COUNT")
Hhhilulu Dec 1, 2023
fa4196c
Add some test case and optimize logic
Hhhilulu Dec 1, 2023
8daab14
Delete the previous environment variable
Hhhilulu Dec 3, 2023
a8af8fe
Delete the previous environment variable
Hhhilulu Dec 4, 2023
279a8e6
Optimize log information
Hhhilulu Dec 4, 2023
89a67ed
fix flask8
Hhhilulu Dec 4, 2023
4ac4cea
fix flask8
Hhhilulu Dec 4, 2023
9b3d695
delete "the"
Hhhilulu Dec 4, 2023
3c61469
Extract as public function and refine worker_count
Hhhilulu Dec 4, 2023
467f2e0
refine load_worker_count_in_env
Hhhilulu Dec 4, 2023
1b5a1c3
Rename variable
Hhhilulu Dec 4, 2023
8c77dba
Refine load_worker_count_in_env
Hhhilulu Dec 5, 2023
8ac1eda
Optimize log
Hhhilulu Dec 5, 2023
7273b63
fix comments
Hhhilulu Dec 5, 2023
5d3f7f8
Rename variable
Hhhilulu Dec 5, 2023
0bed290
fix comments
Hhhilulu Dec 5, 2023
c6bd663
delete useless function
Hhhilulu Dec 5, 2023
76286b7
remove flow_executor's worker_count
Hhhilulu Dec 5, 2023
f062661
Optimize log content
Hhhilulu Dec 5, 2023
96216c9
Refine _determine_worker_count
Hhhilulu Dec 6, 2023
8882ba2
fix typo
Hhhilulu Dec 6, 2023
05ed6c1
Optimization logs
Hhhilulu Dec 6, 2023
01653a8
Rmove bracket
Hhhilulu Dec 6, 2023
130b1b1
Add space
Hhhilulu Dec 6, 2023
beff59d
Optimization log
Hhhilulu Dec 6, 2023
1ac3321
Refine _determine_worker_count logic
Hhhilulu Dec 7, 2023
3c3cb17
fix flask8
Hhhilulu Dec 7, 2023
2c1a1b6
Refine if logic
Hhhilulu Dec 7, 2023
c1e228a
Redefine the calculation of factors
Hhhilulu Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/promptflow/promptflow/_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def extract_user_frame_summaries(frame_summaries: List[traceback.FrameSummary]):
# If the current frame is in tool.py and the next frame is not in _core folder
# then we can say that the next frame is in user code.
if cur_file == tool_file and not next_file.startswith(core_folder):
return frame_summaries[i + 1 :]
return frame_summaries[i + 1:]
return frame_summaries


Expand Down Expand Up @@ -247,3 +247,19 @@ def parse_ua_to_dict(ua):
key, value = item.split("/")
ua_dict[key] = value
return ua_dict


def get_int_env_var(env_var_name, default_value=None):
"""
The function `get_int_env_var` retrieves an integer environment variable value, with an optional
default value if the variable is not set or cannot be converted to an integer.

:param env_var_name: The name of the environment variable you want to retrieve the value of
:param default_value: The default value is the value that will be returned if the environment
variable is not found or if it cannot be converted to an integer
:return: an integer value.
"""
try:
return int(os.environ.get(env_var_name, default_value))
except Exception:
return default_value
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import contextvars
import math
import multiprocessing
import os
import queue
Expand All @@ -21,7 +20,7 @@
from promptflow._utils.logger_utils import LogContext, bulk_logger
from promptflow._utils.multimedia_utils import _process_recursively, persist_multimedia_data
from promptflow._utils.thread_utils import RepeatLogTimer
from promptflow._utils.utils import log_progress, set_context
from promptflow._utils.utils import log_progress, set_context, get_int_env_var
from promptflow.contracts.multimedia import Image
from promptflow.contracts.run_info import FlowRunInfo
from promptflow.contracts.run_info import RunInfo as NodeRunInfo
Expand Down Expand Up @@ -136,6 +135,8 @@ def format_current_process(self, line_number: int, is_completed=False):


class LineExecutionProcessPool:
_DEFAULT_WORKER_COUNT = 16

def __init__(
self,
flow_executor: FlowExecutor,
Expand All @@ -149,7 +150,6 @@ def __init__(
self._run_id = run_id
self._variant_id = variant_id
self._validate_inputs = validate_inputs
self._worker_count = flow_executor._worker_count
multiprocessing_start_method = os.environ.get("PF_BATCH_METHOD")
sys_start_methods = multiprocessing.get_all_start_methods()
if multiprocessing_start_method and multiprocessing_start_method not in sys_start_methods:
Expand Down Expand Up @@ -194,15 +194,8 @@ def __enter__(self):
self._completed_idx = manager.dict()

self._inputs_queue = Queue()
# Starting a new process in non-fork mode requires to allocate memory. Determine the maximum number of processes
# based on available memory to avoid memory bursting.
if not self._use_fork:
Hhhilulu marked this conversation as resolved.
Show resolved Hide resolved
available_max_worker_count = get_available_max_worker_count()
self._n_process = min(self._worker_count, self._nlines, available_max_worker_count)
bulk_logger.info(f"Not using fork, process count: {self._n_process}")
else:
self._n_process = min(self._worker_count, self._nlines)
bulk_logger.info(f"Using fork, process count: {self._n_process}")
self._n_process = self._determine_worker_count()

pool = ThreadPool(self._n_process, initializer=set_context, initargs=(contextvars.copy_context(),))
self._pool = pool

Expand Down Expand Up @@ -411,6 +404,43 @@ def _generate_thread_status_messages(self, pool: ThreadPool, total_count: int):
msgs.append("Processing Lines: " + ", ".join(lines) + ".")
return msgs

def _determine_worker_count(self):
Hhhilulu marked this conversation as resolved.
Show resolved Hide resolved
worker_count = get_int_env_var("PF_WORKER_COUNT")

# Starting a new process in non-fork mode requires to allocate memory. Calculate the maximum number of processes
# based on available memory to avoid memory bursting.
estimated_available_worker_count = get_available_max_worker_count() if not self._use_fork else None

# If the environment variable PF_WORKER_COUNT exists and valid, use the value as the worker_count.
if worker_count is not None and worker_count > 0:
self._log_set_worker_count(worker_count, estimated_available_worker_count)
return worker_count

# If the environment variable PF_WORKER_COUNT is not set or invalid, take the minimum value among the
# factors: default_worker_count, row_count and estimated_worker_count_based_on_memory_usage
factors = {
"default_worker_count": self._DEFAULT_WORKER_COUNT,
"row_count": self._nlines,
"estimated_worker_count_based_on_memory_usage": estimated_available_worker_count,
}

valid_factors = {k: v for k, v in factors.items() if v is not None and v > 0}

# Take the minimum value as the result
worker_count = min(valid_factors.values())
bulk_logger.info(
f"Set process count to {worker_count} by taking the minimum value among the factors of {valid_factors}.")
return worker_count

def _log_set_worker_count(self, worker_count, estimated_available_worker_count):
bulk_logger.info(
f"Set process count to {worker_count} with the environment variable 'PF_WORKER_COUNT'.")
if estimated_available_worker_count is not None and estimated_available_worker_count < worker_count:
bulk_logger.warning(
f"The current process count ({worker_count}) is larger than recommended process count "
f"({estimated_available_worker_count}) that estimated by system available memory. This may "
f"cause memory exhaustion")


def _exec_line(
executor: FlowExecutor,
Expand Down Expand Up @@ -481,7 +511,6 @@ def create_executor_fork(*, flow_executor: FlowExecutor, storage: AbstractRunSto
run_tracker=run_tracker,
cache_manager=flow_executor._cache_manager,
loaded_tools=flow_executor._loaded_tools,
worker_count=flow_executor._worker_count,
raise_ex=False,
line_timeout_sec=flow_executor._line_timeout_sec,
)
Expand Down Expand Up @@ -540,31 +569,28 @@ def create_executor_legacy(*, flow, connections, loaded_tools, cache_manager, st
def get_available_max_worker_count():
pid = os.getpid()
mem_info = psutil.virtual_memory()
total_memory = mem_info.total / (1024 * 1024) # in MB
total_memory_in_use = mem_info.used / (1024 * 1024) # in MB
available_memory = mem_info.available / (1024 * 1024) # in MB
process = psutil.Process(pid)
process_memory_info = process.memory_info()
process_memory = process_memory_info.rss / (1024 * 1024) # in MB
# To ensure system stability, reserve memory for system usage.
available_max_worker_count = math.floor((available_memory - 0.3 * total_memory) / process_memory)
if available_max_worker_count < 1:
# For the case of vector db, at most 1/3 of the memory will be used, which is 33% of the memory
# In this scenario, the "available_max_worker_count" may be 0, which will cause an error
# "Number of processes must be at least 1" when creating ThreadPool
# So set "available_max_worker_count" to 1 if it's less than 1
estimated_available_worker_count = int(available_memory // process_memory)
if estimated_available_worker_count < 1:
# TODO: For the case of vector db, Optimize execution logic
# 1. Let the main process not consume memory because it does not actually invoke
# 2. When the degree of parallelism is 1, main process executes the task directly and not
# create the child process
bulk_logger.warning(f"Available max worker count {available_max_worker_count} is less than 1, set it to 1.")
available_max_worker_count = 1
bulk_logger.info(
f"""Process {pid} uses {process_memory},
total memory {total_memory}, total memory in use: {total_memory_in_use},
available memory: {available_memory}, available max worker count: {available_max_worker_count}"""
)
return available_max_worker_count
bulk_logger.warning(
f"Current system's available memory is {available_memory}MB, less than the memory "
f"{process_memory}MB required by the process. The maximum available worker count is 1.")
estimated_available_worker_count = 1
else:
bulk_logger.info(
f"Current system's available memory is {available_memory}MB, "
f"memory consumption of current process is {process_memory}MB, "
f"estimated available worker count is {available_memory}/{process_memory} "
f"= {estimated_available_worker_count}"
)
return estimated_available_worker_count


def get_multiprocessing_context(multiprocessing_start_method=None):
Expand Down
14 changes: 0 additions & 14 deletions src/promptflow/promptflow/executor/flow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import copy
import functools
import inspect
import os
import uuid
from pathlib import Path
from threading import current_thread
Expand Down Expand Up @@ -88,7 +87,6 @@ def __init__(
cache_manager: AbstractCacheManager,
loaded_tools: Mapping[str, Callable],
*,
worker_count=None,
raise_ex: bool = False,
working_dir=None,
line_timeout_sec=LINE_TIMEOUT_SEC,
Expand All @@ -106,8 +104,6 @@ def __init__(
:type cache_manager: ~promptflow._core.cache_manager.AbstractCacheManager
:param loaded_tools: A mapping of tool names to their corresponding functions.
:type loaded_tools: Mapping[str, Callable]
:param worker_count: The number of workers to use for parallel execution of the Flow.
:type worker_count: int or None
:param raise_ex: Whether to raise an exception if an error occurs during execution.
:type raise_ex: bool
:param working_dir: The working directory to use for execution.
Expand All @@ -126,16 +122,6 @@ def __init__(
self._connections = connections
self._aggregation_inputs_references = get_aggregation_inputs_properties(flow)
self._aggregation_nodes = {node.name for node in self._flow.nodes if node.aggregation}
if worker_count is not None:
self._worker_count = worker_count
else:
try:
worker_count = int(os.environ.get("PF_WORKER_COUNT", self._DEFAULT_WORKER_COUNT))
self._worker_count = worker_count
except Exception:
self._worker_count = self._DEFAULT_WORKER_COUNT
if self._worker_count <= 0:
self._worker_count = self._DEFAULT_WORKER_COUNT
self._run_tracker = run_tracker
self._cache_manager = cache_manager
self._loaded_tools = loaded_tools
Expand Down
28 changes: 27 additions & 1 deletion src/promptflow/tests/executor/unittests/_utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
import os
from unittest.mock import patch

from promptflow._utils.utils import is_json_serializable
from promptflow._utils.utils import is_json_serializable, get_int_env_var


class MyObj:
Expand All @@ -12,3 +14,27 @@ class TestUtils:
@pytest.mark.parametrize("value, expected_res", [(None, True), (1, True), ("", True), (MyObj(), False)])
def test_is_json_serializable(self, value, expected_res):
assert is_json_serializable(value) == expected_res

@pytest.mark.parametrize(
"env_var, env_value, default_value, expected_result",
[
("TEST_VAR", "10", None, 10), # Valid integer string
("TEST_VAR", "invalid", None, None), # Invalid integer strings
("TEST_VAR", None, 5, 5), # Environment variable does not exist
("TEST_VAR", "10", 5, 10), # Valid integer string with a default value
("TEST_VAR", "invalid", 5, 5), # Invalid integer string with a default value
])
def test_get_int_env_var(self, env_var, env_value, default_value, expected_result):
with patch.dict(os.environ, {env_var: env_value} if env_value is not None else {}):
assert get_int_env_var(env_var, default_value) == expected_result

@pytest.mark.parametrize(
"env_var, env_value, expected_result",
[
("TEST_VAR", "10", 10), # Valid integer string
("TEST_VAR", "invalid", None), # Invalid integer strings
("TEST_VAR", None, None), # Environment variable does not exist
])
def test_get_int_env_var_without_default_vaue(self, env_var, env_value, expected_result):
with patch.dict(os.environ, {env_var: env_value} if env_value is not None else {}):
assert get_int_env_var(env_var) == expected_result
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from unittest.mock import Mock, patch
from unittest.mock import Mock

import pytest

from promptflow import tool
from promptflow.contracts.flow import FlowInputDefinition
from promptflow.contracts.tool import ValueType
from promptflow.executor._line_execution_process_pool import get_available_max_worker_count

from promptflow.executor.flow_executor import (
FlowExecutor,
_ensure_node_result_is_serializable,
Expand Down Expand Up @@ -178,31 +178,3 @@ def test_streaming_tool_should_be_consumed_and_merged(self):
def test_non_streaming_tool_should_not_be_affected(self):
func = _ensure_node_result_is_serializable(non_streaming_tool)
assert func() == 1


class TestGetAvailableMaxWorkerCount:
@pytest.mark.parametrize(
"total_memory, available_memory, process_memory, expected_max_worker_count, actual_calculate_worker_count",
[
(1024.0, 128.0, 64.0, 1, -3), # available_memory - 0.3 * total_memory < 0
(1024.0, 307.20, 64.0, 1, 0), # available_memory - 0.3 * total_memory = 0
(1024.0, 768.0, 64.0, 7, 7), # available_memory - 0.3 * total_memory > 0
],
)
def test_get_available_max_worker_count(
self, total_memory, available_memory, process_memory, expected_max_worker_count, actual_calculate_worker_count
):
with patch("psutil.virtual_memory") as mock_mem:
mock_mem.return_value.total = total_memory * 1024 * 1024
mock_mem.return_value.available = available_memory * 1024 * 1024
with patch("psutil.Process") as mock_process:
mock_process.return_value.memory_info.return_value.rss = process_memory * 1024 * 1024
with patch("promptflow.executor._line_execution_process_pool.bulk_logger") as mock_logger:
mock_logger.warning.return_value = None
max_worker_count = get_available_max_worker_count()
assert max_worker_count == expected_max_worker_count
if actual_calculate_worker_count < 1:
mock_logger.warning.assert_called_with(
f"Available max worker count {actual_calculate_worker_count} is less than 1, "
"set it to 1."
)
Loading
Loading