Skip to content

Commit

Permalink
FEAT Adding labels for individual prompts (#624)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbolor21 authored Dec 22, 2024
1 parent 4d38c73 commit af72dc4
Show file tree
Hide file tree
Showing 12 changed files with 81 additions and 21 deletions.
7 changes: 6 additions & 1 deletion doc/code/targets/5_multi_modal_targets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,13 @@
"source": [
"import pathlib\n",
"from pyrit.prompt_target import OpenAIChatTarget\n",
"from pyrit.memory import CentralMemory, DuckDBMemory\n",
"from pyrit.prompt_normalizer import NormalizerRequestPiece, NormalizerRequest\n",
"from pyrit.orchestrator import PromptSendingOrchestrator\n",
"\n",
"\n",
"CentralMemory.set_memory_instance(DuckDBMemory())\n",
"\n",
"azure_openai_gpt4o_chat_target = OpenAIChatTarget()\n",
"\n",
"image_path = pathlib.Path(\".\") / \"..\" / \"..\" / \"..\" / \"assets\" / \"pyrit_architecture.png\"\n",
Expand All @@ -206,15 +210,16 @@
" NormalizerRequestPiece(\n",
" prompt_value=\"Describe this picture:\",\n",
" prompt_data_type=\"text\",\n",
" labels = {\"harm\": \"sample_harm_category\"}\n",
" ),\n",
" NormalizerRequestPiece(\n",
" prompt_value=str(image_path),\n",
" prompt_data_type=\"image_path\",\n",
" labels = {\"harm\": \"sample_other_harm_category\"}\n",
" ),\n",
" ]\n",
")\n",
"\n",
"\n",
"with PromptSendingOrchestrator(objective_target=azure_openai_gpt4o_chat_target) as orchestrator:\n",
" await orchestrator.send_normalizer_requests_async(prompt_request_list=[normalizer_request]) # type: ignore\n",
" memory = orchestrator.get_memory()\n",
Expand Down
13 changes: 7 additions & 6 deletions doc/code/targets/5_multi_modal_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.2
# jupytext_version: 1.16.4
# kernelspec:
# display_name: pyrit-311
# language: python
Expand Down Expand Up @@ -97,9 +97,13 @@
# %%
import pathlib
from pyrit.prompt_target import OpenAIChatTarget
from pyrit.memory import CentralMemory, DuckDBMemory
from pyrit.prompt_normalizer import NormalizerRequestPiece, NormalizerRequest
from pyrit.orchestrator import PromptSendingOrchestrator


CentralMemory.set_memory_instance(DuckDBMemory())

azure_openai_gpt4o_chat_target = OpenAIChatTarget()

image_path = pathlib.Path(".") / ".." / ".." / ".." / "assets" / "pyrit_architecture.png"
Expand All @@ -115,17 +119,14 @@
normalizer_request = NormalizerRequest(
request_pieces=[
NormalizerRequestPiece(
prompt_value="Describe this picture:",
prompt_data_type="text",
prompt_value="Describe this picture:", prompt_data_type="text", labels={"harm": "sample_harm_category"}
),
NormalizerRequestPiece(
prompt_value=str(image_path),
prompt_data_type="image_path",
prompt_value=str(image_path), prompt_data_type="image_path", labels={"harm": "sample_other_harm_category"}
),
]
)


with PromptSendingOrchestrator(objective_target=azure_openai_gpt4o_chat_target) as orchestrator:
await orchestrator.send_normalizer_requests_async(prompt_request_list=[normalizer_request]) # type: ignore
memory = orchestrator.get_memory()
Expand Down
2 changes: 1 addition & 1 deletion pyrit/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
16 changes: 16 additions & 0 deletions pyrit/common/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.


def combine_dict(existing_dict: dict[str, str] = None, new_dict: dict[str, str] = None) -> dict[str, str]:
"""
Combines two dictionaries containing string keys and values into one
Args:
existing_dict: Dictionary with existing values
new_dict: Dictionary with new values to be added to the existing dictionary.
Note if there's a key clash, the value in new_dict will be used.
Returns: combined dictionary
"""
result = existing_dict or {}
result.update(new_dict or {})
return result
3 changes: 2 additions & 1 deletion pyrit/orchestrator/multi_turn/crescendo_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional
from uuid import uuid4

from pyrit.common.utils import combine_dict
from pyrit.common.path import DATASETS_PATH
from pyrit.exceptions import (
InvalidJsonException,
Expand Down Expand Up @@ -178,7 +179,7 @@ async def run_attack_async(
adversarial_chat_conversation_id = str(uuid4())
objective_target_conversation_id = str(uuid4())

updated_memory_labels = self._combine_with_global_memory_labels(memory_labels=memory_labels)
updated_memory_labels = combine_dict(existing_dict=self._global_memory_labels, new_dict=memory_labels)

adversarial_chat_system_prompt = self._adversarial_chat_system_seed_prompt.render_template_value(
objective=objective,
Expand Down
3 changes: 2 additions & 1 deletion pyrit/orchestrator/multi_turn/red_teaming_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional, Union
from uuid import uuid4

from pyrit.common.utils import combine_dict
from pyrit.common.path import RED_TEAM_ORCHESTRATOR_PATH
from pyrit.models import PromptRequestPiece, Score
from pyrit.orchestrator import MultiTurnOrchestrator, MultiTurnAttackResult
Expand Down Expand Up @@ -139,7 +140,7 @@ async def run_attack_async(
objective_target_conversation_id = str(uuid4())
adversarial_chat_conversation_id = str(uuid4())

updated_memory_labels = self._combine_with_global_memory_labels(memory_labels)
updated_memory_labels = combine_dict(existing_dict=self._global_memory_labels, new_dict=memory_labels)

# Prepare the conversation by adding any provided messages to memory.
# If there is no prepended conversation, the turn count is 1.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from treelib import Tree
from typing import Optional

from pyrit.common.utils import combine_dict
from pyrit.common.path import DATASETS_PATH
from pyrit.memory import MemoryInterface
from pyrit.models import SeedPrompt
Expand Down Expand Up @@ -143,7 +144,7 @@ async def run_attack_async(

best_conversation_id = None

updated_memory_labels = self._combine_with_global_memory_labels(memory_labels)
updated_memory_labels = combine_dict(existing_dict=self._global_memory_labels, new_dict=memory_labels)

for iteration in range(1, self._attack_depth + 1):
logger.info(f"Starting iteration number: {iteration}")
Expand Down
7 changes: 0 additions & 7 deletions pyrit/orchestrator/orchestrator_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,6 @@ def _create_normalizer_request(
request = NormalizerRequest(request_pieces=[request_piece], conversation_id=conversation_id)
return request

def _combine_with_global_memory_labels(self, memory_labels: dict[str, str]) -> dict[str, str]:
"""
Combines the global memory labels with the provided memory labels.
The passed memory_labels take precedence with collisions.
"""
return {**(self._global_memory_labels or {}), **(memory_labels or {})}

def get_memory(self):
"""
Retrieves the memory associated with this orchestrator.
Expand Down
5 changes: 3 additions & 2 deletions pyrit/orchestrator/single_turn/prompt_sending_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from colorama import Fore, Style
from typing import Optional

from pyrit.common.utils import combine_dict
from pyrit.common.display_response import display_image_response
from pyrit.models import PromptDataType, PromptRequestResponse
from pyrit.prompt_normalizer import NormalizerRequest, PromptNormalizer
Expand Down Expand Up @@ -83,7 +84,7 @@ async def send_normalizer_requests_async(
responses: list[PromptRequestResponse] = await self._prompt_normalizer.send_prompt_batch_to_target_async(
requests=prompt_request_list,
target=self._prompt_target,
labels=self._combine_with_global_memory_labels(memory_labels),
labels=combine_dict(existing_dict=self._global_memory_labels, new_dict=memory_labels),
orchestrator_identifier=self.get_identifier(),
batch_size=self._batch_size,
)
Expand Down Expand Up @@ -139,7 +140,7 @@ async def send_prompts_async(

return await self.send_normalizer_requests_async(
prompt_request_list=requests,
memory_labels=self._combine_with_global_memory_labels(memory_labels),
memory_labels=combine_dict(existing_dict=self._global_memory_labels, new_dict=memory_labels),
)

async def print_conversations_async(self):
Expand Down
4 changes: 4 additions & 0 deletions pyrit/prompt_normalizer/normalizer_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT license.

import abc
from typing import Optional
from pyrit.models import data_serializer_factory, PromptDataType
from pyrit.prompt_converter import PromptConverter
from pyrit.prompt_normalizer.prompt_response_converter_configuration import PromptResponseConverterConfiguration
Expand All @@ -15,6 +16,7 @@ def __init__(
prompt_value: str,
prompt_data_type: PromptDataType,
request_converters: list[PromptConverter] = [],
labels: Optional[dict[str, str]] = None,
metadata: str = None,
) -> None:
"""
Expand All @@ -27,6 +29,7 @@ def __init__(
request_converters (list[PromptConverter]): A list of PromptConverter objects.
prompt_value (str): The prompt value.
prompt_data_type (PromptDataType): The data type of the prompt.
labels (Optional[dict[str, str]]): The labels to apply to the prompt. Defaults to None.
metadata (str, Optional): Additional metadata. Defaults to None.
Raises:
Expand All @@ -37,6 +40,7 @@ def __init__(
self.request_converters = request_converters
self.prompt_value = prompt_value
self.prompt_data_type = prompt_data_type
self.labels = labels
self.metadata = metadata

self.validate()
Expand Down
5 changes: 4 additions & 1 deletion pyrit/prompt_normalizer/prompt_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional
from uuid import uuid4

from pyrit.common.utils import combine_dict
from pyrit.common.batch_helper import batch_task_async
from pyrit.exceptions import EmptyResponseException
from pyrit.memory import MemoryInterface, CentralMemory
Expand Down Expand Up @@ -201,14 +202,16 @@ async def _build_prompt_request_response(
prompt_data_type=request_piece.prompt_data_type,
)

combined_memory_labels = combine_dict(existing_dict=labels, new_dict=request_piece.labels)

converter_identifiers = [converter.get_identifier() for converter in request_piece.request_converters]
prompt_request_piece = PromptRequestPiece(
role="user",
original_value=request_piece.prompt_value,
converted_value=converted_prompt_text,
conversation_id=conversation_id,
sequence=sequence,
labels=labels,
labels=combined_memory_labels,
prompt_metadata=request_piece.metadata,
converter_identifiers=converter_identifiers,
prompt_target_identifier=target.get_identifier(),
Expand Down
34 changes: 34 additions & 0 deletions tests/unit/common/test_helper_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from pyrit.common.utils import combine_dict


def test_combine_non_empty_dict():
dict1 = {"a": "b"}
dict2 = {"c": "d"}
assert combine_dict(dict1, dict2) == {"a": "b", "c": "d"}


def test_combine_empty_dict():
dict1 = {}
dict2 = {}
assert combine_dict(dict1, dict2) == {}


def test_combine_first_empty_dict():
dict1 = {"a": "b"}
dict2 = {}
assert combine_dict(dict1, dict2) == {"a": "b"}


def test_combine_second_empty_dict():
dict1 = {}
dict2 = {"c": "d"}
assert combine_dict(dict1, dict2) == {"c": "d"}


def test_combine_dict_same_keys():
dict1 = {"c": "b"}
dict2 = {"c": "d"}
assert combine_dict(dict1, dict2) == {"c": "d"}

0 comments on commit af72dc4

Please sign in to comment.