diff --git a/crossfit/backend/torch/hf/memory_curve_utils.py b/crossfit/backend/torch/hf/memory_curve_utils.py index 0dfad8f..07081f2 100644 --- a/crossfit/backend/torch/hf/memory_curve_utils.py +++ b/crossfit/backend/torch/hf/memory_curve_utils.py @@ -21,7 +21,11 @@ from transformers import PreTrainedModel from crossfit.utils.model_adapter import adapt_model_input -from crossfit.utils.torch_utils import cleanup_torch_cache +from crossfit.utils.torch_utils import ( + cleanup_torch_cache, + get_peak_memory_used, + reset_memory_tracking, +) def fit_memory_estimate_curve( @@ -37,7 +41,7 @@ def fit_memory_estimate_curve( ) -> LinearRegression: print(f"Fitting memory estimate curve for model: {path_or_name}") - device = next(model.parameters()).device + device = "cuda" X: list[list[int]] = [] y: list[float] = [] @@ -51,8 +55,7 @@ def fit_memory_estimate_curve( leave=False, ) for seq_len in seq_len_pbar: - torch.cuda.reset_peak_memory_stats() - + reset_memory_tracking() batch = { "input_ids": torch.randint(1, 501, (batch_size, seq_len)).to(device=device), "attention_mask": torch.ones((batch_size, seq_len)).to(device=device), @@ -60,7 +63,8 @@ def fit_memory_estimate_curve( try: _ = adapt_model_input(model, batch) - memory_used = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB + memory_used = get_peak_memory_used() + memory_used = memory_used / (1024**2) # Convert to MB X.append([batch_size, seq_len, seq_len**2]) y.append(memory_used) diff --git a/crossfit/utils/torch_utils.py b/crossfit/utils/torch_utils.py index 5fc3ec0..8c4f63a 100644 --- a/crossfit/utils/torch_utils.py +++ b/crossfit/utils/torch_utils.py @@ -99,3 +99,55 @@ def cleanup_torch_cache() -> None: gc.collect() torch.cuda.empty_cache() return None + + +def reset_memory_tracking() -> None: + """ + Resets memory counters. + + This function enables memory usage statistics tracking and resets the counters + for peak memory usage. It handles both RMM (RAPIDS Memory Manager) and PyTorch's + native CUDA memory tracking, depending on the current allocator backend. + + If RMM is being used as the allocator, it enables RMM statistics and pushes a new + statistics context. If the default PyTorch allocator is being used, it resets the + peak memory stats for CUDA. + + Returns: + None + """ + if is_torch_memory_rmm(): + import rmm + + rmm.statistics.enable_statistics() + rmm.statistics.push_statistics() + else: + torch.cuda.reset_peak_memory_stats() + + +def get_peak_memory_used() -> int: + """ + Get the peak memory usage in bytes. + + This function retrieves the peak memory usage, either from RMM statistics + if the RMM allocator is being used, or from PyTorch's CUDA memory stats. + + Returns: + int: Peak memory usage in bytes. + """ + if is_torch_memory_rmm(): + import rmm + + stats = rmm.statistics.pop_statistics() + return stats.peak_bytes + else: + return torch.cuda.max_memory_allocated() + + +def is_torch_memory_rmm(): + # TODO: This is hacky, we need to check if the allocator is rmm + # and then reset the peak memory stats + # we get this fixed in Pytorch + # https://github.com/pytorch/pytorch/issues/133281 + # https://github.com/pytorch/pytorch/issues/133280 + return torch.cuda.memory.get_allocator_backend() == "pluggable" diff --git a/tests/op/test_fit_memory_estimate.py b/tests/op/test_fit_memory_estimate.py new file mode 100644 index 0000000..38b45cf --- /dev/null +++ b/tests/op/test_fit_memory_estimate.py @@ -0,0 +1,29 @@ +import pytest +from sklearn.linear_model import LinearRegression + +transformers = pytest.importorskip("transformers") +torch = pytest.importorskip("torch") +rmm_torch_allocator = pytest.importorskip("rmm.allocators.torch").rmm_torch_allocator +fit_memory_estimate_curve = pytest.importorskip( + "crossfit.backend.torch.hf.memory_curve_utils" +).fit_memory_estimate_curve + +MODEL_NAME = "microsoft/deberta-v3-base" + +# Have to do it globally +# TODO: Long term figure out a better way +torch.cuda.memory.change_current_allocator(rmm_torch_allocator) + + +def test_fit_memory_estimate_curve(tmp_path): + # Setup + mem_model_path = tmp_path / "test_memory_model.joblib" + model = transformers.AutoModel.from_pretrained(MODEL_NAME).to("cuda") + result = fit_memory_estimate_curve( + model=model, path_or_name=MODEL_NAME, mem_model_path=str(mem_model_path) + ) + # Assertions + assert isinstance(result, LinearRegression) + assert result.coef_.shape == (3,) # [batch_size, seq_len, seq_len**2] + assert isinstance(result.intercept_, float) + assert mem_model_path.exists()