From b416e5da65a1fecaae599cf800e24c539d7bec09 Mon Sep 17 00:00:00 2001 From: zhen Date: Mon, 4 Sep 2023 17:19:16 +0800 Subject: [PATCH] [Promptflow CLI] Fix variant with quotation raise invalid format (#286) # Description Fix variant and col-mappings with quotation raise invalid format ![image](https://github.com/microsoft/promptflow/assets/17938940/cd8c76a3-e437-4f68-9181-1cb4064e28d2) Please add an informative description that covers that changes made by the pull request and link all relevant issues. # All Promptflow Contribution checklist: - [ ] **The pull request does not introduce [breaking changes]** - [ ] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).** ## General Guidelines and Best Practices - [ ] Title of the pull request is clear and informative. - [ ] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [ ] Pull request includes test coverage for the included changes. --- src/promptflow/promptflow/_cli/_params.py | 6 ++++-- src/promptflow/promptflow/_sdk/_utils.py | 15 ++++++++++++++- .../tests/sdk_cli_test/e2etests/test_cli.py | 8 +++++--- .../tests/sdk_cli_test/unittests/test_utils.py | 16 ++++++++++++++++ 4 files changed, 39 insertions(+), 6 deletions(-) diff --git a/src/promptflow/promptflow/_cli/_params.py b/src/promptflow/promptflow/_cli/_params.py index 7e1677b6787..128c5da8384 100644 --- a/src/promptflow/promptflow/_cli/_params.py +++ b/src/promptflow/promptflow/_cli/_params.py @@ -11,11 +11,13 @@ def __call__(self, parser, namespace, values, option_string=None): super(AppendToDictAction, self).__call__(parser, namespace, action, option_string) def get_action(self, values, option_string): # pylint: disable=no-self-use + from promptflow._sdk._utils import strip_quotation + kwargs = {} for item in values: try: - key, value = item.split("=", 1) - kwargs[key] = value + key, value = strip_quotation(item).split("=", 1) + kwargs[key] = strip_quotation(value) except ValueError: raise Exception("Usage error: {} KEY=VALUE [KEY=VALUE ...]".format(option_string)) return kwargs diff --git a/src/promptflow/promptflow/_sdk/_utils.py b/src/promptflow/promptflow/_sdk/_utils.py index c72c0ca7046..c35d2fa75a8 100644 --- a/src/promptflow/promptflow/_sdk/_utils.py +++ b/src/promptflow/promptflow/_sdk/_utils.py @@ -239,9 +239,22 @@ def load_from_dict(schema: Any, data: Dict, context: Dict, additional_message: s raise ValidationError(decorate_validation_error(schema, pretty_error, additional_message)) +def strip_quotation(value): + """ + To avoid escaping chars in command args, args will be surrounded in quotas. + Need to remove the pair of quotation first. + """ + if value.startswith('"') and value.endswith('"'): + return value[1:-1] + elif value.startswith("'") and value.endswith("'"): + return value[1:-1] + else: + return value + + def parse_variant(variant: str) -> Tuple[str, str]: variant_regex = r"\${([^.]+).([^}]+)}" - match = re.match(variant_regex, variant) + match = re.match(variant_regex, strip_quotation(variant)) if match: return match.group(1), match.group(2) else: diff --git a/src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py b/src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py index 6851fcddded..75100ab58d0 100644 --- a/src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py +++ b/src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py @@ -86,6 +86,8 @@ def test_basic_flow_run_batch_and_eval(self) -> None: ) assert "Completed" in f.getvalue() + # Check the CLI works correctly when the parameter is surrounded by quotation, as below shown: + # --param "key=value" key="value" f = io.StringIO() with contextlib.redirect_stdout(f): run_pf_command( @@ -94,8 +96,8 @@ def test_basic_flow_run_batch_and_eval(self) -> None: "--flow", f"{FLOWS_DIR}/classification_accuracy_evaluation", "--column-mapping", - "groundtruth=${data.answer}", - "prediction=${run.outputs.category}", + "'groundtruth=${data.answer}'", + "prediction='${run.outputs.category}'", "variant_id=${data.variant_id}", "--data", f"{DATAS_DIR}/webClassification3.jsonl", @@ -309,7 +311,7 @@ def test_pf_flow_with_variant(self, capsys): "answer=Channel", "evidence=Url", "--variant", - "${summarize_text_content.variant_1}", + "'${summarize_text_content.variant_1}'", ) output_path = Path(temp_dir) / ".promptflow" / "flow-summarize_text_content-variant_1.output.json" assert output_path.exists() diff --git a/src/promptflow/tests/sdk_cli_test/unittests/test_utils.py b/src/promptflow/tests/sdk_cli_test/unittests/test_utils.py index 352c80faa10..47654bf50e4 100644 --- a/src/promptflow/tests/sdk_cli_test/unittests/test_utils.py +++ b/src/promptflow/tests/sdk_cli_test/unittests/test_utils.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- +import argparse import os import shutil import tempfile @@ -16,6 +17,7 @@ _calculate_column_widths, list_of_dict_to_nested_dict, ) +from promptflow._cli._params import AppendToDictAction from promptflow._sdk._errors import GenerateFlowToolsJsonError from promptflow._sdk._utils import ( decrypt_secret_value, @@ -123,6 +125,20 @@ def test_list_of_dict_to_nested_dict(self): result = list_of_dict_to_nested_dict(test_list) assert result == {"node1": {"connection": "a", "deploy_name": "b"}} + def test_append_to_dict_action(self): + parser = argparse.ArgumentParser(prog="test_dict_action") + parser.add_argument("--dict", action=AppendToDictAction, nargs="+") + args = ["--dict", "key1=val1", "\'key2=val2\'", "\"key3=val3\"", "key4=\'val4\'", "key5=\"val5'"] + args = parser.parse_args(args) + expect_dict = { + "key1": "val1", + "key2": "val2", + "key3": "val3", + "key4": "val4", + "key5": "\"val5'", + } + assert args.dict[0] == expect_dict + def test_build_sorted_column_widths_tuple_list(self) -> None: columns = ["col1", "col2", "col3"] values1 = {"col1": 1, "col2": 4, "col3": 3}