Skip to content

Commit

Permalink
[Improvement][Executor] Pass dynamic line_timeout_sec when executing …
Browse files Browse the repository at this point in the history
…lines in batch run (#1834)

# Description

Pass dynamic `line_timeout_sec` when executing lines in batch run.

# All Promptflow Contribution checklist:
- [x] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [x] Title of the pull request is clear and informative.
- [x] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [x] Pull request includes test coverage for the included changes.
  • Loading branch information
PeiwenGaoMS authored Jan 29, 2024
1 parent d3a9b27 commit 9d415c4
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 24 deletions.
4 changes: 1 addition & 3 deletions src/promptflow/promptflow/batch/_batch_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,7 @@ def __init__(
self._storage = storage
self._kwargs = kwargs

self._batch_timeout_sec = (
batch_timeout_sec if batch_timeout_sec else get_int_env_var("PF_BATCH_TIMEOUT_SEC", None)
)
self._batch_timeout_sec = batch_timeout_sec or get_int_env_var("PF_BATCH_TIMEOUT_SEC")
self._line_timeout_sec = get_int_env_var("PF_LINE_TIMEOUT_SEC", LINE_TIMEOUT_SEC)

# set it to True when the batch run is canceled
Expand Down
36 changes: 29 additions & 7 deletions src/promptflow/promptflow/executor/_line_execution_process_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,24 +237,40 @@ def _handle_output_queue_messages(self, output_queue: Queue, result_list):
return None

def _monitor_workers_and_process_tasks_in_thread(
self, task_queue: Queue, result_list: List[LineResult], index: int, input_queue: Queue, output_queue: Queue
self,
task_queue: Queue,
result_list: List[LineResult],
index: int,
input_queue: Queue,
output_queue: Queue,
batch_start_time: datetime,
):
index, process_id, process_name = self._get_process_info(index)

batch_start_time = datetime.utcnow()
# Entering the while loop requires two conditions:
# 1. The task queue is not empty, meaning there are lines yet to be executed.
# 2. The batch run has not reached the batch timeout limit.
while not self._batch_timeout_expired(batch_start_time):
self._processes_manager.ensure_healthy()
try:
args = task_queue.get(timeout=1)
# Get task from task_queue
inputs, line_number, run_id = task_queue.get(timeout=1)
except queue.Empty:
break

# Calculate the line timeout for the current line.
line_timeout_sec = self._line_timeout_sec
if self._batch_timeout_sec:
remaining_execution_time = (
self._batch_timeout_sec - (datetime.utcnow() - batch_start_time).total_seconds()
)
if remaining_execution_time <= 0:
break
line_timeout_sec = min(line_timeout_sec, remaining_execution_time)

# Put task into input_queue
args = (inputs, line_number, run_id, line_timeout_sec)
input_queue.put(args)
inputs, line_number, run_id = args

self._processing_idx[line_number] = format_current_process_info(process_name, process_id, line_number)
log_process_status(process_name, process_id, line_number)
Expand All @@ -265,7 +281,7 @@ def _monitor_workers_and_process_tasks_in_thread(
returned_node_run_infos = {}

# Responsible for checking the output queue messages and processing them within a specified timeout period.
while not self._line_timeout_expired(start_time) and not self._batch_timeout_expired(batch_start_time):
while not self._batch_timeout_expired(batch_start_time) and not self._line_timeout_expired(start_time):
# Monitor process aliveness.
crashed = not self._is_process_alive(process_id)
if crashed:
Expand Down Expand Up @@ -460,6 +476,7 @@ def run(self, batch_inputs):
),
):
try:
batch_start_time = datetime.utcnow()
args_list = [
(
self._task_queue, # Shared task queue for all sub processes to read the input data.
Expand All @@ -469,6 +486,7 @@ def run(self, batch_inputs):
self._input_queues[i],
# Specific output queue for the sub process, used to receive results from it.
self._output_queues[i],
batch_start_time,
)
for i in range(self._n_process)
]
Expand Down Expand Up @@ -573,13 +591,16 @@ def _log_set_worker_count(self, worker_count, estimated_available_worker_count):
)


def _exec_line(executor: FlowExecutor, output_queue: Queue, *, inputs: dict, run_id, index: int):
def _exec_line(
executor: FlowExecutor, output_queue: Queue, *, inputs: dict, run_id: str, index: int, line_timeout_sec: int
):
try:
line_result = executor.exec_line(
inputs=inputs,
run_id=run_id,
index=index,
node_concurrency=DEFAULT_CONCURRENCY_BULK,
line_timeout_sec=line_timeout_sec,
)
if line_result is not None:
# For eager flow, the output may be a dataclass which is not picklable, we need to convert it to dict.
Expand Down Expand Up @@ -660,13 +681,14 @@ def exec_line_for_queue(executor_creation_func, input_queue: Queue, output_queue

while True:
try:
inputs, line_number, run_id = input_queue.get(timeout=1)
inputs, line_number, run_id, line_timeout_sec = input_queue.get(timeout=1)
result = _exec_line(
executor=executor,
output_queue=output_queue,
inputs=inputs,
run_id=run_id,
index=line_number,
line_timeout_sec=line_timeout_sec,
)
output_queue.put(result)
except queue.Empty:
Expand Down
5 changes: 5 additions & 0 deletions src/promptflow/promptflow/executor/flow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,7 @@ def exec_line(
validate_inputs: bool = True,
node_concurrency=DEFAULT_CONCURRENCY_FLOW,
allow_generator_output: bool = False,
line_timeout_sec: Optional[int] = None,
) -> LineResult:
"""Execute a single line of the flow.
Expand All @@ -711,10 +712,14 @@ def exec_line(
:type node_concurrency: int
:param allow_generator_output: Whether to allow generator output.
:type allow_generator_output: bool
:param line_timeout_sec: The maximum time to wait for a line of output.
:type line_timeout_sec: Optional[int]
:return: The result of executing the line.
:rtype: ~promptflow.executor._result.LineResult
"""
self._node_concurrency = node_concurrency
# TODO: Pass line_timeout_sec to flow node scheduler instead of updating self._line_timeout_sec
self._line_timeout_sec = line_timeout_sec or self._line_timeout_sec
inputs = apply_default_value_for_input(self._flow.inputs, inputs)
# For flow run, validate inputs as default
with self._run_tracker.node_log_manager:
Expand Down
34 changes: 20 additions & 14 deletions src/promptflow/tests/executor/e2etests/test_batch_timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

import pytest

from promptflow._utils.exception_utils import ExceptionPresenter
from promptflow.batch import BatchEngine
from promptflow.batch._result import BatchResult, LineError
from promptflow.contracts.run_info import Status
from promptflow.executor._errors import BatchExecutionTimeoutError, LineExecutionTimeoutError

from ..utils import MemoryRunStorage, get_flow_folder, get_flow_inputs_file, get_yaml_file

Expand Down Expand Up @@ -119,20 +121,25 @@ def test_batch_with_one_line_timeout(self, flow_folder, dev_connections):
assert len(mem_run_storage._node_runs) == 6, "Node runs are persisted in memory storage."

@pytest.mark.parametrize(
"flow_folder",
"flow_folder, line_timeout_sec, batch_timeout_sec, expected_error",
[
ONE_LINE_OF_BULK_TEST_TIMEOUT,
(ONE_LINE_OF_BULK_TEST_TIMEOUT, 600, 5, BatchExecutionTimeoutError(2, 5)),
(ONE_LINE_OF_BULK_TEST_TIMEOUT, 3, 600, LineExecutionTimeoutError(2, 3)),
(ONE_LINE_OF_BULK_TEST_TIMEOUT, 3, 5, LineExecutionTimeoutError(2, 3)),
# TODO: Will change to BatchExecutionTimeoutError after refining the implementation of batch timeout.
# (ONE_LINE_OF_BULK_TEST_TIMEOUT, 3, 3, LineExecutionTimeoutError(2, 3)),
],
)
def test_batch_timeout(self, flow_folder):
def test_batch_timeout(self, flow_folder, line_timeout_sec, batch_timeout_sec, expected_error):
mem_run_storage = MemoryRunStorage()
batch_engine = BatchEngine(
get_yaml_file(flow_folder),
get_flow_folder(flow_folder),
connections={},
storage=mem_run_storage,
)
batch_engine._batch_timeout_sec = 5
batch_engine._line_timeout_sec = line_timeout_sec
batch_engine._batch_timeout_sec = batch_timeout_sec

input_dirs = {"data": get_flow_inputs_file(flow_folder, file_name="samples.json")}
output_dir = Path(mkdtemp())
Expand All @@ -145,10 +152,9 @@ def test_batch_timeout(self, flow_folder):
assert batch_results.total_lines == 3
assert batch_results.completed_lines == 2
assert batch_results.failed_lines == 1
# TODO: Currently, the node status is incomplete.
# We will assert the correct result after refining the implementation of batch timeout.
assert batch_results.node_status == {
"my_python_tool_with_failed_line.completed": 2,
"my_python_tool_with_failed_line.canceled": 1,
"my_python_tool.completed": 3,
}

Expand All @@ -157,16 +163,16 @@ def test_batch_timeout(self, flow_folder):
assert batch_results.error_summary.failed_system_error_lines == 0
assert isinstance(batch_results.error_summary.error_list[0], LineError)
assert batch_results.error_summary.error_list[0].line_number == 2
assert batch_results.error_summary.error_list[0].error["code"] == "UserError"
assert batch_results.error_summary.error_list[0].error["referenceCode"] == "Batch"
assert batch_results.error_summary.error_list[0].error["innerError"]["code"] == "BatchExecutionTimeoutError"
assert (
batch_results.error_summary.error_list[0].error["message"]
== "Line 2 execution terminated due to the total batch run exceeding the batch timeout (5s)."
)

actual_error_dict = batch_results.error_summary.error_list[0].error
expected_error_dict = ExceptionPresenter.create(expected_error).to_dict()
assert actual_error_dict["code"] == expected_error_dict["code"]
assert actual_error_dict["message"] == expected_error_dict["message"]
assert actual_error_dict["referenceCode"] == expected_error_dict["referenceCode"]
assert actual_error_dict["innerError"]["code"] == expected_error_dict["innerError"]["code"]

# assert mem_run_storage persists run infos correctly
assert len(mem_run_storage._flow_runs) == 3, "Flow runs are persisted in memory storage."
# TODO: Currently, the node status is incomplete.
# We will assert the correct result after refining the implementation of batch timeout.
assert len(mem_run_storage._node_runs) == 5, "Node runs are persisted in memory storage."
assert len(mem_run_storage._node_runs) == 6, "Node runs are persisted in memory storage."
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def test_exec_line(self, flow_folder, dev_connections, mocker: MockFixture):
inputs=line_inputs,
run_id=run_id,
index=0,
line_timeout_sec=600,
)
assert isinstance(line_result, LineResult)

Expand All @@ -295,6 +296,7 @@ def test_exec_line_failed_when_line_execution_not_start(self, flow_folder, dev_c
inputs=line_inputs,
run_id=run_id,
index=0,
line_timeout_sec=600,
)
assert isinstance(line_result, LineResult)
assert line_result.run_info.error["message"] == test_error_msg
Expand Down

0 comments on commit 9d415c4

Please sign in to comment.