Skip to content

Commit

Permalink
Refactor csrc with device dispatcher (#1463)
Browse files Browse the repository at this point in the history
* Add device registry for pytorch ops

* add declaration of CheckDeviceConsistency

* fix for torch130

* assert with torch check

* Refactor ops with dispatch

* update rest ops

* faster install

* update compatibility

* update compatibility, rename parameter

* move cpu implement to pytorch/cpu

* update ops/csrc/README.md

* fix rocm support

* update cn document

* update docs

* list instead of map
  • Loading branch information
q.yao authored Nov 23, 2021
1 parent ef8ba75 commit 230f9a3
Show file tree
Hide file tree
Showing 65 changed files with 3,122 additions and 2,646 deletions.
73 changes: 73 additions & 0 deletions docs/compatibility.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,76 @@
### v1.3.18

Some ops have different implementations on different devices. Lots of macros and type checks are scattered in several files, which makes the code hard to maintain. For example:

```c++
if (input.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(rois);
CHECK_CUDA_INPUT(output);
CHECK_CUDA_INPUT(argmax_y);
CHECK_CUDA_INPUT(argmax_x);

roi_align_forward_cuda(input, rois, output, argmax_y, argmax_x,
aligned_height, aligned_width, spatial_scale,
sampling_ratio, pool_mode, aligned);
#else
AT_ERROR("RoIAlign is not compiled with GPU support");
#endif
} else {
CHECK_CPU_INPUT(input);
CHECK_CPU_INPUT(rois);
CHECK_CPU_INPUT(output);
CHECK_CPU_INPUT(argmax_y);
CHECK_CPU_INPUT(argmax_x);
roi_align_forward_cpu(input, rois, output, argmax_y, argmax_x,
aligned_height, aligned_width, spatial_scale,
sampling_ratio, pool_mode, aligned);
}
```
Registry and dispatcher are added to manage these implementations.
```c++
void ROIAlignForwardCUDAKernelLauncher(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
int pool_mode, bool aligned);
void roi_align_forward_cuda(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
int pool_mode, bool aligned) {
ROIAlignForwardCUDAKernelLauncher(
input, rois, output, argmax_y, argmax_x, aligned_height, aligned_width,
spatial_scale, sampling_ratio, pool_mode, aligned);
}
// register cuda implementation
void roi_align_forward_impl(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
int pool_mode, bool aligned);
REGISTER_DEVICE_IMPL(roi_align_forward_impl, CUDA, roi_align_forward_cuda);
// roi_align.cpp
// use the dispatcher to invoke different implementation depending on device type of input tensors.
void roi_align_forward_impl(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
int pool_mode, bool aligned) {
DISPATCH_DEVICE_IMPL(roi_align_forward_impl, input, rois, output, argmax_y,
argmax_x, aligned_height, aligned_width, spatial_scale,
sampling_ratio, pool_mode, aligned);
}
```

### v1.3.11

In order to flexibly support more backends and hardwares like `NVIDIA GPUs` and `AMD GPUs`, the directory of `mmcv/ops/csrc` is refactored. Note that this refactoring will not affect the usage in API. For related information, please refer to [PR1206](https://github.com/open-mmlab/mmcv/pull/1206).
Expand Down
73 changes: 73 additions & 0 deletions docs_zh_CN/compatibility.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,76 @@
### v1.3.18

部分自定义算子对于不同的设备有不同实现,为此添加的大量宏命令与类型检查使得代码变得难以维护。例如:

```c++
if (input.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(rois);
CHECK_CUDA_INPUT(output);
CHECK_CUDA_INPUT(argmax_y);
CHECK_CUDA_INPUT(argmax_x);

roi_align_forward_cuda(input, rois, output, argmax_y, argmax_x,
aligned_height, aligned_width, spatial_scale,
sampling_ratio, pool_mode, aligned);
#else
AT_ERROR("RoIAlign is not compiled with GPU support");
#endif
} else {
CHECK_CPU_INPUT(input);
CHECK_CPU_INPUT(rois);
CHECK_CPU_INPUT(output);
CHECK_CPU_INPUT(argmax_y);
CHECK_CPU_INPUT(argmax_x);
roi_align_forward_cpu(input, rois, output, argmax_y, argmax_x,
aligned_height, aligned_width, spatial_scale,
sampling_ratio, pool_mode, aligned);
}
```
为此我们设计了注册与分发的机制以更好的管理这些算子实现。
```c++
void ROIAlignForwardCUDAKernelLauncher(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
int pool_mode, bool aligned);
void roi_align_forward_cuda(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
int pool_mode, bool aligned) {
ROIAlignForwardCUDAKernelLauncher(
input, rois, output, argmax_y, argmax_x, aligned_height, aligned_width,
spatial_scale, sampling_ratio, pool_mode, aligned);
}
// 注册算子的cuda实现
void roi_align_forward_impl(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
int pool_mode, bool aligned);
REGISTER_DEVICE_IMPL(roi_align_forward_impl, CUDA, roi_align_forward_cuda);
// roi_align.cpp
// 使用dispatcher根据参数中的Tensor device类型对实现进行分发
void roi_align_forward_impl(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
int pool_mode, bool aligned) {
DISPATCH_DEVICE_IMPL(roi_align_forward_impl, input, rois, output, argmax_y,
argmax_x, aligned_height, aligned_width, spatial_scale,
sampling_ratio, pool_mode, aligned);
}
```

### v1.3.11

为了灵活地支持更多的后端和硬件,例如 `NVIDIA GPUs``AMD GPUs`,我们重构了 `mmcv/ops/csrc` 目录。注意,这次重构不会影响 API 的使用。更多相关信息,请参考 [PR1206](https://github.com/open-mmlab/mmcv/pull/1206)
Expand Down
51 changes: 26 additions & 25 deletions mmcv/ops/csrc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ This folder contains all non-python code for MMCV custom ops. Please follow the
│ ├── parrots_cuda_helper.hpp
│ ├── pytorch_cpp_helper.hpp
│ ├── pytorch_cuda_helper.hpp
│ ├── pytorch_device_registry.hpp
│   └── cuda
│   ├── common_cuda_helper.hpp
│   ├── parrots_cudawarpfunction.cuh
Expand All @@ -37,9 +38,12 @@ This folder contains all non-python code for MMCV custom ops. Please follow the
│   ├── pybind.cpp
│   ├── ...
│   ├── ops.cpp
│   └── cuda
│   ├── cuda
│   │   ├── ...
│   │   └── ops_cuda.cu
│   └── cpu
│      ├── ...
│      └── ops_cuda.cu
│      └── ops.cpp
└── tensorrt
├── trt_cuda_helper.cuh
├── trt_plugin_helper.hpp
Expand All @@ -64,6 +68,7 @@ This folder contains all non-python code for MMCV custom ops. Please follow the
- `parrots`: **Parrots** is a deep learning frame for model training and inference. Parrots custom ops are placed in this directory.
- `pytorch`: **PyTorch** custom ops are supported by binding C++ to Python with **pybind11**. The ops implementation and binding codes are placed in this directory.
- `cuda`: This directory contains cuda kernel launchers, which feed memory pointers of tensor to the cuda kernel in `common/cuda`. The launchers provide c++ interface of cuda implementation of corresponding custom ops.
- `cpu`: This directory contain cpu implementations of corresponding custom ops.
- `tensorrt`: **TensorRT** support for custom ops.
- `plugins`: This directory contains the implementation of the supported custom ops. Some ops might also use shared cuda kernel in `common/cuda`.

Expand Down Expand Up @@ -102,42 +107,38 @@ This folder contains all non-python code for MMCV custom ops. Please follow the
}
```

2. Add ops implementation in `pytorch` directory. Select different implementations according to device type.
2. Register implementation for different devices.

```c++
// src/pytorch/new_ops.cpp
#ifdef MMCV_WITH_CUDA
// src/pytorch/cuda/cudabind.cpp
...

Tensor new_ops_forward_cuda(Tensor input, Tensor output, ...){
// implement cuda forward here
// use `NewOpsForwardCUDAKernelLauncher` here
}
#else
// declare interface here.
Tensor new_ops_forward_impl(Tensor input, Tensor output, ...);
// register the implementation for given device (CUDA here).
REGISTER_DEVICE_IMPL(new_ops_forward_impl, CUDA, new_ops_forward_cuda);
```

Tensor new_ops_forward_cpu(Tensor input, Tensor output, ...){
// implement cpu forward here
}
3. Add ops implementation in `pytorch` directory. Select different implementations according to device type.

```c++
// src/pytorch/new_ops.cpp
Tensor new_ops_forward_impl(Tensor input, Tensor output, ...){
// dispatch the implementation according to the device type of input.
DISPATCH_DEVICE_IMPL(new_ops_forward_impl, input, output, ...);
}
...

Tensor new_ops_forward(Tensor input, Tensor output, ...){
// select implementation by input device type
if (boxes.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(output);
return new_ops_forward_cuda(input, output, ...);
#else
AT_ERROR("new ops is not compiled with GPU support");
#endif
} else {
CHECK_CPU_INPUT(input);
CHECK_CPU_INPUT(output);
return new_ops_forward_cpu(input, output, ...);
}
return new_ops_forward_impl(input, output, ...);
}
```

3. Binding the implementation in `pytorch/pybind.cpp`
4. Binding the implementation in `pytorch/pybind.cpp`

```c++
// src/pytorch/pybind.cpp
Expand All @@ -156,7 +157,7 @@ This folder contains all non-python code for MMCV custom ops. Please follow the

```

4. Build MMCV again. Enjoy new ops in python
5. Build MMCV again. Enjoy new ops in python

```python
from ..utils import ext_loader
Expand Down
2 changes: 1 addition & 1 deletion mmcv/ops/csrc/common/cuda/correlation_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/types.h>
#include <torch/extension.h>

#include <iostream>
#include <vector>
Expand Down
7 changes: 0 additions & 7 deletions mmcv/ops/csrc/common/cuda/ms_deform_attn_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ __global__ void ms_deformable_im2col_gpu_kernel(
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;

Expand Down Expand Up @@ -278,7 +277,6 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;

Expand Down Expand Up @@ -369,7 +367,6 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;

Expand Down Expand Up @@ -463,7 +460,6 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;

Expand Down Expand Up @@ -555,7 +551,6 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;

Expand Down Expand Up @@ -658,7 +653,6 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;

Expand Down Expand Up @@ -757,7 +751,6 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;

Expand Down
11 changes: 10 additions & 1 deletion mmcv/ops/csrc/common/cuda/scatter_points_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ __device__ __forceinline__ static void reduceMax(double *address, double val) {
}

// get rid of meaningless warnings when compiling host code
#ifdef HIP_DIFF
__device__ __forceinline__ static void reduceAdd(float *address, float val) {
atomicAdd(address, val);
}
__device__ __forceinline__ static void reduceAdd(double *address, double val) {
atomicAdd(address, val);
}
#else
#ifdef __CUDA_ARCH__
__device__ __forceinline__ static void reduceAdd(float *address, float val) {
#if (__CUDA_ARCH__ < 200)
Expand Down Expand Up @@ -77,7 +85,8 @@ __device__ __forceinline__ static void reduceAdd(double *address, double val) {
atomicAdd(address, val);
#endif
}
#endif
#endif // __CUDA_ARCH__
#endif // HIP_DIFF

template <typename T>
__global__ void feats_reduce_kernel(
Expand Down
Loading

0 comments on commit 230f9a3

Please sign in to comment.