Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT Consolidate Export Conversations into one method #628

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions doc/code/memory/9_exporting_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@
"#csv_file_path = RESULTS_PATH / \"conversation_and_scores_csv_example.csv\"\n",
" \n",
"# # Export the data to a JSON file\n",
"conversation_with_scores = memory.export_all_conversations(file_path=json_file_path, export_type=\"json\")\n",
"conversation_with_scores = memory.export_conversations(file_path=json_file_path, export_type=\"json\")\n",
"print(f\"Exported conversation with scores to JSON: {json_file_path}\")\n",
" \n",
"# Export the data to a CSV file\n",
"# conversation_with_scores = memory.export_all_conversations(file_path=csv_file_path, export_type=\"csv\")\n",
"# conversation_with_scores = memory.export_conversations(file_path=csv_file_path, export_type=\"csv\")\n",
"# print(f\"Exported conversation with scores to CSV: {csv_file_path}\")"
]
},
Expand Down Expand Up @@ -141,11 +141,11 @@
"csv_file_path = RESULTS_PATH / \"conversation_and_scores_csv_example.csv\"\n",
"\n",
"# Export the data to a JSON file\n",
"# conversation_with_scores = azure_memory.export_conversation_by_id(conversation_id=conversation_id, file_path=json_file_path, export_type=\"json\")\n",
"# conversation_with_scores = azure_memory.export_conversations(conversation_id=conversation_id, file_path=json_file_path, export_type=\"json\")\n",
"# print(f\"Exported conversation with scores to JSON: {json_file_path}\")\n",
"\n",
"# Export the data to a CSV file\n",
"conversation_with_scores = azure_memory.export_conversation_by_id(conversation_id=conversation_id, file_path=json_file_path, export_type=\"csv\")\n",
"conversation_with_scores = azure_memory.export_conversations(conversation_id=conversation_id, file_path=json_file_path, export_type=\"csv\")\n",
"print(f\"Exported conversation with scores to CSV: {csv_file_path}\")\n",
"\n",
"# Cleanup memory resources\n",
Expand Down
8 changes: 4 additions & 4 deletions doc/code/memory/9_exporting_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@
# csv_file_path = RESULTS_PATH / "conversation_and_scores_csv_example.csv"

# # Export the data to a JSON file
conversation_with_scores = memory.export_all_conversations(file_path=json_file_path, export_type="json")
conversation_with_scores = memory.export_conversations(file_path=json_file_path, export_type="json")
print(f"Exported conversation with scores to JSON: {json_file_path}")

# Export the data to a CSV file
# conversation_with_scores = memory.export_all_conversations(file_path=csv_file_path, export_type="csv")
# conversation_with_scores = memory.export_conversations(file_path=csv_file_path, export_type="csv")
# print(f"Exported conversation with scores to CSV: {csv_file_path}")

# %% [markdown]
Expand Down Expand Up @@ -120,11 +120,11 @@
csv_file_path = RESULTS_PATH / "conversation_and_scores_csv_example.csv"

# Export the data to a JSON file
# conversation_with_scores = azure_memory.export_conversation_by_id(conversation_id=conversation_id, file_path=json_file_path, export_type="json")
# conversation_with_scores = azure_memory.export_conversations(conversation_id=conversation_id, file_path=json_file_path, export_type="json")
# print(f"Exported conversation with scores to JSON: {json_file_path}")

# Export the data to a CSV file
conversation_with_scores = azure_memory.export_conversation_by_id(
conversation_with_scores = azure_memory.export_conversations(
conversation_id=conversation_id, file_path=json_file_path, export_type="csv"
)
print(f"Exported conversation with scores to CSV: {csv_file_path}")
Expand Down
123 changes: 73 additions & 50 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,28 @@ def get_prompt_request_pieces(
not_data_type: Optional[str] = None,
converted_value_sha256: Optional[list[str]] = None,
) -> list[PromptRequestPiece]:
"""
Retrieves a list of PromptRequestPiece objects based on the specified filters.

Args:
orchestrator_id (Optional[str | uuid.UUID], optional): The ID of the orchestrator. Defaults to None.
conversation_id (Optional[str | uuid.UUID], optional): The ID of the conversation. Defaults to None.
prompt_ids (Optional[list[str] | list[uuid.UUID]], optional): A list of prompt IDs. Defaults to None.
labels (Optional[dict[str, str]], optional): A dictionary of labels. Defaults to None.
sent_after (Optional[datetime], optional): Filter for prompts sent after this datetime. Defaults to None.
sent_before (Optional[datetime], optional): Filter for prompts sent before this datetime. Defaults to None.
original_values (Optional[list[str]], optional): A list of original values. Defaults to None.
converted_values (Optional[list[str]], optional): A list of converted values. Defaults to None.
data_type (Optional[str], optional): The data type to filter by. Defaults to None.
not_data_type (Optional[str], optional): The data type to exclude. Defaults to None.
converted_value_sha256 (Optional[list[str]], optional): A list of SHA256 hashes of converted values.
Defaults to None.
Returns:
list[PromptRequestPiece]: A list of PromptRequestPiece objects that match the specified filters.
Raises:
Exception: If there is an error retrieving the prompts,
an exception is logged and an empty list is returned.
"""
conditions = []
if orchestrator_id:
conditions.append(self._get_prompt_pieces_orchestrator_conditions(orchestrator_id=str(orchestrator_id)))
Expand Down Expand Up @@ -374,28 +396,6 @@ def duplicate_conversation_excluding_last_turn(

return new_conversation_id

def export_conversation_by_orchestrator_id(
self, *, orchestrator_id: str, file_path: Path = None, export_type: str = "json"
):
"""
Exports conversation data with the given orchestrator ID to a specified file.
This will contain all conversations that were sent by the same orchestrator.

Args:
orchestrator_id (str): The ID of the orchestrator from which to export conversations.
file_path (str): The path to the file where the data will be exported.
If not provided, a default path using RESULTS_PATH will be constructed.
export_type (str): The format of the export. Defaults to "json".
"""
data = self.get_prompt_request_pieces(orchestrator_id=orchestrator_id)

# If file_path is not provided, construct a default using the exporter's results_path
if not file_path:
file_name = f"{str(orchestrator_id)}.{export_type}"
file_path = RESULTS_PATH / file_name

self.exporter.export_data(data, file_path=file_path, export_type=export_type)

def add_request_response_to_memory(self, *, request: PromptRequestResponse) -> None:
"""
Inserts a list of prompt request pieces into the memory storage.
Expand Down Expand Up @@ -523,25 +523,6 @@ def get_chat_messages_with_conversation_id(self, *, conversation_id: str) -> lis
memory_entries = self.get_prompt_request_pieces(conversation_id=conversation_id)
return [ChatMessage(role=me.role, content=me.converted_value) for me in memory_entries] # type: ignore

def export_conversation_by_id(self, *, conversation_id: str, file_path: Path = None, export_type: str = "json"):
"""
Exports conversation data with the given conversation ID to a specified file.

Args:
conversation_id (str): The ID of the conversation to export.
file_path (str): The path to the file where the data will be exported.
If not provided, a default path using RESULTS_PATH will be constructed.
export_type (str): The format of the export. Defaults to "json".
"""
data = self.get_prompt_request_pieces(conversation_id=conversation_id)

# If file_path is not provided, construct a default using the exporter's results_path
if not file_path:
file_name = f"{conversation_id}.{export_type}"
file_path = RESULTS_PATH / file_name

self.exporter.export_data(data, file_path=file_path, export_type=export_type)

def get_seed_prompts(
self,
*,
Expand Down Expand Up @@ -740,19 +721,61 @@ def get_seed_prompt_groups(
seed_prompt_groups = SeedPromptDataset.group_seed_prompts_by_prompt_group_id(seed_prompts)
return seed_prompt_groups

def export_all_conversations(self, *, file_path: Optional[Path] = None, export_type: str = "json"):
def export_conversations(
self,
*,
orchestrator_id: Optional[str | uuid.UUID] = None,
conversation_id: Optional[str | uuid.UUID] = None,
prompt_ids: Optional[list[str] | list[uuid.UUID]] = None,
labels: Optional[dict[str, str]] = None,
sent_after: Optional[datetime] = None,
sent_before: Optional[datetime] = None,
original_values: Optional[list[str]] = None,
converted_values: Optional[list[str]] = None,
data_type: Optional[str] = None,
not_data_type: Optional[str] = None,
converted_value_sha256: Optional[list[str]] = None,
file_path: Optional[Path] = None,
export_type: str = "json",
):
"""
Exports all conversations with scores to a specified file.
Exports conversation data with the given inputs to a specified file.
Defaults to all conversations if no filters are provided.

Args:
file_path (str): The path to the file where the data will be exported.
If not provided, a default path using RESULTS_PATH will be constructed.
export_type (str): The format of the export. Defaults to "json".
"""
all_prompt_pieces = self.get_prompt_request_pieces()
orchestrator_id (Optional[str | uuid.UUID], optional): The ID of the orchestrator. Defaults to None.
conversation_id (Optional[str | uuid.UUID], optional): The ID of the conversation. Defaults to None.
prompt_ids (Optional[list[str] | list[uuid.UUID]], optional): A list of prompt IDs. Defaults to None.
labels (Optional[dict[str, str]], optional): A dictionary of labels. Defaults to None.
sent_after (Optional[datetime], optional): Filter for prompts sent after this datetime. Defaults to None.
sent_before (Optional[datetime], optional): Filter for prompts sent before this datetime. Defaults to None.
original_values (Optional[list[str]], optional): A list of original values. Defaults to None.
converted_values (Optional[list[str]], optional): A list of converted values. Defaults to None.
data_type (Optional[str], optional): The data type to filter by. Defaults to None.
not_data_type (Optional[str], optional): The data type to exclude. Defaults to None.
converted_value_sha256 (Optional[list[str]], optional): A list of SHA256 hashes of converted values.
Defaults to None.
file_path (Optional[Path], optional): The path to the file where the data will be exported.
Defaults to None.
export_type (str, optional): The format of the export. Defaults to "json".
"""
data = self.get_prompt_request_pieces(
orchestrator_id=orchestrator_id,
conversation_id=conversation_id,
prompt_ids=prompt_ids,
labels=labels,
sent_after=sent_after,
sent_before=sent_before,
original_values=original_values,
converted_values=converted_values,
data_type=data_type,
not_data_type=not_data_type,
converted_value_sha256=converted_value_sha256,
)

# If file_path is not provided, construct a default using the exporter's results_path
if not file_path:
file_name = f"conversations.{export_type}"
file_name = f"exported_conversations_on_{datetime.now().strftime('%Y_%m_%d')}.{export_type}"
file_path = RESULTS_PATH / file_name

self.exporter.export_data(all_prompt_pieces, file_path=file_path, export_type=export_type)
self.exporter.export_data(data, file_path=file_path, export_type=export_type)
8 changes: 4 additions & 4 deletions tests/unit/memory/test_memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,15 +555,15 @@ def test_export_conversation_by_orchestrator_id_file_created(
):
orchestrator1_id = sample_conversations[0].orchestrator_identifier["id"]

# Default path in export_conversation_by_orchestrator_id()
# Default path in export_conversations()
file_name = f"{orchestrator1_id}.json"
file_path = Path(RESULTS_PATH, file_name)

duckdb_instance.exporter = MemoryExporter()

with patch("pyrit.memory.duckdb_memory.DuckDBMemory.get_prompt_request_pieces") as mock_get:
mock_get.return_value = sample_conversations
duckdb_instance.export_conversation_by_orchestrator_id(orchestrator_id=orchestrator1_id)
duckdb_instance.export_conversations(orchestrator_id=orchestrator1_id, file_path=file_path)

# Verify file was created
assert file_path.exists()
Expand Down Expand Up @@ -1245,7 +1245,7 @@ def test_export_all_conversations_with_scores_correct_data(duckdb_instance: Memo
mock_get_pieces.return_value = [MagicMock(original_prompt_id="1234", converted_value="sample piece")]
mock_get_scores.return_value = [MagicMock(prompt_request_response_id="1234", score_value=10)]

duckdb_instance.export_all_conversations(file_path=file_path)
duckdb_instance.export_conversations(file_path=file_path)

pos_arg, named_args = mock_export_data.call_args
assert str(named_args["file_path"]) == temp_file.file.name
Expand All @@ -1268,7 +1268,7 @@ def test_export_all_conversations_with_scores_empty_data(duckdb_instance: Memory
mock_get_pieces.return_value = []
mock_get_scores.return_value = []

duckdb_instance.export_all_conversations(file_path=file_path)
duckdb_instance.export_conversations(file_path=file_path)
mock_export_data.assert_called_once_with(expected_data, file_path=file_path, export_type="json")


Expand Down
Loading