Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 89 additions & 16 deletions torchrec/distributed/benchmark/benchmark_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,74 @@ def multi_stream_optimized(
assert checks.item()


# an optimized version of muti-stream memory footprint
def non_blocking_copy(
_batch_inputs: List[Dict[str, Any]],
dim: int,
num_mul: int,
num_concat: int,
ctx: MultiProcessContext,
preallocated: bool = False,
**_kwargs: Dict[str, Any],
) -> None:
with record_function("## setup ##"):
main_stream = torch.cuda.current_stream()
data_copy_stream = torch.cuda.Stream()
irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5

# the host to device data transfer will block cuda execution without the `pin_memory()`
host_data = (torch.rand(dim, dim) - 0.5).pin_memory()
if preallocated:
# pre-allocate memory on the device for the incoming data transfer from the host
device_data = torch.empty_like(host_data, device=ctx.device)
else:
device_data = torch.empty(0, device=ctx.device)

with record_function("## irrelevant compute before h2d ##"):
pre_comms = _compute(
dim=dim, num_mul=num_mul, num_concat=1, ctx=ctx, x=irrelevant_data
)

with record_function("## copy data to device ##"):
with data_copy_stream:
if preallocated:
# copy data to device, this will not block the main stream
device_data.copy_(host_data, non_blocking=True)
else:
device_data = host_data.to(ctx.device, non_blocking=True)

with record_function("## irrelevant compute after h2d ##"):
irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5
pre_comms = _compute(
dim=dim, num_mul=num_mul, num_concat=1, ctx=ctx, x=irrelevant_data
)

with record_function("## pre-comms compute ##"):
# make sure the data copy is done before the pre-comms compute
main_stream.wait_stream(data_copy_stream)
pre_comms = _compute(
dim=dim, num_mul=num_mul, num_concat=1, ctx=ctx, x=device_data
)


def preallocated_non_blocking_copy(
_batch_inputs: List[Dict[str, Any]],
dim: int,
num_mul: int,
num_concat: int,
ctx: MultiProcessContext,
**_kwargs: Dict[str, Any],
) -> None:
return non_blocking_copy(
_batch_inputs=_batch_inputs,
dim=dim,
num_mul=num_mul,
num_concat=num_concat,
ctx=ctx,
preallocated=True,
)


# single-rank runner
def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig) -> None:
# Ensure GPUs are available and we have enough of them
Expand All @@ -489,22 +557,27 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
backend="nccl",
use_deterministic_algorithms=False,
) as ctx:
if arg.name.startswith("a2a_sync_base"):
func = a2a_sync_base
elif arg.name.startswith("a2a_async_base"):
func = a2a_async_base
elif arg.name.startswith("a2a_async_twice"):
func = a2a_async_twice
elif arg.name.startswith("lazyawaitable"):
func = lazyawaitable
elif arg.name.startswith("multi_stream_memory"):
func = multi_stream_memory
elif arg.name.startswith("single_stream_memory"):
func = single_stream_memory
elif arg.name.startswith("multi_stream_optimized"):
func = multi_stream_optimized
else:
raise ValueError(f"Unknown benchmark name: {arg.name}")
match arg.name.lower():
case "a2a_sync_base":
func = a2a_sync_base
case "a2a_async_base":
func = a2a_async_base
case "a2a_async_twice":
func = a2a_async_twice
case "lazyawaitable":
func = lazyawaitable
case "multi_stream_memory":
func = multi_stream_memory
case "single_stream_memory":
func = single_stream_memory
case "multi_stream_optimized":
func = multi_stream_optimized
case "non_blocking_copy":
func = non_blocking_copy
case "preallocated_non_blocking_copy":
func = preallocated_non_blocking_copy
case _:
raise ValueError(f"Unknown benchmark name: {arg.name}")

result = benchmark_func(
bench_inputs=[],
Expand Down
Loading