Skip to content

Commit

Permalink
Fix Comments
Browse files Browse the repository at this point in the history
  • Loading branch information
crazygao committed Oct 30, 2023
1 parent 5fb93e0 commit 4ff7d8d
Show file tree
Hide file tree
Showing 11 changed files with 173 additions and 77 deletions.
29 changes: 15 additions & 14 deletions src/promptflow/tests/sdk_cli_test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,6 @@ def setup_local_connection(local_client):
_connection_setup = True


@pytest.fixture
def remove_local_connection(local_client):
local_client.connections.delete("azure_open_ai_connection")
local_client.connections.delete("serp_connection")
local_client.connections.delete("custom_connection")
local_client.connections.delete("gpt2_connection")
local_client.connections.delete("open_ai_connection")


@pytest.fixture
def flow_serving_client(mocker: MockerFixture):
model_path = (Path(MODEL_ROOT) / "basic-with-connection").resolve().absolute().as_posix()
Expand Down Expand Up @@ -152,16 +143,26 @@ def serving_client_python_stream_tools(mocker: MockerFixture):

@pytest.fixture
def mock_for_recordings(request: pytest.FixtureRequest, mocker: MockerFixture) -> None:
recording_folder: Path = RECORDINGS_TEST_CONFIGS_ROOT / request.cls.__name__
"""
mock_for_recordings This is the entrance method of recording/replaying mode.
environment variables: PF_RECORDING_MODE is the key env var to control this test feature.
Record: is_recording() will return True, is_replaying() will return False.
Get node run info (currently llm node), and save the info in the following key value pair
Key: Ordered dict of all inputs => sha1 hash value
Value: base64 of output value.
Replay: is_recording() will return False, is_replaying() will return True.
hijack all llm nodes with customized tool, it calculate the hash of inputs, and get outputs.
"""
recording_file: Path = RECORDINGS_TEST_CONFIGS_ROOT / f"{str(request.cls.__name__).lower()}_storage_record.json"
if is_recording():
recording_folder.mkdir(parents=True, exist_ok=True)
RECORDINGS_TEST_CONFIGS_ROOT.mkdir(parents=True, exist_ok=True)
mocker.patch(
"promptflow._sdk.operations._local_storage_operations.LocalStorageOperations.persist_node_run",
mock_persist_node_run(recording_folder),
mock_persist_node_run(recording_file),
)
mocker.patch(
"promptflow._sdk.operations._flow_operations.FlowOperations._test",
mock_flowoperations_test(recording_folder),
mock_flowoperations_test(recording_file),
)

if is_replaying():
Expand All @@ -172,7 +173,7 @@ def mock_for_recordings(request: pytest.FixtureRequest, mocker: MockerFixture) -

mocker.patch(
"promptflow.executor._tool_resolver.ToolResolver.resolve_tool_by_node",
mock_toolresolver_resolve_tool_by_node(recording_folder),
mock_toolresolver_resolve_tool_by_node(recording_file),
)
mocker.patch(
"promptflow._sdk._utils.get_local_connections_from_executable", mock_get_local_connections_from_executable
Expand Down
1 change: 1 addition & 0 deletions src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def test_basic_flow_run(self) -> None:
)
assert "Completed" in f.getvalue()

@pytest.mark.skipif(is_replaying(), reason="Instable in replay mode.")
def test_basic_flow_run_batch_and_eval(self) -> None:
run_id = str(uuid.uuid4())
f = io.StringIO()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,7 @@ def test_executor_logs_in_batch_run_logs(self, pf: PFClient) -> None:
# so it is expected to be printed here
assert "Starting run without column mapping may lead to unexpected results." in logs

@pytest.mark.skipif(is_replaying(), reason="Recording and replaying cannot support image input")
def test_basic_image_flow_bulk_run(self, pf, local_client) -> None:
image_flow_path = f"{FLOWS_DIR}/python_tool_with_simple_image"
data_path = f"{image_flow_path}/image_inputs/inputs.jsonl"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def mock_update_run_func(self, run_info: NodeRunInfo):
run_info.system_metrics["total_tokens"] = 0


