diff --git a/validator/main.py b/validator/main.py index 22f6a5a..c695d47 100644 --- a/validator/main.py +++ b/validator/main.py @@ -1,5 +1,5 @@ import math -from typing import Callable, Optional, Union +from typing import Callable, List, Optional, Union import torch from torch.nn import functional as F @@ -140,7 +140,7 @@ def _mean_pool(model_output, attention_mask): input_mask_expanded.sum(1), min=1e-9 ) - def _embed(self, prompts: list[str]): + def _embed(self, prompts: List[str]): """Taken from https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 We use the long-form to avoid a dependency on sentence transformers. This method returns the maximum of the matches against all known attacks. @@ -160,8 +160,8 @@ def _embed(self, prompts: list[str]): def _match_known_malicious_prompts( self, - prompts: list[str] | torch.Tensor, - ) -> list[float]: + prompts: Union[List[str], torch.Tensor], + ) -> List[float]: """Returns an array of floats, one per prompt, with the max match to known attacks. If prompts is a list of strings, embeddings will be generated. If embeddings are passed, they will be used.""" @@ -179,7 +179,7 @@ def _match_known_malicious_prompts( def _predict_and_remap( self, model, - prompts: list[str], + prompts: List[str], label_field: str, score_field: str, safe_case: str, @@ -199,7 +199,7 @@ def _predict_and_remap( scores.append(new_score) return scores - def _predict_jailbreak(self, prompts: list[str]) -> list[float]: + def _predict_jailbreak(self, prompts: List[str]) -> List[float]: return [ DetectJailbreak._rescale(s, *self.text_attack_scales) for s in self._predict_and_remap( @@ -212,7 +212,7 @@ def _predict_jailbreak(self, prompts: list[str]) -> list[float]: ) ] - def _predict_saturation(self, prompts: list[str]) -> list[float]: + def _predict_saturation(self, prompts: List[str]) -> List[float]: return [ DetectJailbreak._rescale( s, @@ -230,9 +230,9 @@ def _predict_saturation(self, prompts: list[str]) -> list[float]: def predict_jailbreak( self, - prompts: list[str], + prompts: List[str], reduction_function: Optional[Callable] = max, - ) -> Union[list[float], list[dict]]: + ) -> Union[List[float], List[dict]]: if isinstance(prompts, str): print("WARN: predict_jailbreak should be called with a list of strings.") prompts = [prompts, ] @@ -256,7 +256,7 @@ def predict_jailbreak( def validate( self, - value: Union[str, list[str]], + value: Union[str, List[str]], metadata: Optional[dict] = None, ) -> ValidationResult: """Validates that will return a failure if the value is a jailbreak attempt. diff --git a/validator/models.py b/validator/models.py index 10a08d6..2bdd3ad 100644 --- a/validator/models.py +++ b/validator/models.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import List, Tuple, Optional, Union import numpy import torch @@ -8,7 +8,7 @@ def string_to_one_hot_tensor( - text: Union[str, list[str], tuple[str]], + text: Union[str, List[str], Tuple[str]], max_length: int = 2048, left_truncate: bool = True, ) -> torch.Tensor: @@ -71,7 +71,7 @@ def get_current_device(self): def forward( self, - x: Union[str, list[str], numpy.ndarray, torch.Tensor] + x: Union[str, List[str], numpy.ndarray, torch.Tensor] ) -> torch.Tensor: if isinstance(x, str) or isinstance(x, list) or isinstance(x, tuple): x = string_to_one_hot_tensor(x).to(self.get_current_device()) @@ -113,7 +113,7 @@ def get_current_device(self): def forward( self, - x: Union[str, list[str], numpy.ndarray, torch.Tensor], + x: Union[str, List[str], numpy.ndarray, torch.Tensor], y: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -209,5 +209,5 @@ def __init__( device=device, ) - def __call__(self, text: Union[str, list[str]]) -> list[dict]: + def __call__(self, text: Union[str, List[str]]) -> List[dict]: return self.pipe(text)