Skip to content

Conversation

@JiaqingFu
Copy link

@JiaqingFu JiaqingFu commented Oct 11, 2025

Summary by CodeRabbit

  • New Features

    • Optional local-cluster group returned from distributed initialization for finer local-group control.
    • Configurable rendezvous endpoint via environment variables in the distributed launcher (deterministic endpoint).
    • New distributed example demonstrating overlapping all-gather with NVSHMEM and rank-aware intra/inter-node flows.
  • Bug Fixes

    • Fixed NVSHMEM header includes in CUDA code generation.
    • Allocator buffer pointers now allocated on the CUDA device to prevent device mismatches.

Copilot AI review requested due to automatic review settings October 11, 2025 11:43
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements support for internode copy with intranode copy functionality in the TileLang distributed framework. The changes enable efficient multi-node distributed operations by combining NVSHMEM-based internode communication with intranode operations.

  • Adds local communication group creation alongside tensor processor groups
  • Fixes include statement syntax errors and improves CUDA device allocation
  • Introduces a comprehensive example demonstrating overlapping allgather operations across nodes

Reviewed Changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
tilelang/utils/allocator.py Adds CUDA device specification for buffer pointer allocation
tilelang/distributed/utils.py Extends initialization to support local communication groups and sets device_id
tilelang/distributed/launch.sh Updates master address/port configuration for Arnold worker environments
src/target/codegen_cuda.cc Fixes malformed include statements for NVSHMEM headers
examples/distributed/example_overlapping_allgather.py Adds new example demonstrating internode and intranode copy operations

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
base = (RANK // local_world_size) * local_world_size
LC_GROUP = torch.distributed.new_group(list(range(base, base + local_world_size)), backend="nccl")
print(local_world_size,LC_GROUP,TP_GROUP)
Copy link

Copilot AI Oct 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug print statement should be removed or replaced with proper logging for production code.

Copilot uses AI. Check for mistakes.
tid = T.get_thread_binding()
if tid == 0:
T.print(T.cast(rank[0],"int32"),msg="signal")
T.print(T.cast(num_rank[0],"int32"),msg="signal")
Copy link

Copilot AI Oct 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug print statements should be removed or replaced with proper logging in production code.

Suggested change
T.print(T.cast(num_rank[0],"int32"),msg="signal")

Copilot uses AI. Check for mistakes.
@coderabbitai
Copy link

coderabbitai bot commented Oct 11, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds a new NVSHMEM-based overlapping allgather example, corrects NVSHMEM includes in CUDA codegen, makes the distributed launcher use environment-driven rendezvous endpoints, extends init_distributed to optionally return a local cluster group and bind device_id, and ensures allocator table tensors are placed on CUDA.

Changes

Cohort / File(s) Summary
New NVSHMEM allgather example
examples/distributed/example_overlapping_allgather.py
Adds inter-node (internode_gather) and intra-node (intranode_gather) TileLang kernels, example setup for NVSHMEM-backed allocator, distributed init/use of LOCAL_WORLD_SIZE/LOCAL_RANK, synchronization, kernel compilation/execution, and optional kernel source printing.
CUDA codegen NVSHMEM include fix
src/target/codegen_cuda.cc
Fixes malformed NVSHMEM include lines so codegen emits #include <nvshmem.h> and #include <nvshmemx.h> when NVSHMEM is enabled.
Distributed launcher rendezvous via env
tilelang/distributed/launch.sh
Replaces random master port/address with env-driven ARNOLD_WORKER_0_HOST/ARNOLD_WORKER_0_PORT, sets --rdzv_endpoint=${master_addr}:${master_port}, and removes random port selection.
Distributed init with LC_GROUP
tilelang/distributed/utils.py
Adds return_lc_group parameter, derives device_id from LOCAL_RANK for dist.init_process_group, computes local_world_size and constructs LC_GROUP when requested, and adjusts return tuple to include LC_GROUP when enabled.
Allocator device placement
tilelang/utils/allocator.py
Allocates buffer_ptrs on CUDA device (device='cuda') in _init_table to align allocator IPC/synchronization with CUDA memory.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant U as User Script
  participant DU as init_distributed
  participant PT as torch.distributed
  participant NV as NVSHMEM

  U->>DU: init_distributed(return_tp_group?, return_lc_group?)
  DU->>PT: init_process_group(backend, device_id=LOCAL_RANK)
  alt return_lc_group
    DU->>DU: compute local_world_size, base, create LC_GROUP
    DU-->>U: WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP, LC_GROUP
  else return_tp_group
    DU-->>U: WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP
  else
    DU-->>U: WORLD_SIZE, RANK, LOCAL_RANK
  end
  U->>NV: NVSHMEM alloc / barrier as needed
Loading
sequenceDiagram
  autonumber
  participant S as Script
  participant TL as TileLang Compiler
  participant CG as CUDA Codegen

  S->>TL: compile(kernel, use_nvshmem=true)
  TL->>CG: generate CUDA source
  CG->>CG: emit includes: <nvshmem.h>, <nvshmemx.h>
  CG-->>TL: CUDA source
  TL-->>S: compiled module
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~40 minutes

Possibly related PRs

Suggested reviewers

  • chengyupku
  • tzj-fxz

Poem

A rabbit compiles under moonlit glow,
Tiles hop between ranks in a quiet row.
NVSHMEM bridges, headers neat and true,
Local rings hum, rendezvous set too.
Hooray — kernels run; the bytes all flew! 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title concisely describes the primary feature added by this pull request, namely support for both inter-node and intra-node copy operations, and aligns with the main functionality implemented in the changeset.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d8b26c7 and 6ca8c15.

📒 Files selected for processing (2)
  • examples/distributed/example_overlapping_allgather.py (1 hunks)
  • tilelang/distributed/utils.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tilelang/distributed/utils.py (1)
tilelang/profiler/__init__.py (1)
  • init_distributed (67-88)
examples/distributed/example_overlapping_allgather.py (10)
tilelang/distributed/utils.py (1)
  • init_distributed (59-90)
tilelang/language/distributed/multi_device/nvshmem.py (2)
  • get_pe (6-8)
  • putmem_nbi_block (108-116)
tilelang/language/tir/op.py (1)
  • address_of (463-479)
tilelang/language/distributed/common.py (3)
  • get_rank (8-11)
  • get_num_ranks (14-17)
  • put_block (70-87)
tilelang/env.py (2)
  • disable_cache (247-248)
  • get (136-139)
tilelang/utils/allocator.py (2)
  • get_allocator (237-249)
  • device (114-115)
tilelang/jit/__init__.py (1)
  • compile (32-81)
tilelang/utils/tensor.py (1)
  • tensor (45-58)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-85)
