diff --git a/inference/download_model.py b/inference/download_model.py index 6e2751d..305eedb 100644 --- a/inference/download_model.py +++ b/inference/download_model.py @@ -38,7 +38,7 @@ ) def download_model(model_name, model_revision, force_download=False): from huggingface_hub import snapshot_download - from transformers import AutoTokenizer, AutoModelForCausalLM + from transformers import AutoTokenizer import json import os diff --git a/inference/serving-non-optimized.py b/inference/serving-non-optimized.py index cc5f0ac..82ca0ce 100644 --- a/inference/serving-non-optimized.py +++ b/inference/serving-non-optimized.py @@ -56,8 +56,8 @@ def generate(self, chat, is_user_prompt=True, enforce_policies=None): tokenizer = self.tokenizer model = self.model - print(f"Model: Loaded on device") - print(f"Model: Chat {chat}") + print("Model: Loaded on device") + print("Model: Chat {chat}") INPUT_POLICIES = { "NO_DANGEROUS_CONTENT": "\"No Dangerous Content\": The prompt shall not contain or seek generation of content that harming oneself and/or others (e.g., accessing or building firearms and explosive devices, promotion of terrorism, instructions for suicide).", diff --git a/pyproject.toml b/pyproject.toml index d3574eb..5b0474e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,4 +27,9 @@ testpaths = [ ] [tool.pyright] -include = ["validator"] \ No newline at end of file +include = ["validator"] + +[tool.setuptools] +packages = [ + "validator" +] \ No newline at end of file diff --git a/tests/test_validator.py b/tests/test_validator.py index 77afb8a..63597cc 100644 --- a/tests/test_validator.py +++ b/tests/test_validator.py @@ -3,13 +3,13 @@ from guardrails import Guard import pytest -from validator import ValidatorTemplate +from validator import ShieldGemma2B # We use 'exception' as the validator's fail action, # so we expect failures to always raise an Exception # Learn more about corrective actions here: # https://www.guardrailsai.com/docs/concepts/output/#%EF%B8%8F-specifying-corrective-actions -guard = Guard.from_string(validators=[ValidatorTemplate(arg_1="arg_1", arg_2="arg_2", on_fail="exception")]) +guard = Guard.from_string(validators=[ShieldGemma2B(arg_1="arg_1", arg_2="arg_2", on_fail="exception")]) def test_pass(): test_output = "pass" diff --git a/validator/__init__.py b/validator/__init__.py index 72a2623..19f40a7 100644 --- a/validator/__init__.py +++ b/validator/__init__.py @@ -1,3 +1,3 @@ -from .main import ValidatorTemplate +from .main import ShieldGemma2B -__all__ = ["ValidatorTemplate"] +__all__ = ["ShieldGemma2B"]