1
- from typing import Any , Callable , Dict , Optional
1
+
2
+ import json
3
+ from typing import Any , Callable , Dict , List , Optional
4
+ from enum import Enum
5
+ from guardrails .validator_base import ErrorSpan
2
6
3
7
from guardrails .validator_base import (
4
8
FailResult ,
7
11
Validator ,
8
12
register_validator ,
9
13
)
14
+ from guardrails .logger import logger
15
+
10
16
17
+ class Policies (str , Enum ):
18
+ NO_DANGEROUS_CONTENT = "NO_DANGEROUS_CONTENT"
19
+ NO_HARASSMENT = "NO_HARASSMENT"
20
+ NO_HATE_SPEECH = "NO_HATE_SPEECH"
21
+ NO_SEXUAL_CONTENT = "NO_SEXUAL_CONTENT"
11
22
12
- @register_validator (name = "guardrails/validator_template" , data_type = "string" )
13
- class ValidatorTemplate (Validator ):
14
- """Validates that {fill in how you validator interacts with the passed value}.
23
+ SUP = "SUP"
15
24
25
+
26
+ @register_validator (name = "guardrails/shieldgemma_2b" , data_type = "string" )
27
+ class ShieldGemma2B (Validator ):
28
+ """
29
+ Classifies model inputs or outputs as "safe" or "unsafe" based on certain policies defined by the ShieldGemma-2B model.
30
+
16
31
**Key Properties**
17
32
18
33
| Property | Description |
19
34
| ----------------------------- | --------------------------------- |
20
- | Name for `format` attribute | `guardrails/validator_template` |
35
+ | Name for `format` attribute | `guardrails/shieldgemma_2b` |
21
36
| Supported data types | `string` |
22
- | Programmatic fix | {If you support programmatic fixes, explain it here. Otherwise `None`} |
37
+ | Programmatic fix | None |
23
38
24
39
Args:
25
- arg_1 (string ): {Description of the argument here}
26
- arg_2 (string ): {Description of the argument here}
40
+ policies (List[Policies] ): List of Policies enum values to enforce.
41
+ score_threshold (float ): Threshold score for the classification. If the score is above this threshold, the input is considered unsafe.
27
42
""" # noqa
28
43
29
- # If you don't have any init args, you can omit the __init__ method.
44
+ Policies = Policies
45
+
30
46
def __init__ (
31
47
self ,
32
- arg_1 : str ,
33
- arg_2 : str ,
48
+ policies : Optional [List [Policies ]] = None ,
49
+ validation_method : Optional [str ] = "full" ,
50
+ score_threshold : Optional [float ] = None ,
34
51
on_fail : Optional [Callable ] = None ,
35
52
):
36
- super ().__init__ (on_fail = on_fail , arg_1 = arg_1 , arg_2 = arg_2 )
37
- self ._arg_1 = arg_1
38
- self ._arg_2 = arg_2
53
+
54
+ super ().__init__ (
55
+ on_fail = on_fail ,
56
+ validation_method = validation_method ,
57
+ )
58
+
59
+ try :
60
+ self ._policies = [policy .value for policy in policies ] if policies else []
61
+ except AttributeError as e :
62
+ raise ValueError ("Invalid policies provided. Please provide a list of ShieldGemma2B.Policies enum values." ) from e
63
+
64
+ self .score_threshold = score_threshold
65
+
66
+ if policies and isinstance (policies , list ) and len (policies ) > 1 :
67
+ logger .warn ((
68
+ "GemmaShield2B Validator: Multiple policies are not officially supported"
69
+ "it is recommended to use one policy at a time."
70
+ "Refer to: https://huggingface.co/google/shieldgemma-2b/discussions/11"
71
+ ))
39
72
40
73
def validate (self , value : Any , metadata : Dict = {}) -> ValidationResult :
41
- """Validates that {fill in how you validator interacts with the passed value}."""
42
- # Add your custom validator logic here and return a PassResult or FailResult accordingly.
43
- if value != "pass" : # FIXME
74
+
75
+ if not value :
76
+ raise ValueError ("Value cannot be empty." )
77
+
78
+ (classification , score ) = self ._inference (value )
79
+
80
+ is_unsafe = classification == "unsafe"
81
+
82
+ if is_unsafe :
83
+ error_span = ErrorSpan (
84
+ start = 0 ,
85
+ end = len (value ),
86
+ reason = f"Unsafe content: { value } " ,
87
+ )
44
88
return FailResult (
45
- error_message = "{A descriptive but concise error message about why validation failed}" ,
46
- fix_value = "{The programmtic fix if applicable, otherwise remove this kwarg.}" ,
89
+ error_message = (
90
+ f"Prompt contains unsafe content. Classification: { classification } , Score: { score } "
91
+ ),
92
+ error_spans = [error_span ],
47
93
)
48
- return PassResult ()
94
+ else :
95
+ return PassResult ()
96
+
97
+
98
+ def _inference_local (self , value : str ):
99
+ raise NotImplementedError ("Local inference is not supported for ShieldGemma2B validator." )
100
+
101
+ def _inference_remote (self , value : str ) -> ValidationResult :
102
+ """Remote inference method for this validator."""
103
+ request_body = {
104
+ "policies" : self ._policies ,
105
+ "score_threshold" : self .score_threshold ,
106
+ "chat" : [
107
+ {
108
+ "role" : "user" ,
109
+ "content" : value
110
+ }
111
+ ]
112
+ }
113
+
114
+ response = self ._hub_inference_request (json .dumps (request_body ), self .validation_endpoint )
115
+
116
+ status = response .get ("status" )
117
+ if status != 200 :
118
+ detail = response .get ("response" ,{}).get ("detail" , "Unknown error" )
119
+ raise ValueError (f"Failed to get valid response from ShieldGemma-2B model. Status: { status } . Detail: { detail } " )
120
+
121
+ response_data = response .get ("response" )
122
+
123
+ classification = response_data .get ("class" )
124
+ score = response_data .get ("score" )
125
+
126
+ return (classification , score )
127
+
0 commit comments