Skip to content

Commit

Permalink
FEAT Consolidate Export Conversations into one method (#628)
Browse files Browse the repository at this point in the history
  • Loading branch information
eugeniavkim authored Dec 29, 2024
1 parent 85d6159 commit 5c47f1a
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 62 deletions.
8 changes: 4 additions & 4 deletions doc/code/memory/9_exporting_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@
"#csv_file_path = RESULTS_PATH / \"conversation_and_scores_csv_example.csv\"\n",
" \n",
"# # Export the data to a JSON file\n",
"conversation_with_scores = memory.export_all_conversations(file_path=json_file_path, export_type=\"json\")\n",
"conversation_with_scores = memory.export_conversations(file_path=json_file_path, export_type=\"json\")\n",
"print(f\"Exported conversation with scores to JSON: {json_file_path}\")\n",
" \n",
"# Export the data to a CSV file\n",
"# conversation_with_scores = memory.export_all_conversations(file_path=csv_file_path, export_type=\"csv\")\n",
"# conversation_with_scores = memory.export_conversations(file_path=csv_file_path, export_type=\"csv\")\n",
"# print(f\"Exported conversation with scores to CSV: {csv_file_path}\")"
]
},
Expand Down Expand Up @@ -141,11 +141,11 @@
"csv_file_path = RESULTS_PATH / \"conversation_and_scores_csv_example.csv\"\n",
"\n",
"# Export the data to a JSON file\n",
"# conversation_with_scores = azure_memory.export_conversation_by_id(conversation_id=conversation_id, file_path=json_file_path, export_type=\"json\")\n",
"# conversation_with_scores = azure_memory.export_conversations(conversation_id=conversation_id, file_path=json_file_path, export_type=\"json\")\n",
"# print(f\"Exported conversation with scores to JSON: {json_file_path}\")\n",
"\n",
"# Export the data to a CSV file\n",
"conversation_with_scores = azure_memory.export_conversation_by_id(conversation_id=conversation_id, file_path=json_file_path, export_type=\"csv\")\n",
"conversation_with_scores = azure_memory.export_conversations(conversation_id=conversation_id, file_path=json_file_path, export_type=\"csv\")\n",
"print(f\"Exported conversation with scores to CSV: {csv_file_path}\")\n",
"\n",
"# Cleanup memory resources\n",
Expand Down
8 changes: 4 additions & 4 deletions doc/code/memory/9_exporting_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@
# csv_file_path = RESULTS_PATH / "conversation_and_scores_csv_example.csv"

# # Export the data to a JSON file
conversation_with_scores = memory.export_all_conversations(file_path=json_file_path, export_type="json")
conversation_with_scores = memory.export_conversations(file_path=json_file_path, export_type="json")
print(f"Exported conversation with scores to JSON: {json_file_path}")

# Export the data to a CSV file
# conversation_with_scores = memory.export_all_conversations(file_path=csv_file_path, export_type="csv")
# conversation_with_scores = memory.export_conversations(file_path=csv_file_path, export_type="csv")
# print(f"Exported conversation with scores to CSV: {csv_file_path}")

# %% [markdown]
Expand Down Expand Up @@ -120,11 +120,11 @@
csv_file_path = RESULTS_PATH / "conversation_and_scores_csv_example.csv"

# Export the data to a JSON file
# conversation_with_scores = azure_memory.export_conversation_by_id(conversation_id=conversation_id, file_path=json_file_path, export_type="json")
# conversation_with_scores = azure_memory.export_conversations(conversation_id=conversation_id, file_path=json_file_path, export_type="json")
# print(f"Exported conversation with scores to JSON: {json_file_path}")

