Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ops for running a custom pytorch classifier #38

Merged
merged 10 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions crossfit/backend/cudf/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,38 @@ def create_list_series_from_2d_ar(ar, index):
)

return cudf.Series(lc, index=index)


def create_nested_list_series_from_3d_ar(ar, index):
"""
Create a cudf list of lists series from 3d arrays
"""
n_slices, n_rows, n_cols = ar.shape
flattened_data = ar.reshape(-1) # Flatten the 3-D array into 1-D

# Inner list offsets (for each row in 2D slices)
inner_offsets = cp.arange(
start=0, stop=n_cols * n_rows * n_slices + 1, step=n_cols, dtype="int32"
)
inner_list_data = as_column(flattened_data)
inner_list_offsets = as_column(inner_offsets)

# Outer list offsets (for each 2D slice in the 3D array)
outer_offsets = cp.arange(start=0, stop=n_slices + 1, step=1, dtype="int32") * n_rows
outer_list_offsets = as_column(outer_offsets)

# Constructing the nested ListColumn
lc = cudf.core.column.ListColumn(
size=n_slices,
dtype=cudf.ListDtype(inner_list_data.dtype),
children=(
outer_list_offsets,
cudf.core.column.ListColumn(
size=inner_offsets.size - 1,
dtype=cudf.ListDtype(inner_list_data.dtype),
children=(inner_list_offsets, inner_list_data),
),
),
)

return cudf.Series(lc, index=index)
2 changes: 1 addition & 1 deletion crossfit/backend/torch/hf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def fit_memory_estimate_curve(self, model=None):
y.append(memory_used)

except RuntimeError as e:
if "out of memory" in str(e):
if "out of memory" in str(e) or "out_of_memory" in str(e):
pass
else:
raise e
Expand Down
99 changes: 99 additions & 0 deletions crossfit/backend/torch/op/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2023 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
from typing import Optional

import cudf
import cupy as cp
import torch

from crossfit.backend.cudf.series import (
create_list_series_from_2d_ar,
create_nested_list_series_from_3d_ar,
)
from crossfit.backend.torch.loader import DEFAULT_BATCH_SIZE, InMemoryLoader, SortedSeqLoader
from crossfit.backend.torch.model import Model
from crossfit.op.base import Op


class Predictor(Op):
def __init__(
self,
model: Model,
pre=None,
cols=False,
keep_cols=None,
batch_size: int = DEFAULT_BATCH_SIZE,
max_mem: str = "16GB",
sorted_data_loader: bool = True,
model_output_col: Optional[str] = None,
pred_output_col: str = "preds",
):
super().__init__(pre=pre, cols=cols, keep_cols=keep_cols)
self.model = model
self.batch_size = batch_size
self.max_mem = max_mem
self.max_mem_gb = int(self.max_mem.split("GB")[0]) / 2.5
self.sorted_data_loader = sorted_data_loader
self.model_output_col = model_output_col
self.pred_output_col = pred_output_col

def setup(self):
self.model.load_on_worker(self)

@torch.no_grad()
def call(self, data, partition_info=None):
index = data.index
if self.sorted_data_loader:
loader = SortedSeqLoader(
data[["input_ids", "attention_mask"]],
self.model,
progress_bar=self.create_progress_bar(len(data), partition_info),
initial_batch_size=self.batch_size,
)
else:
loader = InMemoryLoader(
data[["input_ids", "attention_mask"]],
batch_size=self.batch_size,
progress_bar=self.create_progress_bar(len(data), partition_info),
max_seq_len=self.model.max_seq_length(),
)

all_outputs_ls = []
for output in loader.map(self.model.get_model(self)):
if isinstance(output, dict):
if self.model_output_col not in output:
raise ValueError(f"Column '{self.model_outupt_col}' not found in model output.")
all_outputs_ls.append(output[self.model_output_col])
else:
all_outputs_ls.append(output)

out = cudf.DataFrame(index=index)
outputs = cp.asarray(torch.vstack(all_outputs_ls))
_index = loader.sort_column(index.values) if self.sorted_data_loader else index
if len(outputs.shape) == 2:
out[self.pred_output_col] = create_list_series_from_2d_ar(outputs, _index)
elif len(outputs.shape) == 3:
out[self.pred_output_col] = create_nested_list_series_from_3d_ar(outputs, _index)
else:
raise RuntimeError(f"Unexpected output shape: {output.shape}")

gc.collect()
torch.cuda.empty_cache()

return out

def meta(self):
return {self.pred_output_col: "float32"}
73 changes: 16 additions & 57 deletions crossfit/backend/torch/op/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import gc

import cudf
import cupy as cp
import torch

