Skip to content

Commit

Permalink
Fix pytorch memory curve estimation for rmm backed allocator (#94)
Browse files Browse the repository at this point in the history
* fix pytorch memory curve estimation

Signed-off-by: Vibhu Jawa <[email protected]>

* Add test

Signed-off-by: Vibhu Jawa <[email protected]>

* Add test for rmm

Signed-off-by: Vibhu Jawa <[email protected]>

* move imports

Signed-off-by: Vibhu Jawa <[email protected]>

* Fix based on Praateeks review

---------

Signed-off-by: Vibhu Jawa <[email protected]>
  • Loading branch information
VibhuJawa authored Oct 10, 2024
1 parent d7e2643 commit a043f45
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 5 deletions.
14 changes: 9 additions & 5 deletions crossfit/backend/torch/hf/memory_curve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
from transformers import PreTrainedModel

from crossfit.utils.model_adapter import adapt_model_input
from crossfit.utils.torch_utils import cleanup_torch_cache
from crossfit.utils.torch_utils import (
cleanup_torch_cache,
get_peak_memory_used,
reset_memory_tracking,
)


def fit_memory_estimate_curve(
Expand All @@ -37,7 +41,7 @@ def fit_memory_estimate_curve(
) -> LinearRegression:
print(f"Fitting memory estimate curve for model: {path_or_name}")

device = next(model.parameters()).device
device = "cuda"
X: list[list[int]] = []
y: list[float] = []

Expand All @@ -51,16 +55,16 @@ def fit_memory_estimate_curve(
leave=False,
)
for seq_len in seq_len_pbar:
torch.cuda.reset_peak_memory_stats()

reset_memory_tracking()
batch = {
"input_ids": torch.randint(1, 501, (batch_size, seq_len)).to(device=device),
"attention_mask": torch.ones((batch_size, seq_len)).to(device=device),
}

try:
_ = adapt_model_input(model, batch)
memory_used = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
memory_used = get_peak_memory_used()
memory_used = memory_used / (1024**2) # Convert to MB
X.append([batch_size, seq_len, seq_len**2])
y.append(memory_used)

Expand Down
52 changes: 52 additions & 0 deletions crossfit/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,55 @@ def cleanup_torch_cache() -> None:
gc.collect()
torch.cuda.empty_cache()
return None


def reset_memory_tracking() -> None:
"""
Resets memory counters.
This function enables memory usage statistics tracking and resets the counters
for peak memory usage. It handles both RMM (RAPIDS Memory Manager) and PyTorch's
native CUDA memory tracking, depending on the current allocator backend.
If RMM is being used as the allocator, it enables RMM statistics and pushes a new
statistics context. If the default PyTorch allocator is being used, it resets the
peak memory stats for CUDA.
Returns:
None
"""
if is_torch_memory_rmm():
import rmm

rmm.statistics.enable_statistics()
rmm.statistics.push_statistics()
else:
torch.cuda.reset_peak_memory_stats()


def get_peak_memory_used() -> int:
"""
Get the peak memory usage in bytes.
This function retrieves the peak memory usage, either from RMM statistics
if the RMM allocator is being used, or from PyTorch's CUDA memory stats.
Returns:
int: Peak memory usage in bytes.
"""
if is_torch_memory_rmm():
import rmm

stats = rmm.statistics.pop_statistics()
return stats.peak_bytes
else:
return torch.cuda.max_memory_allocated()


def is_torch_memory_rmm():
# TODO: This is hacky, we need to check if the allocator is rmm
# and then reset the peak memory stats
# we get this fixed in Pytorch
# https://github.com/pytorch/pytorch/issues/133281
# https://github.com/pytorch/pytorch/issues/133280
return torch.cuda.memory.get_allocator_backend() == "pluggable"
29 changes: 29 additions & 0 deletions tests/op/test_fit_memory_estimate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest
from sklearn.linear_model import LinearRegression

transformers = pytest.importorskip("transformers")
torch = pytest.importorskip("torch")
rmm_torch_allocator = pytest.importorskip("rmm.allocators.torch").rmm_torch_allocator
fit_memory_estimate_curve = pytest.importorskip(
"crossfit.backend.torch.hf.memory_curve_utils"
).fit_memory_estimate_curve

MODEL_NAME = "microsoft/deberta-v3-base"

# Have to do it globally
# TODO: Long term figure out a better way
torch.cuda.memory.change_current_allocator(rmm_torch_allocator)


def test_fit_memory_estimate_curve(tmp_path):
# Setup
mem_model_path = tmp_path / "test_memory_model.joblib"
model = transformers.AutoModel.from_pretrained(MODEL_NAME).to("cuda")
result = fit_memory_estimate_curve(
model=model, path_or_name=MODEL_NAME, mem_model_path=str(mem_model_path)
)
# Assertions
assert isinstance(result, LinearRegression)
assert result.coef_.shape == (3,) # [batch_size, seq_len, seq_len**2]
assert isinstance(result.intercept_, float)
assert mem_model_path.exists()

0 comments on commit a043f45

Please sign in to comment.