Skip to content

Commit

Permalink
[Executor][Bugfix] Add skip/activate condition to node_dependencies i…
Browse files Browse the repository at this point in the history
…n _is_node_ready (#312)

# Description

### Bug impact:
When a node has no inputs but has activate config, executor will raise
`InputNotFoundFromAncestorNodeOutput `.

### Bug root cause:
The skip/activate conditions of the node are not added to node
dependencies when judging wether a node is ready.

### Bug fix:
Add skip/activate conditions to node_dependencies in `_is_node_ready`

# All Promptflow Contribution checklist:
- [x] **The pull request does not introduce [breaking changes]**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).**

## General Guidelines and Best Practices
- [x] Title of the pull request is clear and informative.
- [x] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [x] Pull request includes test coverage for the included changes.
  • Loading branch information
PeiwenGaoMS authored Sep 6, 2023
1 parent c1b3cb1 commit cc73084
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 14 deletions.
15 changes: 15 additions & 0 deletions src/promptflow/promptflow/executor/_dag_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from promptflow.contracts.flow import InputAssignment, InputValueType, Node
from promptflow.executor import _input_assignment_parser
from promptflow.executor._errors import ReferenceNodeBypassed


class DAGManager:
Expand Down Expand Up @@ -73,6 +74,12 @@ def completed(self) -> bool:
def _is_node_ready(self, node: Node) -> bool:
"""Returns True if the node is ready to be executed."""
node_dependencies = [i for i in node.inputs.values()]
# Add skip and activate conditions as node dependencies
if node.skip:
node_dependencies.extend([node.skip.condition, node.skip.return_value])
if node.activate:
node_dependencies.append(node.activate.condition)

for node_dependency in node_dependencies:
if (
node_dependency.value_type == InputValueType.NODE_REFERENCE
Expand All @@ -86,6 +93,14 @@ def _is_node_bypassable(self, node: Node) -> bool:
"""Returns True if the node should be bypassed."""
# Bypass node if the skip condition is met
if self._is_skip_condition_met(node):
if self._is_node_dependency_bypassed(node.skip.return_value):
raise ReferenceNodeBypassed(
message_format="The node {reference_node_name} referenced by {node_name} has been bypassed, "
"so the value of this node cannot be returned. Please refer to the node that "
"will not be bypassed as the default return value.",
reference_node_name=node.skip.return_value.value,
node_name=node.name,
)
skip_return = self._get_node_dependency_value(node.skip.return_value)
# This is not a good practice, but we need to update the default output of bypassed node
# to completed_nodes_outputs. We will remove these after skip config is deprecated.
Expand Down
13 changes: 10 additions & 3 deletions src/promptflow/promptflow/executor/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,12 @@ class NodeConcurrencyNotFound(SystemErrorException):
class NodeReferenceError(UserErrorException):
"""Exception raised when node reference not found or unsupported"""

def __init__(self, message, target=ErrorTarget.FLOW_EXECUTOR):
msg = f"Invalid node reference: {message}"
super().__init__(message=msg, target=target)
def __init__(self, message="", message_format="", target=ErrorTarget.FLOW_EXECUTOR, **kwargs):
if message:
message = f"Invalid node reference: {message}"
elif message_format:
message_format = f"Invalid node reference: {message_format}"
super().__init__(message=message, message_format=message_format, target=target, **kwargs)


class UnsupportedReference(NodeReferenceError):
Expand All @@ -149,6 +152,10 @@ class InvalidReferenceProperty(NodeReferenceError):
pass


class ReferenceNodeBypassed(NodeReferenceError):
pass


class LineExecutionTimeoutError(UserErrorException):
"""Exception raised when single line execution timeout"""

Expand Down
22 changes: 11 additions & 11 deletions src/promptflow/tests/executor/e2etests/test_activate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,20 @@
get_yaml_file,
)

ACTIVATE_FLOW_TEST_CASES = ["conditional_flow_with_activate", "activate_with_no_inputs"]


@pytest.mark.usefixtures("dev_connections")
@pytest.mark.e2etest
class TestExecutorActivate:
def test_flow_run_activate(self, dev_connections):
flow_folder = "conditional_flow_with_activate"
@pytest.mark.parametrize("flow_folder", ACTIVATE_FLOW_TEST_CASES)
def test_flow_run_activate(self, dev_connections, flow_folder):
executor = FlowExecutor.create(get_yaml_file(flow_folder), dev_connections)
results = executor.exec_line(get_flow_inputs(flow_folder))
# Assert the flow result
expected_result = get_flow_expected_result(flow_folder)
expected_outputs = expected_result[0]["expected_outputs"]
expected_bypassed_nodes = expected_result[0]["expected_bypassed_nodes"]
self.assert_activate_flow_run_result(results, expected_outputs, expected_bypassed_nodes)
expected_result = expected_result[0] if isinstance(expected_result, list) else get_flow_expected_result
self.assert_activate_flow_run_result(results, expected_result)

