diff --git a/.github/workflows/gpu-ci.yml b/.github/workflows/gpu-ci.yml index 7b02c043..9f22d96c 100644 --- a/.github/workflows/gpu-ci.yml +++ b/.github/workflows/gpu-ci.yml @@ -57,7 +57,7 @@ jobs: if: ${{ github.event_name == 'pull_request' }} shell: bash run: | - echo "HEAD_SHA=$(echo ${{ github.event.pull_request.head.sha }} | cut cut -c1-8)" >> ${GITHUB_ENV} + echo "HEAD_SHA=$(echo ${{ github.event.pull_request.head.sha }} | cut -c1-8)" >> ${GITHUB_ENV} - name: Setup Environment (Push) if: ${{ github.event_name == 'push' }} diff --git a/crossfit/backend/cudf/array.py b/crossfit/backend/cudf/array.py index 57350415..1c98cbda 100644 --- a/crossfit/backend/cudf/array.py +++ b/crossfit/backend/cudf/array.py @@ -29,7 +29,7 @@ def __init__(self): def concatenate(self, series_list, *, axis=None): return cudf.concat(series_list, axis=axis or 0) - np_backend_dispatch.register((cudf.Series, cudf.GenericIndex))(CudfBackend()) + np_backend_dispatch.register((cudf.Series, cudf.Index))(CudfBackend()) @conversion.dispatch_to_dlpack.register_lazy("cudf") diff --git a/crossfit/backend/cudf/series.py b/crossfit/backend/cudf/series.py index 9b2d0070..941c8910 100644 --- a/crossfit/backend/cudf/series.py +++ b/crossfit/backend/cudf/series.py @@ -30,7 +30,7 @@ def create_list_series_from_1d_or_2d_ar(ar, index): return RuntimeError(f"Unexpected input shape: {ar.shape}") data = as_column(ar.flatten()) offset_col = as_column(cp.arange(start=0, stop=len(data) + 1, step=n_cols), dtype="int32") - mask_col = cp.full(shape=n_rows, fill_value=True) + mask_col = cp.full(shape=n_rows, fill_value=cp.bool_(True)) mask = cudf._lib.transform.bools_to_mask(as_column(mask_col)) lc = cudf.core.column.ListColumn( size=n_rows, diff --git a/crossfit/backend/torch/hf/model.py b/crossfit/backend/torch/hf/model.py index d48eaa80..93b39ea0 100644 --- a/crossfit/backend/torch/hf/model.py +++ b/crossfit/backend/torch/hf/model.py @@ -30,7 +30,7 @@ class HFModel(Model): def __init__(self, path_or_name: str, max_mem_gb: int = 16, training=False): super().__init__(path_or_name, max_mem_gb) - + if not training: with torch.no_grad(): self.fit_memory_estimate_curve() @@ -129,7 +129,14 @@ def estimate_memory(self, max_num_tokens: int, batch_size: int) -> int: return predicted_memory[0] / 1024 # Convert from MB to GB def max_seq_length(self) -> int: - return self.load_cfg().max_position_embeddings + max_seq_length = self.load_tokenizer().model_max_length + # Gaurd against the HF bug + # which sets max_seq_length to max(int) for some models + if max_seq_length > 1e5: + max_seq_length = AutoConfig.from_pretrained( + self.model_name_or_path + ).max_position_embeddings + return max_seq_length class SentenceTransformerModel(HFModel): diff --git a/crossfit/backend/torch/loader.py b/crossfit/backend/torch/loader.py index 63c7ef79..444c57da 100644 --- a/crossfit/backend/torch/loader.py +++ b/crossfit/backend/torch/loader.py @@ -1,4 +1,4 @@ -# Copyright 2023 NVIDIA CORPORATION +# Copyright 2023-2024 NVIDIA CORPORATION # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ import warnings from itertools import islice -from typing import Dict, overload +from typing import Dict, overload, Optional import torch @@ -35,7 +35,7 @@ def __init__(self, data: Dict[str, torch.Tensor], batch_size: int, progress_bar= def __init__(self, data: CrossFrame, batch_size: int, progress_bar=None): ... - def __init__(self, data, batch_size: int, progress_bar=None, max_seq_len=None): + def __init__(self, data, batch_size: int, progress_bar=None, max_seq_len=None, padding_side:str = "right"): self.data = CrossFrame(data).cast(torch.Tensor) self.tensor_dict = self.data.to_dict() self._batch_size = batch_size @@ -44,6 +44,7 @@ def __init__(self, data, batch_size: int, progress_bar=None, max_seq_len=None): self._to_map = [] self.progress_bar = progress_bar self.max_seq_len = max_seq_len + self.padding_side = padding_side def map(self, fn): self._to_map.append(fn) @@ -65,8 +66,10 @@ def __next__(self): batch = {key: val[self.current_idx : end] for key, val in self.tensor_dict.items()} if self.max_seq_len is not None: - batch = {key: val[:, : self.max_seq_len] for key, val in batch.items()} - + if self.padding_side == "right": + batch = {key: val[:, : self.max_seq_len] for key, val in batch.items()} + else: + batch = {key: val[:, -self.max_seq_len :] for key, val in batch.items()} self.current_idx += self.batch_size for fn in self._to_map: @@ -96,15 +99,20 @@ def __init__( self.to_ignore = to_ignore or [] self.to_ignore.append("seq_length") self.model = model + self.pad_token_id = self.model.load_tokenizer().pad_token_id + self.padding_side = self.model.load_tokenizer().padding_side frame = CrossFrame(data).cast(torch.Tensor) - seq_length = (frame[sort_key] != 0).sum(axis=1) + seq_length = (frame[sort_key] != self.pad_token_id).sum(axis=1) self.sorted_indices = seq_length.argsort(descending=True) frame = frame.apply(lambda x: x[self.sorted_indices]) frame = frame.assign(seq_length=seq_length[self.sorted_indices]) super().__init__(frame, initial_batch_size, progress_bar=progress_bar) self.splits = self._find_optimal_splits() + # TODO: Debug PRINTS + print(f"Padding side: {self.padding_side}") + print(f"Pad token id: {self.pad_token_id}") def sort_column(self, col): indices = convert_array(self.sorted_indices, type(col)) @@ -138,7 +146,10 @@ def __next__(self): if key not in self.to_ignore } clip_len = min(max(_tokens[start], _tokens[end - 1]), self.model.max_seq_length()) - batch = {key: val[:, :clip_len] for key, val in batch.items()} + if self.padding_side == "right": + batch = {key: val[:, :clip_len] for key, val in batch.items()} + else: + batch = {key: val[:, -clip_len:] for key, val in batch.items()} for fn in self._to_map: batch = fn(batch) diff --git a/crossfit/backend/torch/op/base.py b/crossfit/backend/torch/op/base.py index 28528f64..c83e985a 100644 --- a/crossfit/backend/torch/op/base.py +++ b/crossfit/backend/torch/op/base.py @@ -66,6 +66,7 @@ def call(self, data, partition_info=None): loader = InMemoryLoader( data[["input_ids", "attention_mask"]], batch_size=self.batch_size, + padding_side=self.model.load_tokenizer().padding_side, progress_bar=self.create_progress_bar(len(data), partition_info), max_seq_len=self.model.max_seq_length(), ) diff --git a/crossfit/op/tokenize.py b/crossfit/op/tokenize.py index e4d659ef..b14613d7 100644 --- a/crossfit/op/tokenize.py +++ b/crossfit/op/tokenize.py @@ -1,4 +1,4 @@ -# Copyright 2023 NVIDIA CORPORATION +# Copyright 2023-2024 NVIDIA CORPORATION # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,6 +31,7 @@ class TokenizerType(Enum): SUBWORD = 1 SENTENCE_PIECE = 2 + DEFAULT = 3 class Tokenizer(Op): @@ -55,8 +56,10 @@ def __init__( GPUTokenizer.from_pretrained(self.model) def tokenize_strings(self, sentences, max_length=None): - if self.tokenizer_type == TokenizerType.SENTENCE_PIECE: + if self.tokenizer_type in [TokenizerType.SENTENCE_PIECE, TokenizerType.DEFAULT]: tokenizer = self.model.load_tokenizer() + self.padding_side = tokenizer.padding_side + self.pad_token_id = tokenizer.pad_token_id if isinstance(sentences, cudf.Series): sentences = sentences.to_arrow().to_pylist() @@ -81,6 +84,8 @@ def tokenize_strings(self, sentences, max_length=None): tokenizer = GPUTokenizer.from_pretrained(self.model) worker.tokenizer = tokenizer + self.padding_side = tokenizer.padding_side + self.pad_token_id = tokenizer.pad_token_id return worker.tokenizer( sentences, max_length=max_length or self.max_length, @@ -110,7 +115,11 @@ def call_column(self, data): text = text.str.slice(0, self.max_chars) tokenized_data = self.tokenize_strings(text).copy() - tokenized_data = clip_tokens(tokenized_data, max_length=self.max_length, return_type="cp") + tokenized_data = clip_tokens(tokenized_data, + max_length=self.max_length, + padding_side=self.padding_side, + pad_token_id=self.pad_token_id, + return_type="cp") input_ids = create_list_series_from_1d_or_2d_ar( tokenized_data["input_ids"].astype("int32"), data.index @@ -173,6 +182,8 @@ def _convert_to_tokenizer_type( tokenizer_type = TokenizerType.SENTENCE_PIECE elif tokenizer_type in ["subword", "bert", TokenizerType.SUBWORD]: tokenizer_type = TokenizerType.SUBWORD + elif tokenizer_type in ["default", TokenizerType.DEFAULT]: + tokenizer_type = TokenizerType.DEFAULT return tokenizer_type @@ -180,6 +191,13 @@ class GPUTokenizer(SubwordTokenizer): def __init__(self, hash_file: str, do_lower_case: bool = True, config=None): super().__init__(str(hash_file), do_lower_case=do_lower_case) self.config = config or {"_name_or_path": hash_file} + self.padding_side = self.config.get("padding_side", "right") + self.pad_token_id = self.config.get("pad_token_id", 0) + if self.padding_side!="right": + raise ValueError(f"Only right padding is supported for GPUTokenizer, got {self.padding_side}") + if self.pad_token_id!=0: + raise ValueError(f"Only pad_token_id=0 is supported for GPUTokenizer, got {self.pad_token_id}") + @classmethod def get_tokenizer_config(cls, name): @@ -224,17 +242,26 @@ def from_pretrained(cls, name, cache_dir=None): return cls(hashed_vocab_path, config=config) -def clip_tokens(token_o, max_length, return_type="pt"): +def clip_tokens(token_o, max_length, padding_side, pad_token_id, return_type="pt"): if not isinstance(token_o["input_ids"], cp.ndarray): token_o = {k: cp.asarray(v) for k, v in token_o.items()} - clip_len = max_length - int((token_o["input_ids"][:, ::-1] != 0).argmax(1).min()) - token_o["input_ids"] = _cast_to_appropriate_type( - token_o["input_ids"][:, :clip_len], return_type - ) - token_o["attention_mask"] = _cast_to_appropriate_type( - token_o["attention_mask"][:, :clip_len], return_type - ) + clip_len = max_length - int((token_o["input_ids"][:, ::-1] != pad_token_id).argmax(1).min()) + + if padding_side == "right": + token_o["input_ids"] = _cast_to_appropriate_type( + token_o["input_ids"][:, :clip_len], return_type + ) + token_o["attention_mask"] = _cast_to_appropriate_type( + token_o["attention_mask"][:, :clip_len], return_type + ) + else: + token_o["input_ids"] = _cast_to_appropriate_type( + token_o["input_ids"][:, -clip_len:], return_type + ) + token_o["attention_mask"] = _cast_to_appropriate_type( + token_o["attention_mask"][:, -clip_len:], return_type + ) if "metadata" in token_o: del token_o["metadata"] diff --git a/tests/op/test_tokenize.py b/tests/op/test_tokenize.py index bc79bc2f..6b01494f 100644 --- a/tests/op/test_tokenize.py +++ b/tests/op/test_tokenize.py @@ -44,3 +44,19 @@ def test_tokenizer_max_chars(model_name="sentence-transformers/all-MiniLM-L6-v2" assert results1["input_ids"][0] == results2["input_ids"][0] assert results1["input_ids"][1] == results2["input_ids"][1] + + +def test_tokenizer_padded(model_name="microsoft/deberta-v3-base"): + model = cf.HFModel(model_name) + tokenizer = op.Tokenizer(model, cols=["text"], tokenizer_type="spm") + ddf = dask_cudf.from_cudf( + cudf.DataFrame({"text": ["hello world", "this is a sentence"]}), + npartitions=2, + ) + results = tokenizer(ddf) + results = results.compute() + + hf_tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) + assert isinstance(results, cudf.DataFrame) + assert results["input_ids"][0] == hf_tokenizer(["hello world"])["input_ids"][0] + assert results["input_ids"][1] == hf_tokenizer(["this is a sentence"])["input_ids"][0]