1+ import torch
12from presidio_analyzer import EntityRecognizer , RecognizerResult
23from gliner import GLiNER
34from .constants import PRESIDIO_TO_GLINER , GLINER_TO_PRESIDIO
45
56
67class GLiNERRecognizer (EntityRecognizer ):
7- def __init__ (self , supported_entities , model_name ):
8+ def __init__ (self , supported_entities , model_name , use_gpu = True ):
89 self .model_name = model_name
910 self .supported_entities = supported_entities
11+ self .use_gpu = use_gpu
12+ self .device = self ._get_device ()
1013
1114 gliner_entities = set ()
1215
@@ -18,12 +21,28 @@ def __init__(self, supported_entities, model_name):
1821
1922 super ().__init__ (supported_entities = supported_entities )
2023
24+ def _get_device (self ):
25+ """Determine the device to use for inference"""
26+ if self .use_gpu and torch .cuda .is_available ():
27+ return torch .device ("cuda" )
28+ return torch .device ("cpu" )
29+
2130 def load (self ) -> None :
22- """No loading required as the model is loaded in the constructor """
31+ """Load the model and move it to the appropriate device """
2332 self .model = GLiNER .from_pretrained (self .model_name )
33+ self .model = self .model .to (self .device )
34+ if self .use_gpu and torch .cuda .is_available ():
35+ self .model .eval ()
2436
2537 def analyze (self , text , entities = None , nlp_artifacts = None ):
26- results = self .model .predict_entities (text , self .gliner_entities )
38+ """Analyze text using GPU-accelerated GLiNER model"""
39+ # Ensure model is on correct device
40+ if hasattr (self .model , 'device' ) and self .model .device != self .device :
41+ self .model = self .model .to (self .device )
42+
43+ # Run inference with gradient disabled for efficiency
44+ with torch .no_grad ():
45+ results = self .model .predict_entities (text , self .gliner_entities )
2746 return [
2847 RecognizerResult (
2948 entity_type = GLINER_TO_PRESIDIO [entity ["label" ]],
0 commit comments