Skip to content

Commit c81e06e

Browse files
committed
Add the device_dtype init parameter to the transformers model
1 parent b393c88 commit c81e06e

File tree

5 files changed

+56
-9
lines changed

5 files changed

+56
-9
lines changed

docs/features/models/transformers.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@ title: Transformers
1212

1313
## Model Initialization
1414

15-
To load the model, you can use the `from_transformers` function. It takes 2 arguments:
15+
To load the model, you can use the `from_transformers` function. It takes 3 arguments:
1616

1717
- `model`: a `transformers` model (created with `AutoModelForCausalLM` for instance)
1818
- `tokenizer_or_processor`: a `transformers` tokenizer (created with `AutoTokenizer` for instance, it must be an instance of either `PreTrainedTokenizer` or `PreTrainedTokenizerFast`)
19+
- `device_dtype` (optional): the tensor dtype to use for inference. If not provided, the model will use the default dtype.
1920

2021
For instance:
2122

docs/features/models/transformers_multimodal.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ To load the model, you can use the `from_transformers` function. It takes 2 argu
1212

1313
- `model`: a `transformers` model (created with `AutoModelForImageTextToText` for instance)
1414
- `tokenizer_or_processor`: a `transformers` processor (created with `AutoProcessor` for instance, it must be an instance of `ProcessorMixin`)
15+
- `device_dtype` (optional): the tensor dtype to use for inference. If not provided, the model will use the default dtype.
1516

1617
For instance:
1718

