Skip to content

Commit 085b29b

Browse files
authored
mistral_ & bark_api
1 parent 916355b commit 085b29b

File tree

5 files changed

+108
-47
lines changed

5 files changed

+108
-47
lines changed

.gitignore

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
phi3v/
2-
quantized_phi3v/
1+
models/
32
adapters/
43
*.egg-info
54
*.json

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ agent = Agent()
3636

3737
```python
3838
agent('What is shown in this image?', 'https://collectionapi.metmuseum.org/api/collection/v1/iiif/344291/725918/main-image')
39+
agent('What is the location?')
3940
agent.end()
4041
```
4142

@@ -192,4 +193,4 @@ This project is licensed under the [MIT License](LICENSE).
192193

193194
## Citation
194195

195-
<a href="https://zenodo.org/doi/10.5281/zenodo.11403221"><img src="https://zenodo.org/badge/806709541.svg" alt="DOI"></a>
196+
<a href="https://zenodo.org/doi/10.5281/zenodo.11403221"><img src="https://zenodo.org/badge/806709541.svg" alt="DOI"></a>

examples/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# Visual Question Answering (VQA)
1919
agent = Agent()
2020
agent('What is shown in this image?', 'https://collectionapi.metmuseum.org/api/collection/v1/iiif/344291/725918/main-image')
21+
agent('What is the location?')
2122
agent.end()
2223

2324
# Generative Feedback Loop

gte.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
import datasets
1616
import numpy as np
17+
import os
18+
19+
PATH_GTE = 'models/gte'
1720

1821
def average_pool(last_hidden_state: mx.array, attention_mask: mx.array) -> mx.array:
1922
last_hidden = mx.multiply(last_hidden_state, attention_mask[..., None])
@@ -123,15 +126,17 @@ def __call__(
123126
y = self.encoder(x, attention_mask)
124127
return y, mx.tanh(self.pooler(y[:, 0]))
125128

126-
127129
class GteModel:
128130
def __init__(self) -> None:
129-
model_path = snapshot_download(repo_id="vegaluisjose/mlx-rag")
131+
model_path = PATH_GTE
132+
if not os.path.exists(model_path):
133+
snapshot_download(repo_id="vegaluisjose/mlx-rag", local_dir=model_path)
134+
snapshot_download(repo_id="thenlper/gte-large", allow_patterns=["vocab.txt", "*.json"], local_dir=model_path)
130135
with open(f"{model_path}/config.json") as f:
131136
model_config = ModelConfig(**json.load(f))
132137
self.model = Bert(model_config)
133138
self.model.load_weights(f"{model_path}/model.npz")
134-
self.tokenizer = BertTokenizer.from_pretrained("thenlper/gte-large")
139+
self.tokenizer = BertTokenizer.from_pretrained(model_path)
135140

136141
def __call__(self, input_text: List[str]) -> mx.array:
137142
tokens = self.tokenizer(input_text, return_tensors="np", padding=True)
@@ -201,4 +206,4 @@ def __call__(self, text, n_topk=1):
201206
query_embed = self.embed(text)
202207
scores = mx.matmul(query_embed, self.list_embed.T)
203208
list_idx = mx.argsort(scores)[:,:-1-n_topk:-1].tolist()
204-
return [[self.list_api[j] for j in i] for i in list_idx]
209+
return [[self.list_api[j] for j in i] for i in list_idx]

0 commit comments

Comments
 (0)