Skip to content

Commit d0f0f52

Browse files
Change validator to respect local/remote inference. Also only push to Pypi on release tag.
1 parent f63ccee commit d0f0f52

File tree

3 files changed

+40
-33
lines changed

3 files changed

+40
-33
lines changed

.github/workflows/publish_pypi.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ name: Publish to Guardrails Hub
33
on:
44
workflow_dispatch:
55
push:
6-
branches:
7-
- main
6+
# Publish when new releases are tagged.
7+
tags:
8+
- '*'
89

910
jobs:
1011
setup:

app_inference_spec.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22
# Forked from spec:
33
# github.com/guardrails-ai/models-host/tree/main/ray#adding-new-inference-endpoints
44
import os
5-
from typing import Optional
65
from logging import getLogger
76

8-
from fastapi import HTTPException
97
from pydantic import BaseModel
108
from models_host.base_inference_spec import BaseInferenceSpec
119

@@ -23,13 +21,10 @@
2321

2422
class InputRequest(BaseModel):
2523
message: str
26-
threshold: Optional[float] = None
2724

2825

2926
class OutputResponse(BaseModel):
30-
classification: str
3127
score: float
32-
is_jailbreak: bool
3328

3429

3530
# Using same nomenclature as in Sagemaker classes
@@ -69,30 +64,9 @@ def process_request(self, input_request: InputRequest):
6964
# raise HTTPException(status_code=400, detail="Invalid input format")
7065
args = (message,)
7166
kwargs = {}
72-
if input_request.threshold is not None:
73-
kwargs["threshold"] = input_request.threshold
74-
if not 0.0 <= input_request.threshold <= 1.0:
75-
raise HTTPException(
76-
status_code=400,
77-
detail=f"Threshold must be between 0.0 and 1.0. "
78-
f"Got {input_request.threshold}"
79-
)
8067
return args, kwargs
8168

82-
def infer(self, message: str, threshold: Optional[float] = None) -> OutputResponse:
83-
if threshold is None:
84-
threshold = 0.81
85-
86-
score = self.model.predict_jailbreak([message,])[0]
87-
if score > threshold:
88-
classification = "jailbreak"
89-
is_jailbreak = True
90-
else:
91-
classification = "safe"
92-
is_jailbreak = False
93-
69+
def infer(self, message: str) -> OutputResponse:
9470
return OutputResponse(
95-
classification=classification,
96-
score=score,
97-
is_jailbreak=is_jailbreak,
71+
score=self.model.predict_jailbreak([message,])[0],
9872
)

validator/main.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import json
12
import math
2-
from typing import Callable, List, Optional, Union
3+
from typing import Callable, List, Optional, Union, Any
34

45
import torch
56
from torch.nn import functional as F
@@ -65,8 +66,9 @@ def __init__(
6566
device: str = "cpu",
6667
on_fail: Optional[Callable] = None,
6768
model_path_override: str = "",
69+
**kwargs,
6870
):
69-
super().__init__(on_fail=on_fail)
71+
super().__init__(on_fail=on_fail, **kwargs)
7072
self.device = device
7173
self.threshold = threshold
7274

@@ -271,7 +273,9 @@ def validate(
271273
if isinstance(value, str):
272274
value = [value, ]
273275

274-
scores = self.predict_jailbreak(value)
276+
# _inference is to support local/remote. It is equivalent to this:
277+
# scores = self.predict_jailbreak(value)
278+
scores = self._inference(value)
275279

276280
failed_prompts = list()
277281
failed_scores = list() # To help people calibrate their thresholds.
@@ -289,3 +293,31 @@ def validate(
289293
error_message=failure_message
290294
)
291295
return PassResult()
296+
297+
# The rest of these methods are made for validator compatibility and may have some
298+
# strange properties,
299+
300+
def _inference_local(self, model_input: List[str]) -> Any:
301+
return self.predict_jailbreak(model_input)
302+
303+
def _inference_remote(self, model_input: List[str]) -> Any:
304+
# This needs to be kept in-sync with app_inference_spec.
305+
request_body = {
306+
"inputs": [
307+
{
308+
"name": "message",
309+
"shape": [len(model_input)],
310+
"data": model_input,
311+
"datatype": "BYTES"
312+
}
313+
]
314+
}
315+
response = self._hub_inference_request(
316+
json.dumps(request_body),
317+
self.validation_endpoint
318+
)
319+
if not response or "outputs" not in response:
320+
raise ValueError("Invalid response from remote inference", response)
321+
322+
data = [output["score"] for output in response["outputs"]]
323+
return data

0 commit comments

Comments
 (0)