Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/test_hybrid_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def test_hybrid_ep_benchmark(buffer: deep_ep.HybridEPBuffer, group: dist.Process
dispatch_bf16_rdma_send_bytes = num_rdma_send * HIDDEN_DIM * 2
combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes

dispatch_args = {'hidden': hidden, 'scaling_factor': scaling_factor, 'topk_idx': topk_idx, 'topk_weights': topk_weights, 'num_of_experts': NUM_OF_EXPERTS}
dispatch_args = {'hidden': hidden, 'scaling_factor': scaling_factor, 'topk_idx': topk_idx, 'topk_weights': topk_weights, 'num_of_experts': NUM_OF_EXPERTS, 'handle': handle}
t = bench(lambda: buffer.dispatch(**dispatch_args))[0]
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if hidden.dtype == torch.uint8 else dispatch_bf16_nvl_recv_bytes
if NUM_OF_NODES > 1:
Expand All @@ -279,7 +279,7 @@ def test_hybrid_ep_benchmark(buffer: deep_ep.HybridEPBuffer, group: dist.Process
'''
Benchmark of the dispatch and combine with permute extension
'''
dispatch_with_permute_args = {'hidden': hidden, 'scaling_factor': scaling_factor, 'routing_map': routing_map, 'probs': probs, 'pad_multiple': PAD_MULTIPLE}
dispatch_with_permute_args = {'hidden': hidden, 'scaling_factor': scaling_factor, 'routing_map': routing_map, 'probs': probs, 'pad_multiple': PAD_MULTIPLE, 'handle': handle}
t = bench(lambda: buffer.dispatch_with_permute(**dispatch_with_permute_args))[0]
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if hidden.dtype == torch.uint8 else dispatch_bf16_nvl_recv_bytes
print_in_order(f'[rank {rank}] HybridEP dispatch+permute torch API ({"FP8" if hidden.dtype == torch.uint8 else "BF16"}): '
Expand Down