from crossfit.backend.cudf.series import create_list_series_from_2d_ar
from crossfit.backend.torch.loader import DEFAULT_BATCH_SIZE, InMemoryLoader, SortedSeqLoader
from crossfit.backend.torch.loader import DEFAULT_BATCH_SIZE
from crossfit.backend.torch.model import Model
from crossfit.op.base import Op
from crossfit.backend.torch.op.base import Predictor


class Embedder(Op):
class Embedder(Predictor):
def __init__(
self,
model: Model,
Expand All @@ -34,51 +27,17 @@ def __init__(
batch_size: int = DEFAULT_BATCH_SIZE,
max_mem: str = "16GB",
sorted_data_loader: bool = True,
model_output_col: str = "sentence_embedding",
pred_output_col: str = "embedding",
):
super().__init__(pre=pre, cols=cols, keep_cols=keep_cols)
self.model = model
self.batch_size = batch_size
self.max_mem = max_mem
self.max_mem_gb = int(self.max_mem.split("GB")[0]) / 2.5
self.sorted_data_loader = sorted_data_loader

def setup(self):
self.model.load_on_worker(self)

def teardown(self):
self.model.unload_from_worker(self)

@torch.no_grad()
def call(self, data, partition_info=None):
index = data.index
if self.sorted_data_loader:
loader = SortedSeqLoader(
data[["input_ids", "attention_mask"]],
self.model,
progress_bar=self.create_progress_bar(len(data), partition_info),
initial_batch_size=self.batch_size,
)
else:
loader = InMemoryLoader(
data[["input_ids", "attention_mask"]],
batch_size=self.batch_size,
progress_bar=self.create_progress_bar(len(data), partition_info),
max_seq_len=self.model.max_seq_length(),
)

all_embeddings_ls = []
for output in loader.map(self.model.get_model(self)):
all_embeddings_ls.append(output["sentence_embedding"])

out = cudf.DataFrame(index=index)
embedding = cp.asarray(torch.vstack(all_embeddings_ls))
_index = loader.sort_column(index.values) if self.sorted_data_loader else index
out["embedding"] = create_list_series_from_2d_ar(embedding, _index)

gc.collect()
torch.cuda.empty_cache()

return out

def meta(self):
return {"embedding": "float32"}
super().__init__(
model,
pre=pre,
cols=cols,
keep_cols=keep_cols,
batch_size=batch_size,
max_mem=max_mem,
sorted_data_loader=sorted_data_loader,
model_output_col=model_output_col,
pred_output_col=pred_output_col,
)
16 changes: 16 additions & 0 deletions crossfit/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@
pass


try:
from crossfit.backend.torch.op.base import Predictor

__all__.append("Predictor")
except ImportError:
pass


try:
from crossfit.op.tokenize import Tokenizer

Expand All @@ -38,6 +46,14 @@
pass


try:
from crossfit.op.label import Labeler

__all__.append("Labeler")
except ImportError:
pass


try:
from crossfit.op.vector_search import CuMLANNSearch, CuMLExactSearch, RaftExactSearch

Expand Down
76 changes: 76 additions & 0 deletions crossfit/op/label.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import List, Union

import cudf

from crossfit.op.base import Op


class Labeler(Op):
def __init__(
self,
labels: List[str],
cols=None,
keep_cols=None,
pre=None,
keep_prob: bool = False,
suffix: str = "labels",
):
super().__init__(pre=pre, cols=cols, keep_cols=keep_cols)
self.labels = labels
self.keep_prob = keep_prob
self.suffix = suffix

def call_column(self, data: cudf.Series) -> cudf.Series:
if isinstance(data, cudf.DataFrame):
raise ValueError(
"data must be a Series, got DataFrame. Add a pre step to convert to Series"
)

num_labels = len(data.iloc[0])
if len(self.labels) != num_labels:
raise ValueError(
f"The number of provided labels is {len(self.labels)} "
f"but there are {num_labels} in data."
)

scores = data.list.leaves.values.reshape(-1, num_labels)
classes = scores.argmax(1)
labels_map = {i: self.labels[i] for i in range(len(self.labels))}

return cudf.Series(classes).map(labels_map)

def call(self, data: Union[cudf.Series, cudf.DataFrame]) -> Union[cudf.Series, cudf.DataFrame]:
output = cudf.DataFrame()

if self.cols is None:
if not isinstance(data, cudf.Series):
raise ValueError("data must be a cudf Series")

return self.call_column(data)

for col in self.cols:
if col not in data.columns:
raise ValueError(f"Column {col} not found in data")

labels = self.call_column(data[col])
output[self._construct_name(col, self.suffix)] = labels

return output

def meta(self):
labeled = {"labels": "string"}

if len(self.cols) > 1:
labeled = {
self._construct_name(col, suffix): dtype
for col in self.cols
for suffix, dtype in labeled.items()
}

return labeled

def _construct_name(self, col_name, suffix):
if len(self.cols) == 1:
return suffix

return f"{col_name}_{suffix}"
Loading
Loading