Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix clipping for models with non default padding id and direction #58

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/gpu-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ jobs:
if: ${{ github.event_name == 'pull_request' }}
shell: bash
run: |
echo "HEAD_SHA=$(echo ${{ github.event.pull_request.head.sha }} | cut cut -c1-8)" >> ${GITHUB_ENV}
echo "HEAD_SHA=$(echo ${{ github.event.pull_request.head.sha }} | cut -c1-8)" >> ${GITHUB_ENV}

- name: Setup Environment (Push)
if: ${{ github.event_name == 'push' }}
Expand Down
2 changes: 1 addition & 1 deletion crossfit/backend/cudf/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self):
def concatenate(self, series_list, *, axis=None):
return cudf.concat(series_list, axis=axis or 0)

np_backend_dispatch.register((cudf.Series, cudf.GenericIndex))(CudfBackend())
np_backend_dispatch.register((cudf.Series, cudf.Index))(CudfBackend())


@conversion.dispatch_to_dlpack.register_lazy("cudf")
Expand Down
2 changes: 1 addition & 1 deletion crossfit/backend/cudf/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def create_list_series_from_1d_or_2d_ar(ar, index):
return RuntimeError(f"Unexpected input shape: {ar.shape}")
data = as_column(ar.flatten())
offset_col = as_column(cp.arange(start=0, stop=len(data) + 1, step=n_cols), dtype="int32")
mask_col = cp.full(shape=n_rows, fill_value=True)
mask_col = cp.full(shape=n_rows, fill_value=cp.bool_(True))
mask = cudf._lib.transform.bools_to_mask(as_column(mask_col))
lc = cudf.core.column.ListColumn(
size=n_rows,
Expand Down
11 changes: 9 additions & 2 deletions crossfit/backend/torch/hf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
class HFModel(Model):
def __init__(self, path_or_name: str, max_mem_gb: int = 16, training=False):
super().__init__(path_or_name, max_mem_gb)

if not training:
with torch.no_grad():
self.fit_memory_estimate_curve()
Expand Down Expand Up @@ -129,7 +129,14 @@ def estimate_memory(self, max_num_tokens: int, batch_size: int) -> int:
return predicted_memory[0] / 1024 # Convert from MB to GB

def max_seq_length(self) -> int:
return self.load_cfg().max_position_embeddings
max_seq_length = self.load_tokenizer().model_max_length
# Gaurd against the HF bug
# which sets max_seq_length to max(int) for some models
if max_seq_length > 1e5:
max_seq_length = AutoConfig.from_pretrained(
self.model_name_or_path
).max_position_embeddings
return max_seq_length


class SentenceTransformerModel(HFModel):
Expand Down
25 changes: 18 additions & 7 deletions crossfit/backend/torch/loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 NVIDIA CORPORATION
# Copyright 2023-2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,7 +14,7 @@

import warnings
from itertools import islice
from typing import Dict, overload
from typing import Dict, overload, Optional

import torch

Expand All @@ -35,7 +35,7 @@ def __init__(self, data: Dict[str, torch.Tensor], batch_size: int, progress_bar=
def __init__(self, data: CrossFrame, batch_size: int, progress_bar=None):
...

def __init__(self, data, batch_size: int, progress_bar=None, max_seq_len=None):
def __init__(self, data, batch_size: int, progress_bar=None, max_seq_len=None, padding_side:str = "right"):
self.data = CrossFrame(data).cast(torch.Tensor)
self.tensor_dict = self.data.to_dict()
self._batch_size = batch_size
Expand All @@ -44,6 +44,7 @@ def __init__(self, data, batch_size: int, progress_bar=None, max_seq_len=None):
self._to_map = []
self.progress_bar = progress_bar
self.max_seq_len = max_seq_len
self.padding_side = padding_side

def map(self, fn):
self._to_map.append(fn)
Expand All @@ -65,8 +66,10 @@ def __next__(self):

batch = {key: val[self.current_idx : end] for key, val in self.tensor_dict.items()}
if self.max_seq_len is not None:
batch = {key: val[:, : self.max_seq_len] for key, val in batch.items()}

if self.padding_side == "right":
batch = {key: val[:, : self.max_seq_len] for key, val in batch.items()}
else:
batch = {key: val[:, -self.max_seq_len :] for key, val in batch.items()}
self.current_idx += self.batch_size

for fn in self._to_map:
Expand Down Expand Up @@ -96,15 +99,20 @@ def __init__(
self.to_ignore = to_ignore or []
self.to_ignore.append("seq_length")
self.model = model
self.pad_token_id = self.model.load_tokenizer().pad_token_id
self.padding_side = self.model.load_tokenizer().padding_side

frame = CrossFrame(data).cast(torch.Tensor)
seq_length = (frame[sort_key] != 0).sum(axis=1)
seq_length = (frame[sort_key] != self.pad_token_id).sum(axis=1)
self.sorted_indices = seq_length.argsort(descending=True)
frame = frame.apply(lambda x: x[self.sorted_indices])
frame = frame.assign(seq_length=seq_length[self.sorted_indices])

super().__init__(frame, initial_batch_size, progress_bar=progress_bar)
self.splits = self._find_optimal_splits()
# TODO: Debug PRINTS
print(f"Padding side: {self.padding_side}")
print(f"Pad token id: {self.pad_token_id}")

def sort_column(self, col):
indices = convert_array(self.sorted_indices, type(col))
Expand Down Expand Up @@ -138,7 +146,10 @@ def __next__(self):
if key not in self.to_ignore
}
clip_len = min(max(_tokens[start], _tokens[end - 1]), self.model.max_seq_length())
batch = {key: val[:, :clip_len] for key, val in batch.items()}
if self.padding_side == "right":
batch = {key: val[:, :clip_len] for key, val in batch.items()}
else:
batch = {key: val[:, -clip_len:] for key, val in batch.items()}

for fn in self._to_map:
batch = fn(batch)
Expand Down
1 change: 1 addition & 0 deletions crossfit/backend/torch/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def call(self, data, partition_info=None):
loader = InMemoryLoader(
data[["input_ids", "attention_mask"]],
batch_size=self.batch_size,
padding_side=self.model.load_tokenizer().padding_side,
progress_bar=self.create_progress_bar(len(data), partition_info),
max_seq_len=self.model.max_seq_length(),
)
Expand Down
49 changes: 38 additions & 11 deletions crossfit/op/tokenize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 NVIDIA CORPORATION
# Copyright 2023-2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,6 +31,7 @@
class TokenizerType(Enum):
SUBWORD = 1
SENTENCE_PIECE = 2
DEFAULT = 3


class Tokenizer(Op):
Expand All @@ -55,8 +56,10 @@ def __init__(
GPUTokenizer.from_pretrained(self.model)

def tokenize_strings(self, sentences, max_length=None):
if self.tokenizer_type == TokenizerType.SENTENCE_PIECE:
if self.tokenizer_type in [TokenizerType.SENTENCE_PIECE, TokenizerType.DEFAULT]:
tokenizer = self.model.load_tokenizer()
self.padding_side = tokenizer.padding_side
self.pad_token_id = tokenizer.pad_token_id

if isinstance(sentences, cudf.Series):
sentences = sentences.to_arrow().to_pylist()
Expand All @@ -81,6 +84,8 @@ def tokenize_strings(self, sentences, max_length=None):
tokenizer = GPUTokenizer.from_pretrained(self.model)
worker.tokenizer = tokenizer

self.padding_side = tokenizer.padding_side
self.pad_token_id = tokenizer.pad_token_id
return worker.tokenizer(
sentences,
max_length=max_length or self.max_length,
Expand Down Expand Up @@ -110,7 +115,11 @@ def call_column(self, data):
text = text.str.slice(0, self.max_chars)

tokenized_data = self.tokenize_strings(text).copy()
tokenized_data = clip_tokens(tokenized_data, max_length=self.max_length, return_type="cp")
tokenized_data = clip_tokens(tokenized_data,
max_length=self.max_length,
padding_side=self.padding_side,
pad_token_id=self.pad_token_id,
return_type="cp")

input_ids = create_list_series_from_1d_or_2d_ar(
tokenized_data["input_ids"].astype("int32"), data.index
Expand Down Expand Up @@ -173,13 +182,22 @@ def _convert_to_tokenizer_type(
tokenizer_type = TokenizerType.SENTENCE_PIECE
elif tokenizer_type in ["subword", "bert", TokenizerType.SUBWORD]:
tokenizer_type = TokenizerType.SUBWORD
elif tokenizer_type in ["default", TokenizerType.DEFAULT]:
tokenizer_type = TokenizerType.DEFAULT
return tokenizer_type


class GPUTokenizer(SubwordTokenizer):
def __init__(self, hash_file: str, do_lower_case: bool = True, config=None):
super().__init__(str(hash_file), do_lower_case=do_lower_case)
self.config = config or {"_name_or_path": hash_file}
self.padding_side = self.config.get("padding_side", "right")
self.pad_token_id = self.config.get("pad_token_id", 0)
if self.padding_side!="right":
raise ValueError(f"Only right padding is supported for GPUTokenizer, got {self.padding_side}")
if self.pad_token_id!=0:
raise ValueError(f"Only pad_token_id=0 is supported for GPUTokenizer, got {self.pad_token_id}")


@classmethod
def get_tokenizer_config(cls, name):
Expand Down Expand Up @@ -224,17 +242,26 @@ def from_pretrained(cls, name, cache_dir=None):
return cls(hashed_vocab_path, config=config)


def clip_tokens(token_o, max_length, return_type="pt"):
def clip_tokens(token_o, max_length, padding_side, pad_token_id, return_type="pt"):
if not isinstance(token_o["input_ids"], cp.ndarray):
token_o = {k: cp.asarray(v) for k, v in token_o.items()}

clip_len = max_length - int((token_o["input_ids"][:, ::-1] != 0).argmax(1).min())
token_o["input_ids"] = _cast_to_appropriate_type(
token_o["input_ids"][:, :clip_len], return_type
)
token_o["attention_mask"] = _cast_to_appropriate_type(
token_o["attention_mask"][:, :clip_len], return_type
)
clip_len = max_length - int((token_o["input_ids"][:, ::-1] != pad_token_id).argmax(1).min())

if padding_side == "right":
token_o["input_ids"] = _cast_to_appropriate_type(
token_o["input_ids"][:, :clip_len], return_type
)
token_o["attention_mask"] = _cast_to_appropriate_type(
token_o["attention_mask"][:, :clip_len], return_type
)
else:
token_o["input_ids"] = _cast_to_appropriate_type(
token_o["input_ids"][:, -clip_len:], return_type
)
token_o["attention_mask"] = _cast_to_appropriate_type(
token_o["attention_mask"][:, -clip_len:], return_type
)

if "metadata" in token_o:
del token_o["metadata"]
Expand Down
16 changes: 16 additions & 0 deletions tests/op/test_tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,19 @@ def test_tokenizer_max_chars(model_name="sentence-transformers/all-MiniLM-L6-v2"

assert results1["input_ids"][0] == results2["input_ids"][0]
assert results1["input_ids"][1] == results2["input_ids"][1]


def test_tokenizer_padded(model_name="microsoft/deberta-v3-base"):
model = cf.HFModel(model_name)
tokenizer = op.Tokenizer(model, cols=["text"], tokenizer_type="spm")
ddf = dask_cudf.from_cudf(
cudf.DataFrame({"text": ["hello world", "this is a sentence"]}),
npartitions=2,
)
results = tokenizer(ddf)
results = results.compute()

hf_tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
assert isinstance(results, cudf.DataFrame)
assert results["input_ids"][0] == hf_tokenizer(["hello world"])["input_ids"][0]
assert results["input_ids"][1] == hf_tokenizer(["this is a sentence"])["input_ids"][0]
Loading