outlines/models/transformers.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ def __init__(
209209
self,
210210
model: "PreTrainedModel",
211211
tokenizer: "PreTrainedTokenizer",
212+
*,
213+
device_dtype: Optional["torch.dtype"] = None,
212214
):
213215
"""
214216
Parameters:
@@ -219,6 +221,9 @@ def __init__(
219221
tokenizer
220222
A `PreTrainedTokenizer`, or any tokenizer that is compatible with
221223
the `transformers` API for tokenizers.
224+
device_dtype
225+
The dtype to use for the model. If not provided, the model will use
226+
the default dtype.
222227
223228
"""
224229
# We need to handle the cases in which jax/flax or tensorflow
@@ -237,6 +242,7 @@ def __init__(
237242
self.model = model
238243
self.hf_tokenizer = tokenizer
239244
self.tokenizer = TransformerTokenizer(tokenizer)
245+
self.device_dtype = device_dtype
240246
self.type_adapter = TransformersTypeAdapter(tokenizer=tokenizer)
241247

242248
if (
@@ -287,7 +293,11 @@ def _prepare_model_inputs(
287293
input_ids, attention_mask = self.tokenizer.encode(prompts)
288294
inputs = {
289295
"input_ids": input_ids.to(self.model.device),
290-
"attention_mask": attention_mask.to(self.model.device),
296+
"attention_mask": (
297+
attention_mask.to(self.model.device, dtype=self.device_dtype)
298+
if self.device_dtype is not None
299+
else attention_mask.to(self.model.device)
300+
),
291301
}
292302

293303
return prompts, inputs
@@ -600,7 +610,13 @@ class TransformersMultiModal(Transformers):
600610
601611
"""
602612

603-
def __init__(self, model: "PreTrainedModel", processor):
613+
def __init__(
614+
self,
615+
model: "PreTrainedModel",
616+
processor,
617+
*,
618+
device_dtype: Optional["torch.dtype"] = None,
619+
):
604620
"""Create a TransformersMultiModal model instance
605621
606622
We rely on the `__init__` method of the `Transformers` class to handle
@@ -614,6 +630,9 @@ def __init__(self, model: "PreTrainedModel", processor):
614630
`transformers` API for models.
615631
processor
616632
A `ProcessorMixin` instance.
633+
device_dtype
634+
The dtype to use for the model. If not provided, the model will use
635+
the default dtype.
617636
618637
"""
619638
self.processor = processor
@@ -622,7 +641,7 @@ def __init__(self, model: "PreTrainedModel", processor):
622641

623642
tokenizer: "PreTrainedTokenizer" = self.processor.tokenizer
624643

625-
super().__init__(model, tokenizer)
644+
super().__init__(model, tokenizer, device_dtype=device_dtype)
626645

627646
self.type_adapter = TransformersMultiModalTypeAdapter(
628647
tokenizer=tokenizer
@@ -655,14 +674,20 @@ def _prepare_model_inputs(
655674

656675
inputs = self.processor(
657676
**merged_prompts, padding=True, return_tensors="pt"
658-
).to(self.model.device)
677+
)
678+
if self.device_dtype is not None:
679+
inputs = inputs.to(self.model.device, dtype=self.device_dtype)
680+
else:
681+
inputs = inputs.to(self.model.device)
659682

660683
return merged_prompts["text"], inputs
661684

662685

663686
def from_transformers(
664687
model: "PreTrainedModel",
665688
tokenizer_or_processor: Union["PreTrainedTokenizer", "ProcessorMixin"],
689+
*,
690+
device_dtype: Optional["torch.dtype"] = None,
666691
) -> Union[Transformers, TransformersMultiModal]:
667692
"""Create an Outlines `Transformers` or `TransformersMultiModal` model
668693
instance from a `PreTrainedModel` instance and a `PreTrainedTokenizer` or
@@ -679,6 +704,9 @@ def from_transformers(
679704
tokenizer_or_processor
680705
A `transformers.PreTrainedTokenizer` or
681706
`transformers.ProcessorMixin` instance.
707+
device_dtype
708+
The dtype to use for the model. If not provided, the model will use
709+
the default dtype.
682710
683711
Returns
684712
-------
@@ -693,10 +721,10 @@ def from_transformers(
693721
tokenizer_or_processor, (PreTrainedTokenizer, PreTrainedTokenizerFast)
694722
):
695723
tokenizer = tokenizer_or_processor
696-
return Transformers(model, tokenizer)
724+
return Transformers(model, tokenizer, device_dtype=device_dtype)
697725
elif isinstance(tokenizer_or_processor, ProcessorMixin):
698726
processor = tokenizer_or_processor
699-
return TransformersMultiModal(model, processor)
727+
return TransformersMultiModal(model, processor, device_dtype=device_dtype)
700728
else:
701729
raise ValueError(
702730
"We could determine whether the model passed to `from_transformers`"

tests/models/test_transformers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from pydantic import BaseModel
55
import pytest
6+
import torch
67
import transformers
78

89
import outlines
@@ -47,15 +48,17 @@ def test_transformers_instantiate_mamba():
4748
assert isinstance(model, Transformers)
4849

4950

50-
def test_transformers_instantiate_tokenizer_kwargs():
51+
def test_transformers_instantiate_tokenizer_kwargs_dtype():
5152
model = outlines.from_transformers(
5253
transformers.AutoModelForCausalLM.from_pretrained(TEST_MODEL),
5354
transformers.AutoTokenizer.from_pretrained(
5455
TEST_MODEL, additional_special_tokens=["<t1>", "<t2>"]
5556
),
57+
device_dtype=torch.bfloat16,
5658
)
5759
assert "<t1>" in model.tokenizer.special_tokens
5860
assert "<t2>" in model.tokenizer.special_tokens
61+
assert model.device_dtype == torch.bfloat16
5962

6063

6164
@pytest.fixture
@@ -88,6 +91,10 @@ def test_transformers_call(model, model_bart):
8891
result = model("Respond with one word. Not more.")
8992
assert isinstance(result, str)
9093

94+
model.device_dtype = torch.bfloat16
95+
result = model("Respond with one word. Not more.")
96+
assert isinstance(result, str)
97+
9198
result = model_bart("Respond with one word. Not more.")
9299
assert isinstance(result, str)
93100

tests/models/test_transformers_multimodal.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import io
44
import re
5+
import torch
56
from enum import Enum
67

78
import pytest
@@ -47,15 +48,17 @@ def model():
4748
return model
4849

4950

50-
def test_transformers_multimodal_instantiate_simple():
51+
def test_transformers_multimodal_instantiate():
5152
model = outlines.from_transformers(
5253
LlavaForConditionalGeneration.from_pretrained(TEST_MODEL),
5354
AutoProcessor.from_pretrained(TEST_MODEL),
55+
device_dtype=torch.bfloat16,
5456
)
5557
assert isinstance(model, TransformersMultiModal)
5658
assert isinstance(model.tokenizer, TransformerTokenizer)
5759
assert isinstance(model.type_adapter, TransformersMultiModalTypeAdapter)
5860
assert model.tensor_library_name == "torch"
61+
assert model.device_dtype == torch.bfloat16
5962

6063

6164
def test_transformers_multimodal_simple(model, image):
@@ -74,6 +77,13 @@ def test_transformers_multimodal_call(model, image):
7477
)
7578
assert isinstance(result, str)
7679

80+
model.device_dtype = torch.bfloat16
81+
result = model(
82+
["<image>Describe this image in one sentence:", Image(image)],
83+
max_new_tokens=2,
84+
)
85+
assert isinstance(result, str)
86+
7787

7888
def test_transformers_multimodal_wrong_number_image(model, image):
7989
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)