def test_bulk_run_activate(self, dev_connections):
flow_folder = "conditional_flow_with_activate"
Expand All @@ -52,24 +53,23 @@ def assert_activate_bulk_run_result(self, result: BulkResult, expected_result, e

# Validate the flow line results
for i, line_result in enumerate(result.line_results):
expected_outputs = expected_result[i]["expected_outputs"]
expected_bypassed_nodes = expected_result[i]["expected_bypassed_nodes"]
self.assert_activate_flow_run_result(line_result, expected_outputs, expected_bypassed_nodes)
self.assert_activate_flow_run_result(line_result, expected_result[i])

# Validate the flow status summary
status_summary = result.get_status_summary()
assert status_summary == expected_status_summary

def assert_activate_flow_run_result(self, result: LineResult, expected_outputs, expected_bypassed_nodes):
def assert_activate_flow_run_result(self, result: LineResult, expected_result):
# Validate the flow status
assert result.run_info.status == Status.Completed

# Validate the flow output
assert isinstance(result.output, dict)
assert result.output == expected_outputs
assert result.output == expected_result["expected_outputs"]

# Validate the flow node run infos for the completed nodes
assert len(result.node_run_infos) == 9
assert len(result.node_run_infos) == expected_result["expected_node_count"]
expected_bypassed_nodes = expected_result["expected_bypassed_nodes"]
completed_nodes_run_infos = [
run_info for i, run_info in result.node_run_infos.items() if i not in expected_bypassed_nodes
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from promptflow.contracts.flow import ActivateCondition, InputAssignment, Node, SkipCondition
from promptflow.executor._dag_manager import DAGManager
from promptflow.executor._errors import ReferenceNodeBypassed


def create_test_node(name, input, skip=None, activate=None):
Expand Down Expand Up @@ -57,6 +58,29 @@ def test_pop_bypassed_nodes(self):
assert pop_bypassed_node_names(dag_manager) == expected_bypassed_nodes
assert dag_manager.bypassed_nodes.keys() == expected_bypassed_nodes

def test_pop_bypassed_nodes_with_exception(self):
nodes = [
create_test_node("node1", input="${inputs.text}", activate={"when": "${inputs.text}", "is": "hello"}),
create_test_node("node2", input="${inputs.text}", activate={"when": "${inputs.text}", "is": "world"}),
create_test_node(
"node3",
input="${inputs.text}",
skip={"when": "${node1.output}", "is": "hello", "return": "${node2.output}"},
),
]
flow_inputs = {"text": "hello"}
dag_manager = DAGManager(nodes, flow_inputs)
assert pop_bypassed_node_names(dag_manager) == {"node2"}
dag_manager.complete_nodes({"node1": "hello"})
with pytest.raises(ReferenceNodeBypassed) as e:
dag_manager.pop_bypassable_nodes()
error_message = (
"Invalid node reference: The node node2 referenced by node3 has been bypassed, "
"so the value of this node cannot be returned. Please refer to the node that will "
"not be bypassed as the default return value."
)
assert str(e.value) == error_message, "Expected: {}, Actual: {}".format(error_message, str(e.value))

def test_complete_nodes(self):
nodes = [create_test_node("node1", input="value1")]
dag_manager = DAGManager(nodes, flow_inputs={})
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[
{
"expected_node_count": 2,
"expected_outputs":{
"text": "hello world"
},
"expected_bypassed_nodes":[]
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
inputs:
text:
type: string
outputs:
text:
type: string
reference: ${node_a.output}
nodes:
- name: node_a
type: python
source:
type: code
path: node_a.py
inputs:
input1: ${inputs.text}
- name: node_b
type: python
source:
type: code
path: node_b.py
inputs: {}
activate:
when: ${node_a.output}
is: hello world
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"text": "world"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from promptflow import tool


@tool
def my_python_tool(input1: str) -> str:
return 'hello ' + input1
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from promptflow import tool


@tool
def my_python_tool():
print("Avtivate")
return 'Executing...'
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[
{
"expected_node_count": 9,
"expected_outputs":{
"investigation_method": {
"first": "Skip job info extractor",
Expand All @@ -9,6 +10,7 @@
"expected_bypassed_nodes":["job_info_extractor", "icm_retriever"]
},
{
"expected_node_count": 9,
"expected_outputs":{
"investigation_method": {
"first": "Execute job info extractor",
Expand All @@ -18,6 +20,7 @@
"expected_bypassed_nodes":["incident_info_extractor", "icm_retriever", "kql_tsg_retriever", "tsg_retriever", "investigation_steps", "retriever_summary"]
},
{
"expected_node_count": 9,
"expected_outputs":{
"investigation_method": {
"first": "Skip job info extractor",
Expand Down

0 comments on commit cc73084

Please sign in to comment.