From 64df2feb6e5b0666eb98a216179b2f18066635d1 Mon Sep 17 00:00:00 2001 From: laochonlam Date: Sun, 22 Feb 2026 00:02:22 +0000 Subject: [PATCH 1/2] Add benchmarking option for FP8 dispatch. Enable benchmarking with either FP8 dispatch or direct BF16 dispatch through a runtime flag for apples-to-apples performance comparison. --- ep/bench/test_low_latency.py | 42 +++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/ep/bench/test_low_latency.py b/ep/bench/test_low_latency.py index 45667be72..535701711 100644 --- a/ep/bench/test_low_latency.py +++ b/ep/bench/test_low_latency.py @@ -85,6 +85,7 @@ def test_main( group: dist.ProcessGroup, buffer: Buffer, use_logfmt: bool = False, + dispatch_use_fp8: bool = True, seed: int = 0, skip_benchmark: bool = False, debug_hash: bool = False, @@ -152,14 +153,14 @@ def _record_hash(label: str, t: torch.Tensor, include_in_overall: bool = True): for current_x in x_list: for return_recv_hook in (False, True): - for dispatch_use_fp8 in (False, True): + for dispatch_use_fp8_case in (False, True): for round_scale in (False,): - for round_scale in (False, True) if dispatch_use_fp8 else (False,): + for round_scale in (False, True) if dispatch_use_fp8_case else (False,): for use_ue8m0 in (False, True) if round_scale else (False,): print( "Start experiment with settings:" f" return_recv_hook={return_recv_hook}" - f" dispatch_use_fp8={dispatch_use_fp8}" + f" dispatch_use_fp8={dispatch_use_fp8_case}" f" round_scale={round_scale}" f" use_ue8m0={use_ue8m0}", flush=True, @@ -180,7 +181,7 @@ def _record_hash(label: str, t: torch.Tensor, include_in_overall: bool = True): topk_idx, num_tokens, num_experts, - use_fp8=dispatch_use_fp8, + use_fp8=dispatch_use_fp8_case, round_scale=round_scale, use_ue8m0=use_ue8m0, cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, @@ -194,7 +195,7 @@ def _record_hash(label: str, t: torch.Tensor, include_in_overall: bool = True): ) packed_recv_x = ( (packed_recv_x[0], packed_recv_x[1].contiguous()) - if dispatch_use_fp8 + if dispatch_use_fp8_case else packed_recv_x ) simulated_gemm_x = ( @@ -202,7 +203,7 @@ def _record_hash(label: str, t: torch.Tensor, include_in_overall: bool = True): packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128), ).view(packed_recv_x[0].shape) - if dispatch_use_fp8 + if dispatch_use_fp8_case else packed_recv_x.clone() ) all_topk_idx = torch.empty( @@ -219,7 +220,7 @@ def _record_hash(label: str, t: torch.Tensor, include_in_overall: bool = True): per_token_cast_back( packed_recv_x[0][i], packed_recv_x[1][i] ) - if dispatch_use_fp8 + if dispatch_use_fp8_case else packed_recv_x[i] ) recv_count, recv_src_info, recv_layout_range = ( @@ -284,11 +285,11 @@ def _record_hash(label: str, t: torch.Tensor, include_in_overall: bool = True): - j + rank_offset ).sum().item() == 0 - if dispatch_use_fp8: + if dispatch_use_fp8_case: tag = ( f"x={'x' if current_x is x else 'rand'}" f"|hook={return_recv_hook}" - f"|fp8={dispatch_use_fp8}" + f"|fp8={dispatch_use_fp8_case}" f"|rs={round_scale}" f"|ue={use_ue8m0}" f"|le={i}" @@ -306,7 +307,7 @@ def _record_hash(label: str, t: torch.Tensor, include_in_overall: bool = True): tag = ( f"x={'x' if current_x is x else 'rand'}" f"|hook={return_recv_hook}" - f"|fp8={dispatch_use_fp8}" + f"|fp8={dispatch_use_fp8_case}" f"|rs={round_scale}" f"|ue={use_ue8m0}" f"|le={i}" @@ -368,12 +369,12 @@ def _record_hash(label: str, t: torch.Tensor, include_in_overall: bool = True): ) assert torch.isnan(combined_x).sum().item() == 0 assert diff < ( - 9e-4 if dispatch_use_fp8 else 1e-5 - ), f"Error: {diff=}, {dispatch_use_fp8=}, {zero_copy=}" + 9e-4 if dispatch_use_fp8_case else 1e-5 + ), f"Error: {diff=}, {dispatch_use_fp8_case=}, {zero_copy=}" tag = ( f"x={'x' if current_x is x else 'rand'}" f"|hook={return_recv_hook}" - f"|fp8={dispatch_use_fp8}" + f"|fp8={dispatch_use_fp8_case}" f"|rs={round_scale}" f"|ue={use_ue8m0}" f"|zc={zero_copy}" @@ -396,7 +397,7 @@ def test_func(return_recv_hook: bool): num_tokens, num_experts, cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, - use_fp8=True, + use_fp8=dispatch_use_fp8, async_finish=False, return_recv_hook=return_recv_hook, ) @@ -422,7 +423,9 @@ def test_func(return_recv_hook: bool): num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0 for i in range(num_tokens): num_selections = (topk_idx[i] != -1).sum().item() - num_dispatch_comm_bytes += num_fp8_bytes * num_selections + num_dispatch_comm_bytes += ( + num_fp8_bytes if dispatch_use_fp8 else num_bf16_bytes + ) * num_selections num_combine_comm_bytes += ( num_logfmt10_bytes if use_logfmt else num_bf16_bytes ) * num_selections @@ -494,6 +497,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): group, buffer, use_logfmt=args.use_logfmt, + dispatch_use_fp8=args.dispatch_use_fp8, seed=seed, skip_benchmark=args.pressure_test_mode == 1, debug_hash=args.debug_hash, @@ -521,6 +525,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): group, buffer, use_logfmt=args.use_logfmt, + dispatch_use_fp8=args.dispatch_use_fp8, seed=seed, skip_benchmark=args.pressure_test_mode == 1, debug_hash=args.debug_hash, @@ -601,6 +606,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): parser.add_argument( "--use-logfmt", action="store_true", help="Whether to test LogFMT combine" ) + parser.add_argument( + "--dispatch-use-fp8", + type=bool, + default=True, + action=argparse.BooleanOptionalAction, + help="Whether dispatch path uses FP8 casting (default: true).", + ) parser.add_argument( "--pressure-test-mode", type=int, From 477fc7b70155c0df2bd2e781ee390cdc602b9ccd Mon Sep 17 00:00:00 2001 From: laochonlam Date: Sun, 22 Feb 2026 00:50:16 +0000 Subject: [PATCH 2/2] Format low-latency benchmark loop expression for readability. Keep behavior unchanged while making the conditional tuple iteration clearer and Black-compliant. Co-authored-by: Cursor --- ep/bench/test_low_latency.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ep/bench/test_low_latency.py b/ep/bench/test_low_latency.py index 535701711..774fc8097 100644 --- a/ep/bench/test_low_latency.py +++ b/ep/bench/test_low_latency.py @@ -155,7 +155,9 @@ def _record_hash(label: str, t: torch.Tensor, include_in_overall: bool = True): for return_recv_hook in (False, True): for dispatch_use_fp8_case in (False, True): for round_scale in (False,): - for round_scale in (False, True) if dispatch_use_fp8_case else (False,): + for round_scale in ( + (False, True) if dispatch_use_fp8_case else (False,) + ): for use_ue8m0 in (False, True) if round_scale else (False,): print( "Start experiment with settings:"