Skip to content

Commit

Permalink
[PromptFlow] Optimize the start process logic and add some tests (#624)
Browse files Browse the repository at this point in the history
# Description

- Optimize the start process, when the task_queue is empty, don't start
a new process.
- Add some tests.

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.
  • Loading branch information
Hhhilulu authored Oct 10, 2023
1 parent fb2e70b commit cd55ffa
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 9 deletions.
21 changes: 12 additions & 9 deletions src/promptflow/promptflow/executor/_line_execution_process_pool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import contextvars
import math
import time
import multiprocessing
import os
import queue
Expand Down Expand Up @@ -50,7 +49,7 @@ def __init__(self, executor_creation_func):
self.is_ready = False
self._executor_creation_func = executor_creation_func

def start_new(self):
def start_new(self, task_queue: Queue):
input_queue = Queue()
output_queue = Queue()
self.input_queue = input_queue
Expand Down Expand Up @@ -86,7 +85,9 @@ def start_new(self):
except queue.Empty:
logger.info(f"Process {process.pid} did not send ready message, exit.")
self.end()
self.start_new()
# If there are no more tasks, the process is not re-created
if not task_queue.empty():
self.start_new(task_queue)

def end(self):
# When process failed to start and the task_queue is empty.
Expand All @@ -112,7 +113,7 @@ def format_current_process(self, line_number: int, is_completed=False):
logger.info(
f"Process name: {process_name}, Process id: {process_pid}, Line number: {line_number} start execution.")

return f"Process name({process_name})-Process id({process_pid})"
return f"Process name({process_name})-Process id({process_pid})-Line number({line_number})"


class LineExecutionProcessPool:
Expand Down Expand Up @@ -183,12 +184,13 @@ def __exit__(self, exc_type, exc_val, exc_tb):

def _timeout_process_wrapper(self, task_queue: Queue, idx: int, timeout_time, result_list):
healthy_ensured_process = HealthyEnsuredProcess(self._executor_creation_func)
healthy_ensured_process.start_new()
healthy_ensured_process.start_new(task_queue)

if not healthy_ensured_process.process.is_alive():
return

while True:
try:
while not healthy_ensured_process.is_ready and not task_queue.empty():
time.sleep(1)
args = task_queue.get(timeout=1)
except queue.Empty:
logger.info(f"Process {idx} queue empty, exit.")
Expand Down Expand Up @@ -228,8 +230,9 @@ def _timeout_process_wrapper(self, task_queue: Queue, idx: int, timeout_time, re
)
result_list.append(result)
self._completed_idx[line_number] = healthy_ensured_process.format_current_process(line_number, True)
healthy_ensured_process.end()
healthy_ensured_process.start_new()
if not task_queue.empty():
healthy_ensured_process.end()
healthy_ensured_process.start_new(task_queue)

self._processing_idx.pop(line_number)
log_progress(
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pytest

from multiprocessing import Queue
from promptflow.executor._line_execution_process_pool import HealthyEnsuredProcess

from unittest.mock import patch
import time


def executor_creation_func(storage):
pass


def executor_creation_func_timeout(storage):
time.sleep(60)
pass


def end_process(healthy_ensured_process):
while healthy_ensured_process.process.is_alive():
healthy_ensured_process.end()
time.sleep(1)
return


@pytest.mark.unittest
class TestHealthyEnsuredProcess:

def test_healthy_ensured_process(self):
healthy_ensured_process = HealthyEnsuredProcess(executor_creation_func)
assert healthy_ensured_process.is_ready is False
task_queue = Queue()
healthy_ensured_process.start_new(task_queue)
assert healthy_ensured_process.process.is_alive()
assert healthy_ensured_process.is_ready is True
end_process(healthy_ensured_process)
assert healthy_ensured_process.process.is_alive() is False

def test_unhealthy_process(self):
healthy_ensured_process = HealthyEnsuredProcess(executor_creation_func_timeout)
assert healthy_ensured_process.is_ready is False
task_queue = Queue()
healthy_ensured_process.start_new(task_queue)
assert healthy_ensured_process.process.is_alive() is True
assert healthy_ensured_process.is_ready is False
end_process(healthy_ensured_process)
assert healthy_ensured_process.process.is_alive() is False

def test_format_current_process(self):
healthy_ensured_process = HealthyEnsuredProcess(executor_creation_func)
healthy_ensured_process.process = patch(
'promptflow.executor._line_execution_process_pool.Process', autospec=True)
healthy_ensured_process.process.name = "process_name"
healthy_ensured_process.process.pid = 123
line_number = 13
formatted_message = healthy_ensured_process.format_current_process(line_number)
process_name = healthy_ensured_process.process.name
process_pid = healthy_ensured_process.process.pid
expected_log_message = (
f"Process name({process_name})-Process id({process_pid})-Line number({line_number})"
)
assert formatted_message == expected_log_message

@patch('promptflow.executor._line_execution_process_pool.logger.info', autospec=True)
def test_format_completed_process(self, mock_logger_info):
healthy_ensured_process = HealthyEnsuredProcess(executor_creation_func)
healthy_ensured_process.process = patch(
'promptflow.executor._line_execution_process_pool.Process', autospec=True)
healthy_ensured_process.process.name = "process_name"
healthy_ensured_process.process.pid = 123
line_number = 13
mock_logger_info.reset_mock()
healthy_ensured_process.format_current_process(line_number, True)
process_name = healthy_ensured_process.process.name
process_pid = healthy_ensured_process.process.pid
exexpected_log_message = (
f"Process name: {process_name}, Process id: {process_pid}, Line number: {line_number} completed."
)
mock_logger_info.assert_called_once_with(exexpected_log_message)
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest

from multiprocessing import Queue
from promptflow.executor._line_execution_process_pool import QueueRunStorage
from promptflow.contracts.run_info import FlowRunInfo
from promptflow.contracts.run_info import RunInfo as NodeRunInfo


@pytest.mark.unittest
class TestLineExecutionProcessPool:
def test_persist_node_run(self):
queue = Queue()
run_storage = QueueRunStorage(queue)
node_run_info = NodeRunInfo(
node="node1",
flow_run_id="flow_run_id",
run_id="run_id",
status="status",
inputs="inputs",
output="output",
metrics="metrics",
error="error",
parent_run_id="parent_run_id",
start_time="start_time",
end_time="end_time",
index="index",
api_calls="api_calls",
variant_id="variant_id",
cached_run_id="cached_run_id",
cached_flow_run_id="cached_flow_run_id",
logs="logs",
system_metrics="system_metrics",
result="result",
)
run_storage.persist_node_run(node_run_info)
assert queue.get() == node_run_info

def test_persist_flow_run(self):
queue = Queue()
run_storage = QueueRunStorage(queue)
flow_run_info = FlowRunInfo(
run_id="run_id",
status="status",
inputs="inputs",
output="output",
metrics="metrics",
request="request",
root_run_id="root_run_id",
source_run_id="source_run_id",
flow_id="flow_id",
error="error",
parent_run_id="parent_run_id",
start_time="start_time",
end_time="end_time",
index="index",
api_calls="api_calls",
variant_id="variant_id",
system_metrics="system_metrics",
result="result",
)
run_storage.persist_flow_run(flow_run_info)
assert queue.get() == flow_run_info

0 comments on commit cd55ffa

Please sign in to comment.