diff --git a/src/promptflow/promptflow/batch/_batch_engine.py b/src/promptflow/promptflow/batch/_batch_engine.py index ad5c167372d..2738a85e8d0 100644 --- a/src/promptflow/promptflow/batch/_batch_engine.py +++ b/src/promptflow/promptflow/batch/_batch_engine.py @@ -119,9 +119,7 @@ def __init__( self._storage = storage self._kwargs = kwargs - self._batch_timeout_sec = ( - batch_timeout_sec if batch_timeout_sec else get_int_env_var("PF_BATCH_TIMEOUT_SEC", None) - ) + self._batch_timeout_sec = batch_timeout_sec or get_int_env_var("PF_BATCH_TIMEOUT_SEC") self._line_timeout_sec = get_int_env_var("PF_LINE_TIMEOUT_SEC", LINE_TIMEOUT_SEC) # set it to True when the batch run is canceled diff --git a/src/promptflow/promptflow/executor/_line_execution_process_pool.py b/src/promptflow/promptflow/executor/_line_execution_process_pool.py index 4ef81604d18..6704f56af17 100644 --- a/src/promptflow/promptflow/executor/_line_execution_process_pool.py +++ b/src/promptflow/promptflow/executor/_line_execution_process_pool.py @@ -237,24 +237,40 @@ def _handle_output_queue_messages(self, output_queue: Queue, result_list): return None def _monitor_workers_and_process_tasks_in_thread( - self, task_queue: Queue, result_list: List[LineResult], index: int, input_queue: Queue, output_queue: Queue + self, + task_queue: Queue, + result_list: List[LineResult], + index: int, + input_queue: Queue, + output_queue: Queue, + batch_start_time: datetime, ): index, process_id, process_name = self._get_process_info(index) - batch_start_time = datetime.utcnow() # Entering the while loop requires two conditions: # 1. The task queue is not empty, meaning there are lines yet to be executed. # 2. The batch run has not reached the batch timeout limit. while not self._batch_timeout_expired(batch_start_time): self._processes_manager.ensure_healthy() try: - args = task_queue.get(timeout=1) + # Get task from task_queue + inputs, line_number, run_id = task_queue.get(timeout=1) except queue.Empty: break + # Calculate the line timeout for the current line. + line_timeout_sec = self._line_timeout_sec + if self._batch_timeout_sec: + remaining_execution_time = ( + self._batch_timeout_sec - (datetime.utcnow() - batch_start_time).total_seconds() + ) + if remaining_execution_time <= 0: + break + line_timeout_sec = min(line_timeout_sec, remaining_execution_time) + # Put task into input_queue + args = (inputs, line_number, run_id, line_timeout_sec) input_queue.put(args) - inputs, line_number, run_id = args self._processing_idx[line_number] = format_current_process_info(process_name, process_id, line_number) log_process_status(process_name, process_id, line_number) @@ -265,7 +281,7 @@ def _monitor_workers_and_process_tasks_in_thread( returned_node_run_infos = {} # Responsible for checking the output queue messages and processing them within a specified timeout period. - while not self._line_timeout_expired(start_time) and not self._batch_timeout_expired(batch_start_time): + while not self._batch_timeout_expired(batch_start_time) and not self._line_timeout_expired(start_time): # Monitor process aliveness. crashed = not self._is_process_alive(process_id) if crashed: @@ -460,6 +476,7 @@ def run(self, batch_inputs): ), ): try: + batch_start_time = datetime.utcnow() args_list = [ ( self._task_queue, # Shared task queue for all sub processes to read the input data. @@ -469,6 +486,7 @@ def run(self, batch_inputs): self._input_queues[i], # Specific output queue for the sub process, used to receive results from it. self._output_queues[i], + batch_start_time, ) for i in range(self._n_process) ] @@ -573,13 +591,16 @@ def _log_set_worker_count(self, worker_count, estimated_available_worker_count): ) -def _exec_line(executor: FlowExecutor, output_queue: Queue, *, inputs: dict, run_id, index: int): +def _exec_line( + executor: FlowExecutor, output_queue: Queue, *, inputs: dict, run_id: str, index: int, line_timeout_sec: int +): try: line_result = executor.exec_line( inputs=inputs, run_id=run_id, index=index, node_concurrency=DEFAULT_CONCURRENCY_BULK, + line_timeout_sec=line_timeout_sec, ) if line_result is not None: # For eager flow, the output may be a dataclass which is not picklable, we need to convert it to dict. @@ -660,13 +681,14 @@ def exec_line_for_queue(executor_creation_func, input_queue: Queue, output_queue while True: try: - inputs, line_number, run_id = input_queue.get(timeout=1) + inputs, line_number, run_id, line_timeout_sec = input_queue.get(timeout=1) result = _exec_line( executor=executor, output_queue=output_queue, inputs=inputs, run_id=run_id, index=line_number, + line_timeout_sec=line_timeout_sec, ) output_queue.put(result) except queue.Empty: diff --git a/src/promptflow/promptflow/executor/flow_executor.py b/src/promptflow/promptflow/executor/flow_executor.py index 76db71562bd..bc7689fae5a 100644 --- a/src/promptflow/promptflow/executor/flow_executor.py +++ b/src/promptflow/promptflow/executor/flow_executor.py @@ -694,6 +694,7 @@ def exec_line( validate_inputs: bool = True, node_concurrency=DEFAULT_CONCURRENCY_FLOW, allow_generator_output: bool = False, + line_timeout_sec: Optional[int] = None, ) -> LineResult: """Execute a single line of the flow. @@ -711,10 +712,14 @@ def exec_line( :type node_concurrency: int :param allow_generator_output: Whether to allow generator output. :type allow_generator_output: bool + :param line_timeout_sec: The maximum time to wait for a line of output. + :type line_timeout_sec: Optional[int] :return: The result of executing the line. :rtype: ~promptflow.executor._result.LineResult """ self._node_concurrency = node_concurrency + # TODO: Pass line_timeout_sec to flow node scheduler instead of updating self._line_timeout_sec + self._line_timeout_sec = line_timeout_sec or self._line_timeout_sec inputs = apply_default_value_for_input(self._flow.inputs, inputs) # For flow run, validate inputs as default with self._run_tracker.node_log_manager: diff --git a/src/promptflow/tests/executor/e2etests/test_batch_timeout.py b/src/promptflow/tests/executor/e2etests/test_batch_timeout.py index 157dbd8124c..e8125cc35b6 100644 --- a/src/promptflow/tests/executor/e2etests/test_batch_timeout.py +++ b/src/promptflow/tests/executor/e2etests/test_batch_timeout.py @@ -3,9 +3,11 @@ import pytest +from promptflow._utils.exception_utils import ExceptionPresenter from promptflow.batch import BatchEngine from promptflow.batch._result import BatchResult, LineError from promptflow.contracts.run_info import Status +from promptflow.executor._errors import BatchExecutionTimeoutError, LineExecutionTimeoutError from ..utils import MemoryRunStorage, get_flow_folder, get_flow_inputs_file, get_yaml_file @@ -119,12 +121,16 @@ def test_batch_with_one_line_timeout(self, flow_folder, dev_connections): assert len(mem_run_storage._node_runs) == 6, "Node runs are persisted in memory storage." @pytest.mark.parametrize( - "flow_folder", + "flow_folder, line_timeout_sec, batch_timeout_sec, expected_error", [ - ONE_LINE_OF_BULK_TEST_TIMEOUT, + (ONE_LINE_OF_BULK_TEST_TIMEOUT, 600, 5, BatchExecutionTimeoutError(2, 5)), + (ONE_LINE_OF_BULK_TEST_TIMEOUT, 3, 600, LineExecutionTimeoutError(2, 3)), + (ONE_LINE_OF_BULK_TEST_TIMEOUT, 3, 5, LineExecutionTimeoutError(2, 3)), + # TODO: Will change to BatchExecutionTimeoutError after refining the implementation of batch timeout. + # (ONE_LINE_OF_BULK_TEST_TIMEOUT, 3, 3, LineExecutionTimeoutError(2, 3)), ], ) - def test_batch_timeout(self, flow_folder): + def test_batch_timeout(self, flow_folder, line_timeout_sec, batch_timeout_sec, expected_error): mem_run_storage = MemoryRunStorage() batch_engine = BatchEngine( get_yaml_file(flow_folder), @@ -132,7 +138,8 @@ def test_batch_timeout(self, flow_folder): connections={}, storage=mem_run_storage, ) - batch_engine._batch_timeout_sec = 5 + batch_engine._line_timeout_sec = line_timeout_sec + batch_engine._batch_timeout_sec = batch_timeout_sec input_dirs = {"data": get_flow_inputs_file(flow_folder, file_name="samples.json")} output_dir = Path(mkdtemp()) @@ -145,10 +152,9 @@ def test_batch_timeout(self, flow_folder): assert batch_results.total_lines == 3 assert batch_results.completed_lines == 2 assert batch_results.failed_lines == 1 - # TODO: Currently, the node status is incomplete. - # We will assert the correct result after refining the implementation of batch timeout. assert batch_results.node_status == { "my_python_tool_with_failed_line.completed": 2, + "my_python_tool_with_failed_line.canceled": 1, "my_python_tool.completed": 3, } @@ -157,16 +163,16 @@ def test_batch_timeout(self, flow_folder): assert batch_results.error_summary.failed_system_error_lines == 0 assert isinstance(batch_results.error_summary.error_list[0], LineError) assert batch_results.error_summary.error_list[0].line_number == 2 - assert batch_results.error_summary.error_list[0].error["code"] == "UserError" - assert batch_results.error_summary.error_list[0].error["referenceCode"] == "Batch" - assert batch_results.error_summary.error_list[0].error["innerError"]["code"] == "BatchExecutionTimeoutError" - assert ( - batch_results.error_summary.error_list[0].error["message"] - == "Line 2 execution terminated due to the total batch run exceeding the batch timeout (5s)." - ) + + actual_error_dict = batch_results.error_summary.error_list[0].error + expected_error_dict = ExceptionPresenter.create(expected_error).to_dict() + assert actual_error_dict["code"] == expected_error_dict["code"] + assert actual_error_dict["message"] == expected_error_dict["message"] + assert actual_error_dict["referenceCode"] == expected_error_dict["referenceCode"] + assert actual_error_dict["innerError"]["code"] == expected_error_dict["innerError"]["code"] # assert mem_run_storage persists run infos correctly assert len(mem_run_storage._flow_runs) == 3, "Flow runs are persisted in memory storage." # TODO: Currently, the node status is incomplete. # We will assert the correct result after refining the implementation of batch timeout. - assert len(mem_run_storage._node_runs) == 5, "Node runs are persisted in memory storage." + assert len(mem_run_storage._node_runs) == 6, "Node runs are persisted in memory storage." diff --git a/src/promptflow/tests/executor/unittests/processpool/test_line_execution_process_pool.py b/src/promptflow/tests/executor/unittests/processpool/test_line_execution_process_pool.py index 8b7ffb350de..278336c9bb1 100644 --- a/src/promptflow/tests/executor/unittests/processpool/test_line_execution_process_pool.py +++ b/src/promptflow/tests/executor/unittests/processpool/test_line_execution_process_pool.py @@ -270,6 +270,7 @@ def test_exec_line(self, flow_folder, dev_connections, mocker: MockFixture): inputs=line_inputs, run_id=run_id, index=0, + line_timeout_sec=600, ) assert isinstance(line_result, LineResult) @@ -295,6 +296,7 @@ def test_exec_line_failed_when_line_execution_not_start(self, flow_folder, dev_c inputs=line_inputs, run_id=run_id, index=0, + line_timeout_sec=600, ) assert isinstance(line_result, LineResult) assert line_result.run_info.error["message"] == test_error_msg