# Export the data to a CSV file
conversation_with_scores = azure_memory.export_conversation_by_id(
conversation_with_scores = azure_memory.export_conversations(
conversation_id=conversation_id, file_path=json_file_path, export_type="csv"
)
print(f"Exported conversation with scores to CSV: {csv_file_path}")
Expand Down
123 changes: 73 additions & 50 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,28 @@ def get_prompt_request_pieces(
not_data_type: Optional[str] = None,
converted_value_sha256: Optional[list[str]] = None,
) -> list[PromptRequestPiece]:
"""
Retrieves a list of PromptRequestPiece objects based on the specified filters.
Args:
orchestrator_id (Optional[str | uuid.UUID], optional): The ID of the orchestrator. Defaults to None.
conversation_id (Optional[str | uuid.UUID], optional): The ID of the conversation. Defaults to None.
prompt_ids (Optional[list[str] | list[uuid.UUID]], optional): A list of prompt IDs. Defaults to None.
labels (Optional[dict[str, str]], optional): A dictionary of labels. Defaults to None.
sent_after (Optional[datetime], optional): Filter for prompts sent after this datetime. Defaults to None.
sent_before (Optional[datetime], optional): Filter for prompts sent before this datetime. Defaults to None.
original_values (Optional[list[str]], optional): A list of original values. Defaults to None.
converted_values (Optional[list[str]], optional): A list of converted values. Defaults to None.
data_type (Optional[str], optional): The data type to filter by. Defaults to None.
not_data_type (Optional[str], optional): The data type to exclude. Defaults to None.
converted_value_sha256 (Optional[list[str]], optional): A list of SHA256 hashes of converted values.
Defaults to None.
Returns:
list[PromptRequestPiece]: A list of PromptRequestPiece objects that match the specified filters.
Raises:
Exception: If there is an error retrieving the prompts,
an exception is logged and an empty list is returned.
"""
conditions = []
if orchestrator_id:
conditions.append(self._get_prompt_pieces_orchestrator_conditions(orchestrator_id=str(orchestrator_id)))
Expand Down Expand Up @@ -374,28 +396,6 @@ def duplicate_conversation_excluding_last_turn(

return new_conversation_id

def export_conversation_by_orchestrator_id(
self, *, orchestrator_id: str, file_path: Path = None, export_type: str = "json"
):
"""
Exports conversation data with the given orchestrator ID to a specified file.
This will contain all conversations that were sent by the same orchestrator.
Args:
orchestrator_id (str): The ID of the orchestrator from which to export conversations.
file_path (str): The path to the file where the data will be exported.
If not provided, a default path using RESULTS_PATH will be constructed.
export_type (str): The format of the export. Defaults to "json".
"""
data = self.get_prompt_request_pieces(orchestrator_id=orchestrator_id)

# If file_path is not provided, construct a default using the exporter's results_path
if not file_path:
file_name = f"{str(orchestrator_id)}.{export_type}"
file_path = RESULTS_PATH / file_name

self.exporter.export_data(data, file_path=file_path, export_type=export_type)

def add_request_response_to_memory(self, *, request: PromptRequestResponse) -> None:
"""
Inserts a list of prompt request pieces into the memory storage.
Expand Down Expand Up @@ -523,25 +523,6 @@ def get_chat_messages_with_conversation_id(self, *, conversation_id: str) -> lis
memory_entries = self.get_prompt_request_pieces(conversation_id=conversation_id)
return [ChatMessage(role=me.role, content=me.converted_value) for me in memory_entries] # type: ignore

def export_conversation_by_id(self, *, conversation_id: str, file_path: Path = None, export_type: str = "json"):
"""
Exports conversation data with the given conversation ID to a specified file.
Args:
conversation_id (str): The ID of the conversation to export.
file_path (str): The path to the file where the data will be exported.
If not provided, a default path using RESULTS_PATH will be constructed.
export_type (str): The format of the export. Defaults to "json".
"""
data = self.get_prompt_request_pieces(conversation_id=conversation_id)

# If file_path is not provided, construct a default using the exporter's results_path
if not file_path:
file_name = f"{conversation_id}.{export_type}"
file_path = RESULTS_PATH / file_name

self.exporter.export_data(data, file_path=file_path, export_type=export_type)

def get_seed_prompts(
self,
*,
Expand Down Expand Up @@ -740,19 +721,61 @@ def get_seed_prompt_groups(
seed_prompt_groups = SeedPromptDataset.group_seed_prompts_by_prompt_group_id(seed_prompts)
return seed_prompt_groups

def export_all_conversations(self, *, file_path: Optional[Path] = None, export_type: str = "json"):
def export_conversations(
self,
*,
orchestrator_id: Optional[str | uuid.UUID] = None,
conversation_id: Optional[str | uuid.UUID] = None,
prompt_ids: Optional[list[str] | list[uuid.UUID]] = None,
labels: Optional[dict[str, str]] = None,
sent_after: Optional[datetime] = None,
sent_before: Optional[datetime] = None,
original_values: Optional[list[str]] = None,
converted_values: Optional[list[str]] = None,
data_type: Optional[str] = None,
not_data_type: Optional[str] = None,
converted_value_sha256: Optional[list[str]] = None,
file_path: Optional[Path] = None,
export_type: str = "json",
):
"""
Exports all conversations with scores to a specified file.
Exports conversation data with the given inputs to a specified file.
Defaults to all conversations if no filters are provided.
Args:
file_path (str): The path to the file where the data will be exported.
If not provided, a default path using RESULTS_PATH will be constructed.
export_type (str): The format of the export. Defaults to "json".
"""
all_prompt_pieces = self.get_prompt_request_pieces()
orchestrator_id (Optional[str | uuid.UUID], optional): The ID of the orchestrator. Defaults to None.
conversation_id (Optional[str | uuid.UUID], optional): The ID of the conversation. Defaults to None.
prompt_ids (Optional[list[str] | list[uuid.UUID]], optional): A list of prompt IDs. Defaults to None.
labels (Optional[dict[str, str]], optional): A dictionary of labels. Defaults to None.
sent_after (Optional[datetime], optional): Filter for prompts sent after this datetime. Defaults to None.
sent_before (Optional[datetime], optional): Filter for prompts sent before this datetime. Defaults to None.
original_values (Optional[list[str]], optional): A list of original values. Defaults to None.
converted_values (Optional[list[str]], optional): A list of converted values. Defaults to None.
data_type (Optional[str], optional): The data type to filter by. Defaults to None.
not_data_type (Optional[str], optional): The data type to exclude. Defaults to None.
converted_value_sha256 (Optional[list[str]], optional): A list of SHA256 hashes of converted values.
Defaults to None.
file_path (Optional[Path], optional): The path to the file where the data will be exported.
Defaults to None.
export_type (str, optional): The format of the export. Defaults to "json".
"""
data = self.get_prompt_request_pieces(
orchestrator_id=orchestrator_id,
conversation_id=conversation_id,
prompt_ids=prompt_ids,
labels=labels,
sent_after=sent_after,
sent_before=sent_before,
original_values=original_values,
converted_values=converted_values,
data_type=data_type,
not_data_type=not_data_type,
converted_value_sha256=converted_value_sha256,
)

# If file_path is not provided, construct a default using the exporter's results_path
if not file_path:
file_name = f"conversations.{export_type}"
file_name = f"exported_conversations_on_{datetime.now().strftime('%Y_%m_%d')}.{export_type}"
file_path = RESULTS_PATH / file_name

self.exporter.export_data(all_prompt_pieces, file_path=file_path, export_type=export_type)
self.exporter.export_data(data, file_path=file_path, export_type=export_type)
8 changes: 4 additions & 4 deletions tests/unit/memory/test_memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,15 +555,15 @@ def test_export_conversation_by_orchestrator_id_file_created(
):
orchestrator1_id = sample_conversations[0].orchestrator_identifier["id"]

# Default path in export_conversation_by_orchestrator_id()
# Default path in export_conversations()
file_name = f"{orchestrator1_id}.json"
file_path = Path(RESULTS_PATH, file_name)

duckdb_instance.exporter = MemoryExporter()

with patch("pyrit.memory.duckdb_memory.DuckDBMemory.get_prompt_request_pieces") as mock_get:
mock_get.return_value = sample_conversations
duckdb_instance.export_conversation_by_orchestrator_id(orchestrator_id=orchestrator1_id)
duckdb_instance.export_conversations(orchestrator_id=orchestrator1_id, file_path=file_path)

# Verify file was created
assert file_path.exists()
Expand Down Expand Up @@ -1245,7 +1245,7 @@ def test_export_all_conversations_with_scores_correct_data(duckdb_instance: Memo
mock_get_pieces.return_value = [MagicMock(original_prompt_id="1234", converted_value="sample piece")]
mock_get_scores.return_value = [MagicMock(prompt_request_response_id="1234", score_value=10)]

duckdb_instance.export_all_conversations(file_path=file_path)
duckdb_instance.export_conversations(file_path=file_path)

pos_arg, named_args = mock_export_data.call_args
assert str(named_args["file_path"]) == temp_file.file.name
Expand All @@ -1268,7 +1268,7 @@ def test_export_all_conversations_with_scores_empty_data(duckdb_instance: Memory
mock_get_pieces.return_value = []
mock_get_scores.return_value = []

duckdb_instance.export_all_conversations(file_path=file_path)
duckdb_instance.export_conversations(file_path=file_path)
mock_export_data.assert_called_once_with(expected_data, file_path=file_path, export_type="json")


Expand Down

0 comments on commit 5c47f1a

Please sign in to comment.