Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 48 additions & 4 deletions inference/serving-non-optimized.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
import modal

MODEL_ALIAS = "gemma2"
Expand Down Expand Up @@ -129,11 +130,27 @@ def tgi_app():

from typing import List
from pydantic import BaseModel
import logging

TOKEN = os.getenv("TOKEN")
if TOKEN is None:
raise ValueError("Please set the TOKEN environment variable")

# Create a logger
logger = logging.getLogger(MODEL_ALIAS)
logger.setLevel(logging.DEBUG)

# Create a handler for logging to stdout
stdout_handler = logging.StreamHandler()
stdout_handler.setLevel(logging.DEBUG)

# Create a formatter for the log messages
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
stdout_handler.setFormatter(formatter)

# Add the handler to the logger
logger.addHandler(stdout_handler)

volume.reload() # ensure we have the latest version of the weights

app = fastapi.FastAPI()
Expand All @@ -157,6 +174,24 @@ async def is_authenticated(api_key: str = fastapi.Security(http_bearer)):
detail="Invalid authentication credentials",
)
return {"username": "authenticated_user"}

@app.exception_handler(Exception)
def error_handler(request, exc):
status_code = 500
detail = "Internal Server Error"
logger.exception(exc)
if isinstance(exc, fastapi.HTTPException):
status_code = exc.status_code
detail = exc.detail
return fastapi.responses.JSONResponse(
status_code=status_code,
content={
"status": status_code,
"response": {
"detail": detail,
}
},
)

router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)])

Expand All @@ -165,22 +200,31 @@ class ChatMessages(BaseModel):
content: str

class ChatClassificationRequestBody(BaseModel):
score_threshold: Optional[float] = None
policies: Optional[List[str]] = None
chat: List[ChatMessages]


@router.post("/v1/chat/classification")
async def chat_classification_response(body: ChatClassificationRequestBody):
policies = body.policies
score_threshold = body.score_threshold or 0.5
chat = body.model_dump().get("chat",[])

print("Serving request for chat classification...")
print(f"Chat: {chat}")
score = Model().generate.remote(chat)
score = Model().generate.remote(chat, enforce_policies=policies)

is_unsafe = score > 0.5
is_unsafe = score > score_threshold

return {
"class": "unsafe" if is_unsafe else "safe",
"score": score,
"status": 200,
"response": {
"class": "unsafe" if is_unsafe else "safe",
"score": score,
"applied_policies": policies,
"score_threshold": score_threshold
}
}


Expand Down
117 changes: 97 additions & 20 deletions validator/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from typing import Any, Callable, Dict, Optional

import json
from typing import Any, Callable, Dict, List, Optional
from enum import Enum
from guardrails.validator_base import ErrorSpan

from guardrails.validator_base import (
FailResult,
Expand All @@ -7,42 +11,115 @@
Validator,
register_validator,
)
from guardrails.logger import logger


class Policies(str, Enum):
NO_DANGEROUS_CONTENT = "NO_DANGEROUS_CONTENT"
NO_HARASSMENT = "NO_HARASSMENT"
NO_HATE_SPEECH = "NO_HATE_SPEECH"
NO_SEXUAL_CONTENT = "NO_SEXUAL_CONTENT"

@register_validator(name="guardrails/validator_template", data_type="string")
class ValidatorTemplate(Validator):
"""Validates that {fill in how you validator interacts with the passed value}.

@register_validator(name="guardrails/shieldgemma_2b", data_type="string")
class ShieldGemma2B(Validator):
"""
Classifies model inputs or outputs as "safe" or "unsafe" based on certain policies defined by the ShieldGemma-2B model.

**Key Properties**

| Property | Description |
| ----------------------------- | --------------------------------- |
| Name for `format` attribute | `guardrails/validator_template` |
| Name for `format` attribute | `guardrails/shieldgemma_2b` |
| Supported data types | `string` |
| Programmatic fix | {If you support programmatic fixes, explain it here. Otherwise `None`} |
| Programmatic fix | None |

Args:
arg_1 (string): {Description of the argument here}
arg_2 (string): {Description of the argument here}
policies (List[Policies]): List of Policies enum values to enforce.
score_threshold (float): Threshold score for the classification. If the score is above this threshold, the input is considered unsafe.
""" # noqa

# If you don't have any init args, you can omit the __init__ method.
Policies = Policies

def __init__(
self,
arg_1: str,
arg_2: str,
policies: Optional[List[Policies]] = None,
validation_method: Optional[str] = "full",
score_threshold: Optional[float] = None,
on_fail: Optional[Callable] = None,
):
super().__init__(on_fail=on_fail, arg_1=arg_1, arg_2=arg_2)
self._arg_1 = arg_1
self._arg_2 = arg_2

super().__init__(
on_fail=on_fail,
validation_method=validation_method,
)

try:
self._policies = [policy.value for policy in policies] if policies else []
except AttributeError as e:
raise ValueError("Invalid policies provided. Please provide a list of ShieldGemma2B.Policies enum values.") from e

self.score_threshold = score_threshold

if policies and isinstance(policies, list) and len(policies) > 1:
logger.warn((
"GemmaShield2B Validator: Multiple policies are not officially supported"
"it is recommended to use one policy at a time."
"Refer to: https://huggingface.co/google/shieldgemma-2b/discussions/11"
))

def validate(self, value: Any, metadata: Dict = {}) -> ValidationResult:
"""Validates that {fill in how you validator interacts with the passed value}."""
# Add your custom validator logic here and return a PassResult or FailResult accordingly.
if value != "pass": # FIXME

if not value:
raise ValueError("Value cannot be empty.")

(classification, score) = self._inference(value)

is_unsafe = classification == "unsafe"

if is_unsafe:
error_span = ErrorSpan(
start=0,
end=len(value),
reason=f"Unsafe content: {value}",
)
return FailResult(
error_message="{A descriptive but concise error message about why validation failed}",
fix_value="{The programmtic fix if applicable, otherwise remove this kwarg.}",
error_message=(
f"Prompt contains unsafe content. Classification: {classification}, Score: {score}"
),
error_spans=[error_span],
)
return PassResult()
else:
return PassResult()


def _inference_local(self, value: str):
raise NotImplementedError("Local inference is not supported for ShieldGemma2B validator.")

def _inference_remote(self, value: str) -> ValidationResult:
"""Remote inference method for this validator."""
request_body = {
"policies": self._policies,
"score_threshold": self.score_threshold,
"chat": [
{
"role": "user",
"content": value
}
]
}

response = self._hub_inference_request(json.dumps(request_body), self.validation_endpoint)

status = response.get("status")
if status != 200:
detail = response.get("response",{}).get("detail", "Unknown error")
raise ValueError(f"Failed to get valid response from ShieldGemma-2B model. Status: {status}. Detail: {detail}")

response_data = response.get("response")

classification = response_data.get("class")
score = response_data.get("score")

return (classification, score)

5 changes: 1 addition & 4 deletions validator/post-install.py
Original file line number Diff line number Diff line change
@@ -1,4 +1 @@
print("post-install starting...")
print("This is where you would do things like download nltk tokenizers or login to the HuggingFace hub...")
print("post-install complete!")
# If you don't have anything to add here you should delete this file.
# No post install script
Loading