Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow HF and sentence-transformer models #63

Merged
merged 3 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion crossfit/backend/torch/hf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from crossfit.backend.torch.model import Model
from crossfit.dataset.home import CF_HOME
from crossfit.utils.model_adapter import adapt_model_input


class HFModel(Model):
Expand Down Expand Up @@ -96,7 +97,7 @@ def fit_memory_estimate_curve(self, model=None):
}

try:
_ = model(**batch)
_ = adapt_model_input(model, batch)
memory_used = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
X.append([batch_size, seq_len, seq_len**2])
y.append(memory_used)
Expand Down
5 changes: 3 additions & 2 deletions crossfit/backend/torch/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from crossfit.data.array.conversion import convert_array
from crossfit.data.array.dispatch import crossarray
from crossfit.data.dataframe.dispatch import CrossFrame
from crossfit.utils.model_adapter import adapt_model_input

DEFAULT_BATCH_SIZE = 512

Expand Down Expand Up @@ -70,7 +71,7 @@ def __next__(self):
self.current_idx += self.batch_size

for fn in self._to_map:
batch = fn(batch)
batch = adapt_model_input(fn, batch)

if self.progress_bar is not None:
self.progress_bar.update(batch_size)
Expand Down Expand Up @@ -141,7 +142,7 @@ def __next__(self):
batch = {key: val[:, :clip_len] for key, val in batch.items()}

for fn in self._to_map:
batch = fn(batch)
batch = adapt_model_input(fn, batch)

break
except torch.cuda.OutOfMemoryError:
Expand Down
2 changes: 1 addition & 1 deletion crossfit/backend/torch/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def call(self, data, partition_info=None):
for output in loader.map(self.model.get_model(self.get_worker())):
if isinstance(output, dict):
if self.model_output_col not in output:
raise ValueError(f"Column '{self.model_outupt_col}' not found in model output.")
raise ValueError(f"Column '{self.model_output_col}' not found in model output.")
output = output[self.model_output_col]

if self.post is not None:
Expand Down
23 changes: 23 additions & 0 deletions crossfit/utils/model_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Any, Callable


def adapt_model_input(model: Callable, encoded_input: dict) -> Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function looks good 👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @sarahyurick for the reviews

"""
Adapt the encoded input to the model, handling both single and multiple argument cases.

This function allows flexible calling of different model types:
- Models expecting keyword arguments (e.g., Hugging Face models)
- Models expecting a single dictionary input (e.g., Sentence Transformers)

:param model: The model function to apply
:param encoded_input: The encoded input to pass to the model
:return: The output of the model
"""
try:
# First, try to call the model with keyword arguments
# For standard Hugging Face models
return model(**encoded_input)
except TypeError:
# If that fails, try calling it with a single argument
# This is useful for models like Sentence Transformers
return model(encoded_input)
52 changes: 52 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from crossfit.utils.model_adapter import adapt_model_input

torch = pytest.importorskip("torch")
sentence_transformers = pytest.importorskip("sentence_transformers")
transformers = pytest.importorskip("transformers")


def test_adapt_model_input_hf():
from transformers import AutoTokenizer, DistilBertModel

with torch.no_grad():
model_hf = DistilBertModel.from_pretrained("distilbert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")

# Hugging Face model output
outputs_hf = model_hf(**inputs)
adapted_inputs_hf = adapt_model_input(model_hf, inputs)
assert torch.equal(adapted_inputs_hf.last_hidden_state, outputs_hf.last_hidden_state)


def test_adapt_model_input_sentence_transformers():
from transformers import AutoTokenizer

with torch.no_grad():
model_st = sentence_transformers.SentenceTransformer("all-MiniLM-L6-v2").to("cpu")
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

inputs = tokenizer(
["Hello", "my dog is cute"], return_tensors="pt", padding=True, truncation=True
)
# Sentence Transformers model output
expected_output = model_st(inputs)
adapted_output_st = adapt_model_input(model_st, inputs)

assert torch.equal(adapted_output_st.sentence_embedding, expected_output.sentence_embedding)
Loading