Skip to content

Commit

Permalink
batch longer sequence first in sorted seq loader (#18)
Browse files Browse the repository at this point in the history
* batch longer sequence first in sorted seq loader

* allow batch size change in user api

* reduce batch size in tests

* retry smaller batch if we run out of memory
  • Loading branch information
edknv committed Nov 4, 2023
1 parent df5e74b commit 7edbb44
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 36 deletions.
31 changes: 24 additions & 7 deletions crossfit/backend/torch/hf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,32 +57,49 @@ def fit_memory_estimate_curve(self, model=None):
y = []

max_seq = self.max_seq_length()
for batch_size in tqdm(range(1, 2048, 256)):
for seq_len in list(range(16, max_seq, 64)) + [max_seq]:
for batch_size in tqdm(range(2048, 0, -256)):
if batch_size <= 0:
continue

for seq_len in range(max_seq, 0, -64):
if seq_len <= 0:
continue

torch.cuda.reset_peak_memory_stats()

batch = {
"input_ids": torch.randint(1, 501, (batch_size, seq_len)).to(device=device),
"attention_mask": torch.ones((batch_size, seq_len)).to(device=device),
"input_ids": torch.randint(1, 501, (batch_size, seq_len)).to(
device=device
),
"attention_mask": torch.ones((batch_size, seq_len)).to(
device=device
),
}

try:
outputs = model(batch)
memory_used = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
memory_used = torch.cuda.max_memory_allocated() / (
1024**2
) # Convert to MB
X.append([batch_size, seq_len, seq_len**2])
y.append(memory_used)

except RuntimeError as e:
if "out of memory" in str(e):
torch.cuda.empty_cache()
pass
else:
raise e
finally:
del batch
if "outputs" in vars():
del outputs
gc.collect()
torch.cuda.empty_cache()

self.mem = LinearRegression().fit(np.array(X), np.array(y))
os.makedirs(cache_dir, exist_ok=True)
joblib.dump(self.mem, mem_model_path)

del outputs
if remove_model:
del model
gc.collect()
Expand Down
46 changes: 33 additions & 13 deletions crossfit/backend/torch/loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict, overload
from itertools import islice
import warnings

import torch

Expand All @@ -9,6 +10,9 @@
from crossfit.data.array.conversion import convert_array


DEFAULT_BATCH_SIZE = 512


class InMemoryLoader:
@overload
def __init__(
Expand Down Expand Up @@ -75,7 +79,7 @@ def __init__(
data: CrossFrame,
model: Model,
sort_key: str = "input_ids",
initial_batch_size: int = 512,
initial_batch_size: int = DEFAULT_BATCH_SIZE,
to_ignore=None,
progress_bar=None,
):
Expand All @@ -86,7 +90,7 @@ def __init__(

frame = CrossFrame(data).cast(torch.Tensor)
seq_length = (frame[sort_key] != 0).sum(axis=1)
self.sorted_indices = seq_length.argsort()
self.sorted_indices = seq_length.argsort(descending=True)
frame = frame.apply(lambda x: x[self.sorted_indices])
frame = frame.assign(seq_length=seq_length[self.sorted_indices])

Expand All @@ -113,22 +117,38 @@ def __next__(self):
start = 0
else:
start = self.splits[self.current_idx - 1]
end = min(self.splits[self.current_idx], self.num_rows)

_tokens = self.tensor_dict["seq_length"]

batch = {
key: val[start:end]
for key, val in self.tensor_dict.items()
if key not in self.to_ignore
}
clip_len = min(_tokens[end - 1], self.model.max_seq_length())
batch = {key: val[:, :clip_len] for key, val in batch.items()}
end = min(self.splits[self.current_idx], self.num_rows)
while end > start:
try:
batch = {
key: val[start:end]
for key, val in self.tensor_dict.items()
if key not in self.to_ignore
}
clip_len = min(
max(_tokens[start], _tokens[end - 1]), self.model.max_seq_length()
)
batch = {key: val[:, :clip_len] for key, val in batch.items()}

for fn in self._to_map:
batch = fn(batch)

break
except torch.cuda.OutOfMemoryError as e:
mid = start + (end - start) // 2
warnings.warn(
f"Not enough memeory for a batch size of {end - start}. "
f"Retrying with a new batch size of {mid - start}. "
f"Consider setting initial batch size to {mid - start}."
)
self.splits.insert(self.current_idx, mid)
end = min(self.splits[self.current_idx], self.num_rows)

self.current_idx += 1

for fn in self._to_map:
batch = fn(batch)

if self.progress_bar is not None:
self.progress_bar.update(end - start)

Expand Down
4 changes: 2 additions & 2 deletions crossfit/backend/torch/op/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from crossfit.op.base import Op
from crossfit.backend.cudf.series import create_list_series_from_2d_ar
from crossfit.backend.torch.model import Model
from crossfit.backend.torch.loader import SortedSeqLoader, InMemoryLoader
from crossfit.backend.torch.loader import DEFAULT_BATCH_SIZE, SortedSeqLoader, InMemoryLoader


class Embedder(Op):
Expand All @@ -17,7 +17,7 @@ def __init__(
pre=None,
cols=False,
keep_cols=None,
batch_size=1024,
batch_size: int = DEFAULT_BATCH_SIZE,
max_mem: str = "16GB",
sorted_data_loader: bool = True,
):
Expand Down
10 changes: 8 additions & 2 deletions crossfit/report/beir/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from crossfit.dataset.load import load_dataset
from crossfit.op.vector_search import VectorSearchOp
from crossfit.backend.torch.model import Model
from crossfit.backend.torch.loader import DEFAULT_BATCH_SIZE


def embed(
Expand All @@ -19,14 +20,17 @@ def embed(
out_dir: Optional[str] = None,
tiny_sample: bool = False,
sorted_data_loader: bool = True,
batch_size: int = DEFAULT_BATCH_SIZE,
) -> EmbeddingDatataset:
dataset: IRDataset = load_dataset(
"beir/" + dataset_name, overwrite=overwrite, tiny_sample=tiny_sample
)

out_dir = out_dir or CF_HOME
processed_name = "processed-test" if tiny_sample else "processed"
emb_dir = os.path.join(out_dir, processed_name, "beir", dataset_name, "emb", model.path_or_name)
emb_dir = os.path.join(
out_dir, processed_name, "beir", dataset_name, "emb", model.path_or_name
)

if os.path.exists(emb_dir):
if overwrite:
Expand All @@ -53,7 +57,9 @@ def embed(

pipe = op.Sequential(
op.Tokenizer(model, cols=["text"]),
op.Embedder(model, sorted_data_loader=sorted_data_loader),
op.Embedder(
model, sorted_data_loader=sorted_data_loader, batch_size=batch_size
),
repartition=partitions,
keep_cols=["index", "_id"],
)
Expand Down
10 changes: 10 additions & 0 deletions crossfit/report/beir/report.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import gc
from typing import List, Optional

import cudf
import cupy as cp
import dask_cudf
from cuml.preprocessing import LabelEncoder
import numpy as np
import torch

from crossfit.backend.dask.aggregate import aggregate
from crossfit.data.sparse.dispatch import CrossSparse
Expand All @@ -17,6 +19,7 @@
from crossfit.report.base import Report
from crossfit.op.vector_search import VectorSearchOp
from crossfit.backend.torch.model import Model
from crossfit.backend.torch.loader import DEFAULT_BATCH_SIZE


class BeirMetricAggregator(Aggregator):
Expand Down Expand Up @@ -163,6 +166,7 @@ def beir_report(
groupby=["split"],
tiny_sample=False,
sorted_data_loader: bool = True,
batch_size: int = DEFAULT_BATCH_SIZE,
) -> BeirReport:
embeddings: EmbeddingDatataset = embed(
dataset_name,
Expand All @@ -173,6 +177,7 @@ def beir_report(
vector_search=vector_search,
tiny_sample=tiny_sample,
sorted_data_loader=sorted_data_loader,
batch_size=batch_size,
)

observations = []
Expand All @@ -190,6 +195,11 @@ def beir_report(
data = dask_cudf.concat(observations)
joined = join_predictions(data, embeddings.predictions)

del data
del embeddings
gc.collect()
torch.cuda.empty_cache()

aggregator = BeirMetricAggregator(ks)
aggregator = Aggregator(aggregator, groupby=groupby, name="")

Expand Down
4 changes: 3 additions & 1 deletion tests/report/beir/test_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@


@pytest.mark.singlegpu
@pytest.mark.parametrize("dataset", ["nq"])
@pytest.mark.parametrize("dataset", ["fiqa", "hotpotqa", "nq"])
def test_embed_multi_gpu(
dataset,
model_name="sentence-transformers/all-MiniLM-L6-v2",
k=10,
batch_size=64,
):
model = cf.SentenceTransformerModel(model_name)
vector_search = cf.TorchExactSearch(k=k)
Expand All @@ -24,6 +25,7 @@ def test_embed_multi_gpu(
vector_search=vector_search,
overwrite=True,
tiny_sample=True,
batch_size=batch_size,
)
embeds = embeds.predictions.ddf().compute().to_pandas()

Expand Down
27 changes: 20 additions & 7 deletions tests/report/beir/test_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@
import crossfit as cf
from crossfit.data.sparse.ranking import SparseNumericLabels, SparseRankings
from crossfit.metric.ranking import NDCG
from crossfit.report.beir.report import (create_csr_matrix,
create_label_encoder,
join_predictions)
from crossfit.report.beir.report import (
create_csr_matrix,
create_label_encoder,
join_predictions,
)


@pytest.mark.singlegpu
@pytest.mark.parametrize("dataset", ["nq"])
@pytest.mark.parametrize("dataset", ["fiqa", "hotpotqa", "nq"])
def test_beir_report(
dataset, model_name="sentence-transformers/all-MiniLM-L6-v2", k=10
dataset,
model_name="sentence-transformers/all-MiniLM-L6-v2",
k=10,
batch_size=8,
):
model = cf.SentenceTransformerModel(model_name)
vector_search = cf.TorchExactSearch(k=k)
Expand All @@ -26,6 +31,7 @@ def test_beir_report(
vector_search=vector_search,
overwrite=True,
tiny_sample=True,
batch_size=batch_size,
)

expected_columns = [
Expand All @@ -45,8 +51,13 @@ def test_beir_report(


@pytest.mark.singlegpu
@pytest.mark.parametrize("dataset", ["hotpotqa"])
def test_no_invalid_scores(dataset, model_name="sentence-transformers/all-MiniLM-L6-v2", k=10):
@pytest.mark.parametrize("dataset", ["fiqa", "hotpotqa", "nq"])
def test_no_invalid_scores(
dataset,
model_name="sentence-transformers/all-MiniLM-L6-v2",
k=5,
batch_size=8,
):
model = cf.SentenceTransformerModel(model_name)
vector_search = cf.TorchExactSearch(k=k)
embeds = cf.embed(
Expand All @@ -55,9 +66,11 @@ def test_no_invalid_scores(dataset, model_name="sentence-transformers/all-MiniLM
vector_search=vector_search,
overwrite=True,
tiny_sample=True,
batch_size=batch_size,
)
test = embeds.data.test.ddf()
test["split"] = "test"

df = join_predictions(test, embeds.predictions).compute()

encoder = create_label_encoder(df, ["corpus-index-pred", "corpus-index-obs"])
Expand Down
10 changes: 6 additions & 4 deletions tests/report/data_overview/test_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

import crossfit as cf
from crossfit.backend.dask.aggregate import aggregate
from crossfit.report.data_overview.report import (CategoricalMetrics,
ContinuousMetrics,
DataOverviewReport,
data_overview_report)
from crossfit.report.data_overview.report import (
CategoricalMetrics,
ContinuousMetrics,
DataOverviewReport,
data_overview_report,
)
from crossfit.report.data_overview.visualization.facets import FacetsOverview
from tests.utils import sample_df

Expand Down

0 comments on commit 7edbb44

Please sign in to comment.