Skip to content

Conversation

GeYuhong
Copy link

@GeYuhong GeYuhong commented Aug 18, 2025

This feature supports that activations of models are offloaded in the forward pass and prefetched in the backward pass.

Note: Must use TransformerEngine2.5 with the feature pr(NVIDIA/TransformerEngine#2145).

Currently, this feature can be used in a few modules, such as core_attention and router_fc1, we will support more modules(including qkv_linear, router_fc2 and shared_experts) as soon as possible.

We rewrite the _indices_to_multihot() in the token_dispatcher to remove all implicit synchronization without using fused ops, ensuring consistency in bitwise.

The following is the experimental results(dp4tp1cp1ep4pp2vpp2), including end-to-end performance and peak memory.
end2end perf:

  elapsed time per iteration (ms)
baseline 1262
baseline-new_indices_to_multihot 1249.7
offload_qkv 1253.8

peak memory ($R$ is the ratio of the actual decrease in peak memory to the theoretical value, where the theoretical values of stage0 and stage1 are 1440M and 800M respectively):

rank_id base/B base-new_indices_to_multihot/B error between bases/M offload_qkv/B error offload vs base/M $R$ error offload vs base-new/M $R$
0 43687144448 43689495552 -2.24 42179546624 1437.76 99.84% 1440 100%
1 43687562240 43689913344 -2.24 42179859968 1437.86 99.85% 1440.1 100%
2 43687014912 43689366016 -2.24 42179417088 1437.76 99.84% 1440 100%
3 43686620672 43688971776 -2.24 42179022848 1437.76 99.84% 1440 100%
4 44975166976 44977182208 -1.92 44138519040 797.89 99.74% 799.81 99.98%
5 44975987712 44977182208 -1.14 44138321920 798.86 99.86% 800 100%
6 44975987712 44977182208 -1.14 44138716160 798.48 99.81% 799.62 99.95%
7 44973536256 44975551488 -1.92 44136691200 798.08 99.76% 800 100%

Copy link

copy-pr-bot bot commented Aug 18, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@yanring yanring requested review from yanring and hxbai August 19, 2025 05:37
@hxbai
Copy link
Collaborator

hxbai commented Aug 19, 2025

rank 1 | 0 1 2 0 1 2 3 4 3 4
"""

offload_mlp_input: bool = False
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this flag?

Copy link
Author

@GeYuhong GeYuhong Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, see in 1555e6d




class ChunkOffloadHandler:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should try reusing the code of cpu_offload.py in TE as much as possible. IIUC, the class should derive from TE’s AsyncDoubleBufferGroupOffloadHandler().

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, (although the PipelineOffload class was applied, it achieves very limited reuse). See in e845344

tensor_on_device.record_stream(self.d2h_stream)
self._tensor_tag_to_state[tensor_tag] = state
self._offloaded_group_count = group_to_offload + 1
self._f_event.record(self.d2h_stream)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can use stream synchronization instead to reduce the descrepancy from TE’s AsyncDoubleBufferGroupOffloadHandler(). Event synchronization is light-weight but I think it doesn't impact the perf a lot here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, see in e845344.

return GroupStartFunction.apply(tensor, cur_forward_chunk)


def offloading_checker(tensor):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need the checker?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, see in 7168ccd

return len(self._queue)

def reset_chunk_handler(self, num_layer, offload_mlp_input=True):
cur_vpp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_virtual_pipeline_model_parallel_rank() is deprecated now. The vpp_size(named as vp_stage now) is passed at runtime. The MR is here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, see in c9f00c7.

MoEAuxLossAutoScaler.set_loss_scale(loss_scale)
else:
if config.offload_activation:
MoEPositiveAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need an extra loss scaler?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, see in 4b0d3f1.


return hidden_states

def _offload_qkv_linear_forward(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make it as a factory function to simplify the calling logic of registering and offloading tensors?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, see in b00acbc.

@GeYuhong GeYuhong force-pushed the activation_offloading branch from 64d156d to af236f1 Compare September 3, 2025 06:02
@GeYuhong GeYuhong force-pushed the activation_offloading branch from a410128 to 1555e6d Compare September 3, 2025 15:56
@GeYuhong
Copy link
Author

GeYuhong commented Sep 8, 2025

@GeYuhong Is your known bug related to this? https://github.com/NVIDIA/TransformerEngine/blob/734bcedd9d86e4be30ce44f1ef67af5f69f3670d/transformer_engine/pytorch/module/linear.py#L402-L406

yes, this is the bug we encountered and we haved fixed it. Thank you!

@yspMing
Copy link

yspMing commented Sep 8, 2025

Is this PR ready for using? Or there exists some limitations for applying this patch

@GeYuhong
Copy link
Author

GeYuhong commented Sep 9, 2025

Is this PR ready for using? Or there exists some limitations for applying this patch
This feature is ready for core_attn offload and router-fc1 offload. We will support other modules in a few days, incluing router-fc2, linear_qkv etc.

Hongbin Liu and others added 4 commits September 8, 2025 22:18
Hongbinl/activation offloading

add arguments.py and minor fix, OOTB runable now
support activation offloading at PP=1&PP&VPP
core_attn_out.offloading_activation = True
with PipelineOffloadManager.get_instance():
output, bias = self.linear_proj(core_attn_out)
output, bias = group_prefetch_offload_commit(output, bias, release_tensors=[core_attn_out])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

core_attn_out is also saved by the fused attn. It may not work as expected.

@hxbai
Copy link
Collaborator

hxbai commented Sep 18, 2025

I think the offloading for the input of the activation function (swiglu) should also be added.

self.bulk_offload_group(self._layer_index)
if len(release_tensors) > 0:
cur_stream = torch.cuda.current_stream()
for release_tensor in release_tensors:
Copy link
Collaborator

@hxbai hxbai Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This cannot work with FP8. Can the saved FP8 tensors be released in time?

)

if args.offload_activation:
assert not args.overlap_grad_reduce, "overlap_grad_reduce is not supported with offload_activation"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't they enable together?

Hongbin Liu and others added 26 commits September 17, 2025 23:54
Signed-off-by: Hongbin Liu <[email protected]>
Signed-off-by: Hongbin Liu <[email protected]>
Signed-off-by: Hongbin Liu <[email protected]>
Signed-off-by: Hongbin Liu <[email protected]>
Signed-off-by: Hongbin Liu <[email protected]>
Signed-off-by: Hongbin Liu <[email protected]>
Signed-off-by: Hongbin Liu <[email protected]>
Signed-off-by: Hongbin Liu <[email protected]>
Signed-off-by: Hongbin Liu <[email protected]>
Signed-off-by: Hongbin Liu <[email protected]>
Signed-off-by: Hongbin Liu <[email protected]>
Signed-off-by: Hongbin Liu <[email protected]>
support mixed dense&moe layer and a2a overlap
Signed-off-by: Hongbin Liu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants