DeepCompile: Use min_cut_rematerialization for partitioning joint graphs #7609
+72
−132
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Motivation
PyTorch provides
min_cut_rematerialization_partition()
to partition a joint graph while respecting recomputation annotation. That algorithm forms a data-flow-like graph from the joint graph, adds to edges weights from some recomputation-cost-related heuristics and applies the min-cut algorithm to determine which nodes to recompute. Users can force recomputation of a node by annotating itsnode.meta["recompute"]
to MUST_RECOMPUTE or PREFER_RECOMPUTE, as is implemented in [1].While originally designed for activation checkpointing, min_cut_rematerialization can also be used to recompute param aliases. When partitioning a joint graph, we don't want to save for backward the gathered parameters and values computed from them via aliasing ops, as that essentially means the gathered parameter will be saved. Instead of customizing the partitioner or patching
choose_saved_values_set
, we can achieve that by annotating such nodes to be MUST_RECOMPUTE.Both eager and inductor backends can use min_cut_rematerialization easily. The eager backend can use min-cut by customizing the partition_fn for
aot_module_simplified
, and is already using that for graphs with activation checkpointing enabled. The inductor backend uses that algorithm since torch 2.0.0 [2] and is still the default after the inductor partitioner is made configurable a few weeks ago [3].That approach also helps DeepCompile + torch autocast nicely. When autocast is enabled, downcasted parameters are preferred to be recomputed. It suffices to mark such casting nodes as must-recompute.
[1] https://github.com/pytorch/pytorch/blob/main/torch/_functorch/partitioners.py#L1813
[2] https://github.com/pytorch/pytorch/blob/v2.0.0/torch/_inductor/compile_fx.py#L459
[3] pytorch/pytorch#157580
Proposal
Motivated by the flexibility and the requirement for optimizing DeepCompile + autocast, I propose to switch to the min-cut-based partitioner for both backends. This PR implements that switch, cleans up dead code and also recomputes downcasted parameters in the backward.
Preliminary Evaluation
Here's a summary of the tests using https://gist.github.com/eternalNight/3c2cf8c703f1e9e7742d3b7f9e1edae3 on a 8x RTX 5090 node.
Reduced memory with eager backend
The initial goal of this PR is to reduce peak memory usage when torch autocast is enabled. That is achieved according to the first row of the table, but in two different ways simultaneously.
fast_free_schedule
will arange most allgather at the beginning of the graph. That leads to a even higher peak during forward, but is no longer seen with PR.add_z3_gather_release
, I noticed that recomputations selected by min-cut is slightly different (that test script has activation checkpointing enabled for the LLM module). That can also impact computation time and memory usage.Here's the shape of memory usage before this PR with eager backend + torch autocast. eager + BF16 shows similar shapes. Numbers reported in the table are peak during forward. The peak memory usage during backend reduces ~0.7GB in both cases.
After this PR:
Similar memory with inductor backend
Unlike eager backend, the inductor backend uses similar memory with or without this PR. The memory usage pattern is as follows, which requires further analysis.
Before this PR:
After this PR: