Skip to content

Commit

Permalink
Fix Comments
Browse files Browse the repository at this point in the history
  • Loading branch information
crazygao committed Oct 23, 2023
1 parent c8cf75a commit 6673310
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 53 deletions.
1 change: 1 addition & 0 deletions src/promptflow/promptflow/_core/run_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def _update_flow_run_info_with_node_runs(self, run_info):
run_info.system_metrics = run_info.system_metrics or {}
run_info.system_metrics.update(self.collect_metrics(child_run_infos, self.OPENAI_AGGREGATE_METRICS))
if os.environ.get("PF_RECORDING_MODE", None) == "replay":
# some tests require this metric to be set.
run_info.system_metrics["total_tokens"] = 0

def _node_run_postprocess(self, run_info: RunInfo, output, ex: Optional[Exception]):
Expand Down
5 changes: 5 additions & 0 deletions src/promptflow/promptflow/_core/tool_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@


class ToolRecord(ToolProvider):
"""
ToolRecord Record inputs and outputs of llm tool, in replay mode,
this tool will read the cached result from storage_record.json
"""

@tool
def completion(toolType: str, *args, **kwargs) -> str:
# "AzureOpenAI" = args[0], this is type indicator, there may be more than one indicators
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import collections
import datetime
import json
import logging
import os
import shutil
from dataclasses import asdict, dataclass
from functools import partial
Expand All @@ -28,14 +26,12 @@
get_run_output_path,
)
from promptflow._sdk._errors import BulkRunException
from promptflow._sdk._utils import generate_flow_tools_json
from promptflow._sdk._utils import generate_flow_tools_json, record_node_run
from promptflow._sdk.entities import Run
from promptflow._sdk.entities._flow import Flow
from promptflow._utils.dataclass_serializer import serialize
from promptflow._utils.exception_utils import PromptflowExceptionPresenter
from promptflow._utils.logger_utils import LogContext
from promptflow._utils.tool_utils import get_inputs_for_prompt_template
from promptflow._utils.utils import RecordStorage
from promptflow.contracts.run_info import FlowRunInfo
from promptflow.contracts.run_info import RunInfo as NodeRunInfo
from promptflow.contracts.run_info import Status
Expand Down Expand Up @@ -216,24 +212,6 @@ def _dump_meta_file(self) -> None:
with open(self._meta_path, mode="w", encoding=DEFAULT_ENCODING) as f:
json.dump({"batch_size": LOCAL_STORAGE_BATCH_SIZE}, f)

def _record_node_run(self, run_info: NodeRunInfo) -> None:
"""Persist node run record to local storage."""
if os.environ.get("PF_RECORDING_MODE", None) == "record":
for api_call in run_info.api_calls:
hashDict = {}
if "name" in api_call and api_call["name"].startswith("AzureOpenAI"):
flow_folder = self._flow_path

prompt_tpl = api_call["inputs"]["prompt"]
prompt_tpl_inputs = get_inputs_for_prompt_template(prompt_tpl)

for keyword in prompt_tpl_inputs:
if keyword in api_call["inputs"]:
hashDict[keyword] = api_call["inputs"][keyword]
hashDict["prompt"] = prompt_tpl
hashDict = collections.OrderedDict(sorted(hashDict.items()))
RecordStorage.set_record(flow_folder, hashDict, run_info.output)

def dump_snapshot(self, flow: Flow) -> None:
"""Dump flow directory to snapshot folder, input file will be dumped after the run."""
shutil.copytree(
Expand Down Expand Up @@ -390,7 +368,7 @@ def persist_node_run(self, run_info: NodeRunInfo) -> None:
line_number = 0 if node_run_record.line_number is None else node_run_record.line_number
filename = f"{str(line_number).zfill(self.LINE_NUMBER_WIDTH)}.jsonl"
node_run_record.dump(node_folder / filename, run_name=self._run.name)
self._record_node_run(run_info)
record_node_run(run_info, self._flow_path)

def persist_flow_run(self, run_info: FlowRunInfo) -> None:
"""Persist line run record to local storage."""
Expand Down
5 changes: 4 additions & 1 deletion src/promptflow/promptflow/_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def __getattr__(self, item):


class RecordStorage:
# static class for recording
"""
RecordStorage static class to manage recording file storage_record.json
"""

runItems: Dict[str, Dict[str, str]] = {}

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import multiprocessing
import os
import queue
import time
from datetime import datetime
from functools import partial
from logging import INFO
Expand Down
1 change: 1 addition & 0 deletions src/promptflow/promptflow/executor/_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def get_status_summary(self):

def get_openai_metrics(self):
if os.environ.get("PF_RECORDING_MODE", None) == "replay":
# Some tests request the metrics in replay mode.
total_metrics = {"total_tokens": 0, "duration": 0}
return total_metrics
node_run_infos = chain(self._get_line_run_infos(), self._get_aggr_run_infos())
Expand Down
4 changes: 3 additions & 1 deletion src/promptflow/promptflow/executor/_tool_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,9 @@ def _resolve_llm_node(self, node: Node, convert_input_types=False) -> ResolvedTo
return ResolvedTool(updated_node, tool, api_func, init_args)

def _resolve_replay_node(self, node: Node, convert_input_types=False) -> ResolvedTool:
# Provider must be prepared.
# in replay mode, replace original tool with just_return tool
# the tool iteslf just return saved record from storage_record.json
# processing no logic.
if (node.api == "completion" or node.api == "chat") and (
node.connection == "azure_open_ai_connection" or node.provider == "AzureOpenAI"
):
Expand Down
25 changes: 0 additions & 25 deletions src/promptflow/tests/sdk_cli_azure_test/unittests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
import os
import sys
from pathlib import Path
from typing import List
from unittest.mock import MagicMock

import pandas as pd
import pytest
from azure.ai.ml.constants._common import AZUREML_RESOURCE_PROVIDER, RESOURCE_ID_FORMAT
from pytest_mock import MockFixture

from promptflow._cli._pf_azure.entry import main
from promptflow._sdk._configuration import ConfigFileNotFound, Configuration
from promptflow._sdk._constants import VIS_PORTAL_URL_TMPL
from promptflow._utils.context_utils import _change_working_dir
from promptflow.azure.operations._run_operations import RunOperations

CONFIG_DATA_ROOT = Path(__file__).parent.parent.parent / "test_configs" / "configs"


def run_pf_command(*args, cwd=None):
origin_argv, origin_cwd = sys.argv, os.path.abspath(os.curdir)
Expand Down Expand Up @@ -49,25 +43,6 @@ def test_pf_azure_version(self, capfd):
out, err = capfd.readouterr()
assert out == "0.0.1\n"

def test_get_workspace_from_config(self):
# New instance instead of get_instance() to avoid side effect
conf = Configuration(overrides={"connection.provider": "azureml"})
# Test config within flow folder
target_folder = CONFIG_DATA_ROOT / "mock_flow1"
with _change_working_dir(target_folder):
config1 = conf.get_connection_provider()
assert config1 == "azureml:" + RESOURCE_ID_FORMAT.format("sub1", "rg1", AZUREML_RESOURCE_PROVIDER, "ws1")
# Test config using flow parent folder
target_folder = CONFIG_DATA_ROOT / "mock_flow2"
with _change_working_dir(target_folder):
config2 = conf.get_connection_provider()
assert config2 == "azureml:" + RESOURCE_ID_FORMAT.format(
"sub_default", "rg_default", AZUREML_RESOURCE_PROVIDER, "ws_default"
)
# Test config not found
with pytest.raises(ConfigFileNotFound):
Configuration._get_workspace_from_config(path=CONFIG_DATA_ROOT.parent)

def test_run_show(self, mocker: MockFixture, operation_scope_args):
mocked = mocker.patch.object(RunOperations, "get")
# show_run will print the run object, so we need to mock the return value
Expand Down
24 changes: 23 additions & 1 deletion src/promptflow/tests/sdk_cli_test/unittests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@

import pytest

from promptflow._sdk._configuration import Configuration
from promptflow._sdk._configuration import ConfigFileNotFound, Configuration
from promptflow._utils.context_utils import _change_working_dir

CONFIG_DATA_ROOT = Path(__file__).parent.parent.parent / "test_configs" / "configs"
AZUREML_RESOURCE_PROVIDER = "Microsoft.MachineLearningServices"
RESOURCE_ID_FORMAT = "/subscriptions/{}/resourceGroups/{}/providers/{}/workspaces/{}"


@pytest.fixture
Expand All @@ -34,3 +37,22 @@ def test_get_or_set_installation_id(self, config):
def test_config_instance(self, config):
new_config = Configuration.get_instance()
assert new_config is config

def test_get_workspace_from_config(self):
# New instance instead of get_instance() to avoid side effect
conf = Configuration(overrides={"connection.provider": "azureml"})
# Test config within flow folder
target_folder = CONFIG_DATA_ROOT / "mock_flow1"
with _change_working_dir(target_folder):
config1 = conf.get_connection_provider()
assert config1 == "azureml:" + RESOURCE_ID_FORMAT.format("sub1", "rg1", AZUREML_RESOURCE_PROVIDER, "ws1")
# Test config using flow parent folder
target_folder = CONFIG_DATA_ROOT / "mock_flow2"
with _change_working_dir(target_folder):
config2 = conf.get_connection_provider()
assert config2 == "azureml:" + RESOURCE_ID_FORMAT.format(
"sub_default", "rg_default", AZUREML_RESOURCE_PROVIDER, "ws_default"
)
# Test config not found
with pytest.raises(ConfigFileNotFound):
Configuration._get_workspace_from_config(path=CONFIG_DATA_ROOT.parent)

0 comments on commit 6673310

Please sign in to comment.