Skip to content

Commit

Permalink
[Internal][Executor] Run exec_line concurrently and refine cancel log…
Browse files Browse the repository at this point in the history
…ic in batch engine (#1387)

# Description

Run exec_line concurrently and refine cancel logic in batch engine 

# All Promptflow Contribution checklist:
- [x] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [x] Title of the pull request is clear and informative.
- [x] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.
  • Loading branch information
PeiwenGaoMS authored Dec 15, 2023
1 parent aca0a02 commit 42b5def
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 88 deletions.
15 changes: 13 additions & 2 deletions src/promptflow/promptflow/_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,22 @@ def log_progress(
count: int,
total_count: int,
formatter="Finished {count} / {total_count} lines.",
*,
last_log_count: Optional[int] = None,
):
# Calculate log_interval to determine when to log progress.
# If total_count is less than 100, log every 10% of total_count; otherwise, log every 10 lines.
log_interval = min(10, max(int(total_count / 10), 1))
if count > 0 and (count % log_interval == 0 or count == total_count):

# If last_log_count is not None, determine whether to log based on whether the difference
# between the current count and the previous count exceeds log_interval.
# Otherwise, decide based on whether the current count is evenly divisible by log_interval.
if last_log_count:
log_flag = (count - last_log_count) >= log_interval
else:
log_flag = count % log_interval == 0

if count > 0 and (log_flag or count == total_count):
average_execution_time = round((datetime.utcnow().timestamp() - run_start_time.timestamp()) / count, 2)
estimated_execution_time = round(average_execution_time * (total_count - count), 2)
logger.info(formatter.format(count=count, total_count=total_count))
Expand All @@ -167,7 +178,7 @@ def extract_user_frame_summaries(frame_summaries: List[traceback.FrameSummary]):
# If the current frame is in tool.py and the next frame is not in _core folder
# then we can say that the next frame is in user code.
if cur_file == tool_file and not next_file.startswith(core_folder):
return frame_summaries[i + 1:]
return frame_summaries[i + 1 :]
return frame_summaries


Expand Down
43 changes: 23 additions & 20 deletions src/promptflow/promptflow/batch/_base_executor_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import asyncio
from datetime import datetime
from json import JSONDecodeError
from pathlib import Path
Expand Down Expand Up @@ -66,6 +67,10 @@ async def exec_aggregation_async(
"""Execute aggregation nodes"""
raise NotImplementedError()

async def ensure_executor_health(self):
"""Ensure the executor service is healthy before execution"""
pass


class APIBasedExecutorProxy(AbstractExecutorProxy):
@property
Expand All @@ -84,8 +89,6 @@ async def exec_line_async(
run_id: Optional[str] = None,
) -> LineResult:
start_time = datetime.utcnow()
# ensure service health
await self._ensure_executor_health()
# call execution api to get line results
url = self.api_endpoint + "/Execution"
payload = {"run_id": run_id, "line_number": index, "inputs": inputs}
Expand All @@ -104,8 +107,6 @@ async def exec_aggregation_async(
aggregation_inputs: Mapping[str, Any],
run_id: Optional[str] = None,
) -> AggregationResult:
# ensure service health
await self._ensure_executor_health()
# call aggregation api to get aggregation result
async with httpx.AsyncClient() as client:
url = self.api_endpoint + "/Aggregation"
Expand All @@ -114,6 +115,24 @@ async def exec_aggregation_async(
result = self._process_http_response(response)
return AggregationResult.deserialize(result)

async def ensure_executor_health(self):
"""Ensure the executor service is healthy before calling the API to get the results
During testing, we observed that the executor service started quickly on Windows.
However, there is a noticeable delay in booting on Linux.
So we set a specific waiting period. If the executor service fails to return to normal
within the allocated timeout, an exception is thrown to indicate a potential problem.
"""
waiting_health_timeout = 5
start_time = datetime.utcnow()
while (datetime.utcnow() - start_time).seconds < waiting_health_timeout:
if await self._check_health():
return
# wait for 1s to prevent calling the API too frequently
await asyncio.sleep(1)
raise ExecutorServiceUnhealthy(f"{EXECUTOR_UNHEALTHY_MESSAGE}. Please resubmit your flow and try again.")

def _process_http_response(self, response: httpx.Response):
if response.status_code == 200:
# if the status code is 200, the response is the json dict of a line result
Expand All @@ -133,22 +152,6 @@ def _process_http_response(self, response: httpx.Response):
)
return ExceptionPresenter.create(unexpected_error).to_dict()

async def _ensure_executor_health(self):
"""Ensure the executor service is healthy before calling the API to get the results
During testing, we observed that the executor service started quickly on Windows.
However, there is a noticeable delay in booting on Linux.
So we set a specific waiting period. If the executor service fails to return to normal
within the allocated timeout, an exception is thrown to indicate a potential problem.
"""
waiting_health_timeout = 5
start_time = datetime.utcnow()
while (datetime.utcnow() - start_time).seconds < waiting_health_timeout:
if await self._check_health():
return
raise ExecutorServiceUnhealthy(f"{EXECUTOR_UNHEALTHY_MESSAGE}. Please resubmit your flow and try again.")

async def _check_health(self):
try:
health_url = self.api_endpoint + "/health"
Expand Down
142 changes: 101 additions & 41 deletions src/promptflow/promptflow/batch/_batch_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import asyncio
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional

from httpx import ConnectError

from promptflow._constants import LINE_NUMBER_KEY, FlowLanguage
from promptflow._core._errors import UnexpectedError
from promptflow._core.operation_context import OperationContext
Expand All @@ -25,7 +24,6 @@
from promptflow.batch._base_executor_proxy import AbstractExecutorProxy
from promptflow.batch._batch_inputs_processor import BatchInputsProcessor
from promptflow.batch._csharp_executor_proxy import CSharpExecutorProxy
from promptflow.batch._errors import ExecutorServiceUnhealthy
from promptflow.batch._python_executor_proxy import PythonExecutorProxy
from promptflow.batch._result import BatchResult
from promptflow.contracts.flow import Flow
Expand All @@ -36,6 +34,8 @@
from promptflow.storage._run_storage import AbstractRunStorage

OUTPUT_FILE_NAME = "output.jsonl"
# TODO: will remain consistent with PF_WORKER_COUNT in the future
DEFAULT_CONCURRENCY = 10


class BatchEngine:
Expand Down Expand Up @@ -135,20 +135,14 @@ def run(
# run flow in batch mode
with _change_working_dir(self._working_dir):
return async_run_allowing_running_loop(
self._exec_batch, batch_inputs, run_id, output_dir, raise_on_line_failure
self._exec_in_task, batch_inputs, run_id, output_dir, raise_on_line_failure
)
except Exception as e:
bulk_logger.error(f"Error occurred while executing batch run. Exception: {str(e)}")
if isinstance(e, ConnectError) or isinstance(e, ExecutorServiceUnhealthy):
bulk_logger.warning("The batch run may have been canceled or encountered other issues.")
return BatchResult.create(
self._start_time, datetime.utcnow(), [], AggregationResult({}, {}, {}), status=Status.Canceled
)
elif isinstance(e, PromptflowException):
if isinstance(e, PromptflowException):
raise e
else:
# For unexpected error, we need to wrap it to SystemErrorException.
# This allows us to see the stack trace inside.
# for unexpected error, we need to wrap it to SystemErrorException to allow us to see the stack trace.
unexpected_error = UnexpectedError(
target=ErrorTarget.BATCH,
message_format=(
Expand All @@ -158,34 +152,60 @@ def run(
)
raise unexpected_error from e
finally:
# destroy executor proxy if the batch run is not cancelled
# TODO: add a lock to avoid destroy proxy twice
if not self._is_canceled:
self._executor_proxy.destroy()
self._executor_proxy.destroy()

def cancel(self):
"""Cancel the batch run"""
self._is_canceled = True
self._executor_proxy.destroy()

async def _exec_batch(
async def _exec_in_task(
self,
batch_inputs: List[Dict[str, Any]],
run_id: str = None,
output_dir: Path = None,
raise_on_line_failure: bool = False,
) -> BatchResult:
# apply default value in early stage, so we can use it both in line execution and aggregation nodes execution.
# if the batch run is canceled, asyncio.CancelledError will be raised and no results will be returned,
# so we pass empty line results list and aggr results and update them in _exec so that when the batch
# run is canceled we can get the current completed line results and aggr results.
line_results: List[LineResult] = []
aggr_result = AggregationResult({}, {}, {})
task = asyncio.create_task(
self._exec(line_results, aggr_result, batch_inputs, run_id, output_dir, raise_on_line_failure)
)
while not task.done():
# check whether the task is completed or canceled every 1s
await asyncio.sleep(1)
if self._is_canceled:
task.cancel()
# use current completed line results and aggregation results to create a BatchResult
return BatchResult.create(
self._start_time, datetime.utcnow(), line_results, aggr_result, status=Status.Canceled
)
return task.result()

async def _exec(
self,
line_results: List[LineResult],
aggr_result: AggregationResult,
batch_inputs: List[Dict[str, Any]],
run_id: str = None,
output_dir: Path = None,
raise_on_line_failure: bool = False,
) -> BatchResult:
await self._executor_proxy.ensure_executor_health()
# apply default value in early stage, so we can use it both in line and aggregation nodes execution.
batch_inputs = [
apply_default_value_for_input(self._flow.inputs, each_line_input) for each_line_input in batch_inputs
]
run_id = run_id or str(uuid.uuid4())

# execute lines
if isinstance(self._executor_proxy, PythonExecutorProxy):
line_results = self._executor_proxy._exec_batch(batch_inputs, output_dir, run_id)
line_results.extend(self._executor_proxy._exec_batch(batch_inputs, output_dir, run_id))
else:
line_results = await self._exec_batch_internal(batch_inputs, run_id)
await self._exec_batch(line_results, batch_inputs, run_id)
handle_line_failures([r.run_info for r in line_results], raise_on_line_failure)
aggr_results = await self._exec_aggregation_internal(batch_inputs, line_results, run_id)

# persist outputs to output dir
outputs = [
Expand All @@ -195,30 +215,52 @@ async def _exec_batch(
]
self._persist_outputs(outputs, output_dir)

# execute aggregation nodes
aggr_exec_result = await self._exec_aggregation(batch_inputs, line_results, run_id)
# use the execution result to update aggr_result to make sure we can get the aggr_result in _exec_in_task
self._update_aggr_result(aggr_result, aggr_exec_result)
# summary some infos from line results and aggr results to batch result
self._end_time = datetime.utcnow()
return BatchResult.create(self._start_time, self._end_time, line_results, aggr_results)
return BatchResult.create(self._start_time, datetime.utcnow(), line_results, aggr_result)

async def _exec_batch_internal(
async def _exec_batch(
self,
line_results: List[LineResult],
batch_inputs: List[Mapping[str, Any]],
run_id: Optional[str] = None,
) -> List[LineResult]:
line_results = []
semaphore = asyncio.Semaphore(DEFAULT_CONCURRENCY)
pending = [
asyncio.create_task(self._exec_line_under_semaphore(semaphore, line_inputs, i, run_id))
for i, line_inputs in enumerate(batch_inputs)
]

total_lines = len(batch_inputs)
# TODO: concurrent calls to exec_line instead of for loop
for i, each_line_input in enumerate(batch_inputs):
# TODO: catch line run failed to avoid one line break others
line_result = await self._executor_proxy.exec_line_async(each_line_input, i, run_id=run_id)
for node_run in line_result.node_run_infos.values():
self._storage.persist_node_run(node_run)
self._storage.persist_flow_run(line_result.run_info)
line_results.append(line_result)
# log the progress of the batch run
log_progress(self._start_time, bulk_logger, len(line_results), total_lines)
return line_results
completed_line = 0
while completed_line < total_lines:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
completed_line_results = [task.result() for task in done]
self._persist_run_info(completed_line_results)
line_results.extend(completed_line_results)
log_progress(
self._start_time,
bulk_logger,
len(line_results),
total_lines,
last_log_count=completed_line,
)
completed_line = len(line_results)

async def _exec_line_under_semaphore(
self,
semaphore,
inputs: Mapping[str, Any],
index: Optional[int] = None,
run_id: Optional[str] = None,
):
async with semaphore:
return await self._executor_proxy.exec_line_async(inputs, index, run_id)

async def _exec_aggregation_internal(
async def _exec_aggregation(
self,
batch_inputs: List[dict],
line_results: List[LineResult],
Expand Down Expand Up @@ -246,13 +288,18 @@ async def _exec_aggregation_internal(
)
succeeded_aggregation_inputs = collect_lines(succeeded, aggregation_inputs)
try:
aggr_results = await self._executor_proxy.exec_aggregation_async(
aggr_result = await self._executor_proxy.exec_aggregation_async(
succeeded_inputs, succeeded_aggregation_inputs, run_id
)
# if the flow language is python, we have already persisted node run infos during execution.
# so we should persist node run infos in aggr_result for other languages.
if not isinstance(self._executor_proxy, PythonExecutorProxy):
for node_run in aggr_result.node_run_infos.values():
self._storage.persist_node_run(node_run)
bulk_logger.info("Finish executing aggregation nodes.")
return aggr_results
return aggr_result
except PromptflowException as e:
# For PromptflowException, we already do classification, so throw directly.
# for PromptflowException, we already do classification, so throw directly.
raise e
except Exception as e:
error_type_and_message = f"({e.__class__.__name__}) {e}"
Expand All @@ -264,7 +311,20 @@ async def _exec_aggregation_internal(
error_type_and_message=error_type_and_message,
) from e

def _persist_run_info(self, line_results: List[LineResult]):
"""Persist node run infos and flow run info in line result to storage"""
for line_result in line_results:
for node_run in line_result.node_run_infos.values():
self._storage.persist_node_run(node_run)
self._storage.persist_flow_run(line_result.run_info)

def _persist_outputs(self, outputs: List[Mapping[str, Any]], output_dir: Path):
"""Persist outputs to json line file in output directory"""
output_file = output_dir / OUTPUT_FILE_NAME
dump_list_to_jsonl(output_file, outputs)

def _update_aggr_result(self, aggr_result: AggregationResult, aggr_exec_result: AggregationResult):
"""Update aggregation result with the aggregation execution result"""
aggr_result.metrics = aggr_exec_result.metrics
aggr_result.node_run_infos = aggr_exec_result.node_run_infos
aggr_result.output = aggr_exec_result.output
25 changes: 1 addition & 24 deletions src/promptflow/promptflow/batch/_csharp_executor_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
from pathlib import Path
from typing import Any, Mapping, Optional

from promptflow._constants import LINE_NUMBER_KEY
from promptflow._sdk._constants import DEFAULT_ENCODING, FLOW_TOOLS_JSON, PROMPT_FLOW_DIR_NAME
from promptflow.batch._base_executor_proxy import APIBasedExecutorProxy
from promptflow.executor._result import AggregationResult, LineResult
from promptflow.executor._result import AggregationResult
from promptflow.storage._run_storage import AbstractRunStorage

EXECUTOR_SERVICE_DOMAIN = "http://localhost:"
Expand Down Expand Up @@ -56,28 +55,6 @@ def create(
process = subprocess.Popen(command)
return cls(process, port)

async def exec_line_async(
self,
inputs: Mapping[str, Any],
index: Optional[int] = None,
run_id: Optional[str] = None,
) -> LineResult:
line_result = await super().exec_line_async(inputs, index, run_id)
# TODO: check if we should ask C# executor to keep unmatched inputs, although it's not so straightforward
# for executor service to do so.
# local_storage_operations.load_inputs_and_outputs now have an assumption that there is an extra
# line_number key in the inputs.
# This key will be appended to the inputs in below call stack:
# BatchEngine.run =>
# BatchInputsProcessor.process_batch_inputs =>
# ... =>
# BatchInputsProcessor._merge_input_dicts_by_line
# For python, it will be kept in the returned line_result.run_info.inputs
# For csharp, it will be dropped by executor service for now
# Append it here for now to make behavior consistent among ExecutorProxy.
line_result.run_info.inputs[LINE_NUMBER_KEY] = index
return line_result

def destroy(self):
"""Destroy the executor"""
if self._process and self._process.poll() is None:
Expand Down
Loading

0 comments on commit 42b5def

Please sign in to comment.