Skip to content

Commit 6dfa0c1

Browse files
authored
FEAT Add Scores to Data Export with PromptRequestPiece data (#617)
1 parent af72dc4 commit 6dfa0c1

13 files changed

+512
-280
lines changed

doc/code/memory/9_exporting_data.ipynb

+109-131
Large diffs are not rendered by default.

doc/code/memory/9_exporting_data.py

+98-13
Original file line numberDiff line numberDiff line change
@@ -6,41 +6,126 @@
66
# format_name: percent
77
# format_version: '1.3'
88
# jupytext_version: 1.16.4
9+
# kernelspec:
10+
# display_name: pyrit-kernel
11+
# language: python
12+
# name: pyrit-kernel
913
# ---
1014

1115
# %% [markdown]
1216
# # 9. Exporting Data Example
1317
#
14-
# This notebook shows all the different ways to export data from memory. This first example exports all conversations from Azure SQL memory with their respective score values in a JSON format. Without using the database query editor, these export functions allow for a quick approach to exporting data from memory.
18+
# This notebook shows different ways to export data from memory. This first example exports all conversations from local DuckDB memory with their respective score values in a JSON format. The data can currently be exported both as JSON file or a CSV file that will be saved in your results folder within PyRIT. The CSV export is commented out below. In this example, all conversations are exported, but by using other export functions from `memory_interface`, we can export by specific labels and other methods.
1519

1620
# %%
17-
from pyrit.memory.azure_sql_memory import AzureSQLMemory
21+
from uuid import uuid4
22+
1823
from pyrit.common import default_values
19-
from pathlib import Path
24+
from pyrit.common.path import RESULTS_PATH
25+
from pyrit.memory import DuckDBMemory, CentralMemory
26+
from pyrit.models import PromptRequestPiece, PromptRequestResponse
2027

2128
default_values.load_environment_files()
2229

23-
memory = AzureSQLMemory()
30+
memory = DuckDBMemory()
31+
CentralMemory.set_memory_instance(memory)
32+
33+
conversation_id = str(uuid4())
34+
35+
print(conversation_id)
36+
37+
message_list = [
38+
PromptRequestPiece(
39+
role="user", original_value="Hi, chat bot! This is my initial prompt.", conversation_id=conversation_id
40+
),
41+
PromptRequestPiece(
42+
role="assistant", original_value="Nice to meet you! This is my response.", conversation_id=conversation_id
43+
),
44+
PromptRequestPiece(
45+
role="user",
46+
original_value="Wonderful! This is my second prompt to the chat bot!",
47+
conversation_id=conversation_id,
48+
),
49+
]
50+
51+
memory.add_request_response_to_memory(request=PromptRequestResponse([message_list[0]]))
52+
memory.add_request_response_to_memory(request=PromptRequestResponse([message_list[1]]))
53+
memory.add_request_response_to_memory(request=PromptRequestResponse([message_list[2]]))
54+
55+
entries = memory.get_conversation(conversation_id=conversation_id)
56+
57+
for entry in entries:
58+
print(entry)
2459

2560
# Define file path for export
26-
json_file_path = Path("conversation_and_scores_json_example")
27-
# csv_file_path = Path(\"conversation_and_scores_csv_example\")
61+
json_file_path = RESULTS_PATH / "conversation_and_scores_json_example.json"
62+
# csv_file_path = RESULTS_PATH / "conversation_and_scores_csv_example.csv"
2863

29-
# Export the data to a JSON file
30-
conversation_with_scores = memory.export_all_conversations_with_scores(file_path=json_file_path, export_type="json")
64+
# # Export the data to a JSON file
65+
conversation_with_scores = memory.export_all_conversations(file_path=json_file_path, export_type="json")
3166
print(f"Exported conversation with scores to JSON: {json_file_path}")
3267

3368
# Export the data to a CSV file
34-
# conversation_with_scores = memory.export_all_conversations_with_scores(file_path=json_file_path, export_type="csv")
69+
# conversation_with_scores = memory.export_all_conversations(file_path=csv_file_path, export_type="csv")
3570
# print(f"Exported conversation with scores to CSV: {csv_file_path}")
3671

3772
# %% [markdown]
38-
# ## Importing Data as NumPy DataFrame
39-
#
40-
# You can use the exported JSON or CSV files to import the data as a NumPy DataFrame. This can be useful for various data manipulation and analysis tasks.
73+
# You can also use the exported JSON or CSV files to import the data as a NumPy DataFrame. This can be useful for various data manipulation and analysis tasks.
4174

4275
# %%
4376
import pandas as pd # type: ignore
4477

4578
df = pd.read_json(json_file_path)
46-
df.head()
79+
df.head(1)
80+
81+
# %% [markdown]
82+
# Next, we can export data from our Azure SQL database. In this example, we export the data by `conversation_id` and to a CSV file.
83+
84+
# %%
85+
from pyrit.memory import AzureSQLMemory
86+
87+
conversation_id = str(uuid4())
88+
89+
message_list = [
90+
PromptRequestPiece(
91+
role="user", original_value="Hi, chat bot! This is my initial prompt.", conversation_id=conversation_id
92+
),
93+
PromptRequestPiece(
94+
role="assistant", original_value="Nice to meet you! This is my response.", conversation_id=conversation_id
95+
),
96+
PromptRequestPiece(
97+
role="user",
98+
original_value="Wonderful! This is my second prompt to the chat bot!",
99+
conversation_id=conversation_id,
100+
),
101+
]
102+
103+
azure_memory = AzureSQLMemory()
104+
CentralMemory.set_memory_instance(azure_memory)
105+
106+
azure_memory.add_request_response_to_memory(request=PromptRequestResponse([message_list[0]]))
107+
azure_memory.add_request_response_to_memory(request=PromptRequestResponse([message_list[1]]))
108+
azure_memory.add_request_response_to_memory(request=PromptRequestResponse([message_list[2]]))
109+
110+
111+
entries = azure_memory.get_conversation(conversation_id=conversation_id)
112+
113+
for entry in entries:
114+
print(entry)
115+
116+
# Define file path for export
117+
# json_file_path = RESULTS_PATH / "conversation_and_scores_json_example.json"
118+
csv_file_path = RESULTS_PATH / "conversation_and_scores_csv_example.csv"
119+
120+
# Export the data to a JSON file
121+
# conversation_with_scores = azure_memory.export_conversation_by_id(conversation_id=conversation_id, file_path=json_file_path, export_type="json")
122+
# print(f"Exported conversation with scores to JSON: {json_file_path}")
123+
124+
# Export the data to a CSV file
125+
conversation_with_scores = azure_memory.export_conversation_by_id(
126+
conversation_id=conversation_id, file_path=json_file_path, export_type="csv"
127+
)
128+
print(f"Exported conversation with scores to CSV: {csv_file_path}")
129+
130+
# Cleanup memory resources
131+
azure_memory.dispose_engine()

pyrit/memory/azure_sql_memory.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,8 @@ def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> li
234234
conditions=PromptMemoryEntry.conversation_id == str(conversation_id),
235235
) # type: ignore
236236