tilelang/jit/kernel.py (1)
  • initialize (400-409)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-test-nvidia

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
tilelang/distributed/utils.py (1)

80-85: Remove debug print statement.

Line 84 contains a debug print statement that should be removed or replaced with proper logging for production code.

Apply this diff to remove the debug print:

     if return_lc_group:
         local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
         base = (RANK // local_world_size) * local_world_size
         LC_GROUP = torch.distributed.new_group(list(range(base, base + local_world_size)), backend="nccl")
-        print(local_world_size,LC_GROUP,TP_GROUP)
         return WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP, LC_GROUP
examples/distributed/example_overlapping_allgather.py (1)

52-53: Remove debug print statements.

Lines 52-53 contain debug print statements that should be removed or replaced with proper logging for production code.

🧹 Nitpick comments (1)
examples/distributed/example_overlapping_allgather.py (1)

128-128: Avoid modifying global environment state in examples.

Modifying env.USE_NVSHMEM at runtime creates a problematic side effect that could affect other parts of the system. Consider passing this as a parameter or using a context manager instead.

A better approach would be to pass the configuration to the compile function:

# Instead of modifying global state
intrakernel = tilelang.compile(
    intranode_gather(M, WORLD_SIZE, M, 128),
    pass_configs={
        tilelang.PassConfigKey.TL_DISABLE_RDC: True,
        # Consider adding a pass config for NVSHMEM if needed
    }
)

Or document that this example requires specific environment configuration.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2cc1a2c and d8b26c7.

📒 Files selected for processing (5)
  • examples/distributed/example_overlapping_allgather.py (1 hunks)
  • src/target/codegen_cuda.cc (1 hunks)
  • tilelang/distributed/launch.sh (1 hunks)
  • tilelang/distributed/utils.py (2 hunks)
  • tilelang/utils/allocator.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tilelang/distributed/utils.py (1)
tilelang/profiler/__init__.py (1)
  • init_distributed (67-88)
examples/distributed/example_overlapping_allgather.py (9)
tilelang/distributed/utils.py (2)
  • init_distributed (59-89)
  • init_dist (34-56)
tilelang/language/distributed/multi_device/nvshmem.py (2)
  • get_pe (6-8)
  • putmem_nbi_block (108-116)
tilelang/language/tir/op.py (1)
  • address_of (463-479)
tilelang/language/distributed/common.py (3)
  • get_rank (8-11)
  • get_num_ranks (14-17)
  • put_block (70-87)
tilelang/env.py (2)
  • disable_cache (247-248)
  • get (136-139)
tilelang/utils/allocator.py (2)
  • get_allocator (237-249)
  • device (114-115)
tilelang/jit/__init__.py (1)
  • compile (32-81)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-85)
