2121see README.md for more details
2222"""
2323
24+ from contextlib import nullcontext
2425from dataclasses import dataclass
2526from typing import Any , Dict , List , Optional
2627
@@ -50,7 +51,7 @@ class AllToAllSingleRunConfig(BenchFuncConfig):
5051 world_size : int = 2
5152 dim : int = 2048
5253 profile_dir : str = "."
53- num_benchmarks : int = 1
54+ num_benchmarks : int = 2
5455 num_profiles : int = 2
5556 num_mul : int = 5
5657 num_concat : int = 100
@@ -94,6 +95,7 @@ def a2a_sync_base(
9495 num_mul : int ,
9596 num_concat : int ,
9697 ctx : MultiProcessContext ,
98+ ** _kwargs : Dict [str , Any ],
9799) -> None :
98100 with record_function ("## pre-comms compute ##" ):
99101 pre_comms = _compute (dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx )
@@ -186,6 +188,7 @@ def a2a_async_twice(
186188 num_mul : int ,
187189 num_concat : int ,
188190 ctx : MultiProcessContext ,
191+ ** _kwargs : Dict [str , Any ],
189192) -> None :
190193 with record_function ("## pre-comms compute ##" ):
191194 pre_comms = _compute (dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx )
@@ -254,13 +257,14 @@ def a2a_async_twice(
254257 assert checks1 and checks2
255258
256259
257- # all_to_all_single with sync and single stream
260+ # LazyAwaitable
258261def lazyawaitable (
259262 _batch_inputs : List [Dict [str , Any ]],
260263 dim : int ,
261264 num_mul : int ,
262265 num_concat : int ,
263266 ctx : MultiProcessContext ,
267+ ** _kwargs : Dict [str , Any ],
264268) -> None :
265269 with record_function ("## pre-comms compute ##" ):
266270 pre_comms = _compute (dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx )
@@ -294,6 +298,183 @@ def lazyawaitable(
294298 assert check_awaitable .item ()
295299
296300
301+ # muti-stream memory footprint
302+ def multi_stream_memory (
303+ _batch_inputs : List [Dict [str , Any ]],
304+ dim : int ,
305+ num_mul : int ,
306+ num_concat : int ,
307+ ctx : MultiProcessContext ,
308+ multi_stream : bool = True ,
309+ ** _kwargs : Dict [str , Any ],
310+ ) -> None :
311+ with record_function ("## setup ##" ):
312+ main_stream = torch .cuda .current_stream ()
313+ data_copy_stream = torch .cuda .Stream () if multi_stream else nullcontext ()
314+ data_dist_stream = torch .cuda .Stream () if multi_stream else nullcontext ()
315+ irrelevant_data = torch .rand (dim , dim , device = ctx .device ) - 0.5
316+
317+ # the host to device data transfer will block cuda execution without the `pin_memory()`
318+ host_data = (torch .rand (dim , dim ) - 0.5 ).pin_memory ()
319+
320+ with record_function ("## irrelevant compute before h2d ##" ):
321+ pre_comms = _compute (
322+ dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = irrelevant_data
323+ )
324+
325+ with record_function ("## copy data to device ##" ):
326+ # use a separate stream to copy data to device, this will not block the main stream
327+ with data_copy_stream :
328+ device_data = host_data .to (ctx .device , non_blocking = True )
329+ # record the data to main stream, so it won't be freed accidently in the data_copy_stream
330+ device_data .record_stream (main_stream )
331+
332+ with record_function ("## irrelevant compute after h2d ##" ):
333+ irrelevant_data = torch .rand (dim , dim , device = ctx .device ) - 0.5
334+ pre_comms = _compute (
335+ dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = irrelevant_data
336+ )
337+
338+ with record_function ("## pre-comms compute ##" ):
339+ if isinstance (data_copy_stream , torch .cuda .Stream ):
340+ # make sure the data copy is done before the pre-comms compute
341+ main_stream .wait_stream (data_copy_stream )
342+ pre_comms = _compute (
343+ dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = device_data
344+ )
345+
346+ # use a separate stream to do the comms, this will not block the main stream
347+ with data_dist_stream :
348+ with record_function ("## all_to_all_single ##" ):
349+ if isinstance (data_dist_stream , torch .cuda .Stream ):
350+ # make sure the pre-comms compute is done before the comms
351+ data_dist_stream .wait_stream (main_stream )
352+ post_comms = torch .zeros_like (pre_comms )
353+ req = dist .all_to_all_single (
354+ output = post_comms ,
355+ input = pre_comms ,
356+ group = ctx .pg ,
357+ async_op = True ,
358+ )
359+ # record the data to main stream, so it won't be freed accidently in the data_dist_stream
360+ post_comms .record_stream (main_stream )
361+ with record_function ("## a2a comm validation ##" ):
362+ # the comm validation is also done in this separate stream since
363+ # there's no data dependency afterwards
364+ req .wait ()
365+ checks = DeviceToHostTensorAwaitable (_validate (post_comms , ctx ))
366+
367+ with record_function ("## irrelevant compute after a2a ##" ):
368+ irrelevant_data = torch .rand (dim , dim , device = ctx .device ) - 0.5
369+ pre_comms = _compute (
370+ dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = irrelevant_data
371+ )
372+
373+ with record_function ("## post-comms compute ##" ):
374+ req .wait ()
375+ post_comms = _compute (
376+ dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = post_comms [0 ]
377+ )
378+
379+ with record_function ("## assert ##" ):
380+ assert checks .item ()
381+
382+
383+ def single_stream_memory (
384+ _batch_inputs : List [Dict [str , Any ]],
385+ dim : int ,
386+ num_mul : int ,
387+ num_concat : int ,
388+ ctx : MultiProcessContext ,
389+ ** _kwargs : Dict [str , Any ],
390+ ) -> None :
391+ return multi_stream_memory (
392+ _batch_inputs = _batch_inputs ,
393+ dim = dim ,
394+ num_mul = num_mul ,
395+ num_concat = num_concat ,
396+ ctx = ctx ,
397+ multi_stream = False ,
398+ )
399+
400+
401+ # an optimized version of muti-stream memory footprint
402+ def multi_stream_optimized (
403+ _batch_inputs : List [Dict [str , Any ]],
404+ dim : int ,
405+ num_mul : int ,
406+ num_concat : int ,
407+ ctx : MultiProcessContext ,
408+ ** _kwargs : Dict [str , Any ],
409+ ) -> None :
410+ with record_function ("## setup ##" ):
411+ main_stream = torch .cuda .current_stream ()
412+ data_copy_stream = torch .cuda .Stream ()
413+ irrelevant_data = torch .rand (dim , dim , device = ctx .device ) - 0.5
414+
415+ # the host to device data transfer will block cuda execution without the `pin_memory()`
416+ host_data = (torch .rand (dim , dim ) - 0.5 ).pin_memory ()
417+ # pre-allocate memory on the device for the incoming data transfer from the host
418+ device_data = torch .empty_like (host_data , device = ctx .device )
419+
420+ with record_function ("## irrelevant compute before h2d ##" ):
421+ pre_comms = _compute (
422+ dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = irrelevant_data
423+ )
424+
425+ with record_function ("## copy data to device ##" ):
426+ with data_copy_stream :
427+ # copy data to device, this will not block the main stream
428+ device_data .copy_ (host_data , non_blocking = True )
429+
430+ with record_function ("## irrelevant compute after h2d ##" ):
431+ irrelevant_data = torch .rand (dim , dim , device = ctx .device ) - 0.5
432+ pre_comms = _compute (
433+ dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = irrelevant_data
434+ )
435+
436+ with record_function ("## pre-comms compute ##" ):
437+ # make sure the data copy is done before the pre-comms compute
438+ main_stream .wait_stream (data_copy_stream )
439+ pre_comms = _compute (
440+ dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = device_data
441+ )
442+
443+ with record_function ("## pre-allocate memory for a2a on main stream ##" ):
444+ post_comms = torch .zeros_like (pre_comms )
445+
446+ with record_function ("## all_to_all_single ##" ):
447+ # the all_to_all_single from torch.dist has async feature
448+ # it automaically uses a separate stream to do the comms
449+ # without introducing extra memory footprint
450+ req = dist .all_to_all_single (
451+ output = post_comms ,
452+ input = pre_comms ,
453+ group = ctx .pg ,
454+ async_op = True ,
455+ )
456+
457+ with record_function ("## irrelevant compute after a2a ##" ):
458+ irrelevant_data = torch .rand (dim , dim , device = ctx .device ) - 0.5
459+ pre_comms = _compute (
460+ dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = irrelevant_data
461+ )
462+
463+ with record_function ("## a2a comm validation ##" ):
464+ # this req.wait() can be wrapped into a LazyAwaitable
465+ req .wait ()
466+ # still want the compute on the main stream if possible
467+ checks = DeviceToHostTensorAwaitable (_validate (post_comms , ctx ))
468+
469+ with record_function ("## post-comms compute ##" ):
470+ post_comms = _compute (
471+ dim = dim , num_mul = num_mul , num_concat = num_concat , ctx = ctx , x = post_comms [0 ]
472+ )
473+
474+ with record_function ("## assert ##" ):
475+ assert checks .item ()
476+
477+
297478# single-rank runner
298479def a2a_single_runner (rank : int , world_size : int , arg : AllToAllSingleRunConfig ) -> None :
299480 # Ensure GPUs are available and we have enough of them
@@ -308,7 +489,6 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
308489 backend = "nccl" ,
309490 use_deterministic_algorithms = False ,
310491 ) as ctx :
311-
312492 if arg .name .startswith ("a2a_sync_base" ):
313493 func = a2a_sync_base
314494 elif arg .name .startswith ("a2a_async_base" ):
@@ -317,6 +497,12 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
317497 func = a2a_async_twice
318498 elif arg .name .startswith ("lazyawaitable" ):
319499 func = lazyawaitable
500+ elif arg .name .startswith ("multi_stream_memory" ):
501+ func = multi_stream_memory
502+ elif arg .name .startswith ("single_stream_memory" ):
503+ func = single_stream_memory
504+ elif arg .name .startswith ("multi_stream_optimized" ):
505+ func = multi_stream_optimized
320506 else :
321507 raise ValueError (f"Unknown benchmark name: { arg .name } " )
322508
0 commit comments