diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index a29ed598..77894d80 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -320,16 +320,7 @@ Buffer::get_dispatch_layout( std::optional event; if (async) { event = EventHandle(comm_stream); - for (auto& t : {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) { - t.record_stream(comm_stream); - if (allocate_on_comm_stream) - t.record_stream(compute_stream); - } - for (auto& to : {num_tokens_per_rdma_rank}) { - to.has_value() ? to->record_stream(comm_stream) : void(); - if (allocate_on_comm_stream) - to.has_value() ? to->record_stream(compute_stream) : void(); - } + // NOTES: record_stream removed, tensors are now held by Python-layer extra_tensors } else { stream_wait(compute_stream, comm_stream); } @@ -606,32 +597,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, std::optional event; if (async) { event = EventHandle(comm_stream); - for (auto& t : {x, - is_token_in_rank, - rank_prefix_matrix, - channel_prefix_matrix, - recv_x, - recv_src_idx, - recv_channel_prefix_matrix, - send_head}) { - t.record_stream(comm_stream); - if (allocate_on_comm_stream) - t.record_stream(compute_stream); - } - for (auto& to : {x_scales, - topk_idx, - topk_weights, - num_tokens_per_rank, - num_tokens_per_expert, - cached_channel_prefix_matrix, - cached_rank_prefix_matrix, - recv_topk_idx, - recv_topk_weights, - recv_x_scales}) { - to.has_value() ? to->record_stream(comm_stream) : void(); - if (allocate_on_comm_stream) - to.has_value() ? to->record_stream(compute_stream) : void(); - } + // NOTES: record_stream removed, tensors are now held by Python-layer extra_tensors } else { stream_wait(compute_stream, comm_stream); } @@ -774,16 +740,7 @@ std::tuple, std::optional event; if (async) { event = EventHandle(comm_stream); - for (auto& t : {x, src_idx, send_head, rank_prefix_matrix, channel_prefix_matrix, recv_x}) { - t.record_stream(comm_stream); - if (allocate_on_comm_stream) - t.record_stream(compute_stream); - } - for (auto& to : {topk_weights, recv_topk_weights, bias_0, bias_1}) { - to.has_value() ? to->record_stream(comm_stream) : void(); - if (allocate_on_comm_stream) - to.has_value() ? to->record_stream(compute_stream) : void(); - } + // NOTES: record_stream removed, tensors are now held by Python-layer extra_tensors } else { stream_wait(compute_stream, comm_stream); } @@ -1121,39 +1078,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, std::optional event; if (async) { event = EventHandle(comm_stream); - for (auto& t : {x, - is_token_in_rank, - recv_x, - rdma_channel_prefix_matrix, - recv_rdma_rank_prefix_sum, - gbl_channel_prefix_matrix, - recv_gbl_rank_prefix_sum}) { - t.record_stream(comm_stream); - if (allocate_on_comm_stream) - t.record_stream(compute_stream); - } - for (auto& to : {x_scales, - topk_idx, - topk_weights, - num_tokens_per_rank, - num_tokens_per_rdma_rank, - num_tokens_per_expert, - cached_rdma_channel_prefix_matrix, - cached_recv_rdma_rank_prefix_sum, - cached_gbl_channel_prefix_matrix, - cached_recv_gbl_rank_prefix_sum, - recv_topk_idx, - recv_topk_weights, - recv_x_scales, - recv_rdma_channel_prefix_matrix, - recv_gbl_channel_prefix_matrix, - send_rdma_head, - send_nvl_head, - recv_src_meta}) { - to.has_value() ? to->record_stream(comm_stream) : void(); - if (allocate_on_comm_stream) - to.has_value() ? to->record_stream(compute_stream) : void(); - } + // NOTES: record_stream removed, tensors are now held by Python-layer extra_tensors } else { stream_wait(compute_stream, comm_stream); } @@ -1338,24 +1263,7 @@ std::tuple, std::optional event; if (async) { event = EventHandle(comm_stream); - for (auto& t : {x, - src_meta, - is_combined_token_in_rank, - rdma_channel_prefix_matrix, - rdma_rank_prefix_sum, - gbl_channel_prefix_matrix, - combined_x, - combined_rdma_head, - combined_nvl_head}) { - t.record_stream(comm_stream); - if (allocate_on_comm_stream) - t.record_stream(compute_stream); - } - for (auto& to : {topk_weights, combined_topk_weights, bias_0, bias_1}) { - to.has_value() ? to->record_stream(comm_stream) : void(); - if (allocate_on_comm_stream) - to.has_value() ? to->record_stream(compute_stream) : void(); - } + // NOTES: record_stream removed, tensors are now held by Python-layer extra_tensors } else { stream_wait(compute_stream, comm_stream); } diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index bdf26e8e..eb0aa6c2 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -314,7 +314,9 @@ def get_dispatch_layout(self, topk_idx: torch.Tensor, num_experts: int, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event = \ self.runtime.get_dispatch_layout(topk_idx, num_experts, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) - return num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, EventOverlap(event) + tensors_to_record = (topk_idx, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank) if async_finish else None + + return num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, EventOverlap(event, tensors_to_record) # noinspection PyTypeChecker def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], @@ -386,7 +388,9 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], recv_x, recv_x_scales, _, _, _, _, _, _, _, _, event = self.runtime.intranode_dispatch( x, x_scales, None, None, None, is_token_in_rank, None, num_recv_tokens, rank_prefix_matrix, channel_prefix_matrix, expert_alignment, num_worst_tokens, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) - return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event) + + tensors_to_record = (x, x_scales, is_token_in_rank, rank_prefix_matrix, channel_prefix_matrix, recv_x, recv_x_scales, recv_src_idx) if async_finish else None + return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event, tensors_to_record) else: assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event = \ @@ -395,10 +399,10 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], expert_alignment, num_worst_tokens, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) handle = (rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head) - return ( - recv_x, recv_x_scales - ) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap( - event) + tensors_to_record = (x, x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_expert, + is_token_in_rank, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, + recv_x, recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, send_head) if async_finish else None + return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event, tensors_to_record) # noinspection PyTypeChecker def combine(self, x: torch.Tensor, handle: Tuple, @@ -446,7 +450,8 @@ def combine(self, x: torch.Tensor, handle: Tuple, channel_prefix_matrix, send_head, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) - return recv_x, recv_topk_weights, EventOverlap(event) + tensors_to_record = (x, topk_weights, bias_0, bias_1, src_idx, rank_prefix_matrix, channel_prefix_matrix, send_head, recv_x, recv_topk_weights) if async_finish else None + return recv_x, recv_topk_weights, EventOverlap(event, tensors_to_record) # noinspection PyTypeChecker def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], @@ -479,7 +484,11 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te x, x_scales, topk_idx, topk_weights, None, None, is_token_in_rank, None, num_recv_tokens, num_rdma_recv_tokens, rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) - return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event) + + tensors_to_record =(x, x_scales, is_token_in_rank, recv_x, recv_x_scales, + rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, + recv_rdma_channel_prefix_matrix, recv_src_meta, send_rdma_head, send_nvl_head) if async_finish else None + return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event, tensors_to_record) else: assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, \ @@ -494,10 +503,15 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te handle = (is_token_in_rank, rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, recv_src_meta, send_rdma_head, send_nvl_head) - return ( - recv_x, recv_x_scales - ) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap( - event) + tensors_to_record = (x, x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, + is_token_in_rank, recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, + rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, + recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, + recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, + recv_src_meta, send_rdma_head, send_nvl_head) if async_finish else None + + return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event, tensors_to_record) + # noinspection PyTypeChecker def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list], @@ -527,7 +541,10 @@ def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list], send_rdma_head, send_nvl_head, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) - return combined_x, combined_topk_weights, EventOverlap(event) + tensors_to_record = (x, topk_weights, bias_0, bias_1, src_meta, is_combined_token_in_rank, + rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, + send_rdma_head, send_nvl_head, combined_x, combined_topk_weights) if async_finish else None + return combined_x, combined_topk_weights, EventOverlap(event, tensors_to_record) def clean_low_latency_buffer(self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> None: """ diff --git a/deep_ep/utils.py b/deep_ep/utils.py index e61a2c5b..4222df8e 100644 --- a/deep_ep/utils.py +++ b/deep_ep/utils.py @@ -33,9 +33,12 @@ def __init__(self, event: Optional[EventHandle] = None, extra_tensors: Optional[ def current_stream_wait(self) -> None: """ The current stream `torch.cuda.current_stream()` waits for the event to be finished. + After synchronization completes, tensor references are released to allow memory reuse. """ assert self.event is not None self.event.current_stream_wait() + # Release tensor references after synchronization is complete + self.extra_tensors = None def __enter__(self) -> Any: """ @@ -56,9 +59,10 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: Utility for overlapping and Python `with` syntax. Please follow the example in the `__enter__` function. + After synchronization completes, tensor references are released to allow memory reuse. """ if self.event is not None: - self.event.current_stream_wait() + self.current_stream_wait() def check_nvlink_connections(group: dist.ProcessGroup):