diff --git a/validator/main.py b/validator/main.py index 445a70a..e0588e5 100644 --- a/validator/main.py +++ b/validator/main.py @@ -1,5 +1,4 @@ -import importlib -import os +import math from typing import Callable, Optional, Union import torch @@ -45,16 +44,22 @@ class DetectJailbreak(Validator): TEXT_CLASSIFIER_NAME = "zhx123/ftrobertallm" TEXT_CLASSIFIER_PASS_LABEL = 0 TEXT_CLASSIFIER_FAIL_LABEL = 1 + EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" DEFAULT_KNOWN_PROMPT_MATCH_THRESHOLD = 0.9 MALICIOUS_EMBEDDINGS = KNOWN_ATTACKS - SATURATION_CLASSIFIER_NAME = "prompt_saturation_detector_v3_1_final.pth" + SATURATION_CLASSIFIER_PASS_LABEL = "safe" SATURATION_CLASSIFIER_FAIL_LABEL = "jailbreak" + # These were found with a basic low-granularity beam search. + DEFAULT_KNOWN_ATTACK_SCALE_FACTORS = (0.5, 2.0) + DEFAULT_SATURATION_ATTACK_SCALE_FACTORS = (3.5, 2.5) + DEFAULT_TEXT_CLASSIFIER_SCALE_FACTORS = (3.0, 2.5) + def __init__( self, - threshold: float = 0.515, + threshold: float = 0.81, device: str = "cpu", on_fail: Optional[Callable] = None, ): @@ -79,6 +84,15 @@ def __init__( ).to(device) self.known_malicious_embeddings = self._embed(KNOWN_ATTACKS) + # These _are_ modifyable, but not explicitly advertised. + self.known_attack_scales = DetectJailbreak.DEFAULT_KNOWN_ATTACK_SCALE_FACTORS + self.saturation_attack_scales = DetectJailbreak.DEFAULT_SATURATION_ATTACK_SCALE_FACTORS + self.text_attack_scales = DetectJailbreak.DEFAULT_TEXT_CLASSIFIER_SCALE_FACTORS + + @staticmethod + def _rescale(x: float, a: float = 1.0, b: float = 1.0): + return 1.0 / (1.0 + (a*math.exp(-b*x))) + @staticmethod def _mean_pool(model_output, attention_mask): """Taken from https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2.""" @@ -124,7 +138,10 @@ def _match_known_malicious_prompts( prompt_embeddings = prompts # These are already normalized. We don't need to divide by magnitudes again. distances = prompt_embeddings @ self.known_malicious_embeddings.T - return torch.max(distances, axis=1).values.tolist() + return [ + DetectJailbreak._rescale(s, *self.known_attack_scales) + for s in (torch.max(distances, axis=1).values).tolist() + ] def _predict_and_remap( self, @@ -143,36 +160,49 @@ def _predict_and_remap( assert pred[label_field] in {safe_case, fail_case} \ and 0.0 <= old_score <= 1.0 if is_safe: - scores.append(0.5 - (old_score * 0.5)) + new_score = 0.5 - (old_score * 0.5) else: - scores.append(0.5 + (old_score * 0.5)) + new_score = 0.5 + (old_score * 0.5) + scores.append(new_score) return scores def _predict_jailbreak(self, prompts: list[str]) -> list[float]: - return self._predict_and_remap( - self.text_classifier, - prompts, - "label", - "score", - self.TEXT_CLASSIFIER_PASS_LABEL, - self.TEXT_CLASSIFIER_FAIL_LABEL, - ) + return [ + DetectJailbreak._rescale(s, *self.text_attack_scales) + for s in self._predict_and_remap( + self.text_classifier, + prompts, + "label", + "score", + self.TEXT_CLASSIFIER_PASS_LABEL, + self.TEXT_CLASSIFIER_FAIL_LABEL, + ) + ] def _predict_saturation(self, prompts: list[str]) -> list[float]: - return self._predict_and_remap( - self.saturation_attack_detector, - prompts, - "label", - "score", - self.SATURATION_CLASSIFIER_PASS_LABEL, - self.SATURATION_CLASSIFIER_FAIL_LABEL, - ) + return [ + DetectJailbreak._rescale( + s, + self.saturation_attack_scales[0], + self.saturation_attack_scales[1], + ) for s in self._predict_and_remap( + self.saturation_attack_detector, + prompts, + "label", + "score", + self.SATURATION_CLASSIFIER_PASS_LABEL, + self.SATURATION_CLASSIFIER_FAIL_LABEL, + ) + ] def predict_jailbreak( self, prompts: list[str], reduction_function: Optional[Callable] = max, ) -> Union[list[float], list[dict]]: + if isinstance(prompts, str): + print("WARN: predict_jailbreak should be called with a list of strings.") + prompts = [prompts, ] known_attack_scores = self._match_known_malicious_prompts(prompts) saturation_scores = self._predict_saturation(prompts) predicted_scores = self._predict_jailbreak(prompts) @@ -191,7 +221,6 @@ def predict_jailbreak( zip(known_attack_scores, saturation_scores, predicted_scores) ] - def validate( self, value: Union[str, list[str]], diff --git a/validator/models.py b/validator/models.py index 2a6ee5f..0a73fba 100644 --- a/validator/models.py +++ b/validator/models.py @@ -1,4 +1,3 @@ -import os from typing import Optional, Union import numpy