Skip to content

Commit c300ade

Browse files
Initial validator
1 parent e7c161a commit c300ade

File tree

3 files changed

+148
-28
lines changed

3 files changed

+148
-28
lines changed

inference/serving-non-optimized.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Optional
12
import modal
23

34
MODEL_ALIAS = "gemma2"
@@ -129,11 +130,27 @@ def tgi_app():
129130

130131
from typing import List
131132
from pydantic import BaseModel
133+
import logging
132134

133135
TOKEN = os.getenv("TOKEN")
134136
if TOKEN is None:
135137
raise ValueError("Please set the TOKEN environment variable")
136138

139+
# Create a logger
140+
logger = logging.getLogger(MODEL_ALIAS)
141+
logger.setLevel(logging.DEBUG)
142+
143+
# Create a handler for logging to stdout
144+
stdout_handler = logging.StreamHandler()
145+
stdout_handler.setLevel(logging.DEBUG)
146+
147+
# Create a formatter for the log messages
148+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
149+
stdout_handler.setFormatter(formatter)
150+
151+
# Add the handler to the logger
152+
logger.addHandler(stdout_handler)
153+
137154
volume.reload() # ensure we have the latest version of the weights
138155

139156
app = fastapi.FastAPI()
@@ -157,6 +174,24 @@ async def is_authenticated(api_key: str = fastapi.Security(http_bearer)):
157174
detail="Invalid authentication credentials",
158175
)
159176
return {"username": "authenticated_user"}
177+
178+
@app.exception_handler(Exception)
179+
def error_handler(request, exc):
180+
status_code = 500
181+
detail = "Internal Server Error"
182+
logger.exception(exc)
183+
if isinstance(exc, fastapi.HTTPException):
184+
status_code = exc.status_code
185+
detail = exc.detail
186+
return fastapi.responses.JSONResponse(
187+
status_code=status_code,
188+
content={
189+
"status": status_code,
190+
"response": {
191+
"detail": detail,
192+
}
193+
},
194+
)
160195

161196
router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)])
162197

@@ -165,22 +200,31 @@ class ChatMessages(BaseModel):
165200
content: str
166201

167202
class ChatClassificationRequestBody(BaseModel):
203+
score_threshold: Optional[float] = None
204+
policies: Optional[List[str]] = None
168205
chat: List[ChatMessages]
169206

170207

171208
@router.post("/v1/chat/classification")
172209
async def chat_classification_response(body: ChatClassificationRequestBody):
210+
policies = body.policies
211+
score_threshold = body.score_threshold or 0.5
173212
chat = body.model_dump().get("chat",[])
174213

175214
print("Serving request for chat classification...")
176215
print(f"Chat: {chat}")
177-
score = Model().generate.remote(chat)
216+
score = Model().generate.remote(chat, enforce_policies=policies)
178217

179-
is_unsafe = score > 0.5
218+
is_unsafe = score > score_threshold
180219

181220
return {
182-
"class": "unsafe" if is_unsafe else "safe",
183-
"score": score,
221+
"status": 200,
222+
"response": {
223+
"class": "unsafe" if is_unsafe else "safe",
224+
"score": score,
225+
"applied_policies": policies,
226+
"score_threshold": score_threshold
227+
}
184228
}
185229

186230

validator/main.py

Lines changed: 99 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
from typing import Any, Callable, Dict, Optional
1+
2+
import json
3+
from typing import Any, Callable, Dict, List, Optional
4+
from enum import Enum
5+
from guardrails.validator_base import ErrorSpan
26

37
from guardrails.validator_base import (
48
FailResult,
@@ -7,42 +11,117 @@
711
Validator,
812
register_validator,
913
)
14+
from guardrails.logger import logger
15+
1016

17+
class Policies(str, Enum):
18+
NO_DANGEROUS_CONTENT = "NO_DANGEROUS_CONTENT"
19+
NO_HARASSMENT = "NO_HARASSMENT"
20+
NO_HATE_SPEECH = "NO_HATE_SPEECH"
21+
NO_SEXUAL_CONTENT = "NO_SEXUAL_CONTENT"
1122

12-
@register_validator(name="guardrails/validator_template", data_type="string")
13-
class ValidatorTemplate(Validator):
14-
"""Validates that {fill in how you validator interacts with the passed value}.
23+
SUP = "SUP"
1524

25+
26+
@register_validator(name="guardrails/shieldgemma_2b", data_type="string")
27+
class ShieldGemma2B(Validator):
28+
"""
29+
Classifies model inputs or outputs as "safe" or "unsafe" based on certain policies defined by the ShieldGemma-2B model.
30+
1631
**Key Properties**
1732
1833
| Property | Description |
1934
| ----------------------------- | --------------------------------- |
20-
| Name for `format` attribute | `guardrails/validator_template` |
35+
| Name for `format` attribute | `guardrails/shieldgemma_2b` |
2136
| Supported data types | `string` |
22-
| Programmatic fix | {If you support programmatic fixes, explain it here. Otherwise `None`} |
37+
| Programmatic fix | None |
2338
2439
Args:
25-
arg_1 (string): {Description of the argument here}
26-
arg_2 (string): {Description of the argument here}
40+
policies (List[Policies]): List of Policies enum values to enforce.
41+
score_threshold (float): Threshold score for the classification. If the score is above this threshold, the input is considered unsafe.
2742
""" # noqa
2843

29-
# If you don't have any init args, you can omit the __init__ method.
44+
Policies = Policies
45+
3046
def __init__(
3147
self,
32-
arg_1: str,
33-
arg_2: str,
48+
policies: Optional[List[Policies]] = None,
49+
validation_method: Optional[str] = "full",
50+
score_threshold: Optional[float] = None,
3451
on_fail: Optional[Callable] = None,
3552
):
36-
super().__init__(on_fail=on_fail, arg_1=arg_1, arg_2=arg_2)
37-
self._arg_1 = arg_1
38-
self._arg_2 = arg_2
53+
54+
super().__init__(
55+
on_fail=on_fail,
56+
validation_method=validation_method,
57+
)
58+
59+
try:
60+
self._policies = [policy.value for policy in policies] if policies else []
61+
except AttributeError as e:
62+
raise ValueError("Invalid policies provided. Please provide a list of ShieldGemma2B.Policies enum values.") from e
63+
64+
self.score_threshold = score_threshold
65+
66+
if policies and isinstance(policies, list) and len(policies) > 1:
67+
logger.warn((
68+
"GemmaShield2B Validator: Multiple policies are not officially supported"
69+
"it is recommended to use one policy at a time."
70+
"Refer to: https://huggingface.co/google/shieldgemma-2b/discussions/11"
71+
))
3972

4073
def validate(self, value: Any, metadata: Dict = {}) -> ValidationResult:
41-
"""Validates that {fill in how you validator interacts with the passed value}."""
42-
# Add your custom validator logic here and return a PassResult or FailResult accordingly.
43-
if value != "pass": # FIXME
74+
75+
if not value:
76+
raise ValueError("Value cannot be empty.")
77+
78+
(classification, score) = self._inference(value)
79+
80+
is_unsafe = classification == "unsafe"
81+
82+
if is_unsafe:
83+
error_span = ErrorSpan(
84+
start=0,
85+
end=len(value),
86+
reason=f"Unsafe content: {value}",
87+
)
4488
return FailResult(
45-
error_message="{A descriptive but concise error message about why validation failed}",
46-
fix_value="{The programmtic fix if applicable, otherwise remove this kwarg.}",
89+
error_message=(
90+
f"Prompt contains unsafe content. Classification: {classification}, Score: {score}"
91+
),
92+
error_spans=[error_span],
4793
)
48-
return PassResult()
94+
else:
95+
return PassResult()
96+
97+
98+
def _inference_local(self, value: str):
99+
raise NotImplementedError("Local inference is not supported for ShieldGemma2B validator.")
100+
101+
def _inference_remote(self, value: str) -> ValidationResult:
102+
"""Remote inference method for this validator."""
103+
request_body = {
104+
"policies": self._policies,
105+
"score_threshold": self.score_threshold,
106+
"chat": [
107+
{
108+
"role": "user",
109+
"content": value
110+
}
111+
]
112+
}
113+
114+
response = self._hub_inference_request(json.dumps(request_body), self.validation_endpoint)
115+
116+
status = response.get("status")
117+
if status != 200:
118+
detail = response.get("response",{}).get("detail", "Unknown error")
119+
raise ValueError(f"Failed to get valid response from ShieldGemma-2B model. Status: {status}. Detail: {detail}")
120+
121+
response_data = response.get("response")
122+
123+
classification = response_data.get("class")
124+
score = response_data.get("score")
125+
126+
return (classification, score)
127+

validator/post-install.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1 @@
1-
print("post-install starting...")
2-
print("This is where you would do things like download nltk tokenizers or login to the HuggingFace hub...")
3-
print("post-install complete!")
4-
# If you don't have anything to add here you should delete this file.
1+
# No post install script

0 commit comments

Comments
 (0)