Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Intel GPU (XPU) Support with Triton Implementation #8679

Open
6 tasks
Stonepia opened this issue Oct 11, 2024 · 2 comments
Open
6 tasks

[RFC] Intel GPU (XPU) Support with Triton Implementation #8679

Stonepia opened this issue Oct 11, 2024 · 2 comments

Comments

@Stonepia
Copy link

Stonepia commented Oct 11, 2024

🚀 1. The feature

This RFC proposes using Triton language to implement Intel GPU (torch.xpu) kernels.

1.1. Motivation

This RFC aims to discuss the XPU kernel implementation with the Triton language. The Triton kernel will contain most of the device-agnostic logic, and the device-specific logic will be implemented separately.

The only package requirement for this is triton package. However, as triton is by default the requirements of torch, we could ignore this package requirements in torchvision.

2. Op List

  • nms
  • roi pool
  • roi align
  • ps roi pool
  • ps roi align
  • deform_conv2d

3. File Structure

torchvision
├── csrc
│   └── ops   # This will contain the current op registration in C++, will be no change
│       └── nms.cpp
└── ops        # This folder is pure Python
  ├── triton # The shared Triton kernel implementation. It won't register the kernel.
  │   └── nms_kernel.py
  └── xpu    # The device-specific kernel implementation and registration
    └── nms_kernel.py

4. Details

4.1. Register Custom Kernel Implementation

The kernel implementation consists of two parts. Firstly, we aim to use Triton to implement device-agnostic logic to cover as much of the logic as possible. Secondly, we will use ATen operations for the logic that Triton does not support and device-specific logics.

4.1.1. Triton Kernel Implementation

To extend the generality, we propose to implement the main logic using the Triton language. This is device-agnostic. As long as the device/platform support Triton, the Triton kernel should work.

# torchvision/ops/triton/nms_kernel.py
@triton.jit
def triton_nms_kernel(boxes, scores, output_ptr, threshold, num_boxes, BLOCK_SIZE: tl.constexpr):
  pid = tl.program_id(0)
  block_start = pid * BLOCK_SIZE
  offsets = block_start + tl.arange(0, BLOCK_SIZE)
  ...

4.1.2. XPU Kernel Implementation and Registration

This will implement device specific logic. There are two aspects:

  1. Triton unsupported ops. For example, tensor initialization, argsort, etc.
  2. Device-specific Triton configuration. This would affect the performance of the kernel (e.g., pass Triton config to the kernel).

The developer will have to wrap the above logic into a custom kernel and register the kernel to PyTorch using torch.library.register_kernel(kernel_name, device).

# torchvision/ops/xpu/nms_kernel.py

@torch.library.register_kernel("torchvision::nms", "xpu")
def xpu_custom_nms_triton_kernel(boxes: Tensor, scores: Tensor, threshold: float) -> Tensor:
  # 1 : Perform operations that Triton doesn't support:
  # - Tensor initialization, which may be device-specific
  output = torch.ones(num_boxes, dtype=torch.int32, device=boxes.device)
  # - Use ATen calls for Triton unsupported ops (argsort in this case)
  order = torch.argsort(scores, descending=True)
  boxes = boxes[order]
  ...

  # 2 : Call Triton Kernel
  # Hardware vendors can decide which configuration is optimal for their device. 
  # For example, BLOCK_SIZE in this case.
  grid = lambda meta: (triton.cdiv(num_boxes, meta['BLOCK_SIZE']),)
  triton_nms_kernel[grid](boxes, scores, mask, threshold, num_boxes, BLOCK_SIZE=meta['BLOCK_SIZE'])

  # 3 : Post Processing Logic
  output = torch.narrow(...)

  return output

In this way, the user could call the torchvision.ops.nms with both eager mode and torch.compile.

5. Run The Kernel

5.1. Eager Mode

The user can directly call the kernel in eager mode. The nms call will fall into the custom implementation in this scenario.

output = torchvision.ops.nms(boxes, scores, threshold)
print(output)

5.2. torch.compile Mode

The user can call torch.compile with the following:

nms_compile = torch.compile(torchvision.ops.nms)
compiled_output = nms_compile(boxes, scores, threshold)

The torch.compile will use inductor compiler to decompose ops into primitive ops. When the op does not have registered the decomposition, it will fallback to ATen call. i.e., In this scenario, the inductor won't decompose our custom kernel, and it will behave the same as eager mode.

6. Alternatives

  • Native C++ Implementation:

Currently, all the kernel backends are in torchvision/csrc/ops and are written in C++. This way requires every device backend developer to implement using their own compiler/language (for example, CUDA with nvcc). The pros of native C++ are that it could reach peak performance, and it will have Windows support.

By comparison, Triton has some limitations. For example, Triton does not support manually putting the tensor into the shared memory, thus the logic in CUDA nms kernel could not be implemented in Triton. However, the pros of Triton are that Triton is more general with reasonable performance. The optimization logic in Triton can be shared by different devices.

7. Limitations

  • Windows Support: Currently, Triton does not support Windows. Thus, kernels implemented with Triton will fail on Windows. However, when the Triton supports the Windows platform, we don't need any additional efforts.
  • Performance Alignment: Since the Triton has some limitations, the performance may not fully align with native C++ implementation.

8. Additional context

  1. Is it possible that both Triton kernel and C++ implementation exist for torch.compile mode?

In this case, one may register as a custom lowering. Like the following code:

@register_lowering(torchvision.ops.nms)
def nms(...):
  if device == 'xpu' and USE_ATEN_BACKEND:
     aten_cpp_nms(...)
  elif device == 'xpu' and USE_TRITON_BACKEND:
     triton_nms(...)
  else:
     ...

cc: @EikanWang @fengyuan14 @riverliuintel

@EikanWang
Copy link

@NicolasHug , may I know your comments?

@abhi-glitchhg
Copy link
Contributor

This is interesting! ❤️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants