Skip to content

Commit

Permalink
[SDK] Add orchestrator inputs, data entity (#1721)
Browse files Browse the repository at this point in the history
# Description

Please add an informative description that covers that changes made by
the pull request and link all relevant issues.

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.

---------

Signed-off-by: Brynn Yin <[email protected]>
  • Loading branch information
brynn-code authored Jan 11, 2024
1 parent 581db99 commit af12168
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 33 deletions.
8 changes: 4 additions & 4 deletions src/promptflow/promptflow/_cli/_pf/_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,24 +163,24 @@ def create_experiment(args: argparse.Namespace):
logger.debug("Creating experiment from template %s", template.name)
experiment = Experiment.from_template(template)
logger.debug("Creating experiment %s", experiment.name)
exp = _get_pf_client().experiments.create_or_update(experiment)
exp = _get_pf_client()._experiments.create_or_update(experiment)
print(json.dumps(exp._to_dict(), indent=4))


@exception_handler("List experiment")
def list_experiment(args: argparse.Namespace):
list_view_type = get_list_view_type(archived_only=args.archived_only, include_archived=args.include_archived)
results = _get_pf_client().experiments.list(args.max_results, list_view_type=list_view_type)
results = _get_pf_client()._experiments.list(args.max_results, list_view_type=list_view_type)
print(json.dumps([result._to_dict() for result in results], indent=4))


@exception_handler("Show experiment")
def show_experiment(args: argparse.Namespace):
result = _get_pf_client().experiments.get(args.name)
result = _get_pf_client()._experiments.get(args.name)
print(json.dumps(result._to_dict(), indent=4))


@exception_handler("Start experiment")
def start_experiment(args: argparse.Namespace):
result = _get_pf_client().experiments.start(args.name)
result = _get_pf_client()._experiments.start(args.name)
print(json.dumps(result._to_dict(), indent=4))
18 changes: 18 additions & 0 deletions src/promptflow/promptflow/_sdk/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,21 @@ class ExperimentNotFoundError(SDKError):
"""Exception raised if experiment cannot be found."""

pass


class ExperimentValidationError(SDKError):
"""Exception raised if experiment validation failed."""

pass


class ExperimentValueError(SDKError):
"""Exception raised if experiment validation failed."""

pass


class ExperimentHasCycle(SDKError):
"""Exception raised if experiment validation failed."""

pass
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

from promptflow._sdk._configuration import Configuration
from promptflow._sdk._constants import ExperimentNodeType, ExperimentStatus
from promptflow._sdk._errors import ExperimentHasCycle, ExperimentValueError
from promptflow._sdk._submitter import RunSubmitter
from promptflow._sdk.entities import Run
from promptflow._sdk.entities._experiment import Experiment
from promptflow._sdk.operations import RunOperations
from promptflow._sdk.operations._experiment_operations import ExperimentOperations
from promptflow._utils.logger_utils import LoggerFactory
from promptflow.exceptions import UserErrorException

logger = LoggerFactory.get_logger(name=__name__)

Expand Down Expand Up @@ -44,12 +44,11 @@ def start(self, experiment: Experiment, **kwargs):
resolved_nodes = self._ensure_nodes_order(experiment.nodes)

# Run nodes
data_dict = {data.get("name", None): data for data in experiment.data}
run_dict = {}
try:
for node in resolved_nodes:
logger.debug(f"Running node {node.name}.")
run = self._run_node(node, experiment, data_dict, run_dict)
logger.info(f"Running node {node.name}.")
run = self._run_node(node, experiment, run_dict)
# Update node run to experiment
experiment._append_node_run(node.name, run)
self.experiment_operations.create_or_update(experiment)
Expand Down Expand Up @@ -92,23 +91,23 @@ def _prepare_edges(node):
referenced_nodes.discard(node.name)
break
if not action:
raise UserErrorException(f"Experiment has circular dependency {edges!r}")
raise ExperimentHasCycle(f"Experiment has circular dependency {edges!r}")

logger.debug(f"Experiment nodes resolved order: {[node.name for node in resolved_nodes]}")
return resolved_nodes

def _run_node(self, node, experiment, data_dict, run_dict) -> Run:
def _run_node(self, node, experiment, run_dict) -> Run:
if node.type == ExperimentNodeType.FLOW:
return self._run_flow_node(node, experiment, data_dict, run_dict)
return self._run_flow_node(node, experiment, run_dict)
elif node.type == ExperimentNodeType.CODE:
return self._run_script_node(node, experiment)
raise UserErrorException(f"Unknown experiment node {node.name!r} type {node.type!r}")
raise ExperimentValueError(f"Unknown experiment node {node.name!r} type {node.type!r}")

def _run_flow_node(self, node, experiment, data_dict, run_dict):
def _run_flow_node(self, node, experiment, run_dict):
run_output_path = (Path(experiment._output_dir) / "runs" / node.name).resolve().absolute().as_posix()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
run = ExperimentRun(
experiment_data=data_dict,
experiment=experiment,
experiment_runs=run_dict,
# Use node name as prefix for run name?
name=f"{node.name}_attempt{timestamp}",
Expand All @@ -132,10 +131,30 @@ def _run_script_node(self, node, experiment):
class ExperimentRun(Run):
"""Experiment run, includes experiment running context, like data, inputs and runs."""

def __init__(self, experiment_data, experiment_runs, **kwargs):
self.experiment_data = experiment_data
def __init__(self, experiment, experiment_runs, **kwargs):
self.experiment = experiment
self.experiment_data = {data.name: data for data in experiment.data}
self.experiment_inputs = {input.name: input for input in experiment.inputs}
self.experiment_runs = experiment_runs
super().__init__(**kwargs)
self._resolve_column_mapping()

def _resolve_column_mapping(self):
"""Resolve column mapping with experiment inputs to constant values."""
logger.info(f"Start resolve node {self.display_name!r} column mapping.")
resolved_mapping = {}
for name, value in self.column_mapping.items():
if not value.startswith("${inputs."):
resolved_mapping[name] = value
continue
input_name = value.split(".")[1].replace("}", "")
if input_name not in self.experiment_inputs:
raise ExperimentValueError(
f"Node {self.display_name!r} inputs {value!r} related experiment input {input_name!r} not found."
)
resolved_mapping[name] = self.experiment_inputs[input_name].default
logger.debug(f"Resolved node {self.display_name!r} column mapping {resolved_mapping}.")
self.column_mapping = resolved_mapping


class ExperimentRunSubmitter(RunSubmitter):
Expand All @@ -155,24 +174,26 @@ def _resolve_input_dirs(self, run: ExperimentRun):
for value in inputs_mapping.values():
referenced_data, referenced_run = None, None
if value.startswith("${data."):
referenced_data = value.split(".")[1]
referenced_data = value.split(".")[1].replace("}", "")
elif value.startswith("${"):
referenced_run = value.split(".")[0].replace("${", "")
if referenced_data:
if data_name and data_name != referenced_data:
raise UserErrorException(
raise ExperimentValueError(
f"Experiment has multiple data inputs {data_name!r} and {referenced_data!r}"
)
data_name = referenced_data
if referenced_run:
if run_name and run_name != referenced_run:
raise UserErrorException(f"Experiment has multiple run inputs {run_name!r} and {referenced_run!r}")
raise ExperimentValueError(
f"Experiment has multiple run inputs {run_name!r} and {referenced_run!r}"
)
run_name = referenced_run
logger.debug(f"Resolve node {run.name} referenced data {data_name!r}, run {run_name!r}.")
# Build inputs from experiment data and run
result = {}
if data_name in run.experiment_data and run.experiment_data[data_name].get("path"):
result.update({f"data.{data_name}": run.experiment_data[data_name]["path"]})
if data_name in run.experiment_data and run.experiment_data[data_name].path:
result.update({f"data.{data_name}": run.experiment_data[data_name].path})
if run_name in run.experiment_runs:
result.update(
{
Expand Down
108 changes: 96 additions & 12 deletions src/promptflow/promptflow/_sdk/entities/_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

from marshmallow import Schema

from promptflow._sdk._constants import (
BASE_PATH_CONTEXT_KEY,
PARAMS_OVERRIDE_KEY,
Expand All @@ -18,22 +20,71 @@
ExperimentNodeType,
ExperimentStatus,
)
from promptflow._sdk._errors import ExperimentValidationError, ExperimentValueError
from promptflow._sdk._orm.experiment import Experiment as ORMExperiment
from promptflow._sdk._submitter import remove_additional_includes
from promptflow._sdk._utils import _merge_local_code_and_additional_includes, _sanitize_python_variable_name
from promptflow._sdk.entities import Run
from promptflow._sdk.entities._validation import MutableValidationResult, SchemaValidatableMixin
from promptflow._sdk.entities._yaml_translatable import YAMLTranslatableMixin
from promptflow._sdk.schemas._experiment import (
ExperimentDataSchema,
ExperimentInputSchema,
ExperimentSchema,
ExperimentTemplateSchema,
FlowNodeSchema,
ScriptNodeSchema,
)
from promptflow._utils.logger_utils import get_cli_sdk_logger
from promptflow.contracts.tool import ValueType

logger = get_cli_sdk_logger()


class ExperimentData(YAMLTranslatableMixin):
def __init__(self, name, path, **kwargs):
self.name = name
self.path = path

@classmethod
def _get_schema_cls(cls):
return ExperimentDataSchema


class ExperimentInput(YAMLTranslatableMixin):
def __init__(self, name, default, type, **kwargs):
self.name = name
self.type, self.default = self._resolve_type_and_default(type, default)

@classmethod
def _get_schema_cls(cls):
return ExperimentInputSchema

def _resolve_type_and_default(self, typ, default):
supported_types = [
ValueType.INT,
ValueType.STRING,
ValueType.DOUBLE,
ValueType.LIST,
ValueType.OBJECT,
ValueType.BOOL,
]
value_type: ValueType = next((i for i in supported_types if typ.lower() == i.value.lower()), None)
if value_type is None:
raise ExperimentValueError(f"Unknown experiment input type {typ!r}, supported are {supported_types}.")
return value_type.value, value_type.parse(default)

@classmethod
def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str = None, **kwargs):
# Override this to avoid 'type' got pop out
schema_cls = cls._get_schema_cls()
try:
loaded_data = schema_cls(context=context).load(data, **kwargs)
except Exception as e:
raise Exception(f"Load experiment input failed with {str(e)}. f{(additional_message or '')}.")
return cls(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data)


class FlowNode(YAMLTranslatableMixin):
def __init__(
self,
Expand Down Expand Up @@ -109,7 +160,7 @@ def _save_snapshot(self, target):
pass


class ExperimentTemplate(YAMLTranslatableMixin):
class ExperimentTemplate(YAMLTranslatableMixin, SchemaValidatableMixin):
def __init__(self, nodes, name=None, description=None, data=None, inputs=None, **kwargs):
self._base_path = kwargs.get(BASE_PATH_CONTEXT_KEY, Path("."))
self.name = name or self._generate_name()
Expand Down Expand Up @@ -171,13 +222,38 @@ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str = No
raise Exception(f"Load experiment template failed with {str(e)}. f{(additional_message or '')}.")
return cls(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data)

@classmethod
def _create_schema_for_validation(cls, context) -> Schema:
return cls._get_schema_cls()(context=context)

def _default_context(self) -> dict:
return {BASE_PATH_CONTEXT_KEY: self._base_path}

@classmethod
def _create_validation_error(cls, message: str, no_personal_data_message: str) -> Exception:
return ExperimentValidationError(
message=message,
no_personal_data_message=no_personal_data_message,
)

def _customized_validate(self) -> MutableValidationResult:
"""Validate the resource with customized logic.
Override this method to add customized validation logic.
:return: The customized validation result
:rtype: MutableValidationResult
"""
pass


class Experiment(ExperimentTemplate):
def __init__(
self,
nodes,
name=None,
data=None,
inputs=None,
status=ExperimentStatus.NOT_STARTED,
node_runs=None,
properties=None,
Expand All @@ -192,7 +268,7 @@ def __init__(
self.last_end_time = kwargs.get("last_end_time", None)
self.is_archived = kwargs.get("is_archived", False)
self._output_dir = Path.home() / PROMPT_FLOW_DIR_NAME / PROMPT_FLOW_EXP_DIR_NAME / self.name
super().__init__(nodes, name=self.name, data=data, **kwargs)
super().__init__(nodes, name=self.name, data=data, inputs=inputs, **kwargs)

@classmethod
def _get_schema_cls(cls):
Expand Down Expand Up @@ -240,8 +316,8 @@ def _to_orm_object(self):
last_start_time=self.last_start_time,
last_end_time=self.last_end_time,
properties=json.dumps(self.properties),
data=json.dumps(self.data),
inputs=json.dumps(self.inputs),
data=json.dumps([item._to_dict() for item in self.data]),
inputs=json.dumps([input._to_dict() for input in self.inputs]),
nodes=json.dumps([node._to_dict() for node in self.nodes]),
node_runs=json.dumps(self.node_runs),
)
Expand All @@ -252,21 +328,28 @@ def _to_orm_object(self):
def _from_orm_object(cls, obj: ORMExperiment) -> "Experiment":
"""Create a experiment object from ORM object."""
nodes = []
context = {BASE_PATH_CONTEXT_KEY: "./"}
for node_dict in json.loads(obj.nodes):
if node_dict["type"] == ExperimentNodeType.FLOW:
nodes.append(
FlowNode._load_from_dict(
node_dict, context={BASE_PATH_CONTEXT_KEY: "./"}, additional_message="Failed to load node."
)
FlowNode._load_from_dict(node_dict, context=context, additional_message="Failed to load node.")
)
elif node_dict["type"] == ExperimentNodeType.CODE:
nodes.append(
ScriptNode._load_from_dict(
node_dict, context={BASE_PATH_CONTEXT_KEY: "./"}, additional_message="Failed to load node."
)
ScriptNode._load_from_dict(node_dict, context=context, additional_message="Failed to load node.")
)
else:
raise Exception(f"Unknown node type {node_dict['type']}")
data = [
ExperimentData._load_from_dict(item, context=context, additional_message="Failed to load experiment data")
for item in json.loads(obj.data)
]
inputs = [
ExperimentInput._load_from_dict(
item, context=context, additional_message="Failed to load experiment inputs"
)
for item in json.loads(obj.inputs)
]

return cls(
name=obj.name,
Expand All @@ -277,8 +360,8 @@ def _from_orm_object(cls, obj: ORMExperiment) -> "Experiment":
last_end_time=obj.last_end_time,
is_archived=obj.archived,
properties=json.loads(obj.properties),
data=json.loads(obj.data),
inputs=json.loads(obj.inputs),
data=data,
inputs=inputs,
nodes=nodes,
node_runs=json.loads(obj.node_runs),
)
Expand All @@ -292,6 +375,7 @@ def from_template(cls, template: ExperimentTemplate):
name=exp_name,
description=template.description,
data=copy.deepcopy(template.data),
inputs=copy.deepcopy(template.inputs),
nodes=copy.deepcopy(template.nodes),
base_path=template._base_path,
)
Expand Down
Loading

0 comments on commit af12168

Please sign in to comment.