diff --git a/validator/gliner_recognizer.py b/validator/gliner_recognizer.py index a1ce76a..a495b79 100644 --- a/validator/gliner_recognizer.py +++ b/validator/gliner_recognizer.py @@ -1,12 +1,15 @@ +import torch from presidio_analyzer import EntityRecognizer, RecognizerResult from gliner import GLiNER from .constants import PRESIDIO_TO_GLINER, GLINER_TO_PRESIDIO class GLiNERRecognizer(EntityRecognizer): - def __init__(self, supported_entities, model_name): + def __init__(self, supported_entities, model_name, use_gpu=True): self.model_name = model_name self.supported_entities = supported_entities + self.use_gpu = use_gpu + self.device = self._get_device() gliner_entities = set() @@ -18,12 +21,28 @@ def __init__(self, supported_entities, model_name): super().__init__(supported_entities=supported_entities) + def _get_device(self): + """Determine the device to use for inference""" + if self.use_gpu and torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + def load(self) -> None: - """No loading required as the model is loaded in the constructor""" + """Load the model and move it to the appropriate device""" self.model = GLiNER.from_pretrained(self.model_name) + self.model = self.model.to(self.device) + if self.use_gpu and torch.cuda.is_available(): + self.model.eval() def analyze(self, text, entities=None, nlp_artifacts=None): - results = self.model.predict_entities(text, self.gliner_entities) + """Analyze text using GPU-accelerated GLiNER model""" + # Ensure model is on correct device + if hasattr(self.model, 'device') and self.model.device != self.device: + self.model = self.model.to(self.device) + + # Run inference with gradient disabled for efficiency + with torch.no_grad(): + results = self.model.predict_entities(text, self.gliner_entities) return [ RecognizerResult( entity_type=GLINER_TO_PRESIDIO[entity["label"]], diff --git a/validator/main.py b/validator/main.py index f914106..84cd696 100644 --- a/validator/main.py +++ b/validator/main.py @@ -79,6 +79,7 @@ def __init__( get_entity_threshold: Callable = get_entity_threshold, on_fail: Optional[Callable] = None, use_local: bool = True, + use_gpu: bool = True, **kwargs, ): """Validates that the LLM-generated text does not contain Personally Identifiable Information (PII). @@ -109,6 +110,7 @@ def __init__( entities=entities, get_entity_threshold=get_entity_threshold, use_local=use_local, + use_gpu=use_gpu, **kwargs, ) @@ -119,11 +121,13 @@ def __init__( self.entities = entities self.model_name = model_name self.get_entity_threshold = get_entity_threshold + self.use_gpu = use_gpu if self.use_local: self.gliner_recognizer = GLiNERRecognizer( supported_entities=self.entities, model_name=model_name, + use_gpu=use_gpu, ) registry = RecognizerRegistry() registry.load_predefined_recognizers()