237-
prompt_pieces: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
238-
return prompt_pieces
237+
result = self.get_prompt_request_pieces_with_scores(entries)
238+
return result
239239

240240
except Exception as e:
241241
logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}")
@@ -267,8 +267,9 @@ def get_all_prompt_pieces(self) -> list[PromptRequestPiece]:
267267
"""
268268
Fetches all entries from the specified table and returns them as model instances.
269269
"""
270-
result: list[PromptMemoryEntry] = self.query_entries(PromptMemoryEntry)
271-
return [entry.get_prompt_request_piece() for entry in result]
270+
entries = self.query_entries(PromptMemoryEntry)
271+
result = self.get_prompt_request_pieces_with_scores(entries)
272+
return result
272273

273274
def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[str]) -> list[PromptRequestPiece]:
274275
"""
@@ -285,7 +286,7 @@ def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[str]) -> list[Prom
285286
PromptMemoryEntry,
286287
conditions=PromptMemoryEntry.id.in_(prompt_ids),
287288
)
288-
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
289+
result = self.get_prompt_request_pieces_with_scores(entries)
289290
return result
290291
except Exception as e:
291292
logger.exception(
@@ -319,7 +320,7 @@ def get_prompt_request_piece_by_memory_labels(
319320
sql_condition = text(conditions).bindparams(**{key: str(value) for key, value in memory_labels.items()})
320321

321322
entries = self.query_entries(PromptMemoryEntry, conditions=sql_condition)
322-
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
323+
result = self.get_prompt_request_pieces_with_scores(entries)
323324
return result
324325
except Exception as e:
325326
logger.exception(

pyrit/memory/duckdb_memory.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def get_all_prompt_pieces(self) -> list[PromptRequestPiece]:
8989
Fetches all entries from the specified table and returns them as model instances.
9090
"""
9191
entries = self.query_entries(PromptMemoryEntry)
92-
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
92+
result = self.get_prompt_request_pieces_with_scores(entries)
9393
return result
9494

