Skip to content

Commit 961e7a3

Browse files
Merge pull request #1 from guardrails-ai/jc/add_attack_type_scaling
Rescale the different sub-classifiers to make the scores more aligned.
2 parents 13b00ff + e22dc55 commit 961e7a3

File tree

2 files changed

+53
-25
lines changed

2 files changed

+53
-25
lines changed

validator/main.py

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import importlib
2-
import os
1+
import math
32
from typing import Callable, Optional, Union
43

54
import torch
@@ -45,16 +44,22 @@ class DetectJailbreak(Validator):
4544
TEXT_CLASSIFIER_NAME = "zhx123/ftrobertallm"
4645
TEXT_CLASSIFIER_PASS_LABEL = 0
4746
TEXT_CLASSIFIER_FAIL_LABEL = 1
47+
4848
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
4949
DEFAULT_KNOWN_PROMPT_MATCH_THRESHOLD = 0.9
5050
MALICIOUS_EMBEDDINGS = KNOWN_ATTACKS
51-
SATURATION_CLASSIFIER_NAME = "prompt_saturation_detector_v3_1_final.pth"
51+
5252
SATURATION_CLASSIFIER_PASS_LABEL = "safe"
5353
SATURATION_CLASSIFIER_FAIL_LABEL = "jailbreak"
5454

55+
# These were found with a basic low-granularity beam search.
56+
DEFAULT_KNOWN_ATTACK_SCALE_FACTORS = (0.5, 2.0)
57+
DEFAULT_SATURATION_ATTACK_SCALE_FACTORS = (3.5, 2.5)
58+
DEFAULT_TEXT_CLASSIFIER_SCALE_FACTORS = (3.0, 2.5)
59+
5560
def __init__(
5661
self,
57-
threshold: float = 0.515,
62+
threshold: float = 0.81,
5863
device: str = "cpu",
5964
on_fail: Optional[Callable] = None,
6065
):
@@ -79,6 +84,15 @@ def __init__(
7984
).to(device)
8085
self.known_malicious_embeddings = self._embed(KNOWN_ATTACKS)
8186

87+
# These _are_ modifyable, but not explicitly advertised.
88+
self.known_attack_scales = DetectJailbreak.DEFAULT_KNOWN_ATTACK_SCALE_FACTORS
89+
self.saturation_attack_scales = DetectJailbreak.DEFAULT_SATURATION_ATTACK_SCALE_FACTORS
90+
self.text_attack_scales = DetectJailbreak.DEFAULT_TEXT_CLASSIFIER_SCALE_FACTORS
91+
92+
@staticmethod
93+
def _rescale(x: float, a: float = 1.0, b: float = 1.0):
94+
return 1.0 / (1.0 + (a*math.exp(-b*x)))
95+
8296
@staticmethod
8397
def _mean_pool(model_output, attention_mask):
8498
"""Taken from https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2."""
@@ -124,7 +138,10 @@ def _match_known_malicious_prompts(
124138
prompt_embeddings = prompts
125139
# These are already normalized. We don't need to divide by magnitudes again.
126140
distances = prompt_embeddings @ self.known_malicious_embeddings.T
127-
return torch.max(distances, axis=1).values.tolist()
141+
return [
142+
DetectJailbreak._rescale(s, *self.known_attack_scales)
143+
for s in (torch.max(distances, axis=1).values).tolist()
144+
]
128145

129146
def _predict_and_remap(
130147
self,
@@ -143,36 +160,49 @@ def _predict_and_remap(
143160
assert pred[label_field] in {safe_case, fail_case} \
144161
and 0.0 <= old_score <= 1.0
145162
if is_safe:
146-
scores.append(0.5 - (old_score * 0.5))
163+
new_score = 0.5 - (old_score * 0.5)
147164
else:
148-
scores.append(0.5 + (old_score * 0.5))
165+
new_score = 0.5 + (old_score * 0.5)
166+
scores.append(new_score)
149167
return scores
150168

151169
def _predict_jailbreak(self, prompts: list[str]) -> list[float]:
152-
return self._predict_and_remap(
153-
self.text_classifier,
154-
prompts,
155-
"label",
156-
"score",
157-
self.TEXT_CLASSIFIER_PASS_LABEL,
158-
self.TEXT_CLASSIFIER_FAIL_LABEL,
159-
)
170+
return [
171+
DetectJailbreak._rescale(s, *self.text_attack_scales)
172+
for s in self._predict_and_remap(
173+
self.text_classifier,
174+
prompts,
175+
"label",
176+
"score",
177+
self.TEXT_CLASSIFIER_PASS_LABEL,
178+
self.TEXT_CLASSIFIER_FAIL_LABEL,
179+
)
180+
]
160181

161182
def _predict_saturation(self, prompts: list[str]) -> list[float]:
162-
return self._predict_and_remap(
163-
self.saturation_attack_detector,
164-
prompts,
165-
"label",
166-
"score",
167-
self.SATURATION_CLASSIFIER_PASS_LABEL,
168-
self.SATURATION_CLASSIFIER_FAIL_LABEL,
169-
)
183+
return [
184+
DetectJailbreak._rescale(
185+
s,
186+
self.saturation_attack_scales[0],
187+
self.saturation_attack_scales[1],
188+
) for s in self._predict_and_remap(
189+
self.saturation_attack_detector,
190+
prompts,
191+
"label",
192+
"score",
193+
self.SATURATION_CLASSIFIER_PASS_LABEL,
194+
self.SATURATION_CLASSIFIER_FAIL_LABEL,
195+
)
196+
]
170197

171198
def predict_jailbreak(
172199
self,
173200
prompts: list[str],
174201
reduction_function: Optional[Callable] = max,
175202
) -> Union[list[float], list[dict]]:
203+
if isinstance(prompts, str):
204+
print("WARN: predict_jailbreak should be called with a list of strings.")
205+
prompts = [prompts, ]
176206
known_attack_scores = self._match_known_malicious_prompts(prompts)
177207
saturation_scores = self._predict_saturation(prompts)
178208
predicted_scores = self._predict_jailbreak(prompts)
@@ -191,7 +221,6 @@ def predict_jailbreak(
191221
zip(known_attack_scores, saturation_scores, predicted_scores)
192222
]
193223

194-
195224
def validate(
196225
self,
197226
value: Union[str, list[str]],

validator/models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
from typing import Optional, Union
32

43
import numpy

0 commit comments

Comments
 (0)