Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 53 additions & 24 deletions validator/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import importlib
import os
import math
from typing import Callable, Optional, Union

import torch
Expand Down Expand Up @@ -45,16 +44,22 @@ class DetectJailbreak(Validator):
TEXT_CLASSIFIER_NAME = "zhx123/ftrobertallm"
TEXT_CLASSIFIER_PASS_LABEL = 0
TEXT_CLASSIFIER_FAIL_LABEL = 1

EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
DEFAULT_KNOWN_PROMPT_MATCH_THRESHOLD = 0.9
MALICIOUS_EMBEDDINGS = KNOWN_ATTACKS
SATURATION_CLASSIFIER_NAME = "prompt_saturation_detector_v3_1_final.pth"

SATURATION_CLASSIFIER_PASS_LABEL = "safe"
SATURATION_CLASSIFIER_FAIL_LABEL = "jailbreak"

# These were found with a basic low-granularity beam search.
DEFAULT_KNOWN_ATTACK_SCALE_FACTORS = (0.5, 2.0)
DEFAULT_SATURATION_ATTACK_SCALE_FACTORS = (3.5, 2.5)
DEFAULT_TEXT_CLASSIFIER_SCALE_FACTORS = (3.0, 2.5)

def __init__(
self,
threshold: float = 0.515,
threshold: float = 0.81,
device: str = "cpu",
on_fail: Optional[Callable] = None,
):
Expand All @@ -79,6 +84,15 @@ def __init__(
).to(device)
self.known_malicious_embeddings = self._embed(KNOWN_ATTACKS)

# These _are_ modifyable, but not explicitly advertised.
self.known_attack_scales = DetectJailbreak.DEFAULT_KNOWN_ATTACK_SCALE_FACTORS
self.saturation_attack_scales = DetectJailbreak.DEFAULT_SATURATION_ATTACK_SCALE_FACTORS
self.text_attack_scales = DetectJailbreak.DEFAULT_TEXT_CLASSIFIER_SCALE_FACTORS

@staticmethod
def _rescale(x: float, a: float = 1.0, b: float = 1.0):
return 1.0 / (1.0 + (a*math.exp(-b*x)))

@staticmethod
def _mean_pool(model_output, attention_mask):
"""Taken from https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2."""
Expand Down Expand Up @@ -124,7 +138,10 @@ def _match_known_malicious_prompts(
prompt_embeddings = prompts
# These are already normalized. We don't need to divide by magnitudes again.
distances = prompt_embeddings @ self.known_malicious_embeddings.T
return torch.max(distances, axis=1).values.tolist()
return [
DetectJailbreak._rescale(s, *self.known_attack_scales)
for s in (torch.max(distances, axis=1).values).tolist()
]

def _predict_and_remap(
self,
Expand All @@ -143,36 +160,49 @@ def _predict_and_remap(
assert pred[label_field] in {safe_case, fail_case} \
and 0.0 <= old_score <= 1.0
if is_safe:
scores.append(0.5 - (old_score * 0.5))
new_score = 0.5 - (old_score * 0.5)
else:
scores.append(0.5 + (old_score * 0.5))
new_score = 0.5 + (old_score * 0.5)
scores.append(new_score)
return scores

def _predict_jailbreak(self, prompts: list[str]) -> list[float]:
return self._predict_and_remap(
self.text_classifier,
prompts,
"label",
"score",
self.TEXT_CLASSIFIER_PASS_LABEL,
self.TEXT_CLASSIFIER_FAIL_LABEL,
)
return [
DetectJailbreak._rescale(s, *self.text_attack_scales)
for s in self._predict_and_remap(
self.text_classifier,
prompts,
"label",
"score",
self.TEXT_CLASSIFIER_PASS_LABEL,
self.TEXT_CLASSIFIER_FAIL_LABEL,
)
]

def _predict_saturation(self, prompts: list[str]) -> list[float]:
return self._predict_and_remap(
self.saturation_attack_detector,
prompts,
"label",
"score",
self.SATURATION_CLASSIFIER_PASS_LABEL,
self.SATURATION_CLASSIFIER_FAIL_LABEL,
)
return [
DetectJailbreak._rescale(
s,
self.saturation_attack_scales[0],
self.saturation_attack_scales[1],
) for s in self._predict_and_remap(
self.saturation_attack_detector,
prompts,
"label",
"score",
self.SATURATION_CLASSIFIER_PASS_LABEL,
self.SATURATION_CLASSIFIER_FAIL_LABEL,
)
]

def predict_jailbreak(
self,
prompts: list[str],
reduction_function: Optional[Callable] = max,
) -> Union[list[float], list[dict]]:
if isinstance(prompts, str):
print("WARN: predict_jailbreak should be called with a list of strings.")
prompts = [prompts, ]
known_attack_scores = self._match_known_malicious_prompts(prompts)
saturation_scores = self._predict_saturation(prompts)
predicted_scores = self._predict_jailbreak(prompts)
Expand All @@ -191,7 +221,6 @@ def predict_jailbreak(
zip(known_attack_scores, saturation_scores, predicted_scores)
]


def validate(
self,
value: Union[str, list[str]],
Expand Down
1 change: 0 additions & 1 deletion validator/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from typing import Optional, Union

import numpy
Expand Down
Loading