9595
def get_all_embeddings(self) -> list[EmbeddingDataEntry]:
@@ -113,8 +113,8 @@ def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> li
113113
entries = self.query_entries(
114114
PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == str(conversation_id)
115115
)
116-
prompt_pieces: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
117-
return prompt_pieces
116+
result = self.get_prompt_request_pieces_with_scores(entries)
117+
return result
118118
except Exception as e:
119119
logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}")
120120
return []
@@ -134,7 +134,7 @@ def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[str]) -> list[Prom
134134
PromptMemoryEntry,
135135
conditions=PromptMemoryEntry.id.in_(prompt_ids),
136136
)
137-
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
137+
result = self.get_prompt_request_pieces_with_scores(entries)
138138
return result
139139
except Exception as e:
140140
logger.exception(
@@ -161,7 +161,7 @@ def get_prompt_request_piece_by_memory_labels(
161161
conditions = [PromptMemoryEntry.labels.op("->>")(key) == value for key, value in memory_labels.items()]
162162
query_condition = and_(*conditions)
163163
entries = self.query_entries(PromptMemoryEntry, conditions=query_condition)
164-
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
164+
result = self.get_prompt_request_pieces_with_scores(entries)
165165
return result
166166
except Exception as e:
167167
logger.exception(
@@ -185,7 +185,7 @@ def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: str) -> list[Pr
185185
PromptMemoryEntry,
186186
conditions=PromptMemoryEntry.orchestrator_identifier.op("->>")("id") == orchestrator_id,
187187
) # type: ignore
188-
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
188+
result = self.get_prompt_request_pieces_with_scores(entries)
189189
return result
190190
except Exception as e:
191191
logger.exception(

pyrit/memory/memory_exporter.py

+15-59
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,9 @@
33

44
import csv
55
import json
6-
from typing import Any, Dict, List, Union
7-
import uuid
8-
from datetime import datetime
96
from pathlib import Path
10-
from collections.abc import MutableMapping
117

12-
from sqlalchemy.inspection import inspect
13-
14-
from pyrit.memory.memory_models import Base
8+
from pyrit.models import PromptRequestPiece
159

1610

1711
class MemoryExporter:
@@ -29,14 +23,13 @@ def __init__(self):
2923
}
3024

3125
def export_data(
32-
self, data: Union[List[Base], List[Dict]], *, file_path: Path = None, export_type: str = "json"
26+
self, data: list[PromptRequestPiece], *, file_path: Path = None, export_type: str = "json"
3327
): # type: ignore
3428
"""
3529
Exports the provided data to a file in the specified format.
3630
3731
Args:
38-
data (Union[List[Base], List[Dict]]): The data to be exported, typically a list of SQLAlchemy
39-
model instances or as a list of dictionaries.
32+
data (list[PromptRequestPiece]): The data to be exported, as a list of PromptRequestPiece instances.
4033
file_path (str): The full path, including the file name, where the data will be exported.
4134
export_type (str, Optional): The format for exporting data. Defaults to "json".
4235
@@ -52,36 +45,37 @@ def export_data(
5245
else:
5346
raise ValueError(f"Unsupported export format: {export_type}")
5447

55-
def export_to_json(self, data: Union[List[Base], List[Dict]], file_path: Path = None) -> None: # type: ignore
48+
def export_to_json(self, data: list[PromptRequestPiece], file_path: Path = None) -> None: # type: ignore
5649
"""
5750
Exports the provided data to a JSON file at the specified file path.
5851
Each item in the data list, representing a row from the table,
5952
is converted to a dictionary before being written to the file.
6053
6154
Args:
62-
data (Union[List[Base], List[Dict]]): The data to be exported, as a list of SQLAlchemy model instances
63-
or as a list of dictionaries.
55+
data (list[PromptRequestPiece]): The data to be exported, as a list of PromptRequestPiece instances.
6456
file_path (Path): The full path, including the file name, where the data will be exported.
6557
6658
Raises:
6759
ValueError: If no file_path is provided.
6860
"""
6961
if not file_path:
7062
raise ValueError("Please provide a valid file path for exporting data.")
71-
72-
export_data = [self.model_to_dict(instance) if isinstance(instance, Base) else instance for instance in data]
63+
if not data:
64+
raise ValueError("No data to export.")
65+
export_data = []
66+
for piece in data:
67+
export_data.append(piece.to_dict())
7368
with open(file_path, "w") as f:
7469
json.dump(export_data, f, indent=4)
7570

76-
def export_to_csv(self, data: Union[List[Base], List[Dict]], file_path: Path = None) -> None: # type: ignore
71+
def export_to_csv(self, data: list[PromptRequestPiece], file_path: Path = None) -> None: # type: ignore
7772
"""
7873
Exports the provided data to a CSV file at the specified file path.
7974
Each item in the data list, representing a row from the table,
8075
is converted to a dictionary before being written to the file.
8176
8277
Args:
83-
data (Union[List[Base], List[Dict]]): The data to be exported, as a list of SQLAlchemy model instances
84-
or as a list of dictionaries.
78+
data (list[PromptRequestPiece]): The data to be exported, as a list of PromptRequestPiece instances.
8579
file_path (Path): The full path, including the file name, where the data will be exported.
8680
8781
Raises:
@@ -91,49 +85,11 @@ def export_to_csv(self, data: Union[List[Base], List[Dict]], file_path: Path = N
9185
raise ValueError("Please provide a valid file path for exporting data.")
9286
if not data:
9387
raise ValueError("No data to export.")
94-
95-
export_data = [
96-
_flatten_dict(self.model_to_dict(instance)) if isinstance(instance, Base) else _flatten_dict(instance)
97-
for instance in data
98-
]
88+
export_data = []
89+
for piece in data:
90+
export_data.append(piece.to_dict())
9991
fieldnames = list(export_data[0].keys())
10092
with open(file_path, "w", newline="") as f:
10193
writer = csv.DictWriter(f, fieldnames=fieldnames)
10294
writer.writeheader()
10395
writer.writerows(export_data)
104-
105-
def model_to_dict(self, model_instance: Base): # type: ignore
106-
"""
107-
Converts an SQLAlchemy model instance into a dictionary, serializing
108-
special data types such as UUID and datetime to string representations.
109-
This ensures compatibility with JSON and other serialization formats.
110-
111-
Args:
112-
model_instance: An instance of an SQLAlchemy model.
113-
114-
Returns:
115-
A dictionary representation of the model instance, with special types serialized.
116-
"""
117-
model_dict = {}
118-
for column in inspect(model_instance.__class__).columns:
119-
value = getattr(model_instance, column.name)
120-
if isinstance(value, uuid.UUID):
121-
# Convert UUID to string
122-
model_dict[column.name] = str(value)
123-
elif isinstance(value, datetime):
124-
# Convert datetime to an ISO 8601 formatted string
125-
model_dict[column.name] = value.isoformat()
126-
else:
127-
model_dict[column.name] = value
128-
return model_dict
129-
130-
131-
def _flatten_dict(d: MutableMapping, parent_key: str = "", sep: str = ".") -> MutableMapping:
132-
items: list[tuple[Any, Any]] = []
133-
for k, v in d.items():
134-
new_key = parent_key + sep + k if parent_key else k
135-
if isinstance(v, MutableMapping):
136-
items.extend(_flatten_dict(v, new_key, sep=sep).items())
137-
else:
138-
items.append((new_key, v))
139-
return dict(items)

0 commit comments

Comments
 (0)