Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions app_inference_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# app_inference_spec.py
# Forked from spec:
# github.com/guardrails-ai/models-host/tree/main/ray#adding-new-inference-endpoints
import os
from typing import Optional

from fastapi import HTTPException
from pydantic import BaseModel
from models_host.base_inference_spec import BaseInferenceSpec

from validator import DetectJailbreak


class InputRequest(BaseModel):
message: str
threshold: Optional[float] = None


class OutputResponse(BaseModel):
classification: str
score: float
is_jailbreak: bool


# Using same nomenclature as in Sagemaker classes
class InferenceSpec(BaseInferenceSpec):
def __init__(self):
self.model = None

@property
def device_name(self):
env = os.environ.get("env", "dev")
# JC: Legacy usage of 'env' as a device.
torch_device = "cuda" if env == "prod" else "cpu"
return torch_device

def load(self):
print(f"Loading model DetectJailbreak and moving to {self.device_name}...")
self.model = DetectJailbreak(device=self.device_name)

def process_request(self, input_request: InputRequest):
message = input_request.message
# If needed, sanity check.
# raise HTTPException(status_code=400, detail="Invalid input format")
args = (message,)
kwargs = {}
if input_request.threshold is not None:
kwargs["threshold"] = input_request.threshold
if not 0.0 <= input_request.threshold <= 1.0:
raise HTTPException(
status_code=400,
detail=f"Threshold must be between 0.0 and 1.0. "
f"Got {input_request.threshold}"
)
return args, kwargs

def infer(self, message: str, threshold: Optional[float] = None) -> OutputResponse:
if threshold is None:
threshold = 0.81

score = self.model.predict_jailbreak([message,])[0]
if score > threshold:
classification = "jailbreak"
is_jailbreak = True
else:
classification = "safe"
is_jailbreak = False

return OutputResponse(
classification=classification,
score=score,
is_jailbreak=is_jailbreak,
)
50 changes: 35 additions & 15 deletions validator/models.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import os
from pathlib import Path
from typing import Optional, Union

import numpy
import torch
import torch.nn as nn
from cached_path import cached_path

from .resources import get_tokenizer_and_model_by_path


def string_to_one_hot_tensor(
text: Union[str, list[str]],
text: Union[str, list[str], tuple[str]],
max_length: int = 2048,
left_truncate: bool = True,
) -> torch.Tensor:
Expand All @@ -32,10 +29,14 @@ def string_to_one_hot_tensor(
for idx, t in enumerate(text):
if left_truncate:
t = t[-max_length:]
out[idx, -len(t):, :] = string_to_one_hot_tensor(t, max_length, left_truncate)[0, :, :]
out[idx, -len(t):, :] = string_to_one_hot_tensor(
t, max_length, left_truncate
)[0, :, :]
else:
t = t[:max_length]
out[idx, :len(t), :] = string_to_one_hot_tensor(t, max_length, left_truncate)[0, :, :]
out[idx, :len(t), :] = string_to_one_hot_tensor(
t, max_length, left_truncate
)[0, :, :]
else:
raise Exception("Input was neither a string nor a list of strings.")
return out
Expand Down Expand Up @@ -80,7 +81,7 @@ def forward(
x = self.fan_in(x)
x = self.lstm1(x)[0]
x = self.output_head(x)
x = x[:,-1,0]
x = x[:, -1, 0]
x = self.output_activation(x)
return x

Expand Down Expand Up @@ -124,9 +125,14 @@ def forward(
longest_sequence = len(x[0])
x = torch.LongTensor(x).to(self.get_current_device())
# TODO: is 1 masked or unmasked?
attention_mask = torch.LongTensor([1] * longest_sequence).to(self.get_current_device())
attention_mask = torch.LongTensor(
[1] * longest_sequence
).to(self.get_current_device())
elif isinstance(x, list) or isinstance(x, tuple):
sequences = [self.tokenizer.encode(text, add_special_tokens=True)[-max_size:] for text in x]
sequences = [
self.tokenizer.encode(text, add_special_tokens=True)[-max_size:]
for text in x
]
for token_list in sequences:
longest_sequence = max(longest_sequence, len(token_list))
x = list()
Expand All @@ -135,16 +141,28 @@ def forward(
x.append(
([self.pad_token] * (longest_sequence - len(sequence))) + sequence
)
attention_mask.append([0] * (longest_sequence - len(sequence)) + [1] * len(sequence))
attention_mask.append(
[0] * (longest_sequence - len(sequence)) + [1] * len(sequence)
)
x = torch.LongTensor(x).to(self.get_current_device())
attention_mask = torch.tensor(attention_mask).to(self.get_current_device())

#segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
segments_tensors = torch.zeros(x.shape, dtype=torch.int).to(self.get_current_device())
# segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
segments_tensors = torch.zeros(x.shape, dtype=torch.int) \
.to(self.get_current_device())
if y is not None:
return self.transformer(x, token_type_ids=segments_tensors, attention_mask=attention_mask, labels=y)
return self.transformer(
x,
token_type_ids=segments_tensors,
attention_mask=attention_mask,
labels=y
)
else:
return self.transformer(x, token_type_ids=segments_tensors, attention_mask=attention_mask).logits
return self.transformer(
x,
token_type_ids=segments_tensors,
attention_mask=attention_mask
).logits


class PromptSaturationDetectorV3: # Note: Not nn.Module.
Expand All @@ -155,7 +173,9 @@ def __init__(
device: torch.device = torch.device('cpu'),
model_path_override: str = ""
):
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
from transformers import (
pipeline, AutoTokenizer, AutoModelForSequenceClassification
)
if not model_path_override:
self.model = AutoModelForSequenceClassification.from_pretrained(
"GuardrailsAI/prompt-saturation-attack-detector",
Expand Down
3 changes: 2 additions & 1 deletion validator/post-install.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from transformers import pipeline, AutoTokenizer, AutoModel

print("post-install starting...")
# TODO: It's not clear if the DetectJailbreak will be on the path yet.
# If we can import Detect Jailbreak, it will be safer to read the names of the models
# from the composite model as exposed by DetectJailbreak.XYZ.
from transformers import pipeline, AutoTokenizer, AutoModel
print("Fetching model 1 of 3 (Saturation)")
AutoModel.from_pretrained("GuardrailsAI/prompt-saturation-attack-detector")
AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
Expand Down
Loading