From e925b254323277ae0f38d7a3a277e750af46b193 Mon Sep 17 00:00:00 2001 From: edknv Date: Wed, 17 Jan 2024 21:28:36 -0800 Subject: [PATCH] handle 1d outputs from custom pytorch models --- crossfit/backend/cudf/series.py | 10 ++++++++-- crossfit/backend/torch/op/base.py | 19 ++++++++++++------- crossfit/op/label.py | 2 +- crossfit/op/tokenize.py | 10 +++++----- crossfit/op/vector_search.py | 16 +++++++++------- 5 files changed, 35 insertions(+), 22 deletions(-) diff --git a/crossfit/backend/cudf/series.py b/crossfit/backend/cudf/series.py index 9d557c33..9b2d0070 100644 --- a/crossfit/backend/cudf/series.py +++ b/crossfit/backend/cudf/series.py @@ -17,11 +17,17 @@ from cudf.core.column import as_column -def create_list_series_from_2d_ar(ar, index): +def create_list_series_from_1d_or_2d_ar(ar, index): """ Create a cudf list series from 2d arrays """ - n_rows, n_cols = ar.shape + if len(ar.shape) == 1: + n_rows, *_ = ar.shape + n_cols = 1 + elif len(ar.shape) == 2: + n_rows, n_cols = ar.shape + else: + return RuntimeError(f"Unexpected input shape: {ar.shape}") data = as_column(ar.flatten()) offset_col = as_column(cp.arange(start=0, stop=len(data) + 1, step=n_cols), dtype="int32") mask_col = cp.full(shape=n_rows, fill_value=True) diff --git a/crossfit/backend/torch/op/base.py b/crossfit/backend/torch/op/base.py index 6ba60e95..28528f64 100644 --- a/crossfit/backend/torch/op/base.py +++ b/crossfit/backend/torch/op/base.py @@ -20,7 +20,7 @@ import torch from crossfit.backend.cudf.series import ( - create_list_series_from_2d_ar, + create_list_series_from_1d_or_2d_ar, create_nested_list_series_from_3d_ar, ) from crossfit.backend.torch.loader import DEFAULT_BATCH_SIZE, InMemoryLoader, SortedSeqLoader @@ -33,6 +33,7 @@ def __init__( self, model: Model, pre=None, + post=None, cols=False, keep_cols=None, batch_size: int = DEFAULT_BATCH_SIZE, @@ -43,6 +44,7 @@ def __init__( ): super().__init__(pre=pre, cols=cols, keep_cols=keep_cols) self.model = model + self.post = post self.batch_size = batch_size self.max_mem = max_mem self.max_mem_gb = int(self.max_mem.split("GB")[0]) / 2.5 @@ -73,15 +75,18 @@ def call(self, data, partition_info=None): if isinstance(output, dict): if self.model_output_col not in output: raise ValueError(f"Column '{self.model_outupt_col}' not found in model output.") - all_outputs_ls.append(output[self.model_output_col]) - else: - all_outputs_ls.append(output) + output = output[self.model_output_col] + + if self.post is not None: + output = self.post(output) + + all_outputs_ls.append(output) out = cudf.DataFrame(index=index) - outputs = cp.asarray(torch.vstack(all_outputs_ls)) + outputs = cp.asarray(torch.cat(all_outputs_ls, dim=0)) _index = loader.sort_column(index.values) if self.sorted_data_loader else index - if len(outputs.shape) == 2: - out[self.pred_output_col] = create_list_series_from_2d_ar(outputs, _index) + if len(outputs.shape) <= 2: + out[self.pred_output_col] = create_list_series_from_1d_or_2d_ar(outputs, _index) elif len(outputs.shape) == 3: out[self.pred_output_col] = create_nested_list_series_from_3d_ar(outputs, _index) else: diff --git a/crossfit/op/label.py b/crossfit/op/label.py index 60104f1d..3b8af2a3 100644 --- a/crossfit/op/label.py +++ b/crossfit/op/label.py @@ -34,7 +34,7 @@ def call_column(self, data: cudf.Series) -> cudf.Series: ) scores = data.list.leaves.values.reshape(-1, num_labels) - classes = scores.argmax(1) + classes = scores.argmax(-1) labels_map = {i: self.labels[i] for i in range(len(self.labels))} return cudf.Series(classes).map(labels_map) diff --git a/crossfit/op/tokenize.py b/crossfit/op/tokenize.py index 10207012..9d73f00c 100644 --- a/crossfit/op/tokenize.py +++ b/crossfit/op/tokenize.py @@ -22,7 +22,7 @@ from cudf.utils.hash_vocab_utils import hash_vocab from transformers import AutoConfig, AutoTokenizer -from crossfit.backend.cudf.series import create_list_series_from_2d_ar +from crossfit.backend.cudf.series import create_list_series_from_1d_or_2d_ar from crossfit.backend.torch.model import Model from crossfit.dataset.home import CF_HOME from crossfit.op.base import Op @@ -63,10 +63,10 @@ def tokenize_strings(self, sentences, max_length=None): tokenized_data = tokenizer.batch_encode_plus( sentences, max_length=max_length or self.max_length, - return_tensors="pt", - add_special_tokens=True, padding="max_length", + return_tensors="pt", truncation=True, + add_special_tokens=True, return_token_type_ids=False, ) return tokenized_data @@ -106,10 +106,10 @@ def call_column(self, data): tokenized_data = self.tokenize_strings(text).copy() tokenized_data = clip_tokens(tokenized_data, max_length=self.max_length, return_type="cp") - input_ids = create_list_series_from_2d_ar( + input_ids = create_list_series_from_1d_or_2d_ar( tokenized_data["input_ids"].astype("int32"), data.index ) - attention_mask = create_list_series_from_2d_ar( + attention_mask = create_list_series_from_1d_or_2d_ar( tokenized_data["attention_mask"].astype("int32"), data.index ) diff --git a/crossfit/op/vector_search.py b/crossfit/op/vector_search.py index 1785138f..b799d67e 100644 --- a/crossfit/op/vector_search.py +++ b/crossfit/op/vector_search.py @@ -22,7 +22,7 @@ from dask_cudf import from_delayed from pylibraft.neighbors.brute_force import knn -from crossfit.backend.cudf.series import create_list_series_from_2d_ar +from crossfit.backend.cudf.series import create_list_series_from_1d_or_2d_ar from crossfit.backend.dask.cluster import global_dask_client from crossfit.dataset.base import EmbeddingDatataset from crossfit.op.base import Op @@ -82,8 +82,10 @@ def call(self, queries, items): df = cudf.DataFrame(index=queries.index) df["query-id"] = queries["_id"] df["query-index"] = queries["index"] - df["corpus-index"] = create_list_series_from_2d_ar(items["index"].values[indices], df.index) - df["score"] = create_list_series_from_2d_ar(results, df.index) + df["corpus-index"] = create_list_series_from_1d_or_2d_ar( + items["index"].values[indices], df.index + ) + df["score"] = create_list_series_from_1d_or_2d_ar(results, df.index) return df @@ -105,8 +107,8 @@ def reduce(self, grouped): reduced = cudf.DataFrame(index=grouped.index) reduced["query-index"] = grouped["query-index"] reduced["query-id"] = grouped["query-id"] - reduced["score"] = create_list_series_from_2d_ar(topk_scores, reduced.index) - reduced["corpus-index"] = create_list_series_from_2d_ar(topk_indices, reduced.index) + reduced["score"] = create_list_series_from_1d_or_2d_ar(topk_scores, reduced.index) + reduced["corpus-index"] = create_list_series_from_1d_or_2d_ar(topk_indices, reduced.index) reduced = reduced.set_index("query-index", drop=False) @@ -235,8 +237,8 @@ def join_map(part, n_neighbors: int): df = cudf.DataFrame() df.index = part["index"].values - df["corpus-index"] = create_list_series_from_2d_ar(indices, df.index) - df["score"] = create_list_series_from_2d_ar(distances, df.index) + df["corpus-index"] = create_list_series_from_1d_or_2d_ar(indices, df.index) + df["score"] = create_list_series_from_1d_or_2d_ar(distances, df.index) return df