def mock_persist_node_run(recording_folder: Path):
def mock_persist_node_run(recording_file: Path):
# Mock of LocalStorageOperations. persist_node_run, it will record the batch node run info in recording mode.
def _mock_persist_node_run(self, run_info: NodeRunInfo) -> None:
node_run_record = NodeRunRecord.from_run_info(run_info)
Expand All @@ -41,12 +41,12 @@ def _mock_persist_node_run(self, run_info: NodeRunInfo) -> None:
line_number = 0 if node_run_record.line_number is None else node_run_record.line_number
filename = f"{str(line_number).zfill(self.LINE_NUMBER_WIDTH)}.jsonl"
node_run_record.dump(node_folder / filename, run_name=self._run.name)
record_node_run(node_run_record.run_info, recording_folder)
record_node_run(node_run_record.run_info, recording_file)

return _mock_persist_node_run


def mock_flowoperations_test(recording_folder: Path):
def mock_flowoperations_test(recording_file: Path):
def _mock_flowoperations_test(
self,
flow: Union[str, os.PathLike],
Expand Down Expand Up @@ -79,7 +79,7 @@ def _mock_flowoperations_test(
environment_variables=environment_variables,
stream=True,
)
record_node_run(result, recording_folder)
record_node_run(result, recording_file)
return result
else:
result_flow_test: LineResult = submitter.flow_test(
Expand All @@ -88,7 +88,7 @@ def _mock_flowoperations_test(
stream_log=stream_log,
allow_generator_output=allow_generator_output and is_chat_flow,
)
record_node_run(result_flow_test.run_info, recording_folder)
record_node_run(result_flow_test.run_info, recording_file)
return result_flow_test

return _mock_flowoperations_test
Expand All @@ -100,7 +100,7 @@ def mock_bulkresult_get_openai_metrics(self):
return total_metrics


def mock_toolresolver_resolve_tool_by_node(recording_folder: Path):
def mock_toolresolver_resolve_tool_by_node(recording_file: Path):
# Mock for _tool_resolver.py, currently llm nodes will be resolved with recording utils just_return.
def _resolve_replay_node(self, node: Node, convert_input_types=False) -> ResolvedTool:
# in replay mode, replace original tool with just_return tool
Expand All @@ -111,7 +111,7 @@ def _resolve_replay_node(self, node: Node, convert_input_types=False) -> Resolve
):
prompt_tpl = self._load_source_content(node)
prompt_tpl_inputs = get_inputs_for_prompt_template(prompt_tpl)
callable = partial(just_return, "AzureOpenAI", prompt_tpl, prompt_tpl_inputs, recording_folder)
callable = partial(just_return, "AzureOpenAI", prompt_tpl, prompt_tpl_inputs, recording_file)
return ResolvedTool(node=node, definition=None, callable=callable, init_args={})
else:
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ def recording_or_replaying() -> bool:
return is_recording() or is_replaying()


class RecordItemMissingException(Exception):
pass


class RecordFileMissingException(Exception):
pass


class RecordStorage:
"""
RecordStorage static class to manage recording file storage_record.json
Expand All @@ -42,79 +50,72 @@ class RecordStorage:
runItems: Dict[str, Dict[str, str]] = {}

@staticmethod
def write_file(flow_directory: Path) -> None:
def write_file(recording_file: Path) -> None:

path_hash = hashlib.sha1(str(flow_directory.parts[-4:-1]).encode("utf-8")).hexdigest()
path_hash = hashlib.sha1(str(recording_file.parts[-4:]).encode("utf-8")).hexdigest()
file_content = RecordStorage.runItems.get(path_hash, None)
if file_content is not None:
with open(flow_directory / "storage_record.json", "w+") as fp:
with open(recording_file, "w+") as fp:
json.dump(RecordStorage.runItems[path_hash], fp, indent=4)

@staticmethod
def load_file(flow_directory: Path) -> None:
path_hash = hashlib.sha1(str(flow_directory.parts[-4:-1]).encode("utf-8")).hexdigest()
def load_file(recording_file: Path) -> None:
path_hash = hashlib.sha1(str(recording_file.parts[-4:]).encode("utf-8")).hexdigest()
local_content = RecordStorage.runItems.get(path_hash, None)
if not local_content:
if not os.path.exists(flow_directory / "storage_record.json"):
if not os.path.exists(recording_file):
return
with open(flow_directory / "storage_record.json", "r", encoding="utf-8") as fp:
with open(recording_file, "r", encoding="utf-8") as fp:
RecordStorage.runItems[path_hash] = json.load(fp)

@staticmethod
def get_record(flow_directory: Path, hashDict: OrderedDict) -> str:
# special deal remove text_content, because it is not stable.
if "text_content" in hashDict:
hashDict.pop("text_content")

def get_record(recording_file: Path, hashDict: OrderedDict) -> str:
hash_value: str = hashlib.sha1(str(hashDict).encode("utf-8")).hexdigest()
path_hash: str = hashlib.sha1(str(flow_directory.parts[-4:-1]).encode("utf-8")).hexdigest()
path_hash: str = hashlib.sha1(str(recording_file.parts[-4:]).encode("utf-8")).hexdigest()
file_item: Dict[str, str] = RecordStorage.runItems.get(path_hash, None)
if file_item is None:
RecordStorage.load_file(flow_directory)
RecordStorage.load_file(recording_file)
file_item = RecordStorage.runItems.get(path_hash, None)
if file_item is not None:
item = file_item.get(hash_value, None)
if item is not None:
real_item = base64.b64decode(bytes(item, "utf-8")).decode()
return real_item
else:
raise BaseException(
f"Record item not found in folder {flow_directory}.\n"
raise RecordItemMissingException(
f"Record item not found in folder {recording_file}.\n"
f"Path hash {path_hash}\nHash value: {hash_value}\n"
f"Hash dict: {hashDict}\nHashed values: {json.dumps(hashDict)}\n"
)
else:
raise BaseException(f"Record file not found in folder {flow_directory}.")
raise RecordFileMissingException(f"Record file not found in folder {recording_file}.")

@staticmethod
def set_record(flow_directory: Path, hashDict: OrderedDict, output: object) -> None:
# special deal remove text_content, because it is not stable.
if "text_content" in hashDict:
hashDict.pop("text_content")
def set_record(recording_file: Path, hashDict: OrderedDict, output: object) -> None:
hash_value: str = hashlib.sha1(str(hashDict).encode("utf-8")).hexdigest()
path_hash: str = hashlib.sha1(str(flow_directory.parts[-4:-1]).encode("utf-8")).hexdigest()
path_hash: str = hashlib.sha1(str(recording_file.parts[-4:]).encode("utf-8")).hexdigest()
output_base64: str = base64.b64encode(bytes(output, "utf-8")).decode(encoding="utf-8")
current_saved_record: Dict[str, str] = RecordStorage.runItems.get(path_hash, None)
if current_saved_record is None:
RecordStorage.load_file(flow_directory)
RecordStorage.load_file(recording_file)
if RecordStorage.runItems is None:
RecordStorage.runItems = {}
if (RecordStorage.runItems.get(path_hash, None)) is None:
RecordStorage.runItems[path_hash] = {}
RecordStorage.runItems[path_hash][hash_value] = output_base64
RecordStorage.write_file(flow_directory)
RecordStorage.write_file(recording_file)
else:
saved_output = current_saved_record.get(hash_value, None)
if saved_output is not None and saved_output == output_base64:
return
else:
current_saved_record[hash_value] = output_base64
RecordStorage.write_file(flow_directory)
RecordStorage.write_file(recording_file)


class ToolRecord(ToolProvider):
class ToolRecordPlayer(ToolProvider):
"""
ToolRecord Record inputs and outputs of llm tool, in replay mode,
ToolRecordPlayer Record inputs and outputs of llm tool, in replay mode,
this tool will read the cached result from storage_record.json
"""

Expand All @@ -123,7 +124,7 @@ def completion(toolType: str, *args, **kwargs) -> str:
# "AzureOpenAI" = args[0], this is type indicator, there may be more than one indicators
prompt_tmpl = args[1]
prompt_tpl_inputs = args[2]
working_folder = args[3]
recording_file = args[3]

hashDict = {}
for keyword in prompt_tpl_inputs:
Expand All @@ -132,14 +133,14 @@ def completion(toolType: str, *args, **kwargs) -> str:
hashDict["prompt"] = prompt_tmpl
hashDict = collections.OrderedDict(sorted(hashDict.items()))

real_item = RecordStorage.get_record(working_folder, hashDict)
real_item = RecordStorage.get_record(recording_file, hashDict)
return real_item


@tool
def just_return(toolType: str, *args, **kwargs) -> str:
# Replay: Promptflow internal test tool, get all input and return recorded output
return ToolRecord().completion(toolType, *args, **kwargs)
return ToolRecordPlayer().completion(toolType, *args, **kwargs)


def _record_node_run(run_info: NodeRunInfo, flow_folder: Path, api_call: Dict[str, Any]) -> None:
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

Loading

0 comments on commit 4ff7d8d

Please sign in to comment.