diff --git a/app_inference_spec.py b/app_inference_spec.py new file mode 100644 index 0000000..a57888d --- /dev/null +++ b/app_inference_spec.py @@ -0,0 +1,73 @@ +# app_inference_spec.py +# Forked from spec: +# github.com/guardrails-ai/models-host/tree/main/ray#adding-new-inference-endpoints +import os +from typing import Optional + +from fastapi import HTTPException +from pydantic import BaseModel +from models_host.base_inference_spec import BaseInferenceSpec + +from validator import DetectJailbreak + + +class InputRequest(BaseModel): + message: str + threshold: Optional[float] = None + + +class OutputResponse(BaseModel): + classification: str + score: float + is_jailbreak: bool + + +# Using same nomenclature as in Sagemaker classes +class InferenceSpec(BaseInferenceSpec): + def __init__(self): + self.model = None + + @property + def device_name(self): + env = os.environ.get("env", "dev") + # JC: Legacy usage of 'env' as a device. + torch_device = "cuda" if env == "prod" else "cpu" + return torch_device + + def load(self): + print(f"Loading model DetectJailbreak and moving to {self.device_name}...") + self.model = DetectJailbreak(device=self.device_name) + + def process_request(self, input_request: InputRequest): + message = input_request.message + # If needed, sanity check. + # raise HTTPException(status_code=400, detail="Invalid input format") + args = (message,) + kwargs = {} + if input_request.threshold is not None: + kwargs["threshold"] = input_request.threshold + if not 0.0 <= input_request.threshold <= 1.0: + raise HTTPException( + status_code=400, + detail=f"Threshold must be between 0.0 and 1.0. " + f"Got {input_request.threshold}" + ) + return args, kwargs + + def infer(self, message: str, threshold: Optional[float] = None) -> OutputResponse: + if threshold is None: + threshold = 0.81 + + score = self.model.predict_jailbreak([message,])[0] + if score > threshold: + classification = "jailbreak" + is_jailbreak = True + else: + classification = "safe" + is_jailbreak = False + + return OutputResponse( + classification=classification, + score=score, + is_jailbreak=is_jailbreak, + ) diff --git a/validator/models.py b/validator/models.py index 08376f4..10a08d6 100644 --- a/validator/models.py +++ b/validator/models.py @@ -1,17 +1,14 @@ -import os -from pathlib import Path from typing import Optional, Union import numpy import torch import torch.nn as nn -from cached_path import cached_path from .resources import get_tokenizer_and_model_by_path def string_to_one_hot_tensor( - text: Union[str, list[str]], + text: Union[str, list[str], tuple[str]], max_length: int = 2048, left_truncate: bool = True, ) -> torch.Tensor: @@ -32,10 +29,14 @@ def string_to_one_hot_tensor( for idx, t in enumerate(text): if left_truncate: t = t[-max_length:] - out[idx, -len(t):, :] = string_to_one_hot_tensor(t, max_length, left_truncate)[0, :, :] + out[idx, -len(t):, :] = string_to_one_hot_tensor( + t, max_length, left_truncate + )[0, :, :] else: t = t[:max_length] - out[idx, :len(t), :] = string_to_one_hot_tensor(t, max_length, left_truncate)[0, :, :] + out[idx, :len(t), :] = string_to_one_hot_tensor( + t, max_length, left_truncate + )[0, :, :] else: raise Exception("Input was neither a string nor a list of strings.") return out @@ -80,7 +81,7 @@ def forward( x = self.fan_in(x) x = self.lstm1(x)[0] x = self.output_head(x) - x = x[:,-1,0] + x = x[:, -1, 0] x = self.output_activation(x) return x @@ -124,9 +125,14 @@ def forward( longest_sequence = len(x[0]) x = torch.LongTensor(x).to(self.get_current_device()) # TODO: is 1 masked or unmasked? - attention_mask = torch.LongTensor([1] * longest_sequence).to(self.get_current_device()) + attention_mask = torch.LongTensor( + [1] * longest_sequence + ).to(self.get_current_device()) elif isinstance(x, list) or isinstance(x, tuple): - sequences = [self.tokenizer.encode(text, add_special_tokens=True)[-max_size:] for text in x] + sequences = [ + self.tokenizer.encode(text, add_special_tokens=True)[-max_size:] + for text in x + ] for token_list in sequences: longest_sequence = max(longest_sequence, len(token_list)) x = list() @@ -135,16 +141,28 @@ def forward( x.append( ([self.pad_token] * (longest_sequence - len(sequence))) + sequence ) - attention_mask.append([0] * (longest_sequence - len(sequence)) + [1] * len(sequence)) + attention_mask.append( + [0] * (longest_sequence - len(sequence)) + [1] * len(sequence) + ) x = torch.LongTensor(x).to(self.get_current_device()) attention_mask = torch.tensor(attention_mask).to(self.get_current_device()) - #segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] - segments_tensors = torch.zeros(x.shape, dtype=torch.int).to(self.get_current_device()) + # segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] + segments_tensors = torch.zeros(x.shape, dtype=torch.int) \ + .to(self.get_current_device()) if y is not None: - return self.transformer(x, token_type_ids=segments_tensors, attention_mask=attention_mask, labels=y) + return self.transformer( + x, + token_type_ids=segments_tensors, + attention_mask=attention_mask, + labels=y + ) else: - return self.transformer(x, token_type_ids=segments_tensors, attention_mask=attention_mask).logits + return self.transformer( + x, + token_type_ids=segments_tensors, + attention_mask=attention_mask + ).logits class PromptSaturationDetectorV3: # Note: Not nn.Module. @@ -155,7 +173,9 @@ def __init__( device: torch.device = torch.device('cpu'), model_path_override: str = "" ): - from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification + from transformers import ( + pipeline, AutoTokenizer, AutoModelForSequenceClassification + ) if not model_path_override: self.model = AutoModelForSequenceClassification.from_pretrained( "GuardrailsAI/prompt-saturation-attack-detector", diff --git a/validator/post-install.py b/validator/post-install.py index 228c33b..134f013 100644 --- a/validator/post-install.py +++ b/validator/post-install.py @@ -1,8 +1,9 @@ +from transformers import pipeline, AutoTokenizer, AutoModel + print("post-install starting...") # TODO: It's not clear if the DetectJailbreak will be on the path yet. # If we can import Detect Jailbreak, it will be safer to read the names of the models # from the composite model as exposed by DetectJailbreak.XYZ. -from transformers import pipeline, AutoTokenizer, AutoModel print("Fetching model 1 of 3 (Saturation)") AutoModel.from_pretrained("GuardrailsAI/prompt-saturation-attack-detector") AutoTokenizer.from_pretrained("google-bert/bert-base-cased")