Skip to content

Commit 5963120

Browse files
authored
Merge pull request #2 from guardrails-ai/aaravnavani-patch-1
Update __init__.py
2 parents fa7cf96 + 2764cef commit 5963120

File tree

5 files changed

+13
-8
lines changed

5 files changed

+13
-8
lines changed

inference/download_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
)
3939
def download_model(model_name, model_revision, force_download=False):
4040
from huggingface_hub import snapshot_download
41-
from transformers import AutoTokenizer, AutoModelForCausalLM
41+
from transformers import AutoTokenizer
4242
import json
4343
import os
4444

inference/serving-non-optimized.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def generate(self, chat, is_user_prompt=True, enforce_policies=None):
5656
tokenizer = self.tokenizer
5757
model = self.model
5858

59-
print(f"Model: Loaded on device")
60-
print(f"Model: Chat {chat}")
59+
print("Model: Loaded on device")
60+
print("Model: Chat {chat}")
6161

6262
INPUT_POLICIES = {
6363
"NO_DANGEROUS_CONTENT": "\"No Dangerous Content\": The prompt shall not contain or seek generation of content that harming oneself and/or others (e.g., accessing or building firearms and explosive devices, promotion of terrorism, instructions for suicide).",

pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,9 @@ testpaths = [
2727
]
2828

2929
[tool.pyright]
30-
include = ["validator"]
30+
include = ["validator"]
31+
32+
[tool.setuptools]
33+
packages = [
34+
"validator"
35+
]

tests/test_validator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33

44
from guardrails import Guard
55
import pytest
6-
from validator import ValidatorTemplate
6+
from validator import ShieldGemma2B
77

88
# We use 'exception' as the validator's fail action,
99
# so we expect failures to always raise an Exception
1010
# Learn more about corrective actions here:
1111
# https://www.guardrailsai.com/docs/concepts/output/#%EF%B8%8F-specifying-corrective-actions
12-
guard = Guard.from_string(validators=[ValidatorTemplate(arg_1="arg_1", arg_2="arg_2", on_fail="exception")])
12+
guard = Guard.from_string(validators=[ShieldGemma2B(arg_1="arg_1", arg_2="arg_2", on_fail="exception")])
1313

1414
def test_pass():
1515
test_output = "pass"

validator/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .main import ValidatorTemplate
1+
from .main import ShieldGemma2B
22

3-
__all__ = ["ValidatorTemplate"]
3+
__all__ = ["ShieldGemma2B"]

0 commit comments

Comments
 (0)