tilelang/jit/kernel.py (1)
  • initialize (400-409)
🪛 GitHub Actions: CI
examples/distributed/example_overlapping_allgather.py

[error] 1-1: F401: 'argparse' imported but unused.


[error] 8-8: F401: 'tilelang.distributed.utils.dtype_map' imported but unused.


[error] 8-8: F401: 'tilelang.distributed.utils.perf_fn' imported but unused.


[error] 9-9: F401: 'tilelang.distributed.utils.init_dist' imported but unused.


[error] 19-19: F401: 'cuda.cuda' imported but unused.

🔇 Additional comments (5)
src/target/codegen_cuda.cc (1)

184-185: LGTM! Critical fix for NVSHMEM include directives.

The corrected include statements now use proper preprocessor syntax. Previously, the stray ">>" characters would have caused compilation errors when NVSHMEM is enabled.

tilelang/distributed/launch.sh (1)

23-29: LGTM! Environment-driven rendezvous endpoint configuration.

The change from hardcoded values to ARNOLD_WORKER_0_HOST and ARNOLD_WORKER_0_PORT environment variables enables flexible distributed launcher configuration. The port parsing logic correctly extracts the first comma-separated field, and sensible defaults are provided for local execution.

tilelang/utils/allocator.py (1)

152-152: LGTM! Ensures buffer_ptrs is allocated on CUDA device.

Adding device='cuda' ensures the buffer_ptrs tensor is created on the CUDA device, which is required for proper NVSHMEM and distributed IPC operations. This aligns with the distributed CUDA workflow introduced elsewhere in the PR.

tilelang/distributed/utils.py (1)

66-66: LGTM! Explicit device_id initialization.

Adding the device_id parameter to init_process_group ensures proper device context initialization for distributed operations. This aligns with PyTorch distributed best practices and coordinates with the LOCAL_RANK-based device management.

examples/distributed/example_overlapping_allgather.py (1)

22-36: Verify rank calculation logic for multi-node scenarios.

The rank calculation at line 31 (T.get_pe()+local_world_size)%(2*local_world_size) assumes exactly 2 nodes. Ensure this logic is documented and handles edge cases correctly.

Run the following script to verify the rank calculation produces expected mappings:

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (2)
examples/distributed/example_overlapping_allgather.py (2)

1-1: Remove unused imports flagged by the linter.

The following imports are unused and cause pipeline failures:

  • Line 1: argparse
  • Line 8: dtype_map, perf_fn
  • Line 9: init_dist
  • Line 16: cuda (from cuda.bindings.driver)
  • Line 19: cuda (from cuda module)

Apply this diff:

-import argparse
 import torch
 import torch.distributed as dist
 import pynvshmem
 import tilelang
 import tilelang.language as T
 import os
-from tilelang.distributed.utils import init_distributed, dtype_map, perf_fn
-from tilelang.distributed.utils import init_dist
+from tilelang.distributed.utils import init_distributed
 from tilelang.env import env
 from packaging import version
 import importlib.metadata
 cuda_python_version = importlib.metadata.version("cuda-python")
 if version.parse(cuda_python_version) >= version.parse("12.8.0"):
-    from cuda.bindings import driver as cuda
     from cuda.bindings import runtime as cudart
 else:
-    from cuda import cuda, cudart
+    from cuda import cudart

Also applies to: 8-9, 16-16, 19-19


54-56: Remove or conditionalize debug print statements in production kernel.

The kernel contains debug print statements that should be removed or made conditional:

if tid == 0:
    T.print(T.cast(rank[0], "int32"), msg="signal")
    T.print(T.cast(num_rank[0], "int32"), msg="signal")

Apply this diff to remove the debug prints:

             rank[0] = T.get_rank()
             num_rank[0] = T.get_num_ranks()
-            tid = T.get_thread_binding()
-            if tid == 0:
-                T.print(T.cast(rank[0], "int32"), msg="signal")
-                T.print(T.cast(num_rank[0], "int32"), msg="signal")
             for k in T.serial(world_size // 2):  # 2 node
🧹 Nitpick comments (3)
examples/distributed/example_overlapping_allgather.py (3)

83-83: Remove redundant LOCAL_RANK assignment.

LOCAL_RANK is already obtained from init_distributed on line 80:

WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP, LC_GROUP = init_distributed(...)

Reassigning it from the environment is redundant.

Apply this diff:

     WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP, LC_GROUP = init_distributed(
         return_tp_group=True, return_lc_group=True)
     local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
-    LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))

