Skip to content

Conversation

eternalNight
Copy link
Contributor

Motivation

PyTorch provides min_cut_rematerialization_partition() to partition a joint graph while respecting recomputation annotation. That algorithm forms a data-flow-like graph from the joint graph, adds to edges weights from some recomputation-cost-related heuristics and applies the min-cut algorithm to determine which nodes to recompute. Users can force recomputation of a node by annotating its node.meta["recompute"] to MUST_RECOMPUTE or PREFER_RECOMPUTE, as is implemented in [1].

While originally designed for activation checkpointing, min_cut_rematerialization can also be used to recompute param aliases. When partitioning a joint graph, we don't want to save for backward the gathered parameters and values computed from them via aliasing ops, as that essentially means the gathered parameter will be saved. Instead of customizing the partitioner or patching choose_saved_values_set, we can achieve that by annotating such nodes to be MUST_RECOMPUTE.

Both eager and inductor backends can use min_cut_rematerialization easily. The eager backend can use min-cut by customizing the partition_fn for aot_module_simplified, and is already using that for graphs with activation checkpointing enabled. The inductor backend uses that algorithm since torch 2.0.0 [2] and is still the default after the inductor partitioner is made configurable a few weeks ago [3].

That approach also helps DeepCompile + torch autocast nicely. When autocast is enabled, downcasted parameters are preferred to be recomputed. It suffices to mark such casting nodes as must-recompute.

[1] https://github.com/pytorch/pytorch/blob/main/torch/_functorch/partitioners.py#L1813
[2] https://github.com/pytorch/pytorch/blob/v2.0.0/torch/_inductor/compile_fx.py#L459
[3] pytorch/pytorch#157580

Proposal

Motivated by the flexibility and the requirement for optimizing DeepCompile + autocast, I propose to switch to the min-cut-based partitioner for both backends. This PR implements that switch, cleans up dead code and also recomputes downcasted parameters in the backward.

Preliminary Evaluation

Here's a summary of the tests using https://gist.github.com/eternalNight/3c2cf8c703f1e9e7742d3b7f9e1edae3 on a 8x RTX 5090 node.

Configuration Base Time (ms) Base Mem (GB) Time with this PR (ms) Mem with this PR (GB)
eager + autocast 551.92 12.07 571.24 9.96
eager + bf16 419.87 9.47 445.76 7.30
inductor + autocast 546.97 12.84 570.09 13.04
inductor + bf16 444.03 10.01 444.70 10.19

Reduced memory with eager backend

The initial goal of this PR is to reduce peak memory usage when torch autocast is enabled. That is achieved according to the first row of the table, but in two different ways simultaneously.

  1. Downcasted parameters during forward are throwed away and recomputed (by the fused cast + allgather) in the backward pass.
  2. Without this PR, fast_free_schedule will arange most allgather at the beginning of the graph. That leads to a even higher peak during forward, but is no longer seen with PR.
  3. By diffing the graphs passed to add_z3_gather_release, I noticed that recomputations selected by min-cut is slightly different (that test script has activation checkpointing enabled for the LLM module). That can also impact computation time and memory usage.

Here's the shape of memory usage before this PR with eager backend + torch autocast. eager + BF16 shows similar shapes. Numbers reported in the table are peak during forward. The peak memory usage during backend reduces ~0.7GB in both cases.

image

After this PR:

image

Similar memory with inductor backend

Unlike eager backend, the inductor backend uses similar memory with or without this PR. The memory usage pattern is as follows, which requires further analysis.

Before this PR:

image

After this PR:

image

Copy link
Contributor

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @eternalNight, this is definitely a significant improvement.
Even though we currently don't see performance improvement for inductor cases, the benefit is very clear for the performance with eager and cleanness of the code.

@tohtana
Copy link
Contributor

tohtana commented Oct 2, 2025

@eternalNight As some of our CI env use very old PyTorch, they fail with loading deepspeed. Can you add a guard to avoid the error?
I think DeepCompile now works with v2.6 or later, and this PR should work with these versions. But I'm okay to limit it to v2.7 or later if necessary.

@loadams @sfc-gh-truwase This doesn't need to be addressed for this PR, but do we have a plan to update the supported PyTorch version? python tests fail because the PyTorch version there was very old. The top page says

For full feature support we recommend a version of PyTorch that is >= 1.9 and ideally the latest PyTorch stable release.

but it would be okay to limit to 2.0 or later.

@eternalNight eternalNight force-pushed the eternalNight/partition_by_min_cut_rematerialization branch from a0bd752 to ff830e7 Compare October 3, 2025 03:07
PyTorch provides `min_cut_rematerialization_partition()` to partition a
joint graph while respecting recomputation annotation. That algorithm
forms a data-flow-like graph from the joint graph, adds to edges weights
from some recomputation-cost-related heuristics and applies the min-cut
algorithm to determine which nodes to recompute. Users can force
recomputation of a node by annotating its `node.meta["recompute"]` to
MUST_RECOMPUTE or PREFER_RECOMPUTE, as is implemented in [1].

While originally designed for activation checkpointing,
min_cut_rematerialization suits our needs quite well. When partitioning
a joint graph, we don't want to save for backward the gathered
parameters and values computed from them via aliasing ops, as that
essentially means the gathered parameter will be saved. Instead of
customizing the partitioner or patching choose_saved_values_set, we can
achieve that by annotating such nodes to be MUST_RECOMPUTE.

Both eager and inductor backends can use min_cut_rematerialization
easily. The eager backend can use min-cut by customizing the
partition_fn for `aot_module_simplified`, and is already using that for
graphs with activation checkpointing enabled. The inductor backend uses
that algorithm since torch 2.0.0 and is still the default after the
inductor partitioner is made configurable a few weeks ago [3].

That approach also helps DeepCompile + torch autocast nicely. When
autocast is enabled, downcasted parameters are preferred to be
recomputed. It suffices to mark such casting nodes as must-recompute.

Motivated by the flexibility and the requirement for optimizing
DeepCompile + autocast, I propose to switch to the min-cut-based
partitioner for both backends. This PR implements that switch, cleans up
dead code and also recomputes downcasted parameters in the backward.

[1] https://github.com/pytorch/pytorch/blob/main/torch/_functorch/partitioners.py#L1813

[2] https://github.com/pytorch/pytorch/blob/v2.8.0/torch/_inductor/compile_fx.py#L2281

[3] pytorch/pytorch#157580

Signed-off-by: Junjie Mao <[email protected]>
@eternalNight eternalNight force-pushed the eternalNight/partition_by_min_cut_rematerialization branch from ff830e7 to aac7b75 Compare October 3, 2025 03:08
@eternalNight
Copy link
Contributor Author

@eternalNight As some of our CI env use very old PyTorch, they fail with loading deepspeed. Can you add a guard to avoid the error? I think DeepCompile now works with v2.6 or later, and this PR should work with these versions. But I'm okay to limit it to v2.7 or later if necessary.

I just added a try ... except ImportError block around the pytorch imports in partitioner.py. That CheckpointPolicy enum is available at least since torch v2.6 (ref), so I think we can keep that version constraint for now.

@tohtana tohtana enabled auto-merge (squash) October 3, 2025 03:16
@tohtana tohtana merged commit 2a76988 into deepspeedai:master Oct 3, 2025
12 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants