Skip to content

Commit

Permalink
add ops for running a custom pytorch classifier (#38)
Browse files Browse the repository at this point in the history
* add ops for running a custom pytorch classifier

* update

* lint

* incrase num partitions

* catch out of memory error

* decrease pool size

* try moving model to inside context

* comment out test for now

* minor updates
  • Loading branch information
edknv authored Jan 10, 2024
1 parent 2a7eefe commit bbaadae
Show file tree
Hide file tree
Showing 8 changed files with 423 additions and 73 deletions.
35 changes: 35 additions & 0 deletions crossfit/backend/cudf/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,38 @@ def create_list_series_from_2d_ar(ar, index):
)

return cudf.Series(lc, index=index)


def create_nested_list_series_from_3d_ar(ar, index):
"""
Create a cudf list of lists series from 3d arrays
"""
n_slices, n_rows, n_cols = ar.shape
flattened_data = ar.reshape(-1) # Flatten the 3-D array into 1-D

# Inner list offsets (for each row in 2D slices)
inner_offsets = cp.arange(
start=0, stop=n_cols * n_rows * n_slices + 1, step=n_cols, dtype="int32"
)
inner_list_data = as_column(flattened_data)
inner_list_offsets = as_column(inner_offsets)

# Outer list offsets (for each 2D slice in the 3D array)
outer_offsets = cp.arange(start=0, stop=n_slices + 1, step=1, dtype="int32") * n_rows
outer_list_offsets = as_column(outer_offsets)

# Constructing the nested ListColumn
lc = cudf.core.column.ListColumn(
size=n_slices,
dtype=cudf.ListDtype(inner_list_data.dtype),
children=(
outer_list_offsets,
cudf.core.column.ListColumn(
size=inner_offsets.size - 1,
dtype=cudf.ListDtype(inner_list_data.dtype),
children=(inner_list_offsets, inner_list_data),
),
),
)

return cudf.Series(lc, index=index)
2 changes: 1 addition & 1 deletion crossfit/backend/torch/hf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def fit_memory_estimate_curve(self, model=None):
y.append(memory_used)

except RuntimeError as e:
if "out of memory" in str(e):
if "out of memory" in str(e) or "out_of_memory" in str(e):
pass
else:
raise e
Expand Down
99 changes: 99 additions & 0 deletions crossfit/backend/torch/op/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2023 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
from typing import Optional

import cudf
import cupy as cp
import torch

from crossfit.backend.cudf.series import (
create_list_series_from_2d_ar,
create_nested_list_series_from_3d_ar,
)
from crossfit.backend.torch.loader import DEFAULT_BATCH_SIZE, InMemoryLoader, SortedSeqLoader
from crossfit.backend.torch.model import Model
from crossfit.op.base import Op


class Predictor(Op):
def __init__(
self,
model: Model,
pre=None,
cols=False,
keep_cols=None,
batch_size: int = DEFAULT_BATCH_SIZE,
max_mem: str = "16GB",
sorted_data_loader: bool = True,
model_output_col: Optional[str] = None,
pred_output_col: str = "preds",
):
super().__init__(pre=pre, cols=cols, keep_cols=keep_cols)
self.model = model
self.batch_size = batch_size
self.max_mem = max_mem
self.max_mem_gb = int(self.max_mem.split("GB")[0]) / 2.5
self.sorted_data_loader = sorted_data_loader
self.model_output_col = model_output_col
self.pred_output_col = pred_output_col

def setup(self):
self.model.load_on_worker(self)

@torch.no_grad()
def call(self, data, partition_info=None):
index = data.index
if self.sorted_data_loader:
loader = SortedSeqLoader(
data[["input_ids", "attention_mask"]],
self.model,
progress_bar=self.create_progress_bar(len(data), partition_info),
initial_batch_size=self.batch_size,
)
else:
loader = InMemoryLoader(
data[["input_ids", "attention_mask"]],
batch_size=self.batch_size,
progress_bar=self.create_progress_bar(len(data), partition_info),
max_seq_len=self.model.max_seq_length(),
)

all_outputs_ls = []
for output in loader.map(self.model.get_model(self)):
if isinstance(output, dict):
if self.model_output_col not in output:
raise ValueError(f"Column '{self.model_outupt_col}' not found in model output.")
all_outputs_ls.append(output[self.model_output_col])
else:
all_outputs_ls.append(output)

out = cudf.DataFrame(index=index)
outputs = cp.asarray(torch.vstack(all_outputs_ls))
_index = loader.sort_column(index.values) if self.sorted_data_loader else index
if len(outputs.shape) == 2:
out[self.pred_output_col] = create_list_series_from_2d_ar(outputs, _index)
elif len(outputs.shape) == 3:
out[self.pred_output_col] = create_nested_list_series_from_3d_ar(outputs, _index)
else:
raise RuntimeError(f"Unexpected output shape: {output.shape}")

gc.collect()
torch.cuda.empty_cache()

return out

def meta(self):
return {self.pred_output_col: "float32"}
73 changes: 16 additions & 57 deletions crossfit/backend/torch/op/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import gc

import cudf
import cupy as cp
import torch

from crossfit.backend.cudf.series import create_list_series_from_2d_ar
from crossfit.backend.torch.loader import DEFAULT_BATCH_SIZE, InMemoryLoader, SortedSeqLoader
from crossfit.backend.torch.loader import DEFAULT_BATCH_SIZE
from crossfit.backend.torch.model import Model
from crossfit.op.base import Op
from crossfit.backend.torch.op.base import Predictor


class Embedder(Op):
class Embedder(Predictor):
def __init__(
self,
model: Model,
Expand All @@ -34,51 +27,17 @@ def __init__(
batch_size: int = DEFAULT_BATCH_SIZE,
max_mem: str = "16GB",
sorted_data_loader: bool = True,
model_output_col: str = "sentence_embedding",
pred_output_col: str = "embedding",
):
super().__init__(pre=pre, cols=cols, keep_cols=keep_cols)
self.model = model
self.batch_size = batch_size
self.max_mem = max_mem
self.max_mem_gb = int(self.max_mem.split("GB")[0]) / 2.5
self.sorted_data_loader = sorted_data_loader

def setup(self):
self.model.load_on_worker(self)

def teardown(self):
self.model.unload_from_worker(self)

@torch.no_grad()
def call(self, data, partition_info=None):
index = data.index
if self.sorted_data_loader:
loader = SortedSeqLoader(
data[["input_ids", "attention_mask"]],
self.model,
progress_bar=self.create_progress_bar(len(data), partition_info),
initial_batch_size=self.batch_size,
)
else:
loader = InMemoryLoader(
data[["input_ids", "attention_mask"]],
batch_size=self.batch_size,
progress_bar=self.create_progress_bar(len(data), partition_info),
max_seq_len=self.model.max_seq_length(),
)

all_embeddings_ls = []
for output in loader.map(self.model.get_model(self)):
all_embeddings_ls.append(output["sentence_embedding"])

out = cudf.DataFrame(index=index)
embedding = cp.asarray(torch.vstack(all_embeddings_ls))
_index = loader.sort_column(index.values) if self.sorted_data_loader else index
out["embedding"] = create_list_series_from_2d_ar(embedding, _index)

gc.collect()
torch.cuda.empty_cache()

return out

def meta(self):
return {"embedding": "float32"}
super().__init__(
model,
pre=pre,
cols=cols,
keep_cols=keep_cols,
batch_size=batch_size,
max_mem=max_mem,
sorted_data_loader=sorted_data_loader,
model_output_col=model_output_col,
pred_output_col=pred_output_col,
)
16 changes: 16 additions & 0 deletions crossfit/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@
pass


try:
from crossfit.backend.torch.op.base import Predictor

__all__.append("Predictor")
except ImportError:
pass


try:
from crossfit.op.tokenize import Tokenizer

Expand All @@ -38,6 +46,14 @@
pass


try:
from crossfit.op.label import Labeler

__all__.append("Labeler")
except ImportError:
pass


try:
from crossfit.op.vector_search import CuMLANNSearch, CuMLExactSearch, RaftExactSearch

Expand Down
76 changes: 76 additions & 0 deletions crossfit/op/label.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import List, Union

import cudf

from crossfit.op.base import Op


class Labeler(Op):
def __init__(
self,
labels: List[str],
cols=None,
keep_cols=None,
pre=None,
keep_prob: bool = False,
suffix: str = "labels",
):
super().__init__(pre=pre, cols=cols, keep_cols=keep_cols)
self.labels = labels
self.keep_prob = keep_prob
self.suffix = suffix

def call_column(self, data: cudf.Series) -> cudf.Series:
if isinstance(data, cudf.DataFrame):
raise ValueError(
"data must be a Series, got DataFrame. Add a pre step to convert to Series"
)

num_labels = len(data.iloc[0])
if len(self.labels) != num_labels:
raise ValueError(
f"The number of provided labels is {len(self.labels)} "
f"but there are {num_labels} in data."
)

scores = data.list.leaves.values.reshape(-1, num_labels)
classes = scores.argmax(1)
labels_map = {i: self.labels[i] for i in range(len(self.labels))}

return cudf.Series(classes).map(labels_map)

def call(self, data: Union[cudf.Series, cudf.DataFrame]) -> Union[cudf.Series, cudf.DataFrame]:
output = cudf.DataFrame()

if self.cols is None:
if not isinstance(data, cudf.Series):
raise ValueError("data must be a cudf Series")

return self.call_column(data)

for col in self.cols:
if col not in data.columns:
raise ValueError(f"Column {col} not found in data")

labels = self.call_column(data[col])
output[self._construct_name(col, self.suffix)] = labels

return output

def meta(self):
labeled = {"labels": "string"}

if len(self.cols) > 1:
labeled = {
self._construct_name(col, suffix): dtype
for col in self.cols
for suffix, dtype in labeled.items()
}

return labeled

def _construct_name(self, col_name, suffix):
if len(self.cols) == 1:
return suffix

return f"{col_name}_{suffix}"
Loading

0 comments on commit bbaadae

Please sign in to comment.