Skip to content

Commit

Permalink
[SDK][Internal] Support command node in orchstrator (#1855)
Browse files Browse the repository at this point in the history
# Description

Please add an informative description that covers that changes made by
the pull request and link all relevant issues.

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.

---------

Signed-off-by: Brynn Yin <[email protected]>
  • Loading branch information
brynn-code authored Jan 26, 2024
1 parent 2035e82 commit 22aa805
Show file tree
Hide file tree
Showing 14 changed files with 470 additions and 72 deletions.
6 changes: 5 additions & 1 deletion src/promptflow/promptflow/_sdk/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class RunTypes:
BATCH = "batch"
EVALUATION = "evaluation"
PAIRWISE_EVALUATE = "pairwise_evaluate"
COMMAND = "command"


class AzureRunTypes:
Expand Down Expand Up @@ -241,6 +242,9 @@ class FlowRunProperties:
NODE_VARIANT = "node_variant"
RUN = "run"
SYSTEM_METRICS = "system_metrics"
# Experiment command node fields only
COMMAND = "command"
OUTPUTS = "outputs"


class CommonYamlFields:
Expand Down Expand Up @@ -398,7 +402,7 @@ class DownloadedRun:

class ExperimentNodeType(object):
FLOW = "flow"
CODE = "code"
COMMAND = "command"


class ExperimentStatus(object):
Expand Down
6 changes: 6 additions & 0 deletions src/promptflow/promptflow/_sdk/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,9 @@ class DownloadInternalError(SDKInternalError):
"""Exception raised if download internal error."""

pass


class ExperimentCommandRunError(SDKError):
"""Exception raised if experiment validation failed."""

pass
3 changes: 3 additions & 0 deletions src/promptflow/promptflow/_sdk/_submitter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from .run_submitter import RunSubmitter
from .test_submitter import TestSubmitter
from .utils import (
Expand Down
297 changes: 253 additions & 44 deletions src/promptflow/promptflow/_sdk/_submitter/experiment_orchestrator.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/promptflow/promptflow/_sdk/_submitter/run_submitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def _validate_inputs(cls, run: Run):
def _submit_bulk_run(
self, flow: Union[ProtectedFlow, EagerFlow], run: Run, local_storage: LocalStorageOperations
) -> dict:
logger.info(f"Submitting run {run.name}, reach logs at {local_storage.logger.file_path}.")
run_id = run.name
if flow.language == FlowLanguage.CSharp:
connections = []
Expand Down
4 changes: 4 additions & 0 deletions src/promptflow/promptflow/_sdk/_submitter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,11 @@
from promptflow._sdk.entities._flow import Flow, ProtectedFlow
from promptflow._utils.context_utils import _change_working_dir
from promptflow._utils.flow_utils import dump_flow_dag, load_flow_dag
from promptflow._utils.logger_utils import get_cli_sdk_logger
from promptflow.contracts.flow import Flow as ExecutableFlow

logger = get_cli_sdk_logger()


def overwrite_variant(flow_dag: dict, tuning_node: str = None, variant: str = None, drop_node_variants: bool = False):
# need to overwrite default variant if tuning node and variant not specified.
Expand Down Expand Up @@ -263,6 +266,7 @@ def resolve_environment_variables(cls, environment_variables: dict, client=None)
if not environment_variables:
return None
connection_names = get_used_connection_names_from_dict(environment_variables)
logger.debug("Used connection names: %s", connection_names)
connections = cls.resolve_connection_names(connection_names=connection_names, client=client)
update_dict_value_with_connections(built_connections=connections, connection_dict=environment_variables)

Expand Down
54 changes: 41 additions & 13 deletions src/promptflow/promptflow/_sdk/entities/_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
from promptflow._sdk.entities._validation import MutableValidationResult, SchemaValidatableMixin
from promptflow._sdk.entities._yaml_translatable import YAMLTranslatableMixin
from promptflow._sdk.schemas._experiment import (
CommandNodeSchema,
ExperimentDataSchema,
ExperimentInputSchema,
ExperimentSchema,
ExperimentTemplateSchema,
FlowNodeSchema,
ScriptNodeSchema,
)
from promptflow._utils.logger_utils import get_cli_sdk_logger
from promptflow.contracts.tool import ValueType
Expand Down Expand Up @@ -114,7 +114,6 @@ def __init__(
self.environment_variables = environment_variables or {}
self.connections = connections or {}
self._properties = properties or {}
self._creation_context = kwargs.get("creation_context", None)
# init here to make sure those fields initialized in all branches.
self.path = path
# default run name: flow directory name + timestamp
Expand All @@ -141,23 +140,52 @@ def _save_snapshot(self, target):
self.path = saved_flow_path.resolve().absolute().as_posix()


class ScriptNode(YAMLTranslatableMixin):
def __init__(self, source, inputs, name, display_name=None, runtime=None, environment_variables=None, **kwargs):
self.type = ExperimentNodeType.CODE
self.display_name = display_name
class CommandNode(YAMLTranslatableMixin):
def __init__(
self,
command,
name,
inputs=None,
outputs=None,
runtime=None,
environment_variables=None,
code=None,
display_name=None,
**kwargs,
):
self.type = ExperimentNodeType.COMMAND
self.name = name
self.source = source
self.inputs = inputs
self.display_name = display_name
self.code = code
self.command = command
self.inputs = inputs or {}
self.outputs = outputs or {}
self.runtime = runtime
self.environment_variables = environment_variables or {}

@classmethod
def _get_schema_cls(cls):
return ScriptNodeSchema
return CommandNodeSchema

def _save_snapshot(self, target):
# Do nothing for script node for now
pass
"""Save command source to experiment snapshot."""
Path(target).mkdir(parents=True, exist_ok=True)
saved_path = Path(target) / self.name
if not self.code:
# Create an empty folder
saved_path.mkdir(parents=True, exist_ok=True)
self.code = saved_path.resolve().absolute().as_posix()
return
code = Path(self.code)
if not code.exists():
raise ExperimentValueError(f"Command node code {code} does not exist.")
if code.is_dir():
shutil.copytree(src=self.code, dst=saved_path)
else:
saved_path.mkdir(parents=True, exist_ok=True)
shutil.copy(src=self.code, dst=saved_path)
logger.debug(f"Command node source saved to {saved_path}.")
self.code = saved_path.resolve().absolute().as_posix()


class ExperimentTemplate(YAMLTranslatableMixin, SchemaValidatableMixin):
Expand Down Expand Up @@ -334,9 +362,9 @@ def _from_orm_object(cls, obj: ORMExperiment) -> "Experiment":
nodes.append(
FlowNode._load_from_dict(node_dict, context=context, additional_message="Failed to load node.")
)
elif node_dict["type"] == ExperimentNodeType.CODE:
elif node_dict["type"] == ExperimentNodeType.COMMAND:
nodes.append(
ScriptNode._load_from_dict(node_dict, context=context, additional_message="Failed to load node.")
CommandNode._load_from_dict(node_dict, context=context, additional_message="Failed to load node.")
)
else:
raise Exception(f"Unknown node type {node_dict['type']}")
Expand Down
11 changes: 10 additions & 1 deletion src/promptflow/promptflow/_sdk/entities/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(
**kwargs,
):
# TODO: remove when RUN CRUD don't depend on this
self.type = RunTypes.BATCH
self.type = kwargs.get("type", RunTypes.BATCH)
self.data = data
self.column_mapping = column_mapping
self.display_name = display_name
Expand Down Expand Up @@ -181,6 +181,8 @@ def __init__(
self._output_path = Path(source)
self._runtime = kwargs.get("runtime", None)
self._resources = kwargs.get("resources", None)
self._outputs = kwargs.get("outputs", None)
self._command = kwargs.get("command", None)

@property
def created_on(self) -> str:
Expand All @@ -204,6 +206,10 @@ def properties(self) -> Dict[str, str]:
result[FlowRunProperties.RUN] = run_name
if self.variant:
result[FlowRunProperties.NODE_VARIANT] = self.variant
if self._command:
result[FlowRunProperties.COMMAND] = self._command
if self._outputs:
result[FlowRunProperties.OUTPUTS] = self._outputs
elif self._run_source == RunInfoSources.EXISTING_RUN:
result = {
FlowRunProperties.OUTPUT_PATH: Path(self.source).resolve().as_posix(),
Expand Down Expand Up @@ -245,6 +251,9 @@ def _from_orm_object(cls, obj: ORMRun) -> "Run":
properties={FlowRunProperties.SYSTEM_METRICS: properties_json.get(FlowRunProperties.SYSTEM_METRICS, {})},
# compatible with old runs, their run_source is empty, treat them as local
run_source=obj.run_source or RunInfoSources.LOCAL,
# experiment command node only fields
command=properties_json.get(FlowRunProperties.COMMAND, None),
outputs=properties_json.get(FlowRunProperties.OUTPUTS, None),
)

@classmethod
Expand Down
29 changes: 19 additions & 10 deletions src/promptflow/promptflow/_sdk/schemas/_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@
from promptflow._sdk.schemas._run import RunSchema


class ScriptNodeSchema(metaclass=PatchedSchemaMeta):
class CommandNodeSchema(YamlFileSchema):
# TODO: Not finalized now. Need to revisit.
name = fields.Str(required=True)
type = StringTransformedEnum(allowed_values=ExperimentNodeType.CODE, required=True)
path = UnionField([LocalPathField(required=True), fields.Str(required=True)])
display_name = fields.Str()
type = StringTransformedEnum(allowed_values=ExperimentNodeType.COMMAND, required=True)
code = LocalPathField(default=".")
command = fields.Str(required=True)
inputs = fields.Dict(keys=fields.Str)
outputs = fields.Dict(keys=fields.Str, values=LocalPathField(allow_none=True))
environment_variables = fields.Dict(keys=fields.Str, values=fields.Str)
# runtime field, only available for cloud run
runtime = fields.Str() # TODO: Revisit the required fields
display_name = fields.Str()
environment_variables = fields.Dict(keys=fields.Str, values=fields.Str)


class FlowNodeSchema(RunSchema):
Expand Down Expand Up @@ -58,24 +60,31 @@ class ExperimentTemplateSchema(YamlFileSchema):
description = fields.Str()
data = fields.List(NestedField(ExperimentDataSchema)) # Optional
inputs = fields.List(NestedField(ExperimentInputSchema)) # Optional
nodes = fields.List(UnionField([NestedField(FlowNodeSchema), NestedField(ScriptNodeSchema)]), required=True)
nodes = fields.List(
UnionField(
[
NestedField(CommandNodeSchema),
NestedField(FlowNodeSchema),
]
),
required=True,
)

@post_load
def resolve_nodes(self, data, **kwargs):
from promptflow._sdk.entities._experiment import FlowNode, ScriptNode
from promptflow._sdk.entities._experiment import CommandNode, FlowNode

nodes = data.get("nodes", [])

resolved_nodes = []
for node in nodes:
if not isinstance(node, dict):
continue
node_type = node.get("type", None)
if node_type == ExperimentNodeType.FLOW:
resolved_nodes.append(FlowNode._load_from_dict(data=node, context=self.context, additional_message=""))
elif node_type == ExperimentNodeType.CODE:
elif node_type == ExperimentNodeType.COMMAND:
resolved_nodes.append(
ScriptNode._load_from_dict(data=node, context=self.context, additional_message="")
CommandNode._load_from_dict(data=node, context=self.context, additional_message="")
)
else:
raise ValueError(f"Unknown node type {node_type} for node {node}.")
Expand Down
5 changes: 5 additions & 0 deletions src/promptflow/promptflow/_sdk/schemas/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ class RunSchema(YamlFileSchema):
connections = fields.Dict(keys=fields.Str(), values=fields.Dict(keys=fields.Str()))
# endregion: context

# region: command node
command = fields.Str(dump_only=True)
outputs = fields.Dict(key=fields.Str(), dump_only=True)
# endregion: command node

@post_load
def resolve_dot_env_file(self, data, **kwargs):
return _resolve_dot_env_file(data, **kwargs)
Expand Down
6 changes: 3 additions & 3 deletions src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1939,7 +1939,7 @@ def test_experiment_start(self, monkeypatch, capfd, local_client):
"experiment",
"create",
"--template",
f"{EXPERIMENT_DIR}/basic-no-script-template/basic.exp.yaml",
f"{EXPERIMENT_DIR}/basic-script-template/basic-script.exp.yaml",
"--name",
exp_name,
)
Expand All @@ -1956,8 +1956,8 @@ def test_experiment_start(self, monkeypatch, capfd, local_client):
out, _ = capfd.readouterr()
assert ExperimentStatus.TERMINATED in out
exp = local_client._experiments.get(name=exp_name)
assert len(exp.node_runs["main"]) > 0
assert len(exp.node_runs["eval"]) > 0
assert len(exp.node_runs) == 4
assert all(len(exp.node_runs[node_name]) > 0 for node_name in exp.node_runs)
metrics = local_client.runs.get_metrics(name=exp.node_runs["eval"][0]["name"])
assert "accuracy" in metrics

Expand Down
45 changes: 45 additions & 0 deletions src/promptflow/tests/sdk_cli_test/e2etests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from promptflow._sdk._constants import ExperimentStatus, RunStatus
from promptflow._sdk._load_functions import load_common
from promptflow._sdk.entities._experiment import (
CommandNode,
Experiment,
ExperimentData,
ExperimentInput,
Expand Down Expand Up @@ -51,6 +52,36 @@ def test_experiment_from_template(self):
assert experiment_dict["nodes"][1].items() == expected["nodes"][1].items()
assert experiment_dict.items() >= expected.items()

def test_experiment_from_template_with_script_node(self):
template_path = EXP_ROOT / "basic-script-template" / "basic-script.exp.yaml"
# Load template and create experiment
template = load_common(ExperimentTemplate, source=template_path)
experiment = Experiment.from_template(template)
# Assert command node load correctly
assert len(experiment.nodes) == 4
expected = dict(yaml.load(open(template_path, "r", encoding="utf-8").read()))
experiment_dict = experiment._to_dict()
assert isinstance(experiment.nodes[0], CommandNode)
assert isinstance(experiment.nodes[1], FlowNode)
assert isinstance(experiment.nodes[2], FlowNode)
assert isinstance(experiment.nodes[3], CommandNode)
gen_data_snapshot_path = experiment._output_dir / "snapshots" / "gen_data"
echo_snapshot_path = experiment._output_dir / "snapshots" / "echo"
expected["nodes"][0]["code"] = gen_data_snapshot_path.absolute().as_posix()
expected["nodes"][3]["code"] = echo_snapshot_path.absolute().as_posix()
expected["nodes"][3]["environment_variables"] = {}
assert experiment_dict["nodes"][0].items() == expected["nodes"][0].items()
assert experiment_dict["nodes"][3].items() == expected["nodes"][3].items()
# Assert snapshots
assert gen_data_snapshot_path.exists()
file_count = len(list(gen_data_snapshot_path.rglob("*")))
assert file_count == 1
assert (gen_data_snapshot_path / "generate_data.py").exists()
# Assert no file exists in echo path
assert echo_snapshot_path.exists()
file_count = len(list(echo_snapshot_path.rglob("*")))
assert file_count == 0

def test_experiment_create_and_get(self):
template_path = EXP_ROOT / "basic-no-script-template" / "basic.exp.yaml"
# Load template and create experiment
Expand Down Expand Up @@ -85,3 +116,17 @@ def test_experiment_start(self):
assert eval_run.display_name == "eval"
metrics = client.runs.get_metrics(name=eval_run.name)
assert "accuracy" in metrics

@pytest.mark.usefixtures("use_secrets_config_file", "recording_injection", "setup_local_connection")
def test_experiment_with_script_start(self):
template_path = EXP_ROOT / "basic-script-template" / "basic-script.exp.yaml"
# Load template and create experiment
template = load_common(ExperimentTemplate, source=template_path)
experiment = Experiment.from_template(template)
client = PFClient()
exp = client._experiments.create_or_update(experiment)
exp = client._experiments.start(exp.name)
assert exp.status == ExperimentStatus.TERMINATED
assert len(exp.node_runs) == 4
for key, val in exp.node_runs.items():
assert val[0]["status"] == RunStatus.COMPLETED, f"Node {key} run failed"
Loading

0 comments on commit 22aa805

Please sign in to comment.