148-148: Add rank context to print statement.

The final print lacks rank information, making output from multiple processes difficult to interpret.

Use the existing dist_print utility for rank-aware output:

+    from tilelang.distributed.utils import dist_print
+
-    print(dst_intra)
+    dist_print(f"dst_intra: {dst_intra}", prefix=True, allowed_ranks="all")

135-135: Avoid global modification of env.USE_NVSHMEM
Disabling NVSHMEM here (examples/distributed/example_overlapping_allgather.py:135) also affects tilelang/jit/adapter and profiler. Scope this change with a context manager or save and restore the original value, or extend the API to accept a disable-NVSHMEM flag.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d8b26c7 and 1e0e208.

📒 Files selected for processing (2)
  • examples/distributed/example_overlapping_allgather.py (1 hunks)
  • tilelang/distributed/utils.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tilelang/distributed/utils.py (1)
tilelang/profiler/__init__.py (1)
  • init_distributed (67-88)
examples/distributed/example_overlapping_allgather.py (11)
tilelang/distributed/utils.py (3)
  • init_distributed (59-90)
  • perf_fn (225-246)
  • init_dist (34-56)
tilelang/language/distributed/multi_device/nvshmem.py (2)
  • get_pe (6-8)
  • putmem_nbi_block (108-116)
tilelang/language/tir/op.py (1)
  • address_of (463-479)
tilelang/language/distributed/common.py (3)
  • get_rank (8-11)
  • get_num_ranks (14-17)
  • put_block (70-87)
tilelang/env.py (2)
  • disable_cache (247-248)
  • get (136-139)
tilelang/utils/allocator.py (2)
  • get_allocator (237-249)
  • device (114-115)
tilelang/jit/__init__.py (1)
  • compile (32-81)
tilelang/distributed/pynvshmem/python/_pynvshmem/__init__.pyi (1)
  • nvshmem_barrier_all (58-59)
tilelang/utils/tensor.py (1)
  • tensor (45-58)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-85)
tilelang/jit/kernel.py (1)
  • initialize (400-409)
🪛 GitHub Actions: CI
examples/distributed/example_overlapping_allgather.py

[error] 1-1: Ruff: F401 'argparse' imported but unused. Remove unused import.


[error] 8-8: Ruff: F401 'tilelang.distributed.utils.dtype_map' imported but unused. Remove unused import.


[error] 8-8: Ruff: F401 'tilelang.distributed.utils.perf_fn' imported but unused. Remove unused import.


[error] 9-9: Ruff: F401 'tilelang.distributed.utils.init_dist' imported but unused. Remove unused import.


[error] 19-19: Ruff: F401 'cuda.cuda' imported but unused. Remove unused import.

🔇 Additional comments (3)
tilelang/distributed/utils.py (3)

59-59: LGTM! Backward-compatible signature extension.

The new return_lc_group parameter enables returning the local communication group while maintaining backward compatibility by defaulting to False.


80-86: LGTM! LC_GROUP construction logic is correct.

The local group construction correctly computes the base rank and creates groups for intra-node communication. The logic properly partitions global ranks into local groups based on LOCAL_WORLD_SIZE.


66-66: Verify device_id parameter support in your PyTorch environment
The sandbox couldn’t import torch, so please manually confirm whether your minimum supported PyTorch version’s dist.init_process_group signature includes device_id. If it does, document this version requirement; otherwise, guard the kwarg via an inspect.signature check as shown above.

):
with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx):
rank = T.alloc_local([1], "uint64")
rank[0] = (T.get_pe() + local_world_size) % (2 * local_world_size) # 2 nodes
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Hardcoded 2-node assumption limits flexibility.

The rank computation assumes exactly 2 nodes:

rank[0] = (T.get_pe() + local_world_size) % (2 * local_world_size)

Consider either:

  1. Parameterizing the number of nodes
  2. Adding a clear assertion/comment that this example requires exactly 2 nodes
  3. Computing the number of nodes dynamically from environment variables

Example parameterization:

-def internode_gather(M, local_world_size, block_M, threads):
+def internode_gather(M, local_world_size, num_nodes, block_M, threads):
     @T.prim_func
     def main(
             dst: T.Tensor((M), "float32"),
             src: T.Tensor((M), "float32"),
     ):
         with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx):
             rank = T.alloc_local([1], "uint64")
-            rank[0] = (T.get_pe() + local_world_size) % (2 * local_world_size)  # 2 nodes
+            rank[0] = (T.get_pe() + local_world_size) % (num_nodes * local_world_size)
🤖 Prompt for AI Agents
In examples/distributed/example_overlapping_allgather.py around line 33, the
rank calculation hardcodes a 2-node assumption using (2 * local_world_size),
which limits flexibility; replace this with a configurable or dynamically
computed node count (e.g., derive num_nodes from environment or from
total_world_size // local_world_size) or add an explicit assertion/comment
stating the example requires exactly 2 nodes; update rank computation to use
num_nodes instead of 2 and validate inputs (raise/assert on mismatch) so the
example works for arbitrary node counts or clearly documents its 2-node
requirement.

if tid == 0:
T.print(T.cast(rank[0], "int32"), msg="signal")
T.print(T.cast(num_rank[0], "int32"), msg="signal")
for k in T.serial(world_size // 2): # 2 node
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Hardcoded 2-node assumption in loop.

Similar to the internode kernel, this loop assumes exactly 2 nodes:

for k in T.serial(world_size // 2):  # 2 node

Consider parameterizing or documenting this constraint clearly. If this example is specifically designed for 2 nodes, add an early assertion:

if __name__ == '__main__':
    num_nodes = int(os.environ.get('NODES', 2))
    assert num_nodes == 2, "This example requires exactly 2 nodes"
🤖 Prompt for AI Agents
In examples/distributed/example_overlapping_allgather.py around line 57, the
loop uses a hardcoded 2-node assumption ("for k in T.serial(world_size // 2):  #
2 node"), which should be made explicit or guarded; either parameterize the
behavior to derive the loop bound from a configurable num_nodes/worker_count
variable or add an early runtime assertion that the example requires exactly 2
nodes (readable from an env var or CLI) and fail fast with a clear message;
update top-of-script argument/env parsing or add the assert in the main entry so
the loop remains correct for the intended number of nodes.

Comment on lines +125 to +133
cudart.cudaMemcpy(src_intra.data_ptr(), src.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, dst.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
else:
cudart.cudaMemcpy(src_intra.data_ptr(), dst.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, src.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add error checking for CUDA memory operations.

The cudaMemcpy calls don't check return codes. CUDA operations can fail silently, leading to incorrect results.

The codebase uses CUDA_CHECK for error handling (see tilelang/distributed/utils.py lines 249-257). Apply error checking:

+    from tilelang.distributed.utils import CUDA_CHECK
+
     if RANK < WORLD_SIZE / 2:
-        cudart.cudaMemcpy(src_intra.data_ptr(), src.data_ptr(), M * 4,
-                          cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
-        cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, dst.data_ptr(), M * 4,
-                          cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
+        CUDA_CHECK(cudart.cudaMemcpy(src_intra.data_ptr(), src.data_ptr(), M * 4,
+                                      cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice))
+        CUDA_CHECK(cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, dst.data_ptr(), M * 4,
+                                      cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice))
     else:
-        cudart.cudaMemcpy(src_intra.data_ptr(), dst.data_ptr(), M * 4,
-                          cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
-        cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, src.data_ptr(), M * 4,
-                          cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
+        CUDA_CHECK(cudart.cudaMemcpy(src_intra.data_ptr(), dst.data_ptr(), M * 4,
+                                      cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice))
+        CUDA_CHECK(cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, src.data_ptr(), M * 4,
+                                      cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
cudart.cudaMemcpy(src_intra.data_ptr(), src.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, dst.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
else:
cudart.cudaMemcpy(src_intra.data_ptr(), dst.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, src.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
from tilelang.distributed.utils import CUDA_CHECK
if RANK < WORLD_SIZE / 2:
CUDA_CHECK(cudart.cudaMemcpy(src_intra.data_ptr(), src.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice))
CUDA_CHECK(cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, dst.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice))
else:
CUDA_CHECK(cudart.cudaMemcpy(src_intra.data_ptr(), dst.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice))
CUDA_CHECK(cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, src.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice))
🤖 Prompt for AI Agents
In examples/distributed/example_overlapping_allgather.py around lines 125 to
133, the cudart.cudaMemcpy calls are missing error checking; wrap each
cudaMemcpy invocation with the project's CUDA_CHECK helper (the same pattern
used in tilelang/distributed/utils.py lines ~249-257) so that the return code is
validated and any CUDA errors are surfaced. Replace each raw
cudart.cudaMemcpy(...) call with a CUDA_CHECK invocation that passes the
cudaMemcpy call and preserves the same arguments and intent.

@chengyupku chengyupku merged commit aa916f4 into tile-ai:main Oct 29, 2025
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants