1
- import importlib
2
- import os
1
+ import math
3
2
from typing import Callable , Optional , Union
4
3
5
4
import torch
@@ -45,16 +44,22 @@ class DetectJailbreak(Validator):
45
44
TEXT_CLASSIFIER_NAME = "zhx123/ftrobertallm"
46
45
TEXT_CLASSIFIER_PASS_LABEL = 0
47
46
TEXT_CLASSIFIER_FAIL_LABEL = 1
47
+
48
48
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
49
49
DEFAULT_KNOWN_PROMPT_MATCH_THRESHOLD = 0.9
50
50
MALICIOUS_EMBEDDINGS = KNOWN_ATTACKS
51
- SATURATION_CLASSIFIER_NAME = "prompt_saturation_detector_v3_1_final.pth"
51
+
52
52
SATURATION_CLASSIFIER_PASS_LABEL = "safe"
53
53
SATURATION_CLASSIFIER_FAIL_LABEL = "jailbreak"
54
54
55
+ # These were found with a basic low-granularity beam search.
56
+ DEFAULT_KNOWN_ATTACK_SCALE_FACTORS = (0.5 , 2.0 )
57
+ DEFAULT_SATURATION_ATTACK_SCALE_FACTORS = (3.5 , 2.5 )
58
+ DEFAULT_TEXT_CLASSIFIER_SCALE_FACTORS = (3.0 , 2.5 )
59
+
55
60
def __init__ (
56
61
self ,
57
- threshold : float = 0.515 ,
62
+ threshold : float = 0.81 ,
58
63
device : str = "cpu" ,
59
64
on_fail : Optional [Callable ] = None ,
60
65
):
@@ -79,6 +84,15 @@ def __init__(
79
84
).to (device )
80
85
self .known_malicious_embeddings = self ._embed (KNOWN_ATTACKS )
81
86
87
+ # These _are_ modifyable, but not explicitly advertised.
88
+ self .known_attack_scales = DetectJailbreak .DEFAULT_KNOWN_ATTACK_SCALE_FACTORS
89
+ self .saturation_attack_scales = DetectJailbreak .DEFAULT_SATURATION_ATTACK_SCALE_FACTORS
90
+ self .text_attack_scales = DetectJailbreak .DEFAULT_TEXT_CLASSIFIER_SCALE_FACTORS
91
+
92
+ @staticmethod
93
+ def _rescale (x : float , a : float = 1.0 , b : float = 1.0 ):
94
+ return 1.0 / (1.0 + (a * math .exp (- b * x )))
95
+
82
96
@staticmethod
83
97
def _mean_pool (model_output , attention_mask ):
84
98
"""Taken from https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2."""
@@ -124,7 +138,10 @@ def _match_known_malicious_prompts(
124
138
prompt_embeddings = prompts
125
139
# These are already normalized. We don't need to divide by magnitudes again.
126
140
distances = prompt_embeddings @ self .known_malicious_embeddings .T
127
- return torch .max (distances , axis = 1 ).values .tolist ()
141
+ return [
142
+ DetectJailbreak ._rescale (s , * self .known_attack_scales )
143
+ for s in (torch .max (distances , axis = 1 ).values ).tolist ()
144
+ ]
128
145
129
146
def _predict_and_remap (
130
147
self ,
@@ -143,36 +160,49 @@ def _predict_and_remap(
143
160
assert pred [label_field ] in {safe_case , fail_case } \
144
161
and 0.0 <= old_score <= 1.0
145
162
if is_safe :
146
- scores . append ( 0.5 - (old_score * 0.5 ) )
163
+ new_score = 0.5 - (old_score * 0.5 )
147
164
else :
148
- scores .append (0.5 + (old_score * 0.5 ))
165
+ new_score = 0.5 + (old_score * 0.5 )
166
+ scores .append (new_score )
149
167
return scores
150
168
151
169
def _predict_jailbreak (self , prompts : list [str ]) -> list [float ]:
152
- return self ._predict_and_remap (
153
- self .text_classifier ,
154
- prompts ,
155
- "label" ,
156
- "score" ,
157
- self .TEXT_CLASSIFIER_PASS_LABEL ,
158
- self .TEXT_CLASSIFIER_FAIL_LABEL ,
159
- )
170
+ return [
171
+ DetectJailbreak ._rescale (s , * self .text_attack_scales )
172
+ for s in self ._predict_and_remap (
173
+ self .text_classifier ,
174
+ prompts ,
175
+ "label" ,
176
+ "score" ,
177
+ self .TEXT_CLASSIFIER_PASS_LABEL ,
178
+ self .TEXT_CLASSIFIER_FAIL_LABEL ,
179
+ )
180
+ ]
160
181
161
182
def _predict_saturation (self , prompts : list [str ]) -> list [float ]:
162
- return self ._predict_and_remap (
163
- self .saturation_attack_detector ,
164
- prompts ,
165
- "label" ,
166
- "score" ,
167
- self .SATURATION_CLASSIFIER_PASS_LABEL ,
168
- self .SATURATION_CLASSIFIER_FAIL_LABEL ,
169
- )
183
+ return [
184
+ DetectJailbreak ._rescale (
185
+ s ,
186
+ self .saturation_attack_scales [0 ],
187
+ self .saturation_attack_scales [1 ],
188
+ ) for s in self ._predict_and_remap (
189
+ self .saturation_attack_detector ,
190
+ prompts ,
191
+ "label" ,
192
+ "score" ,
193
+ self .SATURATION_CLASSIFIER_PASS_LABEL ,
194
+ self .SATURATION_CLASSIFIER_FAIL_LABEL ,
195
+ )
196
+ ]
170
197
171
198
def predict_jailbreak (
172
199
self ,
173
200
prompts : list [str ],
174
201
reduction_function : Optional [Callable ] = max ,
175
202
) -> Union [list [float ], list [dict ]]:
203
+ if isinstance (prompts , str ):
204
+ print ("WARN: predict_jailbreak should be called with a list of strings." )
205
+ prompts = [prompts , ]
176
206
known_attack_scores = self ._match_known_malicious_prompts (prompts )
177
207
saturation_scores = self ._predict_saturation (prompts )
178
208
predicted_scores = self ._predict_jailbreak (prompts )
@@ -191,7 +221,6 @@ def predict_jailbreak(
191
221
zip (known_attack_scores , saturation_scores , predicted_scores )
192
222
]
193
223
194
-
195
224
def validate (
196
225
self ,
197
226
value : Union [str , list [str ]],
0 commit comments