Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved Model Loading #24

Merged
merged 2 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions chat/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,16 @@ def load_model_and_cache(ref):
)
tokenizer.chat_template = chat_template
else:
chat_template = tokenizer.chat_template or (
"{% for message in messages %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<|im_start|>assistant\n' }}"
"{% endif %}"
chat_template = (
tokenizer.chat_template
or (
"{% for message in messages %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<|im_start|>assistant\n' }}"
"{% endif %}"
)
)

supports_system_role = "system role not supported" not in chat_template.lower()
Expand Down
34 changes: 34 additions & 0 deletions examples/bert/sentence_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import mlx.core as mx
import numpy as np

from transformers import AutoConfig, AutoTokenizer
from mlx_transformers.models import BertModel as MLXBertModel


def _mean_pooling(last_hidden_state: mx.array, attention_mask: mx.array):
token_embeddings = last_hidden_state
input_mask_expanded = mx.expand_dims(attention_mask, -1)
input_mask_expanded = mx.broadcast_to(
input_mask_expanded, token_embeddings.shape
).astype(mx.float32)
sum_embeddings = mx.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = mx.clip(input_mask_expanded.sum(axis=1), 1e-9, None)
return sum_embeddings / sum_mask


sentences = ["This is an example sentence", "Each sentence is converted"]

tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
config = AutoConfig.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

model = MLXBertModel(config)
model.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

inputs = tokenizer(sentences, return_tensors="np", padding=True, truncation=True)
inputs = {key: mx.array(v) for key, v in inputs.items()}

outputs = model(**inputs)

sentence_embeddings = _mean_pooling(outputs.last_hidden_state, inputs["attention_mask"])

print(sentence_embeddings)
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def load_model(
model_name,
huggingface_model_architecture="AutoModelForCausalLM",
trust_remote_code=True,
fp16=fp16,
float16=fp16,
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
Expand Down
2 changes: 1 addition & 1 deletion examples/text_generation/phi_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def load_model(
os.path.dirname(os.path.realpath(__file__))

model = mlx_model_class(config)
model.from_pretrained(model_name, fp16=fp16)
model.from_pretrained(model_name, float16=fp16)

tokenizer = AutoTokenizer.from_pretrained(model_name)

Expand Down
107 changes: 70 additions & 37 deletions src/mlx_transformers/models/base.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,98 @@
import importlib
import os
import logging
from typing import Callable, Optional
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from huggingface_hub import snapshot_download
from huggingface_hub import HfFileSystem

import mlx.core as mx
from mlx.utils import tree_unflatten

CONFIG_FILE = "config.json"
WEIGHTS_FILE_NAME = "model.safetensors"

logger = logging.getLogger(__name__)
fs = HfFileSystem()

HF_TOKEN = os.getenv("HF_TOKEN", None)


@dataclass
class ModelLoadingConfig:
"""Configuration for model loading parameters."""

model_name_or_path: str
cache_dir: Optional[str] = None
revision: str = "main"
float16: bool = False
trust_remote_code: bool = False
max_workers: int = 4


class MlxPretrainedMixin:
"""Mixin class for loading pretrained models in MLX format."""

def from_pretrained(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
revision: Optional[str] = "main",
revision: str = "main",
float16: bool = False,
huggingface_model_architecture: Optional[Callable] = None,
trust_remote_code: bool = False,
max_workers: int = 4,
):
if huggingface_model_architecture:
architecture = huggingface_model_architecture
elif hasattr(self.config, "architectures"):
architecture = self.config.architectures[0]
else:
raise ValueError("No architecture found for loading this model")
) -> "MlxPretrainedMixin":
"""
Load a pretrained model from HuggingFace Hub or local path.

transformers_module = importlib.import_module("transformers")
_class = getattr(transformers_module, architecture, None)
Args:
model_name_or_path: HuggingFace model name or path to local model
cache_dir: Directory to store downloaded models
revision: Git revision to use when downloading
float16: Whether to convert model to float16
huggingface_model_architecture: Custom model architecture class
trust_remote_code: Whether to trust remote code when loading
max_workers: Number of worker threads for tensor conversion

if not _class:
raise ValueError(f"Could not find the class for {architecture}")
Returns:
Self with loaded model weights
"""
config = ModelLoadingConfig(
model_name_or_path=model_name_or_path,
cache_dir=cache_dir,
revision=revision,
float16=float16,
trust_remote_code=trust_remote_code,
max_workers=max_workers,
)

dtype = mx.float16 if float16 else mx.float32
logger.info(f"Loading model using the following configuration {config}")

logger.info(f"Loading model from {model_name_or_path}")
model = _class.from_pretrained(
model_name_or_path, trust_remote_code=trust_remote_code
safe_tensor_files = fs.glob(
f"{config.model_name_or_path}/*.safetensors",
**{"revision": config.revision},
)
safe_tensor_files = [f.split("/")[-1] for f in safe_tensor_files]

# # save the tensors
logger.info("Converting model tensors to Mx arrays")
import concurrent.futures
if not safe_tensor_files:
raise ValueError("No safe tensor files found for this model")

def convert_tensor(key, tensor, dtype):
return key, mx.array(tensor.numpy()).astype(dtype)
download_path = snapshot_download(
repo_id=config.model_name_or_path,
allow_patterns="*.safetensors",
max_workers=config.max_workers,
revision=config.revision,
token=HF_TOKEN,
)
dtype = mx.float16 if config.float16 else mx.float32

tensors = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for key, tensor in model.state_dict().items():
future = executor.submit(convert_tensor, key, tensor, dtype)
futures.append(future)
for file in safe_tensor_files:
file_path = Path(download_path) / file
with file_path.open("rb") as f:
state_dict = mx.load(f)

for future in concurrent.futures.as_completed(futures):
key, converted_tensor = future.result()
tensors[key] = converted_tensor
tensors.update(state_dict)

tensors = [(key, tensor) for key, tensor in tensors.items()]
tensors = {k: v.astype(dtype) for k, v in tensors.items()}

self.update(tree_unflatten(tensors))
# Update model weights
self.update(tree_unflatten(list(tensors.items())))
return self
6 changes: 3 additions & 3 deletions src/mlx_transformers/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ def __call__(
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[
1:
] # add attentions if we output them
outputs = (
attention_output,
) + self_outputs[1:] # add attentions if we output them
return outputs


Expand Down
4 changes: 1 addition & 3 deletions src/mlx_transformers/models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ class DynamicCache(Cache):
def __init__(self) -> None:
self.key_cache: List[mx.array] = []
self.value_cache: List[mx.array] = []
self._seen_tokens = (
0 # Used in `generate` to keep tally of how many tokens the cache has seen
)
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen

def __getitem__(self, layer_idx: int) -> List[Tuple[mx.array]]:
"""
Expand Down
6 changes: 3 additions & 3 deletions src/mlx_transformers/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ def __call__(
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[
1:
] # add attentions if we output them
outputs = (
attention_output,
) + self_outputs[1:] # add attentions if we output them
return outputs


Expand Down
8 changes: 5 additions & 3 deletions src/mlx_transformers/models/xlm_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,9 @@ def __call__(
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[
1:
] # add attentions if we output them
outputs = (
attention_output,
) + self_outputs[1:] # add attentions if we output them
return outputs


Expand Down Expand Up @@ -425,6 +425,7 @@ def __call__(
position_ids=position_ids,
token_type_ids=token_type_ids,
)

encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
Expand Down Expand Up @@ -507,6 +508,7 @@ def __call__(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = outputs.last_hidden_state
logits = self.classifier(sequence_output)

Expand Down
38 changes: 28 additions & 10 deletions tests/test_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
AutoConfig,
AutoTokenizer,
BertConfig,
BertForMaskedLM,
BertModel,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
BertTokenizer,
)

from src.mlx_transformers.models import BertForMaskedLM as MlxBertForMaskedLM
from src.mlx_transformers.models import BertModel as MlxBertModel
from src.mlx_transformers.models import (
BertForQuestionAnswering as MlxBertForQuestionAnswering,
)
Expand All @@ -33,14 +33,13 @@ def load_hgf_model(model_name: str, hgf_model_class):
class TestMlxBert(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.model_name = "bert-base-uncased"
cls.model_name = "sentence-transformers/all-MiniLM-L6-v2"
cls.config = BertConfig.from_pretrained(cls.model_name)
cls.tokenizer = BertTokenizer.from_pretrained(cls.model_name)
cls.hgf_model_class = BertForMaskedLM
cls.hgf_model_class = BertModel

# cls.model_class = MlxBertForMaskedLM
cls.model = MlxBertForMaskedLM(cls.config)
cls.model.from_pretrained(cls.model_name)
cls.model = MlxBertModel(cls.config)
cls.model.from_pretrained(cls.model_name, revision="main")

cls.input_text = "Hello, my dog is cute"

Expand All @@ -51,14 +50,14 @@ def test_model_output_hgf(self):

inputs_mlx = {key: mx.array(v) for key, v in inputs_mlx.items()}
outputs_mlx = self.model(**inputs_mlx)
outputs_mlx = np.array(outputs_mlx.logits)
outputs_mlx = np.array(outputs_mlx.last_hidden_state)

inputs_hgf = self.tokenizer(
self.input_text, return_tensors="pt", padding=True, truncation=True
)
hgf_model = load_hgf_model(self.model_name, self.hgf_model_class)
outputs_hgf = hgf_model(**inputs_hgf)
outputs_hgf = outputs_hgf.logits.detach().numpy()
outputs_hgf = outputs_hgf.last_hidden_state.detach().numpy()

self.assertTrue(np.allclose(outputs_mlx, outputs_hgf, atol=1e-4))

Expand All @@ -72,7 +71,7 @@ def setUpClass(cls) -> None:

cls.hgf_model_class = BertForSequenceClassification
cls.model = MlxBertForSequenceClassification(cls.config)
cls.model.from_pretrained(cls.model_name)
cls.model.from_pretrained(cls.model_name, revision="refs/pr/1")

cls.input_text = "Hello, my dog is cute"

Expand All @@ -91,6 +90,7 @@ def test_model_output_hgf(self):
)

inputs_mlx = {key: mx.array(v) for key, v in inputs_mlx.items()}

outputs_mlx = self.model(**inputs_mlx)
outputs_mlx = np.array(outputs_mlx.logits)
predicted_class_id = outputs_mlx.argmax().item()
Expand All @@ -100,6 +100,7 @@ def test_model_output_hgf(self):
self.input_text, return_tensors="pt", padding=True, truncation=True
)
hgf_model = load_hgf_model(self.model_name, self.hgf_model_class)

outputs_hgf = hgf_model(**inputs_hgf)
outputs_hgf = outputs_hgf.logits

Expand Down Expand Up @@ -171,6 +172,9 @@ def test_model_output_hgf(self):
]

self.assertEqual(mlx_predicted_tokens_classes, hgf_predicted_tokens_classes)
self.assertTrue(
np.allclose(np.array(outputs_mlx), outputs_hgf.detach().numpy(), atol=1e-4)
)


class TestMlxBertForQuestionAnswering(unittest.TestCase):
Expand Down Expand Up @@ -232,6 +236,20 @@ def test_model_output_hgf(self):
)

self.assertEqual(mlx_answer, hgf_answer)
self.assertTrue(
np.allclose(
np.array(outputs_mlx.start_logits),
outputs_hgf.start_logits.detach().numpy(),
atol=1e-4,
)
)
self.assertTrue(
np.allclose(
np.array(outputs_mlx.end_logits),
outputs_hgf.end_logits.detach().numpy(),
atol=1e-4,
)
)


if __name__ == "__main__":
Expand Down
Loading
Loading