14
14
15
15
import datasets
16
16
import numpy as np
17
+ import os
18
+
19
+ PATH_GTE = 'models/gte'
17
20
18
21
def average_pool (last_hidden_state : mx .array , attention_mask : mx .array ) -> mx .array :
19
22
last_hidden = mx .multiply (last_hidden_state , attention_mask [..., None ])
@@ -123,15 +126,17 @@ def __call__(
123
126
y = self .encoder (x , attention_mask )
124
127
return y , mx .tanh (self .pooler (y [:, 0 ]))
125
128
126
-
127
129
class GteModel :
128
130
def __init__ (self ) -> None :
129
- model_path = snapshot_download (repo_id = "vegaluisjose/mlx-rag" )
131
+ model_path = PATH_GTE
132
+ if not os .path .exists (model_path ):
133
+ snapshot_download (repo_id = "vegaluisjose/mlx-rag" , local_dir = model_path )
134
+ snapshot_download (repo_id = "thenlper/gte-large" , allow_patterns = ["vocab.txt" , "*.json" ], local_dir = model_path )
130
135
with open (f"{ model_path } /config.json" ) as f :
131
136
model_config = ModelConfig (** json .load (f ))
132
137
self .model = Bert (model_config )
133
138
self .model .load_weights (f"{ model_path } /model.npz" )
134
- self .tokenizer = BertTokenizer .from_pretrained ("thenlper/gte-large" )
139
+ self .tokenizer = BertTokenizer .from_pretrained (model_path )
135
140
136
141
def __call__ (self , input_text : List [str ]) -> mx .array :
137
142
tokens = self .tokenizer (input_text , return_tensors = "np" , padding = True )
@@ -201,4 +206,4 @@ def __call__(self, text, n_topk=1):
201
206
query_embed = self .embed (text )
202
207
scores = mx .matmul (query_embed , self .list_embed .T )
203
208
list_idx = mx .argsort (scores )[:,:- 1 - n_topk :- 1 ].tolist ()
204
- return [[self .list_api [j ] for j in i ] for i in list_idx ]
209
+ return [[self .list_api [j ] for j in i ] for i in list_idx ]
0 commit comments