diff --git a/comps/cores/mega/utils.py b/comps/cores/mega/utils.py index 4d8dc5b3eb..001b9ce4f5 100644 --- a/comps/cores/mega/utils.py +++ b/comps/cores/mega/utils.py @@ -181,3 +181,20 @@ def handle_message(messages): return prompt, images else: return prompt + + +def sanitize_env(value: Optional[str]) -> Optional[str]: + """Remove quotes from a configuration value if present. + + Args: + value (str): The configuration value to sanitize. + Returns: + str: The sanitized configuration value. + """ + if value is None: + return None + if value.startswith('"') and value.endswith('"'): + value = value[1:-1] + elif value.startswith("'") and value.endswith("'"): + value = value[1:-1] + return value diff --git a/comps/cores/proto/docarray.py b/comps/cores/proto/docarray.py index 2e8df1466a..2e8cc66bc9 100644 --- a/comps/cores/proto/docarray.py +++ b/comps/cores/proto/docarray.py @@ -1,7 +1,7 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np from docarray import BaseDoc, DocList @@ -163,16 +163,257 @@ class LVMSearchedMultimodalDoc(SearchedMultimodalDoc): ) -class GeneratedDoc(BaseDoc): - text: str - prompt: str - - class RerankedDoc(BaseDoc): reranked_docs: DocList[TextDoc] initial_query: str +class AnonymizeModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + hidden_names: Optional[List[str]] = None + allowed_names: Optional[List[str]] = None + entity_types: Optional[List[str]] = None + preamble: Optional[str] = None + regex_patterns: Optional[List[str]] = None + use_faker: Optional[bool] = None + recognizer_conf: Optional[str] = None + threshold: Optional[float] = None + language: Optional[str] = None + + +class BanCodeModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + model: Optional[str] = None + threshold: Optional[float] = None + + +class BanCompetitorsModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + competitors: List[str] = ["Competitor1", "Competitor2", "Competitor3"] + model: Optional[str] = None + threshold: Optional[float] = None + redact: Optional[bool] = None + + +class BanSubstringsModel(BaseDoc): + enabled: bool = False + substrings: List[str] = ["backdoor", "malware", "virus"] + match_type: Optional[str] = None + case_sensitive: bool = False + redact: Optional[bool] = None + contains_all: Optional[bool] = None + + +class BanTopicsModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + topics: List[str] = ["violence", "attack", "war"] + threshold: Optional[float] = None + model: Optional[str] = None + + +class CodeModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + languages: List[str] = ["Java", "Python"] + model: Optional[str] = None + is_blocked: Optional[bool] = None + threshold: Optional[float] = None + + +class GibberishModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + model: Optional[str] = None + threshold: Optional[float] = None + match_type: Optional[str] = None + + +class InvisibleText(BaseDoc): + enabled: bool = False + + +class LanguageModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + valid_languages: List[str] = ["en", "es"] + model: Optional[str] = None + threshold: Optional[float] = None + match_type: Optional[str] = None + + +class PromptInjectionModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + model: Optional[str] = None + threshold: Optional[float] = None + match_type: Optional[str] = None + + +class RegexModel(BaseDoc): + enabled: bool = False + patterns: List[str] = ["Bearer [A-Za-z0-9-._~+/]+"] + is_blocked: Optional[bool] = None + match_type: Optional[str] = None + redact: Optional[bool] = None + + +class SecretsModel(BaseDoc): + enabled: bool = False + redact_mode: Optional[str] = None + + +class SentimentModel(BaseDoc): + enabled: bool = False + threshold: Optional[float] = None + lexicon: Optional[str] = None + + +class TokenLimitModel(BaseDoc): + enabled: bool = False + limit: Optional[int] = None + encoding_name: Optional[str] = None + model_name: Optional[str] = None + + +class ToxicityModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + model: Optional[str] = None + threshold: Optional[float] = None + match_type: Optional[str] = None + + +class BiasModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + model: Optional[str] = None + threshold: Optional[float] = None + match_type: Optional[str] = None + + +class DeanonymizeModel(BaseDoc): + enabled: bool = False + matching_strategy: Optional[str] = None + + +class JSONModel(BaseDoc): + enabled: bool = False + required_elements: Optional[int] = None + repair: Optional[bool] = None + + +class LanguageSameModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + model: Optional[str] = None + threshold: Optional[float] = None + + +class MaliciousURLsModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + model: Optional[str] = None + threshold: Optional[float] = None + + +class NoRefusalModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + model: Optional[str] = None + threshold: Optional[float] = None + match_type: Optional[str] = None + + +class NoRefusalLightModel(BaseDoc): + enabled: bool = False + + +class ReadingTimeModel(BaseDoc): + enabled: bool = False + max_time: float = 0.5 + truncate: Optional[bool] = None + + +class FactualConsistencyModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + model: Optional[str] = None + minimum_score: Optional[float] = None + + +class RelevanceModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + model: Optional[str] = None + threshold: Optional[float] = None + + +class SensitiveModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + entity_types: Optional[List[str]] = None + regex_patterns: Optional[List[str]] = None + redact: Optional[bool] = None + recognizer_conf: Optional[str] = None + threshold: Optional[float] = None + + +class URLReachabilityModel(BaseDoc): + enabled: bool = False + success_status_codes: Optional[List[int]] = None + timeout: Optional[int] = None + + +class LLMGuardInputGuardrailParams(BaseDoc): + anonymize: AnonymizeModel = AnonymizeModel() + ban_code: BanCodeModel = BanCodeModel() + ban_competitors: BanCompetitorsModel = BanCompetitorsModel() + ban_substrings: BanSubstringsModel = BanSubstringsModel() + ban_topics: BanTopicsModel = BanTopicsModel() + code: CodeModel = CodeModel() + gibberish: GibberishModel = GibberishModel() + invisible_text: InvisibleText = InvisibleText() + language: LanguageModel = LanguageModel() + prompt_injection: PromptInjectionModel = PromptInjectionModel() + regex: RegexModel = RegexModel() + secrets: SecretsModel = SecretsModel() + sentiment: SentimentModel = SentimentModel() + token_limit: TokenLimitModel = TokenLimitModel() + toxicity: ToxicityModel = ToxicityModel() + + +class LLMGuardOutputGuardrailParams(BaseDoc): + ban_code: BanCodeModel = BanCodeModel() + ban_competitors: BanCompetitorsModel = BanCompetitorsModel() + ban_substrings: BanSubstringsModel = BanSubstringsModel() + ban_topics: BanTopicsModel = BanTopicsModel() + bias: BiasModel = BiasModel() + code: CodeModel = CodeModel() + deanonymize: DeanonymizeModel = DeanonymizeModel() + json_scanner: JSONModel = JSONModel() + language: LanguageModel = LanguageModel() + language_same: LanguageSameModel = LanguageSameModel() + malicious_urls: MaliciousURLsModel = MaliciousURLsModel() + no_refusal: NoRefusalModel = NoRefusalModel() + no_refusal_light: NoRefusalLightModel = NoRefusalLightModel() + reading_time: ReadingTimeModel = ReadingTimeModel() + factual_consistency: FactualConsistencyModel = FactualConsistencyModel() + gibberish: GibberishModel = GibberishModel() + regex: RegexModel = RegexModel() + relevance: RelevanceModel = RelevanceModel() + sensitive: SensitiveModel = SensitiveModel() + sentiment: SentimentModel = SentimentModel() + toxicity: ToxicityModel = ToxicityModel() + url_reachability: URLReachabilityModel = URLReachabilityModel() + anonymize_vault: Optional[List[Tuple]] = ( + None # the only parameter not available in fingerprint. Used to transmit vault + ) + + class LLMParamsDoc(BaseDoc): model: Optional[str] = None # for openai and ollama query: str @@ -187,6 +428,8 @@ class LLMParamsDoc(BaseDoc): repetition_penalty: NonNegativeFloat = 1.03 stream: bool = True language: str = "auto" # can be "en", "zh" + input_guardrail_params: Optional[LLMGuardInputGuardrailParams] = None + output_guardrail_params: Optional[LLMGuardOutputGuardrailParams] = None chat_template: Optional[str] = Field( default=None, @@ -213,6 +456,12 @@ def chat_template_must_contain_variables(cls, v): return v +class GeneratedDoc(BaseDoc): + text: str + prompt: str + output_guardrail_params: Optional[LLMGuardOutputGuardrailParams] = None + + class LLMParams(BaseDoc): model: Optional[str] = None max_tokens: int = 1024 diff --git a/comps/guardrails/src/guardrails/opea_guardrails_microservice.py b/comps/guardrails/src/guardrails/opea_guardrails_microservice.py index 90ec1c5bc9..ac178badd7 100644 --- a/comps/guardrails/src/guardrails/opea_guardrails_microservice.py +++ b/comps/guardrails/src/guardrails/opea_guardrails_microservice.py @@ -1,17 +1,26 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import asyncio import os import time from typing import Union +from dotenv import dotenv_values +from fastapi import HTTPException +from fastapi.responses import StreamingResponse from integrations.llamaguard import OpeaGuardrailsLlamaGuard from integrations.wildguard import OpeaGuardrailsWildGuard +from pydantic import ValidationError +from utils.llm_guard_input_guardrail import OPEALLMGuardInputGuardrail +from utils.llm_guard_output_guardrail import OPEALLMGuardOutputGuardrail from comps import ( CustomLogger, GeneratedDoc, + LLMParamsDoc, OpeaComponentLoader, + SearchedDoc, ServiceType, TextDoc, opea_microservices, @@ -23,6 +32,10 @@ logger = CustomLogger("opea_guardrails_microservice") logflag = os.getenv("LOGFLAG", False) +input_usvc_config = {**dotenv_values("utils/.input_env"), **os.environ} + +output_usvc_config = {**dotenv_values("utils/.output_env"), **os.environ} + guardrails_component_name = os.getenv("GUARDRAILS_COMPONENT_NAME", "OPEA_LLAMA_GUARD") # Initialize OpeaComponentLoader loader = OpeaComponentLoader( @@ -31,6 +44,9 @@ description=f"OPEA Guardrails Component: {guardrails_component_name}", ) +input_guardrail = OPEALLMGuardInputGuardrail(input_usvc_config) +output_guardrail = OPEALLMGuardOutputGuardrail(output_usvc_config) + @register_microservice( name="opea_service@guardrails", @@ -38,24 +54,36 @@ endpoint="/v1/guardrails", host="0.0.0.0", port=9090, - input_datatype=Union[GeneratedDoc, TextDoc], - output_datatype=TextDoc, + input_datatype=Union[LLMParamsDoc, GeneratedDoc, TextDoc], + output_datatype=Union[TextDoc, GeneratedDoc], ) @register_statistics(names=["opea_service@guardrails"]) -async def safety_guard(input: Union[GeneratedDoc, TextDoc]) -> TextDoc: +async def safety_guard(input: Union[LLMParamsDoc, GeneratedDoc, TextDoc]) -> Union[TextDoc, GeneratedDoc]: start = time.time() - # Log the input if logging is enabled if logflag: - logger.info(f"Input received: {input}") + logger.info(f"Received input: {input}") try: - # Use the loader to invoke the component - guardrails_response = await loader.invoke(input) + if isinstance(input, LLMParamsDoc): + processed = input_guardrail.scan_llm_input(input) + if logflag: + logger.info(f"Input guard passed: {processed}") - # Log the result if logging is enabled - if logflag: - logger.info(f"Output received: {guardrails_response}") + elif isinstance(input, GeneratedDoc): + try: + doc = input + except Exception as e: + logger.error(f"Problem using input as GeneratedDoc: {e}") + raise HTTPException(status_code=500, detail=f"{e}") from e + scanned_output = output_guardrail.scan_llm_output(doc) + + processed = scanned_output + else: + processed = input + + # Use the loader to invoke the component + guardrails_response = await loader.invoke(processed) # Record statistics statistics_dict["opea_service@guardrails"].append_latency(time.time() - start, None) diff --git a/comps/guardrails/src/guardrails/requirements.txt b/comps/guardrails/src/guardrails/requirements.txt index e299b4ab9f..88fd3618ae 100644 --- a/comps/guardrails/src/guardrails/requirements.txt +++ b/comps/guardrails/src/guardrails/requirements.txt @@ -5,9 +5,11 @@ huggingface-hub<=0.24.0 langchain-community langchain-huggingface langchain-openai +llm_guard opentelemetry-api opentelemetry-exporter-otlp opentelemetry-sdk +presidio_anonymizer prometheus-fastapi-instrumentator sentencepiece shortuuid diff --git a/comps/guardrails/src/guardrails/utils/.input_env b/comps/guardrails/src/guardrails/utils/.input_env new file mode 100644 index 0000000000..e15a6d87ce --- /dev/null +++ b/comps/guardrails/src/guardrails/utils/.input_env @@ -0,0 +1,113 @@ +# Copyright (C) 2024-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +## LLM Guard Input Guardrail Microservice Settings +## Singular input scanners settings +## Anonymize scanner settings +ANONYMIZE_ENABLED=false +ANONYMIZE_USE_ONNX=false +ANONYMIZE_HIDDEN_NAMES +ANONYMIZE_ALLOWED_NAMES +ANONYMIZE_ENTITY_TYPES +ANONYMIZE_PREAMBLE +ANONYMIZE_REGEX_PATTERMS +ANONYMIZE_USE_FAKER +ANONYMIZE_RECOGNIZER_CONF +ANONYMIZE_THRESHOLD +ANONYMIZE_LANGUAGE + + +## BanCode scanner settings +BAN_CODE_ENABLED=false +BAN_CODE_USE_ONNX=false +BAN_CODE_MODEL +BAN_CODE_THRESHOLD + +## BanCompetitors scanner settings +BAN_COMPETITORS_ENABLED=false +BAN_COMPETITORS_USE_ONNX=false +BAN_COMPETITORS_COMPETITORS="Competitor1,Competitor2,Competitor3" +BAN_COMPETITORS_THRESHOLD +BAN_COMPETITORS_REDACT +BAN_COMPETITORS_MODEL + +## BanSubstrings scanner settings +BAN_SUBSTRINGS_ENABLED=false +BAN_SUBSTRINGS_SUBSTRINGS="backdoor,malware,virus" +BAN_SUBSTRINGS_MATCH_TYPE +BAN_SUBSTRINGS_CASE_SENSITIVE +BAN_SUBSTRINGS_REDACT +BAN_SUBSTRINGS_CONTAINS_ALL + +## BanTopics scanner settings +BAN_TOPICS_ENABLED=false +BAN_TOPICS_USE_ONNX=false +BAN_TOPICS_TOPICS="violence,attack,war" +BAN_TOPICS_THRESHOLD +BAN_TOPICS_MODEL + +## Code scanner settings +CODE_ENABLED=false +CODE_USE_ONNX=false +CODE_LANGUAGES="Java,Python" +CODE_MODEL +CODE_IS_BLOCKED +CODE_THRESHOLD + +## Gibberish scanner settings +GIBBERISH_ENABLED=false +GIBBERISH_USE_ONNX=false +GIBBERISH_MODEL +GIBBERISH_THRESHOLD +GIBBERISH_MATCH_TYPE + +## Invisible Text scanner settings +INVISIBLETEXT_ENABLED=false + +## Language scanner settings +LANGUAGE_ENABLED=false +LANGUAGE_USE_ONNX=false +LANGUAGE_VALID_LANGUAGES="en,es" +LANGUAGE_MODEL +LANGUAGE_THRESHOLD +LANGUAGE_MATCH_TYPE + +## Prompt Injection scanner settings +PROMPT_INJECTION_ENABLED=false +PROMPT_INJECTION_USE_ONNX=false +PROMPT_INJECTION_MODEL +PROMPT_INJECTION_THRESHOLD +PROMPT_INJECTION_MATCH_TYPE + +## Regex scanner settings +REGEX_ENABLED=false +REGEX_PATTERNS="Bearer [A-Za-z0-9-._~+/]+" +REGEX_IS_BLOCKED +REGEX_MATCH_TYPE +REGEX_REDACT + +## Secrets scanner settings +SCERETS_ENABLED=false +SECRETS_REDACT_MODE + +## Sentiment scanner settings +SENTIMENT_ENABLED=false +SENTIMENT_THERSHOLD +SENTIMENT_LEXICON + +## TokenLimit scanner settings +TOKEN_LIMIT_ENABLED=false +TOKEN_LIMIT_LIMIT +TOKEN_LIMIT_ENCODING_NAME +TOKEN_LIMIT_MODEL_NAME + +## Toxicity scanner settings +TOXICITY_ENABLED=false +TOXICITY_USE_ONNX=false +TOXICITY_MODEL +TOXICITY_THRESHOLD +TOXICITY_MATCH_TYPE + +## Uncomment to change the microservice part +# LLM_GUARD_INPUT_SCANNER_USVC_PORT=8050 +# OPEA_LOGGER_LEVEL="INFO" diff --git a/comps/guardrails/src/guardrails/utils/.output_env b/comps/guardrails/src/guardrails/utils/.output_env new file mode 100644 index 0000000000..92fb49ad6a --- /dev/null +++ b/comps/guardrails/src/guardrails/utils/.output_env @@ -0,0 +1,144 @@ +# Copyright (C) 2024-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +## LLM Guard Output Guardrail Microservice Settings +## Singular output scanners settings +## BanCode scanner settings +BAN_CODE_ENABLED=false +BAN_CODE_USE_ONNX=false +BAN_CODE_MODEL +BAN_CODE_THRESHOLD + +## BanCompetitors scanner settings +BAN_COMPETITORS_ENABLED=false +BAN_COMPETITORS_USE_ONNX=false +BAN_COMPETITORS_COMPETITORS="Competitor1,Competitor2,Competitor3" +BAN_COMPETITORS_THRESHOLD +BAN_COMPETITORS_REDACT +BAN_COMPETITORS_MODEL + +## BanSubstrings scanner settings +BAN_SUBSTRINGS_ENABLED=false +BAN_SUBSTRINGS_SUBSTRINGS="backdoor,malware,virus" +BAN_SUBSTRINGS_MATCH_TYPE +BAN_SUBSTRINGS_CASE_SENSITIVE +BAN_SUBSTRINGS_REDACT=true +BAN_SUBSTRINGS_CONTAINS_ALL + +## BanTopics scanner settings +BAN_TOPICS_ENABLED=false +BAN_TOPICS_USE_ONNX=false +BAN_TOPICS_TOPICS="violence,attack,war" +BAN_TOPICS_THRESHOLD +BAN_TOPICS_MODEL + +## Bias scanner settings +BIAS_ENABLED=false +BIAS_USE_ONNX=false +BIAS_MODEL +BIAS_THRESHOLD +BIAS_MATCH_TYPE + +## Codes scanner settings +CODE_ENABLED=false +CODE_USE_ONNX=false +CODE_LANGUAGES="Java,Python" +CODE_MODEL +CODE_IS_BLOCKED +CODE_THRESHOLD + +## Deanonymize scanner settings +DEANONYMIZE_ENABLED=false +DEANONYMIZE_MATCHING_STRATEGY + +## JSON scanner settings +JSON_SCANNER_ENABLED=false +JSON_SCANNER_REQUIRED_ELEMENTS +JSON_SCANNER_REPAIR + +## Language scanner settings +LANGUAGE_ENABLED=false +LANGUAGE_USE_ONNX=false +LANGUAGE_VALID_LANGUAGES="en,es" +LANGUAGE_MODEL +LANGUAGE_THRESHOLD +LANGUAGE_MATCH_TYPE + +## LanguageSame scanner settings +LANGUAGE_SAME_ENABLED=false +LANGUAGE_SAME_USE_ONNX=false +LANGUAGE_SAME_MODEL +LANGUAGE_SAME_THRESHOLD + +## MaliciousURLs scanner settings +MALICIOUS_URLS_ENABLED=false +MALICIOUS_URLS_USE_ONNX=false +MALICIOUS_URLS_MODEL +MALICIOUS_URLS_THRESHOLD + +## NoRefusal scanner settings +NO_REFUSAL_ENABLED=false +NO_REFUSAL_USE_ONNX=false +NO_REFUSAL_MODEL +NO_REFUSAL_THRESHOLD +NO_REFUSAL_MATCH_TYPE + +## NoRefusalLight scanner settings +NO_REFUSAL_LIGHT_ENABLED=false + +## ReadingTime scanner settings +READING_TIME_ENABLED=false +READING_TIME_MAX_TIME=0.5 +READING_TIME_TRUNCATE + +## FactualConsistency scanner settings +FACTUAL_CONSISTENCY_ENABLED=false +FACTUAL_CONSISTENCY_USE_ONNX=false +FACTUAL_CONSISTENCY_MODEL +FACTUAL_CONSISTENCY_MINIMUM_SCORE + +## Gibberish scanner settings +GIBBERISH_ENABLED=false +GIBBERISH_USE_ONNX=false +GIBBERISH_MODEL +GIBBERISH_THRESHOLD +GIBBERISH_MATCH_TYPE + +## Regex scanner settings +REGEX_ENABLED=false +REGEX_PATTERNS="Bearer [A-Za-z0-9-._~+/]+" +REGEX_IS_BLOCKED +REGEX_MATCH_TYPE +REGEX_REDACT + +## Relevance scanner settings +RELEVANCE_ENABLED=false +RELEVANCE_USE_ONNX=false +RELEVANCE_MODEL +RELEVANCE_THRESHOLD + +## Snsitive scanner settings +SENSITIVE_ENABLED=false +SENSITIVE_USE_ONNX=false +SENSITIVE_ENTITY_TYPES +SENSITIVE_REGEX_PATTERNS +SENSITIVE_REDACT +SENSITIVE_RECOGNIZER_CONF +SENSITIVE_THRESHOLD + +## Sentiment scanner settings +SENTIMENT_ENABLED=false +SENTIMENT_THERSHOLD +SENTIMENT_LEXICON + +## Toxicity scanner settings +TOXICITY_ENABLED=false +TOXICITY_USE_ONNX=false +TOXICITY_MODEL +TOXICITY_THRESHOLD +TOXICITY_MATCH_TYPE + +## URLReachability +URL_REACHABILITY_ENABLED=false +URL_REACHABILITY_SUCCESS_STATUS_CODES +URL_REACHABILITY_TIMEOUT \ No newline at end of file diff --git a/comps/guardrails/src/guardrails/utils/llm_guard_input_guardrail.py b/comps/guardrails/src/guardrails/utils/llm_guard_input_guardrail.py new file mode 100644 index 0000000000..1e058beb6f --- /dev/null +++ b/comps/guardrails/src/guardrails/utils/llm_guard_input_guardrail.py @@ -0,0 +1,87 @@ +# Copyright (C) 2024-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from fastapi import HTTPException +from llm_guard import scan_prompt +from utils.llm_guard_input_scanners import InputScannersConfig + +from comps import CustomLogger, LLMParamsDoc + +logger = CustomLogger("opea_llm_guard_input_guardrail_microservice") + + +class OPEALLMGuardInputGuardrail: + """OPEALLMGuardInputGuardrail is responsible for scanning and sanitizing LLM input prompts + using various input scanners provided by LLM Guard.""" + + def __init__(self, usv_config: dict): + try: + self._scanners_config = InputScannersConfig(usv_config) + self._scanners = self._scanners_config.create_enabled_input_scanners() + except ValueError as e: + logger.exception(f"Value Error during scanner initialization: {e}") + raise + except Exception as e: + logger.exception(f"Unexpected error during scanner initialization: {e}") + raise + + def _get_anonymize_vault(self): + for item in self._scanners: + if type(item).__name__ == "Anonymize": + return item._vault.get() + return None + + def _recreate_anonymize_scanner_if_exists(self): + for item in self._scanners: + if type(item).__name__ == "Anonymize": + logger.info("Recreating Anonymize scanner to clear Vault.") + self._scanners.remove(item) + self._scanners.append(self._scanners_config._create_anonymize_scanner()) + break + + def _analyze_scan_outputs(self, prompt, results_valid, results_score): + filtered_results = { + key: value + for key, value in results_valid.items() + if key != "Anonymize" + and not ( + type(scanner := next((s for s in self._scanners if type(s).__name__ == key), None)).__name__ + in {"BanCompetitors", "BanSubstrings", "OPEABanSubstrings", "Regex", "OPEARegexScanner"} + and getattr(scanner, "_redact", False) + ) + } + + if False in filtered_results.values(): + msg = f"Prompt '{prompt}' is not valid, scores: {results_score}" + logger.error(msg) + raise HTTPException(status_code=466, detail="I'm sorry, I cannot assist you with your prompt.") + + def scan_llm_input(self, input_doc: LLMParamsDoc) -> LLMParamsDoc: + fresh_scanners = False + + if input_doc.input_guardrail_params is not None: + if self._scanners_config.changed(input_doc.input_guardrail_params.dict()): + self._scanners = self._scanners_config.create_enabled_input_scanners() + fresh_scanners = True + else: + logger.warning("Input guardrail params not found.") + + if not self._scanners: + logger.info("No scanners enabled. Skipping input scan.") + return input_doc + + if not fresh_scanners: + self._recreate_anonymize_scanner_if_exists() + + user_prompt = input_doc.query + sanitized_user_prompt, results_valid, results_score = scan_prompt(self._scanners, user_prompt) + self._analyze_scan_outputs(user_prompt, results_valid, results_score) + + input_doc.query = sanitized_user_prompt + + if input_doc.output_guardrail_params is not None and "Anonymize" in results_valid: + input_doc.output_guardrail_params.anonymize_vault = self._get_anonymize_vault() + elif input_doc.output_guardrail_params is None and "Anonymize" in results_valid: + logger.warning("Anonymize scanner result exists, but output_guardrail_params is missing.") + + return input_doc diff --git a/comps/guardrails/src/guardrails/utils/llm_guard_input_scanners.py b/comps/guardrails/src/guardrails/utils/llm_guard_input_scanners.py new file mode 100644 index 0000000000..2f916f03df --- /dev/null +++ b/comps/guardrails/src/guardrails/utils/llm_guard_input_scanners.py @@ -0,0 +1,951 @@ +# ruff: noqa: F401 +# Copyright (C) 2024-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +from llm_guard.input_scanners import ( + Anonymize, + BanCode, + BanCompetitors, + BanTopics, + Code, + Gibberish, + InvisibleText, + Language, + PromptInjection, + Secrets, + Sentiment, + TokenLimit, + Toxicity, +) + +# import models definition +from llm_guard.input_scanners.ban_code import MODEL_SM as BANCODE_MODEL_SM +from llm_guard.input_scanners.ban_code import MODEL_TINY as BANCODE_MODEL_TINY +from llm_guard.input_scanners.ban_competitors import MODEL_V1 as BANCOMPETITORS_MODEL_V1 +from llm_guard.input_scanners.ban_topics import MODEL_BGE_M3_V2 as BANTOPICS_MODEL_BGE_M3_V2 +from llm_guard.input_scanners.ban_topics import MODEL_DEBERTA_BASE_V2 as BANTOPICS_MODEL_DEBERTA_BASE_V2 +from llm_guard.input_scanners.ban_topics import MODEL_DEBERTA_LARGE_V2 as BANTOPICS_MODEL_DEBERTA_LARGE_V2 +from llm_guard.input_scanners.ban_topics import MODEL_ROBERTA_BASE_C_V2 as BANTOPICS_MODEL_ROBERTA_BASE_C_V2 +from llm_guard.input_scanners.ban_topics import MODEL_ROBERTA_LARGE_C_V2 as BANTOPICS_MODEL_ROBERTA_LARGE_C_V2 +from llm_guard.input_scanners.code import DEFAULT_MODEL as CODE_DEFAULT_MODEL +from llm_guard.input_scanners.gibberish import DEFAULT_MODEL as GIBBERISH_DEFAULT_MODEL +from llm_guard.input_scanners.language import DEFAULT_MODEL as LANGUAGE_DEFAULT_MODEL +from llm_guard.input_scanners.prompt_injection import V1_MODEL as PROMPTINJECTION_V1_MODEL +from llm_guard.input_scanners.prompt_injection import V2_MODEL as PROMPTINJECTION_V2_MODEL +from llm_guard.input_scanners.prompt_injection import V2_SMALL_MODEL as PROMPTINJECTION_V2_SMALL_MODEL +from llm_guard.input_scanners.toxicity import DEFAULT_MODEL as TOXICITY_DEFAULT_MODEL +from llm_guard.vault import Vault + +ENABLED_SCANNERS = [ + "anonymize", + "ban_code", + "ban_competitors", + "ban_substrings", + "ban_topics", + "code", + "gibberish", + "invisible_text", + "language", + "prompt_injection", + "regex", + "secrets", + "sentiment", + "token_limit", + "toxicity", +] + +from comps import CustomLogger +from comps.cores.mega.utils import sanitize_env +from comps.guardrails.src.guardrails.utils.scanners import OPEABanSubstrings, OPEARegexScanner + +logger = CustomLogger("opea_llm_guard_input_guardrail_microservice") + + +class InputScannersConfig: + + def __init__(self, config_dict): + self._input_scanners_config = { + **self._get_anonymize_config_from_env(config_dict), + **self._get_ban_code_config_from_env(config_dict), + **self._get_ban_competitors_config_from_env(config_dict), + **self._get_ban_substrings_config_from_env(config_dict), + **self._get_ban_topics_config_from_env(config_dict), + **self._get_code_config_from_env(config_dict), + **self._get_gibberish_config_from_env(config_dict), + **self._get_invisible_text_config_from_env(config_dict), + **self._get_language_config_from_env(config_dict), + **self._get_prompt_injection_config_from_env(config_dict), + **self._get_regex_config_from_env(config_dict), + **self._get_secrets_config_from_env(config_dict), + **self._get_sentiment_config_from_env(config_dict), + **self._get_token_limit_config_from_env(config_dict), + **self._get_toxicity_config_from_env(config_dict), + } + + #### METHODS FOR VALIDATING CONFIGS + + def _validate_value(self, value): + """Validate and convert the input value. + + Args: + value (str): The value to be validated and converted. + + Returns: + bool | int | str: The validated and converted value. + """ + if value is None: + return None + elif value.isdigit(): + return float(value) + elif value.lower() == "true": + return True + elif value.lower() == "false": + return False + return value + + def _get_anonymize_config_from_env(self, config_dict): + """Get the Anonymize scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The anonymize scanner configuration. + """ + return { + "anonymize": { + k.replace("ANONYMIZE_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("ANONYMIZE_") + } + } + + def _get_ban_code_config_from_env(self, config_dict): + """Get the BanCode scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The BanCode scanner configuration. + """ + return { + "ban_code": { + k.replace("BAN_CODE_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("BAN_CODE_") + } + } + + def _get_ban_competitors_config_from_env(self, config_dict): + """Get the BanCompetitors scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The BanCompetitors scanner configuration. + """ + return { + "ban_competitors": { + k.replace("BAN_COMPETITORS_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("BAN_COMPETITORS_") + } + } + + def _get_ban_substrings_config_from_env(self, config_dict): + """Get the BanSubstrings scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The BanSubstrings scanner configuration. + """ + return { + "ban_substrings": { + k.replace("BAN_SUBSTRINGS_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("BAN_SUBSTRINGS_") + } + } + + def _get_ban_topics_config_from_env(self, config_dict): + """Get the BanTopics scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The BanTopics scanner configuration. + """ + return { + "ban_topics": { + k.replace("BAN_TOPICS_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("BAN_TOPICS_") + } + } + + def _get_code_config_from_env(self, config_dict): + """Get the Code scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Code scanner configuration. + """ + return { + "code": { + k.replace("CODE_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("CODE_") + } + } + + def _get_gibberish_config_from_env(self, config_dict): + """Get the Gibberish scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Gibberish scanner configuration. + """ + return { + "gibberish": { + k.replace("GIBBERISH_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("GIBBERISH_") + } + } + + def _get_invisible_text_config_from_env(self, config_dict): + """Get the InvisibleText scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The InvisibleText scanner configuration. + """ + return { + "invisible_text": { + k.replace("INVISIBLE_TEXT_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("INVISIBLE_TEXT_") + } + } + + def _get_language_config_from_env(self, config_dict): + """Get the Language scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Language scanner configuration. + """ + return { + "language": { + k.replace("LANGUAGE_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("LANGUAGE_") + } + } + + def _get_prompt_injection_config_from_env(self, config_dict): + """Get the PromptInjection scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The PromptInjection scanner configuration. + """ + return { + "prompt_injection": { + k.replace("PROMPT_INJECTION_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("PROMPT_INJECTION_") + } + } + + def _get_regex_config_from_env(self, config_dict): + """Get the Regex scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Regex scanner configuration. + """ + return { + "regex": { + k.replace("REGEX_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("REGEX_") + } + } + + def _get_secrets_config_from_env(self, config_dict): + """Get the Secrets scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Secrets scanner configuration. + """ + return { + "secrets": { + k.replace("SECRETS_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("SECRETS_") + } + } + + def _get_sentiment_config_from_env(self, config_dict): + """Get the Secrets scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Sentiment scanner configuration. + """ + return { + "sentiment": { + k.replace("SENTIMENT_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("SENTIMENT_") + } + } + + def _get_token_limit_config_from_env(self, config_dict): + """Get the TokenLimit scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The TokenLimit scanner configuration. + """ + return { + "token_limit": { + k.replace("TOKEN_LIMIT_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("TOKEN_LIMIT_") + } + } + + def _get_toxicity_config_from_env(self, config_dict): + """Get the Toxicity scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Toxicity scanner configuration. + """ + return { + "toxicity": { + k.replace("TOXICITY_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("TOXICITY_") + } + } + + #### METHODS FOR CREATING SCANNERS + + def _create_anonymize_scanner(self, scanner_config=None): + if scanner_config is None: + logger.warning( + "_create_anonymize_scanner was invoked without scanner_config. Recreating with saved config to clear the Vault." + ) + if hasattr(self, "_anonymize_params") and self._anonymize_params is not None: + scanner_config = self._anonymize_params + else: + raise ValueError( + "_create_anonymize_scanner was invoked without scanner_config but no self._anonymize_params were saved. Such action is not allowed." + ) + vault = Vault() + anonymize_params = {"vault": vault, "use_onnx": scanner_config.get("use_onnx", False)} + hidden_names = scanner_config.get("hidden_names", None) + allowed_names = scanner_config.get("allowed_names", None) + entity_types = scanner_config.get("entity_types", None) + preamble = scanner_config.get("preamble", None) + regex_patterns = scanner_config.get("regex_patterns", None) + use_faker = scanner_config.get("use_faker", None) + recognizer_conf = scanner_config.get("recognizer_conf", None) + threshold = scanner_config.get("threshold", None) + language = scanner_config.get("language", None) + + if isinstance(hidden_names, str): + hidden_names = sanitize_env(hidden_names) + + if isinstance(allowed_names, str): + allowed_names = sanitize_env(allowed_names) + + if isinstance(entity_types, str): + entity_types = sanitize_env(entity_types) + + if isinstance(regex_patterns, str): + regex_patterns = sanitize_env(regex_patterns) + + if hidden_names is not None: + if isinstance(hidden_names, str): + artifacts = set([",", "", "."]) + anonymize_params["hidden_names"] = list(set(hidden_names.split(",")) - artifacts) + elif isinstance(hidden_names, list): + anonymize_params["hidden_names"] = hidden_names + else: + logger.error("Provided type is not valid for Anonymize scanner") + raise ValueError("Provided type is not valid for Anonymize scanner") + if allowed_names is not None: + if isinstance(allowed_names, str): + artifacts = set([",", "", "."]) + anonymize_params["allowed_names"] = list(set(allowed_names.split(",")) - artifacts) + elif isinstance(hidden_names, list): + anonymize_params["allowed_names"] = allowed_names + else: + logger.error("Provided type is not valid for Anonymize scanner") + raise ValueError("Provided type is not valid for Anonymize scanner") + if entity_types is not None: + if isinstance(entity_types, str): + artifacts = set([",", "", "."]) + anonymize_params["entity_types"] = list(set(entity_types.split(",")) - artifacts) + elif isinstance(hidden_names, list): + anonymize_params["entity_types"] = entity_types + else: + logger.error("Provided type is not valid for Anonymize scanner") + raise ValueError("Provided type is not valid for Anonymize scanner") + if preamble is not None: + anonymize_params["preamble"] = preamble + if regex_patterns is not None: + if isinstance(regex_patterns, str): + artifacts = set([",", "", "."]) + anonymize_params["regex_patterns"] = list(set(regex_patterns.split(",")) - artifacts) + elif isinstance(hidden_names, list): + anonymize_params["regex_patterns"] = regex_patterns + else: + logger.error("Provided type is not valid for Anonymize scanner") + raise ValueError("Provided type is not valid for Anonymize scanner") + if use_faker is not None: + anonymize_params["use_faker"] = use_faker + if recognizer_conf is not None: + anonymize_params["recognizer_conf"] = recognizer_conf + if threshold is not None: + anonymize_params["threshold"] = threshold + if language is not None: + anonymize_params["language"] = language + logger.info(f"Creating Anonymize scanner with params: {anonymize_params}") + self._anonymize_params = {key: value for key, value in anonymize_params.items() if key != "vault"} + return Anonymize(**anonymize_params) + + def _create_ban_code_scanner(self, scanner_config): + enabled_models = {"MODEL_SM": BANCODE_MODEL_SM, "MODEL_TINY": BANCODE_MODEL_TINY} + bancode_params = {"use_onnx": scanner_config.get("use_onnx", False)} # by default we don't want to use onnx + + model_name = scanner_config.get("model", None) + threshold = scanner_config.get("threshold", None) + + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for BanCode scanner: {model_name}") + bancode_params["model"] = enabled_models[model_name] # Model class from LLM Guard + else: + err_msg = f"Model name is not valid for BanCode scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if threshold is not None: + bancode_params["threshold"] = threshold # float + logger.info(f"Creating BanCode scanner with params: {bancode_params}") + return BanCode(**bancode_params) + + def _create_ban_competitors_scanner(self, scanner_config): + enabled_models = {"MODEL_V1": BANCOMPETITORS_MODEL_V1} + ban_competitors_params = { + "use_onnx": scanner_config.get("use_onnx", False) + } # by default we don't want to use onnx + + competitors = scanner_config.get("competitors", None) + threshold = scanner_config.get("threshold", None) + redact = scanner_config.get("redact", None) + model_name = scanner_config.get("model", None) + + if isinstance(competitors, str): + competitors = sanitize_env(competitors) + + if competitors: + if isinstance(competitors, str): + artifacts = set([",", "", "."]) + ban_competitors_params["competitors"] = list(set(competitors.split(",")) - artifacts) # list + elif isinstance(competitors, list): + ban_competitors_params["competitors"] = competitors + else: + logger.error("Provided type is not valid for BanCompetitors scanner") + raise ValueError("Provided type is not valid for BanCompetitors scanner") + else: + logger.error( + "Competitors list is required for BanCompetitors scanner. Please provide a list of competitors." + ) + raise TypeError( + "Competitors list is required for BanCompetitors scanner. Please provide a list of competitors." + ) + if threshold is not None: + ban_competitors_params["threshold"] = threshold # float + if redact is not None: + ban_competitors_params["redact"] = redact + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for BanCompetitors scanner: {model_name}") + ban_competitors_params["model"] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for BanCompetitors scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + logger.info(f"Creating BanCompetitors scanner with params: {ban_competitors_params}") + return BanCompetitors(**ban_competitors_params) + + def _create_ban_substrings_scanner(self, scanner_config): + available_match_types = ["str", "word"] + ban_substrings_params = {} + + substrings = scanner_config.get("substrings", None) + match_type = scanner_config.get("match_type", None) + case_sensitive = scanner_config.get("case_sensitive", None) + redact = scanner_config.get("redact", None) + contains_all = scanner_config.get("contains_all", None) + + if isinstance(substrings, str): + substrings = sanitize_env(substrings) + + if substrings: + if isinstance(substrings, str): + artifacts = set([",", "", "."]) + ban_substrings_params["substrings"] = list(set(substrings.split(",")) - artifacts) # list + elif substrings and isinstance(substrings, list): + ban_substrings_params["substrings"] = substrings + else: + logger.error("Provided type is not valid for BanSubstrings scanner") + raise ValueError("Provided type is not valid for BanSubstrings scanner") + else: + logger.error("Substrings list is required for BanSubstrings scanner") + raise TypeError("Substrings list is required for BanSubstrings scanner") + if match_type is not None and match_type in available_match_types: + ban_substrings_params["match_type"] = match_type # MatchType + if case_sensitive is not None: + ban_substrings_params["case_sensitive"] = case_sensitive # bool + if redact is not None: + ban_substrings_params["redact"] = redact # bool + if contains_all is not None: + ban_substrings_params["contains_all"] = contains_all # bool + logger.info(f"Creating BanSubstrings scanner with params: {ban_substrings_params}") + return OPEABanSubstrings(**ban_substrings_params) + + def _create_ban_topics_scanner(self, scanner_config): + enabled_models = { + "MODEL_DEBERTA_LARGE_V2": BANTOPICS_MODEL_DEBERTA_LARGE_V2, + "MODEL_DEBERTA_BASE_V2": BANTOPICS_MODEL_DEBERTA_BASE_V2, + "MODEL_BGE_M3_V2": BANTOPICS_MODEL_BGE_M3_V2, + "MODEL_ROBERTA_LARGE_C_V2": BANTOPICS_MODEL_ROBERTA_LARGE_C_V2, + "MODEL_ROBERTA_BASE_C_V2": BANTOPICS_MODEL_ROBERTA_BASE_C_V2, + } + ban_topics_params = {"use_onnx": scanner_config.get("use_onnx", False)} + + topics = scanner_config.get("topics", None) + threshold = scanner_config.get("threshold", None) + model_name = scanner_config.get("model", None) + + if isinstance(topics, str): + topics = sanitize_env(topics) + + if topics: + if isinstance(topics, str): + artifacts = set([",", "", "."]) + ban_topics_params["topics"] = list(set(topics.split(",")) - artifacts) + elif isinstance(topics, list): + ban_topics_params["topics"] = topics + else: + logger.error("Provided type is not valid for BanTopics scanner") + raise ValueError("Provided type is not valid for BanTopics scanner") + else: + logger.error("Topics list is required for BanTopics scanner") + raise TypeError("Topics list is required for BanTopics scanner") + if threshold is not None: + ban_topics_params["threshold"] = threshold + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for BanTopics scanner: {model_name}") + ban_topics_params["model"] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for BanTopics scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + logger.info(f"Creating BanTopics scanner with params: {ban_topics_params}") + return BanTopics(**ban_topics_params) + + def _create_code_scanner(self, scanner_config): + enabled_models = {"DEFAULT_MODEL": CODE_DEFAULT_MODEL} + code_params = {"use_onnx": scanner_config.get("use_onnx", False)} + + languages = scanner_config.get("languages", None) + model_name = scanner_config.get("model", None) + is_blocked = scanner_config.get("is_blocked", None) + threshold = scanner_config.get("threshold", None) + + if isinstance(languages, str): + languages = sanitize_env(languages) + + if languages: + if isinstance(languages, str): + artifacts = set([",", "", "."]) + code_params["languages"] = list(set(languages.split(",")) - artifacts) + elif isinstance(languages, list): + code_params["languages"] = languages + else: + logger.error("Provided type is not valid for Code scanner") + raise ValueError("Provided type is not valid for Code scanner") + else: + logger.error("Languages list is required for Code scanner") + raise TypeError("Languages list is required for Code scanner") + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for Code scanner: {model_name}") + code_params["model"] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for Code scanner. Please provide a valid model name. Provided model: {model_name}" + logger.error(err_msg) + raise ValueError(err_msg) + if is_blocked is not None: + code_params["is_blocked"] = is_blocked + if threshold is not None: + code_params["threshold"] = threshold + logger.info(f"Creating Code scanner with params: {code_params}") + return Code(**code_params) + + def _create_gibberish_scanner(self, scanner_config): + enabled_models = {"DEFAULT_MODEL": GIBBERISH_DEFAULT_MODEL} + enabled_match_types = ["sentence", "full"] + gibberish_params = {"use_onnx": scanner_config.get("use_onnx", False)} + + model_name = scanner_config.get("model", None) + threshold = scanner_config.get("threshold", None) + match_type = scanner_config.get("match_type", None) + + if match_type == "sentence": + import nltk + + nltk.download("punkt_tab") + + if threshold is not None: + gibberish_params["threshold"] = threshold + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for Gibberish scanner: {model_name}") + gibberish_params["model"] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for Gibberish scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if match_type is not None and match_type in enabled_match_types: + gibberish_params["match_type"] = match_type + + logger.info(f"Creating Gibberish scanner with params: {gibberish_params}") + return Gibberish(**gibberish_params) + + def _create_invisible_text_scanner(self): + return InvisibleText() + + def _create_language_scanner(self, scanner_config): + enabled_models = {"DEFAULT_MODEL": LANGUAGE_DEFAULT_MODEL} + enabled_match_types = ["sentence", "full"] + language_params = {"use_onnx": scanner_config.get("use_onnx", False)} + + valid_languages = scanner_config.get("valid_languages", None) + model_name = scanner_config.get("model", None) + threshold = scanner_config.get("threshold", None) + match_type = scanner_config.get("match_type", None) + + if isinstance(valid_languages, str): + valid_languages = sanitize_env(valid_languages) + + if valid_languages: + if isinstance(valid_languages, str): + artifacts = set([",", "", "."]) + language_params["valid_languages"] = list(set(valid_languages.split(",")) - artifacts) + elif isinstance(valid_languages, list): + language_params["valid_languages"] = valid_languages + else: + logger.error("Provided type is not valid for Language scanner") + raise ValueError("Provided type is not valid for Language scanner") + else: + logger.error("Valid languages list is required for Language scanner") + raise TypeError("Valid languages list is required for Language scanner") + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for Language scanner: {model_name}") + language_params["model"] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for Language scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if threshold is not None: + language_params["threshold"] = threshold + if match_type is not None and match_type in enabled_match_types: + language_params["match_type"] = match_type + logger.info(f"Creating Language scanner with params: {language_params}") + return Language(**language_params) + + def _create_prompt_injection_scanner(self, scanner_config): + enabled_models = { + "V1_MODEL": PROMPTINJECTION_V1_MODEL, + "V2_MODEL": PROMPTINJECTION_V2_MODEL, + "V2_SMALL_MODEL": PROMPTINJECTION_V2_SMALL_MODEL, + } + enabled_match_types = ["sentence", "full", "truncate_token_head_tail", "truncate_head_tail", "chunks"] + prompt_injection_params = {"use_onnx": scanner_config.get("use_onnx", False)} + + model_name = scanner_config.get("model", None) + threshold = scanner_config.get("threshold", None) + match_type = scanner_config.get("match_type", None) + + if match_type == "sentence": + import nltk + + nltk.download("punkt_tab") + + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for PromptInjection scanner: {model_name}") + prompt_injection_params["model"] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for PromptInjection scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if threshold is not None: + prompt_injection_params["threshold"] = threshold + if match_type is not None and match_type in enabled_match_types: + prompt_injection_params["match_type"] = match_type + logger.info(f"Creating PromptInjection scanner with params: {prompt_injection_params}") + return PromptInjection(**prompt_injection_params) + + def _create_regex_scanner(self, scanner_config): + enabled_match_types = ["search", "fullmatch"] + regex_params = {} + + patterns = scanner_config.get("patterns", None) + is_blocked = scanner_config.get("is_blocked", None) + match_type = scanner_config.get("match_type", None) + redact = scanner_config.get("redact", None) + + if isinstance(patterns, str): + patterns = sanitize_env(patterns) + + if patterns: + if isinstance(patterns, str): + artifacts = set([",", "", "."]) + regex_params["patterns"] = list(set(patterns.split(",")) - artifacts) + elif isinstance(patterns, list): + regex_params["patterns"] = patterns + else: + logger.error("Provided type is not valid for Regex scanner") + raise ValueError("Provided type is not valid for Regex scanner") + else: + logger.error("Patterns list is required for Regex scanner") + raise TypeError("Patterns list is required for Regex scanner") + if is_blocked is not None: + regex_params["is_blocked"] = is_blocked + if match_type is not None and match_type in enabled_match_types: + regex_params["match_type"] = match_type + if redact is not None: + regex_params["redact"] = redact + + logger.info(f"Creating Regex scanner with params: {regex_params}") + return OPEARegexScanner(**regex_params) + + def _create_secrets_scanner(self, scanner_config): + enabled_redact_types = ["partial", "all", "hash"] + secrets_params = {} + + redact = scanner_config.get("redact", None) + + if redact is not None and redact in enabled_redact_types: + secrets_params["redact"] = redact + + logger.info(f"Creating Secrets scanner with params: {secrets_params}") + return Secrets(**secrets_params) + + def _create_sentiment_scanner(self, scanner_config): + enabled_lexicons = ["vader_lexicon"] + sentiment_params = {} + + threshold = scanner_config.get("threshold", None) + lexicon = scanner_config.get("lexicon", None) + + if threshold is not None: + sentiment_params["threshold"] = threshold + if lexicon is not None and lexicon in enabled_lexicons: + sentiment_params["lexicon"] = lexicon + + logger.info(f"Creating Sentiment scanner with params: {sentiment_params}") + return Sentiment(**sentiment_params) + + def _create_token_limit_scanner(self, scanner_config): + enabled_encodings = ["cl100k_base"] # TODO: test more encoding from tiktoken + token_limit_params = {} + + limit = int(scanner_config.get("limit", None)) + encoding_name = scanner_config.get("encoding", None) + + if limit is not None: + token_limit_params["limit"] = limit + if encoding_name is not None and encoding_name in enabled_encodings: + token_limit_params["encoding_name"] = encoding_name + + logger.info(f"Creating TokenLimit scanner with params: {token_limit_params}") + return TokenLimit(**token_limit_params) + + def _create_toxicity_scanner(self, scanner_config): + enabled_models = {"DEFAULT_MODEL": TOXICITY_DEFAULT_MODEL} + enabled_match_types = ["sentence", "full"] + toxicity_params = {"use_onnx": scanner_config.get("use_onnx", False)} + + model_name = scanner_config.get("model", None) + threshold = scanner_config.get("threshold", None) + match_type = scanner_config.get("match_type", None) + + if match_type == "sentence": + import nltk + + nltk.download("punkt_tab") + + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for Toxicity scanner: {model_name}") + toxicity_params["model"] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for Toxicity scanner. Please provide a valid model name. Provided model: {model_name}" + logger.error(err_msg) + raise ValueError(err_msg) + if threshold is not None: + toxicity_params["threshold"] = threshold + if match_type is not None and match_type in enabled_match_types: + toxicity_params["match_type"] = match_type + + logger.info(f"Creating Toxicity scanner with params: {toxicity_params}") + return Toxicity(**toxicity_params) + + def _create_input_scanner(self, scanner_name, scanner_config): + if scanner_name not in ENABLED_SCANNERS: + logger.error(f"Scanner {scanner_name} is not supported. Enabled scanners are: {ENABLED_SCANNERS}") + raise ValueError(f"Scanner {scanner_name} is not supported") + if scanner_name == "anonymize": + return self._create_anonymize_scanner(scanner_config) + elif scanner_name == "ban_code": + return self._create_ban_code_scanner(scanner_config) + elif scanner_name == "ban_competitors": + return self._create_ban_competitors_scanner(scanner_config) + elif scanner_name == "ban_substrings": + return self._create_ban_substrings_scanner(scanner_config) + elif scanner_name == "ban_topics": + return self._create_ban_topics_scanner(scanner_config) + elif scanner_name == "code": + return self._create_code_scanner(scanner_config) + elif scanner_name == "gibberish": + return self._create_gibberish_scanner(scanner_config) + elif scanner_name == "invisible_text": + return self._create_invisible_text_scanner() + elif scanner_name == "language": + return self._create_language_scanner(scanner_config) + elif scanner_name == "prompt_injection": + return self._create_prompt_injection_scanner(scanner_config) + elif scanner_name == "regex": + return self._create_regex_scanner(scanner_config) + elif scanner_name == "secrets": + return self._create_secrets_scanner(scanner_config) + elif scanner_name == "sentiment": + return self._create_sentiment_scanner(scanner_config) + elif scanner_name == "token_limit": + return self._create_token_limit_scanner(scanner_config) + elif scanner_name == "toxicity": + return self._create_toxicity_scanner(scanner_config) + return None + + def create_enabled_input_scanners(self): + """Create and return a list of enabled scanners based on the global configuration. + + Returns: + list: A list of enabled scanner instances. + """ + enabled_scanners_names_and_configs = {k: v for k, v in self._input_scanners_config.items() if v.get("enabled")} + enabled_scanners_objects = [] + + err_msgs = {} # list for all erroneous scanners + only_validation_errors = True + for scanner_name, scanner_config in enabled_scanners_names_and_configs.items(): + try: + logger.info(f"Attempting to create scanner: {scanner_name}") + scanner_object = self._create_input_scanner(scanner_name, scanner_config) + enabled_scanners_objects.append(scanner_object) + except ValueError as e: + err_msg = f"A ValueError occurred during creating input scanner {scanner_name}: {e}" + logger.error(err_msg) + err_msgs[scanner_name] = err_msg + self._input_scanners_config[scanner_name]["enabled"] = False + continue + except TypeError as e: + err_msg = f"A TypeError occurred during creating input scanner {scanner_name}: {e}" + logger.error(err_msg) + err_msgs[scanner_name] = err_msg + self._input_scanners_config[scanner_name]["enabled"] = False + continue + except Exception as e: + err_msg = f"An unexpected error occurred during creating input scanner {scanner_name}: {e}" + logger.error(err_msg) + err_msgs[scanner_name] = err_msg + only_validation_errors = False + self._input_scanners_config[scanner_name]["enabled"] = False + continue + + if err_msgs: + if only_validation_errors: + raise ValueError( + f"Some scanners failed to be created due to validation errors. The details: {err_msgs}" + ) + else: + raise Exception( + f"Some scanners failed to be created due to validation or unexpected errors. The details: {err_msgs}" + ) + + return [s for s in enabled_scanners_objects if s is not None] + + def changed(self, new_scanners_config): + """Check if the scanners configuration has changed. + + Args: + new_scanners_config (dict): The current scanners configuration. + + Returns: + bool: True if the configuration has changed, False otherwise. + """ + del new_scanners_config["id"] + newly_enabled_scanners = { + k: {in_k: in_v for in_k, in_v in v.items() if in_k != "id"} + for k, v in new_scanners_config.items() + if v.get("enabled") + } + previously_enabled_scanners = {k: v for k, v in self._input_scanners_config.items() if v.get("enabled")} + if newly_enabled_scanners == previously_enabled_scanners: # if the enables scanners are the same we do nothing + logger.info("No changes in list for enabled scanners. Checking configuration changes...") + return False + else: + logger.warning("Sanners configuration has been changed, re-creating scanners") + self._input_scanners_config.clear() + stripped_new_scanners_config = { + k: {in_k: in_v for in_k, in_v in v.items() if in_k != "id"} for k, v in new_scanners_config.items() + } + self._input_scanners_config.update(stripped_new_scanners_config) + return True diff --git a/comps/guardrails/src/guardrails/utils/llm_guard_output_guardrail.py b/comps/guardrails/src/guardrails/utils/llm_guard_output_guardrail.py new file mode 100644 index 0000000000..c7c993a214 --- /dev/null +++ b/comps/guardrails/src/guardrails/utils/llm_guard_output_guardrail.py @@ -0,0 +1,98 @@ +# Copyright (C) 2024-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from fastapi import HTTPException +from llm_guard import scan_output +from utils.llm_guard_output_scanners import OutputScannersConfig + +from comps import CustomLogger, GeneratedDoc + +logger = CustomLogger("opea_llm_guard_output_guardrail_microservice") + + +class OPEALLMGuardOutputGuardrail: + """OPEALLMGuardOutputGuardrail is responsible for scanning and sanitizing LLM output responses + using various output scanners provided by LLM Guard. + + This class initializes the output scanners based on the provided configuration and + scans the output responses to ensure they meet the required guardrail criteria. + + Attributes: + _scanners (list): A list of enabled scanners. + + Methods: + __init__(usv_config: list): + Initializes the OPEALLMGuardOutputGuardrail with the provided configuration. + + scan_llm_output(output_doc: object) -> str: + Scans the output from an LLM output document and returns the sanitized output. + """ + + def __init__(self, usv_config: list): + """Initializes the OPEALLMGuardOutputGuardrail with the provided configuration. + + Args: + usv_config (list): The configuration list for initializing the output scanners. + + Raises: + Exception: If an unexpected error occurs during the initialization of scanners. + """ + try: + self._scanners_config = OutputScannersConfig(usv_config) + self._scanners = self._scanners_config.create_enabled_output_scanners() + except Exception as e: + logger.exception( + f"An unexpected error occurred during initializing \ + LLM Guard Output Guardrail scanners: {e}" + ) + raise + + def scan_llm_output(self, output_doc: GeneratedDoc) -> str: + """Scans the output from an LLM output document. + + Args: + output_doc (object): The output document containing the response to be scanned. + + Returns: + str: The sanitized output. + + Raises: + HTTPException: If the output is not valid based on the scanner results. + Exception: If an unexpected error occurs during scanning. + """ + try: + if output_doc.output_guardrail_params is not None: + self._scanners_config.vault = output_doc.output_guardrail_params.anonymize_vault + if self._scanners_config.changed(output_doc.output_guardrail_params.dict()): + self._scanners = self._scanners_config.create_enabled_output_scanners() + else: + logger.warning("Output guardrail params not found in input document.") + if self._scanners: + sanitized_output, results_valid, results_score = scan_output( + self._scanners, output_doc.prompt, output_doc.text + ) + if False in results_valid.values(): + msg = f"LLM Output {output_doc.text} is not valid, scores: {results_score}" + logger.error(msg) + usr_msg = "I'm sorry, but the model output is not valid according to the policies." + redact_or_truncated = [ + c.get("redact", False) or c.get("truncate", False) + for _, c in self._scanners_config._output_scanners_config.items() + ] # to see if sanitized output available + if any(redact_or_truncated): + usr_msg = f"We sanitized the answer due to the guardrails policies: {sanitized_output}" + raise HTTPException(status_code=466, detail=usr_msg) + return sanitized_output + else: + logger.warning("No output scanners enabled. Skipping scanning.") + return output_doc.text + except HTTPException as e: + raise e + except ValueError as e: + error_msg = f"Validation Error occurred while initializing LLM Guard Output Guardrail scanners: {e}" + logger.exception(error_msg) + raise HTTPException(status_code=400, detail=error_msg) + except Exception as e: + error_msg = f"An unexpected error occurred during scanning prompt with LLM Guard Output Guardrail: {e}" + logger.exception(error_msg) + raise HTTPException(status_code=500, detail=error_msg) diff --git a/comps/guardrails/src/guardrails/utils/llm_guard_output_scanners.py b/comps/guardrails/src/guardrails/utils/llm_guard_output_scanners.py new file mode 100644 index 0000000000..af5dcea20a --- /dev/null +++ b/comps/guardrails/src/guardrails/utils/llm_guard_output_scanners.py @@ -0,0 +1,1216 @@ +# ruff: noqa: F401 +# Copyright (C) 2024-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# import models definition +from llm_guard.input_scanners.ban_code import ( + MODEL_SM as BANCODE_MODEL_SM, # input, because the same scanner to input and output +) +from llm_guard.input_scanners.ban_code import MODEL_TINY as BANCODE_MODEL_TINY +from llm_guard.input_scanners.ban_competitors import ( + MODEL_V1 as BANCOMPETITORS_MODEL_V1, # input, because the same scanner to input and output +) +from llm_guard.input_scanners.ban_topics import MODEL_BGE_M3_V2 as BANTOPICS_MODEL_BGE_M3_V2 +from llm_guard.input_scanners.ban_topics import MODEL_DEBERTA_BASE_V2 as BANTOPICS_MODEL_DEBERTA_BASE_V2 +from llm_guard.input_scanners.ban_topics import ( + MODEL_DEBERTA_LARGE_V2 as BANTOPICS_MODEL_DEBERTA_LARGE_V2, # input, because the same scanner to input and output +) +from llm_guard.input_scanners.ban_topics import MODEL_ROBERTA_BASE_C_V2 as BANTOPICS_MODEL_ROBERTA_BASE_C_V2 +from llm_guard.input_scanners.ban_topics import MODEL_ROBERTA_LARGE_C_V2 as BANTOPICS_MODEL_ROBERTA_LARGE_C_V2 +from llm_guard.input_scanners.code import DEFAULT_MODEL as CODE_DEFAULT_MODEL +from llm_guard.input_scanners.gibberish import DEFAULT_MODEL as GIBBERISH_DEFAULT_MODEL +from llm_guard.input_scanners.language import DEFAULT_MODEL as LANGUAGE_DEFAULT_MODEL +from llm_guard.input_scanners.toxicity import DEFAULT_MODEL as TOXICITY_DEFAULT_MODEL +from llm_guard.output_scanners import ( + JSON, + BanCode, + BanCompetitors, + BanTopics, + Bias, + Code, + Deanonymize, + FactualConsistency, + Gibberish, + Language, + LanguageSame, + MaliciousURLs, + NoRefusal, + NoRefusalLight, + ReadingTime, + Relevance, + Sensitive, + Sentiment, + Toxicity, + URLReachability, +) +from llm_guard.output_scanners.bias import DEFAULT_MODEL as BIAS_DEFAULT_MODEL +from llm_guard.output_scanners.malicious_urls import DEFAULT_MODEL as MALICIOUS_URLS_DEFAULT_MODEL +from llm_guard.output_scanners.no_refusal import DEFAULT_MODEL as NO_REFUSAL_DEFAULT_MODEL +from llm_guard.output_scanners.relevance import MODEL_EN_BGE_BASE as RELEVANCE_MODEL_EN_BGE_BASE +from llm_guard.output_scanners.relevance import MODEL_EN_BGE_LARGE as RELEVANCE_MODEL_EN_BGE_LARGE +from llm_guard.output_scanners.relevance import MODEL_EN_BGE_SMALL as RELEVANCE_MODEL_EN_BGE_SMALL +from llm_guard.vault import Vault + +ENABLED_SCANNERS = [ + "ban_code", + "ban_competitors", + "ban_substrings", + "ban_topics", + "bias", + "code", + "deanonymize", + "json_scanner", + "language", + "language_same", + "malicious_urls", + "no_refusal", + "no_refusal_light", + "reading_time", + "factual_consistency", + "gibberish", + "regex", + "relevance", + "sensitive", + "sentiment", + "toxicity", + "url_reachability", +] + +from comps import CustomLogger +from comps.cores.mega.utils import sanitize_env +from comps.guardrails.src.guardrails.utils.scanners import OPEABanSubstrings, OPEARegexScanner + +logger = CustomLogger("opea_llm_guard_output_guardrail_microservice") + + +class OutputScannersConfig: + def __init__(self, config_dict): + self._output_scanners_config = { + **self._get_ban_code_config_from_env(config_dict), + **self._get_ban_competitors_config_from_env(config_dict), + **self._get_ban_substrings_config_from_env(config_dict), + **self._get_ban_topics_config_from_env(config_dict), + **self._get_bias_config_from_env(config_dict), + **self._get_code_config_from_env(config_dict), + **self._get_deanonymize_config_from_env(config_dict), + **self._get_json_scanner_config_from_env(config_dict), + **self._get_language_config_from_env(config_dict), + **self._get_language_same_config_from_env(config_dict), + **self._get_malicious_urls_config_from_env(config_dict), + **self._get_no_refusal_config_from_env(config_dict), + **self._get_no_refusal_light_config_from_env(config_dict), + **self._get_reading_time_config_from_env(config_dict), + **self._get_factual_consistency_config_from_env(config_dict), + **self._get_gibberish_config_from_env(config_dict), + **self._get_regex_config_from_env(config_dict), + **self._get_relevance_config_from_env(config_dict), + **self._get_sensitive_config_from_env(config_dict), + **self._get_sentiment_config_from_env(config_dict), + **self._get_toxicity_config_from_env(config_dict), + **self._get_url_reachability_config_from_env(config_dict), + } + self.vault = None + + #### METHODS FOR VALIDATING CONFIGS + + def _validate_value(self, value): + """Validate and convert the input value. + + Args: + value (str): The value to be validated and converted. + + Returns: + bool | int | str: The validated and converted value. + """ + if value is None: + return None + elif value.isdigit(): + return float(value) + elif value.lower() == "true": + return True + elif value.lower() == "false": + return False + return value + + def _get_ban_code_config_from_env(self, config_dict): + """Get the BanCode scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The BanCode scanner configuration. + """ + return { + "ban_code": { + k.replace("BAN_CODE_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("BAN_CODE_") + } + } + + def _get_ban_competitors_config_from_env(self, config_dict): + """Get the BanCompetitors scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The BanCompetitors scanner configuration. + """ + return { + "ban_competitors": { + k.replace("BAN_COMPETITORS_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("BAN_COMPETITORS_") + } + } + + def _get_ban_substrings_config_from_env(self, config_dict): + """Get the BanSubstrings scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The BanSubstrings scanner configuration. + """ + return { + "ban_substrings": { + k.replace("BAN_SUBSTRINGS_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("BAN_SUBSTRINGS_") + } + } + + def _get_ban_topics_config_from_env(self, config_dict): + """Get the BanTopics scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The BanTopics scanner configuration. + """ + return { + "ban_topics": { + k.replace("BAN_TOPICS_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("BAN_TOPICS_") + } + } + + def _get_bias_config_from_env(self, config_dict): + """Get the Bias scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Bias scanner configuration. + """ + return { + "bias": { + k.replace("BIAS_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("BIAS_") + } + } + + def _get_code_config_from_env(self, config_dict): + """Get the Code scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Code scanner configuration. + """ + return { + "code": { + k.replace("CODE_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("CODE_") + } + } + + def _get_deanonymize_config_from_env(self, config_dict): + """Get the Deanonymize scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The deanonymize scanner configuration. + """ + return { + "deanonymize": { + k.replace("DEANONYMIZE_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("DEANONYMIZE_") + } + } + + def _get_json_scanner_config_from_env(self, config_dict): + """Get the JSON scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The JSON scanner configuration. + """ + return { + "json_scanner": { + k.replace("JSON_SCANNER_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("JSON_SCANNER_") + } + } + + def _get_language_config_from_env(self, config_dict): + """Get the Language scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Language scanner configuration. + """ + return { + "language": { + k.replace("LANGUAGE_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("LANGUAGE_") + } + } + + def _get_language_same_config_from_env(self, config_dict): + """Get the LanguageSame scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The LanguageSame scanner configuration. + """ + return { + "language_same": { + k.replace("LANGUAGE_SAME_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("LANGUAGE_SAME_") + } + } + + def _get_malicious_urls_config_from_env(self, config_dict): + """Get the MaliciousURLs scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The MaliciousURLs scanner configuration. + """ + return { + "malicious_urls": { + k.replace("MALICIOUS_URLS_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("MALICIOUS_URLS_") + } + } + + def _get_no_refusal_config_from_env(self, config_dict): + """Get the NoRefusal scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The NoRefusal scanner configuration. + """ + return { + "no_refusal": { + k.replace("NO_REFUSAL_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("NO_REFUSAL_") + } + } + + def _get_no_refusal_light_config_from_env(self, config_dict): + """Get the NoRefusalLight scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The NoRefusalLight scanner configuration. + """ + return { + "no_refusal_light": { + k.replace("NO_REFUSAL_LIGHT_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("NO_REFUSAL_LIGHT_") + } + } + + def _get_reading_time_config_from_env(self, config_dict): + """Get the ReadingTime scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The ReadingTime scanner configuration. + """ + return { + "reading_time": { + k.replace("READING_TIME_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("READING_TIME_") + } + } + + def _get_factual_consistency_config_from_env(self, config_dict): + """Get the FactualConsitency scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The FactualConsitency scanner configuration. + """ + return { + "factual_consistency": { + k.replace("FACTUAL_CONSISTENCY_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("FACTUAL_CONSISTENCY_") + } + } + + def _get_gibberish_config_from_env(self, config_dict): + """Get the Gibberish scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Gibberish scanner configuration. + """ + return { + "gibberish": { + k.replace("GIBBERISH_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("GIBBERISH_") + } + } + + def _get_regex_config_from_env(self, config_dict): + """Get the Regex scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Regex scanner configuration. + """ + return { + "regex": { + k.replace("REGEX_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("REGEX_") + } + } + + def _get_relevance_config_from_env(self, config_dict): + """Get the Relevance scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Relevance scanner configuration. + """ + return { + "relevance": { + k.replace("RELEVANCE_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("RELEVANCE_") + } + } + + def _get_sensitive_config_from_env(self, config_dict): + """Get the Sensitive scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Sensitive scanner configuration. + """ + return { + "sensitive": { + k.replace("SENSITIVE_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("SENSITIVE_") + } + } + + def _get_sentiment_config_from_env(self, config_dict): + """Get the Sentiment scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Sentiment scanner configuration. + """ + return { + "sentiment": { + k.replace("SENTIMENT_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("SENTIMENT_") + } + } + + def _get_toxicity_config_from_env(self, config_dict): + """Get the Toxicity scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Toxicity scanner configuration. + """ + return { + "toxicity": { + k.replace("TOXICITY_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("TOXICITY_") + } + } + + def _get_url_reachability_config_from_env(self, config_dict): + """Get the URLReachability scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The URLReachability scanner configuration. + """ + return { + "url_reachability": { + k.replace("URL_REACHABILITY_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() + if k.startswith("URL_REACHABILITY_") + } + } + + #### METHODS FOR CREATING SCANNERS + + def _create_ban_code_scanner(self, scanner_config): + enabled_models = {"MODEL_SM": BANCODE_MODEL_SM, "MODEL_TINY": BANCODE_MODEL_TINY} + bancode_params = {"use_onnx": scanner_config.get("use_onnx", False)} # by default we don't want to use onnx + + model_name = scanner_config.get("model", None) + threshold = scanner_config.get("threshold", None) + + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for BanCode scanner: {model_name}") + bancode_params["model"] = enabled_models[model_name] # Model class from LLM Guard + else: + err_msg = f"Model name is not valid for BanCode scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if threshold is not None: + bancode_params["threshold"] = threshold + logger.info(f"Creating BanCode scanner with params: {bancode_params}") + return BanCode(**bancode_params) + + def _create_ban_competitors_scanner(self, scanner_config): + enabled_models = {"MODEL_V1": BANCOMPETITORS_MODEL_V1} + ban_competitors_params = { + "use_onnx": scanner_config.get("use_onnx", False) + } # by default we want don't to use onnx + + competitors = scanner_config.get("competitors", None) + threshold = scanner_config.get("threshold", None) + redact = scanner_config.get("redact", None) + model_name = scanner_config.get("model", None) + + if isinstance(competitors, str): + competitors = sanitize_env(competitors) + + if competitors: + if isinstance(competitors, str): + artifacts = set([",", "", "."]) + ban_competitors_params["competitors"] = list(set(competitors.split(",")) - artifacts) + elif isinstance(competitors, list): + ban_competitors_params["competitors"] = competitors + else: + logger.error("Provided type is not valid for BanCompetitors scanner") + raise ValueError("Provided type is not valid for BanCompetitors scanner") + else: + logger.error( + "Competitors list is required for BanCompetitors scanner. Please provide a list of competitors." + ) + raise TypeError( + "Competitors list is required for BanCompetitors scanner. Please provide a list of competitors." + ) + if threshold is not None: + ban_competitors_params["threshold"] = threshold + if redact is not None: + ban_competitors_params["redact"] = redact + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for BanCompetitors scanner: {model_name}") + ban_competitors_params["model"] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for BanCompetitors scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + logger.info(f"Creating BanCompetitors scanner with params: {ban_competitors_params}") + return BanCompetitors(**ban_competitors_params) + + def _create_ban_substrings_scanner(self, scanner_config): + available_match_types = ["str", "word"] + ban_substrings_params = {} + + substrings = scanner_config.get("substrings", None) + match_type = scanner_config.get("match_type", None) + case_sensitive = scanner_config.get("case_sensitive", None) + redact = scanner_config.get("redact", None) + contains_all = scanner_config.get("contains_all", None) + + if isinstance(substrings, str): + substrings = sanitize_env(substrings) + + if substrings: + if isinstance(substrings, str): + artifacts = set([",", "", "."]) + ban_substrings_params["substrings"] = list(set(substrings.split(",")) - artifacts) + elif substrings and isinstance(substrings, list): + ban_substrings_params["substrings"] = substrings + else: + logger.error("Provided type is not valid for BanSubstrings scanner") + raise ValueError("Provided type is not valid for BanSubstrings scanner") + else: + logger.error("Substrings list is required for BanSubstrings scanner") + raise TypeError("Substrings list is required for BanSubstrings scanner") + if match_type is not None and match_type in available_match_types: + ban_substrings_params["match_type"] = match_type + if case_sensitive is not None: + ban_substrings_params["case_sensitive"] = case_sensitive + if redact is not None: + ban_substrings_params["redact"] = redact + if contains_all is not None: + ban_substrings_params["contains_all"] = contains_all + logger.info(f"Creating BanSubstrings scanner with params: {ban_substrings_params}") + return OPEABanSubstrings(**ban_substrings_params) + + def _create_ban_topics_scanner(self, scanner_config): + enabled_models = { + "MODEL_DEBERTA_LARGE_V2": BANTOPICS_MODEL_DEBERTA_LARGE_V2, + "MODEL_DEBERTA_BASE_V2": BANTOPICS_MODEL_DEBERTA_BASE_V2, + "MODEL_BGE_M3_V2": BANTOPICS_MODEL_BGE_M3_V2, + "MODEL_ROBERTA_LARGE_C_V2": BANTOPICS_MODEL_ROBERTA_LARGE_C_V2, + "MODEL_ROBERTA_BASE_C_V2": BANTOPICS_MODEL_ROBERTA_BASE_C_V2, + } + ban_topics_params = {"use_onnx": scanner_config.get("use_onnx", False)} + + topics = scanner_config.get("topics", None) + threshold = scanner_config.get("threshold", None) + model_name = scanner_config.get("model", None) + + if isinstance(topics, str): + topics = sanitize_env(topics) + + if topics: + if isinstance(topics, str): + artifacts = set([",", "", "."]) + ban_topics_params["topics"] = list(set(topics.split(",")) - artifacts) + elif isinstance(topics, list): + ban_topics_params["topics"] = topics + else: + logger.error("Provided type is not valid for BanTopics scanner") + raise ValueError("Provided type is not valid for BanTopics scanner") + else: + logger.error("Topics list is required for BanTopics scanner") + raise TypeError("Topics list is required for BanTopics scanner") + if threshold is not None: + ban_topics_params["threshold"] = threshold + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for BanTopics scanner: {model_name}") + ban_topics_params["model"] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for BanTopics scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + logger.info(f"Creating BanTopics scanner with params: {ban_topics_params}") + return BanTopics(**ban_topics_params) + + def _create_bias_scanner(self, scanner_config): + available_match_types = ["str", "word"] + enabled_models = {"DEFAULT_MODEL": BIAS_DEFAULT_MODEL} + bias_params = {"use_onnx": scanner_config.get("use_onnx", False)} + + threshold = scanner_config.get("threshold", None) + match_type = scanner_config.get("match_type", None) + model_name = scanner_config.get("model", None) + + if threshold is not None: + bias_params["threshold"] = threshold + if match_type is not None and match_type in available_match_types: + bias_params["match_type"] = match_type + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for Bias scanner: {model_name}") + bias_params["model"] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for Bias scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + + logger.info(f"Creating Bias scanner with params: {bias_params}") + return Bias(**bias_params) + + def _create_code_scanner(self, scanner_config): + enabled_models = {"DEFAULT_MODEL": CODE_DEFAULT_MODEL} + code_params = {"use_onnx": scanner_config.get("use_onnx", False)} + + languages = scanner_config.get("languages", None) + model_name = scanner_config.get("model", None) + is_blocked = scanner_config.get("is_blocked", None) + threshold = scanner_config.get("threshold", None) + + if isinstance(languages, str): + languages = sanitize_env(languages) + + if languages: + if isinstance(languages, str): + artifacts = set([",", "", "."]) + code_params["languages"] = list(set(languages.split(",")) - artifacts) + elif isinstance(languages, list): + code_params["languages"] = languages + else: + logger.error("Provided type is not valid for Code scanner") + raise ValueError("Provided type is not valid for Code scanner") + else: + logger.error("Languages list is required for Code scanner") + raise TypeError("Languages list is required for Code scanner") + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for Code scanner: {model_name}") + code_params["model"] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for Code scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if is_blocked is not None: + code_params["is_blocked"] = is_blocked + if threshold is not None: + code_params["threshold"] = threshold + logger.info(f"Creating Code scanner with params: {code_params}") + return Code(**code_params) + + def _create_deanonymize_scanner(self, scanner_config, vault): + if not vault: + raise Exception("Vault is required for Deanonymize scanner") + deanonymize_params = {"vault": vault} + + matching_strategy = scanner_config.get("matching_strategy", None) + if matching_strategy is not None: + deanonymize_params["matching_strategy"] = matching_strategy + + logger.info(f"Creating Deanonymize scanner with params: {deanonymize_params}") + return Deanonymize(**deanonymize_params) + + def _create_json_scanner(self, scanner_config): + json_scanner_params = {} + + required_elements = scanner_config.get("required_elements", None) + repair = scanner_config.get("repair", None) + + if required_elements is not None: + json_scanner_params["required_elements"] = required_elements + if repair is not None: + json_scanner_params["repair"] = repair + + logger.info(f"Creating JSON scanner with params: {json_scanner_params}") + return JSON(**json_scanner_params) + + def _create_language_scanner(self, scanner_config): + enabled_models = {"DEFAULT_MODEL": LANGUAGE_DEFAULT_MODEL} + enabled_match_types = ["sentence", "full"] + language_params = {"use_onnx": scanner_config.get("use_onnx", False)} + + valid_languages = scanner_config.get("valid_languages", None) + model_name = scanner_config.get("model", None) + threshold = scanner_config.get("threshold", None) + match_type = scanner_config.get("match_type", None) + + if isinstance(valid_languages, str): + valid_languages = sanitize_env(valid_languages) + + if valid_languages: + if isinstance(valid_languages, str): + artifacts = set([",", "", "."]) + language_params["valid_languages"] = list(set(valid_languages.split(",")) - artifacts) + elif isinstance(valid_languages, list): + language_params["valid_languages"] = valid_languages + else: + logger.error("Provided type is not valid for Language scanner") + raise ValueError("Provided type is not valid for Language scanner") + else: + logger.error("Valid languages list is required for Language scanner") + raise TypeError("Valid languages list is required for Language scanner") + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for Language scanner: {model_name}") + language_params["model"] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for Language scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if threshold is not None: + language_params["threshold"] = threshold + if match_type is not None and match_type in enabled_match_types: + language_params["match_type"] = match_type + logger.info(f"Creating Language scanner with params: {language_params}") + return Language(**language_params) + + def _create_language_same_scanner(self, scanner_config): + enabled_models = {"DEFAULT_MODEL": LANGUAGE_DEFAULT_MODEL} + language_same_params = {"use_onnx": scanner_config.get("use_onnx", False)} + + model_name = scanner_config.get("model", None) + threshold = scanner_config.get("threshold", None) + + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for LanguageSame scanner: {model_name}") + language_same_params["model"] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for LanguageSame scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if threshold is not None: + language_same_params["threshold"] = threshold + + logger.info(f"Creating LanguageSame scanner with params: {language_same_params}") + return LanguageSame(**language_same_params) + + def _create_malicious_urls_scanner(self, scanner_config): + enabled_models = {"DEFAULT_MODEL": MALICIOUS_URLS_DEFAULT_MODEL} + malicious_urls_params = {"use_onnx": scanner_config.get("use_onnx", False)} + + threshold = scanner_config.get("threshold", None) + model_name = scanner_config.get("model", None) + + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for MaliciousURLs scanner: {model_name}") + malicious_urls_params["model"] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for MaliciousURLs scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if threshold is not None: + malicious_urls_params["threshold"] = threshold + + logger.info(f"Creating MaliciousURLs scanner with params: {malicious_urls_params}") + return MaliciousURLs(**malicious_urls_params) + + def _create_no_refusal_scanner(self, scanner_config): + enabled_models = {"DEFAULT_MODEL": NO_REFUSAL_DEFAULT_MODEL} + enabled_match_types = ["sentence", "full"] + no_refusal_params = {"use_onnx": scanner_config.get("use_onnx", False)} + + threshold = scanner_config.get("threshold", None) + model_name = scanner_config.get("model", None) + match_type = scanner_config.get("match_type", None) + + if threshold is not None: + no_refusal_params["threshold"] = threshold + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for NoRefusal scanner: {model_name}") + no_refusal_params["model"] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for NoRefusal scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if match_type is not None and match_type in enabled_match_types: + no_refusal_params["match_type"] = match_type + + logger.info(f"Creating NoRefusal scanner with params: {no_refusal_params}") + return NoRefusal(**no_refusal_params) + + def _create_no_refusal_light_scanner(self): + logger.info("Creating NoRefusalLight scanner.") + return NoRefusalLight() + + def _create_reading_time_scanner(self, scanner_config): + reading_time_params = {} + + max_time = scanner_config.get("max_time", None) + truncate = scanner_config.get("truncate", None) + + if max_time is not None: + reading_time_params["max_time"] = float(max_time) + else: + logger.error("Max time is required for ReadingTime scanner") + raise TypeError("Max time is required for ReadingTime scanner") + if truncate is not None: + reading_time_params["truncate"] = truncate + + logger.info(f"Creating ReadingTime scanner with params: {reading_time_params}") + return ReadingTime(**reading_time_params) + + def _create_factual_consistency_scanner(self, scanner_config): + enabled_models = { + "DEFAULT_MODEL": BANTOPICS_MODEL_DEBERTA_BASE_V2 + } # BanTopics model is used as default in FactualConsistency + factual_consistency_params = {"use_onnx": scanner_config.get("use_onnx", False)} + + model_name = scanner_config.get("model_name", None) + minimum_score = scanner_config.get("minimum_score", None) + + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for NoRefusal scanner: {model_name}") + factual_consistency_params["model"] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for NoRefusal scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if minimum_score is not None: + factual_consistency_params["minimum_score"] = minimum_score + + logger.info(f"Creating FactualConsistency scanner with params: {factual_consistency_params}") + return FactualConsistency(**factual_consistency_params) + + def _create_gibberish_scanner(self, scanner_config): + enabled_models = {"DEFAULT_MODEL": GIBBERISH_DEFAULT_MODEL} + enabled_match_types = ["sentence", "full"] + gibberish_params = {"use_onnx": scanner_config.get("use_onnx", False)} + + model_name = scanner_config.get("model", None) + threshold = scanner_config.get("threshold", None) + match_type = scanner_config.get("match_type", None) + + if match_type == "sentence": + import nltk + + nltk.download("punkt_tab") + + if threshold is not None: + gibberish_params["threshold"] = threshold + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for Gibberish scanner: {model_name}") + gibberish_params["model"] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for Gibberish scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if match_type is not None and match_type in enabled_match_types: + gibberish_params["match_type"] = match_type + + logger.info(f"Creating Gibberish scanner with params: {gibberish_params}") + return Gibberish(**gibberish_params) + + def _create_regex_scanner(self, scanner_config): + enabled_match_types = ["search", "fullmatch"] + regex_params = {} + + patterns = scanner_config.get("patterns", None) + is_blocked = scanner_config.get("is_blocked", None) + match_type = scanner_config.get("match_type", None) + redact = scanner_config.get("redact", None) + + if isinstance(patterns, str): + patterns = sanitize_env(patterns) + + if patterns: + if isinstance(patterns, str): + artifacts = set([",", "", "."]) + regex_params["patterns"] = list(set(patterns.split(",")) - artifacts) + elif isinstance(patterns, list): + regex_params["patterns"] = patterns + else: + logger.error("Provided type is not valid for Regex scanner") + raise ValueError("Provided type is not valid for Regex scanner") + else: + logger.error("Patterns list is required for Regex scanner") + raise TypeError("Patterns list is required for Regex scanner") + if is_blocked is not None: + regex_params["is_blocked"] = is_blocked + if match_type is not None and match_type in enabled_match_types: + regex_params["match_type"] = match_type + if redact is not None: + regex_params["redact"] = redact + + logger.info(f"Creating Regex scanner with params: {regex_params}") + return OPEARegexScanner(**regex_params) + + def _create_relevance_scanner(self, scanner_config): + enabled_models = { + "MODEL_EN_BGE_BASE": RELEVANCE_MODEL_EN_BGE_BASE, + "MODEL_EN_BGE_LARGE": RELEVANCE_MODEL_EN_BGE_LARGE, + "MODEL_EN_BGE_SMALL": RELEVANCE_MODEL_EN_BGE_SMALL, + } + relevance_params = { + "use_onnx": scanner_config.get("use_onnx", False) + } # TODO: onnx off, because of bug on LLM Guard side + + model_name = scanner_config.get("model", None) + threshold = scanner_config.get("threshold", None) + + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for Gibberish scanner: {model_name}") + relevance_params["model"] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for Relevance scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if threshold is not None: + relevance_params["threshold"] = threshold + + logger.info(f"Creating Relevance scanner with params: {relevance_params}") + return Relevance(**relevance_params) + + def _create_sensitive_scanner(self, scanner_config): + sensitive_params = {"use_onnx": scanner_config.get("use_onnx", False)} + + entity_types = scanner_config.get("entity_types", None) + regex_patterns = scanner_config.get("regex_patterns", None) + redact = scanner_config.get("redact", None) + recognizer_conf = scanner_config.get("recognizer_conf", None) + threshold = scanner_config.get("threshold", None) + + if entity_types is not None: + if isinstance(entity_types, str): + entity_types = sanitize_env(entity_types) + + if entity_types: + if isinstance(entity_types, str): + artifacts = set([",", "", "."]) + sensitive_params["entity_types"] = list(set(entity_types.split(",")) - artifacts) + elif isinstance(entity_types, list): + sensitive_params["entity_types"] = entity_types + else: + logger.error("Provided type is not valid for Sensitive scanner") + raise ValueError("Provided type is not valid for Sensitive scanner") + + if regex_patterns is not None: + sensitive_params["regex_patterns"] = regex_patterns + if redact is not None: + sensitive_params["redact"] = redact + if recognizer_conf is not None: + sensitive_params["recognizer_conf"] = recognizer_conf + if threshold is not None: + sensitive_params["threshold"] = threshold + + logger.info(f"Creating Sensitive scanner with params: {sensitive_params}") + return Sensitive(**sensitive_params) + + def _create_sentiment_scanner(self, scanner_config): + enabled_lexicons = ["vader_lexicon"] + sentiment_params = {} + + threshold = scanner_config.get("threshold", None) + lexicon = scanner_config.get("lexicon", None) + + if threshold is not None: + sentiment_params["threshold"] = threshold + if lexicon is not None and lexicon in enabled_lexicons: + sentiment_params["lexicon"] = lexicon + + logger.info(f"Creating Sentiment scanner with params: {sentiment_params}") + return Sentiment(**sentiment_params) + + def _create_toxicity_scanner(self, scanner_config): + enabled_models = {"DEFAULT_MODEL": TOXICITY_DEFAULT_MODEL} + enabled_match_types = ["sentence", "full"] + toxicity_params = {"use_onnx": scanner_config.get("use_onnx", False)} + + model_name = scanner_config.get("model", None) + threshold = scanner_config.get("threshold", None) + match_type = scanner_config.get("match_type", None) + + if match_type == "sentence": + import nltk + + nltk.download("punkt_tab") + + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for Toxicity scanner: {model_name}") + toxicity_params["model"] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for Toxicity scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if threshold is not None: + toxicity_params["threshold"] = threshold + if match_type is not None and match_type in enabled_match_types: + toxicity_params["match_type"] = match_type + + logger.info(f"Creating Toxicity scanner with params: {toxicity_params}") + return Toxicity(**toxicity_params) + + def _create_url_reachability_scanner(self, scanner_config): + url_reachability_params = {} + + success_status_codes = scanner_config.get("success_status_codes", None) + timeout = scanner_config.get("timeout", None) + + if success_status_codes is not None: + if isinstance(success_status_codes, str): + artifacts = set([",", "", "."]) + url_reachability_params["success_status_codes"] = list(set(success_status_codes.split(",")) - artifacts) + elif isinstance(success_status_codes, list): + url_reachability_params["success_status_codes"] = success_status_codes + else: + logger.error("Provided type is not valid for Language scanner") + raise ValueError("Provided type is not valid for Language scanner") + if timeout is not None: + url_reachability_params["timeout"] = timeout + + logger.info(f"Creating URLReachability scanner with params: {url_reachability_params}") + return URLReachability(**url_reachability_params) + + def _create_output_scanner(self, scanner_name, scanner_config, vault=None): + if scanner_name not in ENABLED_SCANNERS: + logger.error(f"Scanner {scanner_name} is not supported") + raise Exception(f"Scanner {scanner_name} is not supported. Enabled scanners are: {ENABLED_SCANNERS}") + if scanner_name == "ban_code": + return self._create_ban_code_scanner(scanner_config) + elif scanner_name == "ban_competitors": + return self._create_ban_competitors_scanner(scanner_config) + elif scanner_name == "ban_substrings": + return self._create_ban_substrings_scanner(scanner_config) + elif scanner_name == "ban_topics": + return self._create_ban_topics_scanner(scanner_config) + elif scanner_name == "bias": + return self._create_bias_scanner(scanner_config) + elif scanner_name == "code": + return self._create_code_scanner(scanner_config) + elif scanner_name == "deanonymize": + return self._create_deanonymize_scanner(scanner_config, vault) + elif scanner_name == "json_scanner": + return self._create_json_scanner(scanner_config) + elif scanner_name == "language": + return self._create_language_scanner(scanner_config) + elif scanner_name == "language_same": + return self._create_language_same_scanner(scanner_config) + elif scanner_name == "malicious_urls": + return self._create_malicious_urls_scanner(scanner_config) + elif scanner_name == "no_refusal": + return self._create_no_refusal_scanner(scanner_config) + elif scanner_name == "no_refusal_light": + return self._create_no_refusal_light_scanner() + elif scanner_name == "reading_time": + return self._create_reading_time_scanner(scanner_config) + elif scanner_name == "factual_consistency": + return self._create_factual_consistency_scanner(scanner_config) + elif scanner_name == "gibberish": + return self._create_gibberish_scanner(scanner_config) + elif scanner_name == "regex": + return self._create_regex_scanner(scanner_config) + elif scanner_name == "relevance": + return self._create_relevance_scanner(scanner_config) + elif scanner_name == "sensitive": + return self._create_sensitive_scanner(scanner_config) + elif scanner_name == "sentiment": + return self._create_sentiment_scanner(scanner_config) + elif scanner_name == "toxicity": + return self._create_toxicity_scanner(scanner_config) + elif scanner_name == "url_reachability": + return self._create_url_reachability_scanner(scanner_config) + return None + + def create_enabled_output_scanners(self): + """Create and return a list of enabled scanners based on the global configuration. + + Returns: + list: A list of enabled scanner instances. + """ + enabled_scanners_names_and_configs = { + k: v for k, v in self._output_scanners_config.items() if isinstance(v, dict) and v.get("enabled") + } + enabled_scanners_objects = [] + + err_msgs = {} # list for all erroneous scanners + only_validation_errors = True + for scanner_name, scanner_config in enabled_scanners_names_and_configs.items(): + try: + logger.info(f"Attempting to create scanner: {scanner_name}") + scanner_object = self._create_output_scanner(scanner_name, scanner_config, vault=self.vault) + enabled_scanners_objects.append(scanner_object) + except ValueError as e: + err_msg = f"A ValueError occurred during creating output scanner {scanner_name}: {e}" + logger.error(err_msg) + err_msgs[scanner_name] = err_msg + self._output_scanners_config[scanner_name]["enabled"] = False + continue + except TypeError as e: + err_msg = f"A TypeError occurred during creating output scanner {scanner_name}: {e}" + logger.error(err_msg) + err_msgs[scanner_name] = err_msg + self._output_scanners_config[scanner_name]["enabled"] = False + continue + except Exception as e: + err_msg = f"An unexpected error occurred during creating output scanner {scanner_name}: {e}" + logger.error(err_msg) + err_msgs[scanner_name] = err_msg + self._output_scanners_config[scanner_name]["enabled"] = False + only_validation_errors = False + continue + + if err_msgs: + if only_validation_errors: + raise ValueError( + f"Some scanners failed to be created due to validation errors. The details: {err_msgs}" + ) + else: + raise Exception( + f"Some scanners failed to be created due to validation or unexpected errors. The details: {err_msgs}" + ) + + return [s for s in enabled_scanners_objects if s is not None] + + def changed(self, new_scanners_config): + """Check if the scanners configuration has changed. + + Args: + new_scanners_config (dict): The current scanners configuration. + + Returns: + bool: True if the configuration has changed, False otherwise. + """ + del new_scanners_config["id"] + newly_enabled_scanners = { + k: {in_k: in_v for in_k, in_v in v.items() if in_k != "id"} + for k, v in new_scanners_config.items() + if isinstance(v, dict) and v.get("enabled") + } + previously_enabled_scanners = { + k: v for k, v in self._output_scanners_config.items() if isinstance(v, dict) and v.get("enabled") + } + if newly_enabled_scanners == previously_enabled_scanners: # if the enabled scanners are the same we do nothing + logger.info("No changes in list for enabled scanners. Checking configuration changes...") + return False + else: + logger.warning("Sanners configuration has been changed, re-creating scanners") + self._output_scanners_config.clear() + stripped_new_scanners_config = { + k: {in_k: in_v for in_k, in_v in v.items() if in_k != "id"} + for k, v in new_scanners_config.items() + if isinstance(v, dict) + } + self._output_scanners_config.update(stripped_new_scanners_config) + return True diff --git a/comps/guardrails/src/guardrails/utils/scanners.py b/comps/guardrails/src/guardrails/utils/scanners.py new file mode 100644 index 0000000000..1465bccfa3 --- /dev/null +++ b/comps/guardrails/src/guardrails/utils/scanners.py @@ -0,0 +1,81 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import re +from collections.abc import Iterable + +from llm_guard.input_scanners import BanSubstrings, Regex +from llm_guard.input_scanners.regex import MatchType +from presidio_anonymizer.core.text_replace_builder import TextReplaceBuilder + +from comps import CustomLogger + +logger = CustomLogger("opea_llm_guard_utils_scanners") + + +# The bug is reported here: https://github.com/protectai/llm-guard/issues/210 +class OPEABanSubstrings(BanSubstrings): + def _redact_text(self, text: str, substrings: list[str]) -> str: + redacted_text = text + flags = 0 + if not self._case_sensitive: + flags = re.IGNORECASE + for s in substrings: + regex_redacted = re.compile(re.escape(s), flags) + redacted_text = regex_redacted.sub("[REDACTED]", redacted_text) + return redacted_text + + def scan(self, prompt: str, output: str = None) -> tuple[str, bool, float]: + if output is not None: + return super().scan(output) + return super().scan(prompt) + + +# LLM Guard's Regex Scanner doesn't replace all occurrences of found patterns. +# The bug is reported here: https://github.com/protectai/llm-guard/issues/229 +class OPEARegexScanner(Regex): + def scan(self, prompt: str, output: str = None) -> tuple[str, bool, float]: + text_to_scan = "" + if output is not None: + text_to_scan = output + else: + text_to_scan = prompt + + text_replace_builder = TextReplaceBuilder(original_text=text_to_scan) + for pattern in self._patterns: + if self._match_type == MatchType.SEARCH: + matches = re.finditer(pattern, text_to_scan) + else: + matches = self._match_type.match(pattern, text_to_scan) + + if matches is None: + continue + elif isinstance(matches, Iterable): + matches = list(matches) + if len(matches) == 0: + continue + else: + matches = [matches] + + if self._is_blocked: + logger.warning(f"Pattern was detected in the text: {pattern}") + + if self._redact: + for match in reversed(matches): + text_replace_builder.replace_text_get_insertion_index( + "[REDACTED]", + match.start(), + match.end(), + ) + + return text_replace_builder.output_text, False, 1.0 + + logger.debug(f"Pattern matched the text: {pattern}") + return text_replace_builder.output_text, True, 0.0 + + if self._is_blocked: + logger.debug("None of the patterns were found in the text") + return text_replace_builder.output_text, True, 0.0 + + logger.warning("None of the patterns matched the text") + return text_replace_builder.output_text, False, 1.0 diff --git a/tests/guardrails/test_guardrails_guardrails_llamaguard_on_intel_hpu.sh b/tests/guardrails/test_guardrails_guardrails_llamaguard_on_intel_hpu.sh index bdf65346ca..8daa5cad93 100644 --- a/tests/guardrails/test_guardrails_guardrails_llamaguard_on_intel_hpu.sh +++ b/tests/guardrails/test_guardrails_guardrails_llamaguard_on_intel_hpu.sh @@ -52,7 +52,7 @@ function validate_microservice() { fi echo "test 2 - safe" result=$(http_proxy="" curl http://localhost:11303/v1/guardrails -X POST -d '{"text":"How do you buy a car in the US?","parameters":{"max_new_tokens":32}}' -H 'Content-Type: application/json') - if [[ $result == *"car"* ]]; then + if [[ $result == *"car"* ]]; then echo "Result correct." else echo "Result wrong." @@ -60,7 +60,57 @@ function validate_microservice() { docker logs llamaguard-guardrails-server exit 1 fi + echo "test 3 - input guard" + result=$(http_proxy="" curl http://localhost:11303/v1/guardrails -X POST -d '{ + "query": "What are virus and backdoor?", + "max_new_tokens": 17, + "top_k": 10, + "top_p": 0.95, + "typical_p": 0.95, + "temperature": 0.01, + "repetition_penalty": 1.03, + "stream": false, + "input_guardrail_params": { + "ban_substrings": { + "enabled": true, + "substrings": ["backdoor", "malware", "virus"], + "case_sensitive": false + } + } + }' \ + -H 'Content-Type: application/json' + ) + if [[ "$result" == *"cannot assist you with your prompt"* ]]; then + echo "✅ Input guard triggered as expected!" + else + echo "❌ Input guard failed to trigger!" + fi + + echo "test 4 - output guard" + result=$(http_proxy="" curl http://localhost:11303/v1/guardrails -X POST -d '{"prompt":"What is Deep Learning?", + "text": "To create a virus and backdoor you need to...", + "output_guardrail_params": + {"ban_substrings": + {"enabled":true, + "substrings":["backdoor","malware","virus"], + "match_type":null, + "case_sensitive":false, + "redact":null, + "contains_all":null} + } + }' \ + -H 'Content-Type: application/json' + ) + + echo "Output guard result:" + echo "$result" + if [[ "$result" != *"virus"* ]]; then + echo "✅ Output guard triggered successfully (virus removed or replaced)!" + else + echo "❌ Output guard failed to trigger (virus still in output)!" + exit 1 + fi } function stop_docker() { diff --git a/tests/guardrails/test_guardrails_guardrails_wildguard_on_intel_hpu.sh b/tests/guardrails/test_guardrails_guardrails_wildguard_on_intel_hpu.sh index 34a674bfaa..401402fe39 100644 --- a/tests/guardrails/test_guardrails_guardrails_wildguard_on_intel_hpu.sh +++ b/tests/guardrails/test_guardrails_guardrails_wildguard_on_intel_hpu.sh @@ -60,6 +60,57 @@ function validate_microservice() { docker logs wildguard-guardrails-server exit 1 fi + echo "test 3 - input guard" + result=$(http_proxy="" curl http://localhost:11304/v1/guardrails -X POST -d '{ + "query": "What are virus and backdoor?", + "max_new_tokens": 17, + "top_k": 10, + "top_p": 0.95, + "typical_p": 0.95, + "temperature": 0.01, + "repetition_penalty": 1.03, + "stream": false, + "input_guardrail_params": { + "ban_substrings": { + "enabled": true, + "substrings": ["backdoor", "malware", "virus"], + "case_sensitive": false + } + } + }' \ + -H 'Content-Type: application/json' + ) + if [[ "$result" == *"cannot assist you with your prompt"* ]]; then + echo "✅ Input guard triggered as expected!" + else + echo "❌ Input guard failed to trigger!" + fi + + echo "test 4 - output guard" + result=$(http_proxy="" curl http://localhost:11304/v1/guardrails -X POST -d '{"prompt":"What is Deep Learning?", + "text": "To create a virus and backdoor you need to...", + "output_guardrail_params": + {"ban_substrings": + {"enabled":true, + "substrings":["backdoor","malware","virus"], + "match_type":null, + "case_sensitive":false, + "redact":null, + "contains_all":null} + } + }' \ + -H 'Content-Type: application/json' + ) + + echo "Output guard result:" + echo "$result" + + if [[ "$result" != *"virus"* ]]; then + echo "✅ Output guard triggered successfully (virus removed or replaced)!" + else + echo "❌ Output guard failed to trigger (virus still in output)!" + exit 1 + fi } function stop_docker() {