Skip to content

Commit

Permalink
[Executor] Provide exec_aggregation_async for script executor (#2945)
Browse files Browse the repository at this point in the history
# Description

Provide exec_aggregation_async for script executor.

# All Promptflow Contribution checklist:
- [x] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [x] Title of the pull request is clear and informative.
- [x] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.

---------

Co-authored-by: Lina Tang <[email protected]>
  • Loading branch information
lumoslnt and Lina Tang authored Apr 25, 2024
1 parent 9cb7869 commit 97f059e
Showing 1 changed file with 42 additions and 5 deletions.
47 changes: 42 additions & 5 deletions src/promptflow-core/promptflow/executor/_script_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from promptflow.executor._result import AggregationResult, LineResult
from promptflow.storage import AbstractRunStorage
from promptflow.storage._run_storage import DefaultRunStorage
from promptflow.tracing import ThreadPoolExecutorWithContext
from promptflow.tracing._trace import _traced
from promptflow.tracing._tracer import Tracer
from promptflow.tracing.contracts.trace import TraceType
Expand Down Expand Up @@ -193,10 +194,31 @@ def _exec_aggregation(
) -> AggregationResult:
output, metrics = None, {}
try:
if inspect.iscoroutinefunction(self._aggr_func):
output = async_run_allowing_running_loop(self._aggr_func, **{self._aggr_input_name: inputs})
else:
output = self._aggr_func(**{self._aggr_input_name: inputs})
output = self._aggr_func(**{self._aggr_input_name: inputs})
metrics = output if isinstance(output, dict) else {"metrics": output}
for k, v in metrics.items():
log_metric(k, v)
except Exception:
pass
return AggregationResult(output, metrics, {})

async def exec_aggregation_async(
self,
inputs: Mapping[str, Any],
aggregation_inputs: List[Any],
run_id: Optional[str] = None,
):
if not self._aggr_func:
return AggregationResult({}, {}, {})
# Similar to dag flow, add a prefix "reduce" for run id of aggregation function.
run_id = f"{run_id}_reduce" if run_id is not None else f"{str(uuid.uuid4())}_reduce"
with self._update_operation_context_for_aggregation(run_id):
return await self._exec_aggregation_async(aggregation_inputs)

async def _exec_aggregation_async(self, inputs):
output = None
try:
output = await self._aggr_func_async(**{self._aggr_input_name: inputs})
metrics = output if isinstance(output, dict) else {"metrics": output}
for k, v in metrics.items():
log_metric(k, v)
Expand Down Expand Up @@ -430,7 +452,22 @@ def _initialize_aggr_function(self, flow_obj: object):
)
if not hasattr(aggr_func, "__original_function"):
aggr_func = _traced(aggr_func)
self._aggr_func = aggr_func
if inspect.iscoroutinefunction(aggr_func):

def run_async_function_sync(*args, **kwargs):
return async_run_allowing_running_loop(aggr_func, *args, **kwargs)

self._aggr_func = run_async_function_sync
self._aggr_func_async = aggr_func
else:

async def run_sync_function_async(*args, **kwargs):
with ThreadPoolExecutorWithContext() as executor:
partial_func = partial(aggr_func, *args, **kwargs)
return await asyncio.get_event_loop().run_in_executor(executor, partial_func)

self._aggr_func = aggr_func
self._aggr_func_async = run_sync_function_async
self._aggr_input_name = list(sign.parameters.keys())[0]

def _parse_flow_file(self):
Expand Down

0 comments on commit 97f059e

Please sign in to comment.