Skip to content

Commit 3dfb950

Browse files
issue/1105 fix kernel header include
1 parent 937f9e3 commit 3dfb950

File tree

25 files changed

+51
-194
lines changed

25 files changed

+51
-194
lines changed

src/infiniop/devices/metax/metax_common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
#ifdef ENABLE_METAX_MC_API
77
#include <mcblas/mcblas.h>
88
#include <mcdnn/mcdnn.h>
9+
#include <mcr/mc_runtime.h>
910
#else
1011
#include <hcblas/hcblas.h>
1112
#include <hcdnn/hcdnn.h>
13+
#include <hcr/hc_runtime.h>
1214
#endif
1315
#include <functional>
1416
#include <memory>

src/infiniop/devices/metax/metax_kernel_common.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
#define INFINIOP_METAX_KERNEL __global__ void
22

33
#ifdef ENABLE_METAX_MC_API
4+
#include <maca_bfloat16.h>
5+
#include <maca_fp16.h>
46
#include <maca_fp8.h>
57
#else
8+
#include <hpcc_bfloat16.h>
9+
#include <hpcc_fp16.h>
610
#include <hpcc_fp8.h>
711
#endif
812

src/infiniop/ops/addcmul/cuda/kernel.cuh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
#ifndef __ADDCMUL_CUDA_CUH__
22
#define __ADDCMUL_CUDA_CUH__
33

4-
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
5-
#include <cuda_bf16.h>
6-
#include <cuda_fp16.h>
7-
#endif
84
#include <type_traits>
95

106
namespace op::addcmul::cuda {

src/infiniop/ops/argwhere/cpu/argwhere_cpu.cc

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,27 +37,36 @@ infiniStatus_t calculateArgWhere(
3737
const void *x) {
3838

3939
const Tdata *x_data = reinterpret_cast<const Tdata *>(x);
40-
// int64_t *y_data = reinterpret_cast<int64_t *>(y);
41-
std::vector<size_t> positions;
42-
// #pragma omp parallel for
40+
41+
std::vector<int64_t> positions;
42+
const size_t ndim = info.shapes.size();
43+
4344
for (size_t i = 0; i < info.num_elements; i++) {
44-
size_t pos = 0, tem = i;
45-
std::vector<size_t> position(info.strides.size());
46-
for (size_t j = info.strides.size() - 1; j >= 0; j--) {
47-
position[j] = tem % info.shapes[j];
48-
tem /= info.shapes[j];
49-
pos += position[j] * info.strides[j];
45+
size_t pos = 0;
46+
size_t tmp = i;
47+
48+
std::vector<int64_t> coord(ndim);
49+
50+
// unravel index
51+
for (size_t j = ndim; j-- > 0;) {
52+
coord[j] = tmp % info.shapes[j];
53+
tmp /= info.shapes[j];
54+
pos += coord[j] * info.strides[j];
5055
}
51-
if (fabs(x_data[pos] - 0.0f) > 1e-5) {
52-
for (auto p : position) {
53-
positions.push_back(p);
56+
57+
// PyTorch semantics: != 0
58+
if (x_data[pos] != Tdata(0)) {
59+
for (size_t j = 0; j < ndim; j++) {
60+
positions.push_back(coord[j]);
5461
}
5562
}
5663
}
5764

65+
*count = positions.size() / ndim;
66+
5867
*y = new int64_t[positions.size()];
5968
memcpy(*y, positions.data(), positions.size() * sizeof(int64_t));
60-
*count = positions.size() / info.strides.size();
69+
6170
return INFINI_STATUS_SUCCESS;
6271
}
6372

src/infiniop/ops/atanh/cuda/kernel.cuh

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
#ifndef __ATANH_CUDA_H__
22
#define __ATANH_CUDA_H__
33

4-
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
5-
#include <cuda_bf16.h>
6-
#include <cuda_fp16.h>
7-
#endif
8-
94
namespace op::atanh::cuda {
105
typedef struct AtanhOp {
116
public:

src/infiniop/ops/binary_cross_entropy_with_logits/metax/binary_cross_entropy_with_logits_metax.maca

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
#include "../../../devices/metax/metax_common.h"
22
#include "../../../devices/metax/metax_handle.h"
33
#include "../../../devices/metax/metax_kernel_common.h"
4+
45
#include "binary_cross_entropy_with_logits_metax.h"
5-
#if defined(ENABLE_METAX_MC_API)
6-
#include <mc_runtime.h>
7-
#else
8-
#include <hc_runtime.h>
9-
#endif
6+
107
#include <type_traits>
118

129
namespace op::bce_with_logits::metax {

src/infiniop/ops/equal/cuda/kernel.cuh

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,6 @@
11
#ifndef __EQUAL_CUDA_H__
22
#define __EQUAL_CUDA_H__
33

4-
#if ENABLE_METAX_API
5-
#if defined(ENABLE_METAX_MC_API)
6-
#include <maca_bfloat16.h>
7-
#include <maca_fp16.h>
8-
#else
9-
#include <hpcc_bfloat16.h>
10-
#include <hpcc_fp16.h>
11-
#endif
12-
#elif defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
13-
#include <cuda_bf16.h>
14-
#include <cuda_fp16.h>
15-
#endif
164
#include <type_traits>
175

186
namespace op::equal::cuda {

src/infiniop/ops/hardswish/cuda/kernel.cuh

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,6 @@
22
#define __HARDSWISH_CUDA_H__
33

44
#include <cmath>
5-
#if ENABLE_METAX_API
6-
#if defined(ENABLE_METAX_MC_API)
7-
#include <maca_bfloat16.h>
8-
#include <maca_fp16.h>
9-
#else
10-
#include <hpcc_bfloat16.h>
11-
#include <hpcc_fp16.h>
12-
#endif
13-
#elif defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
14-
#include <cuda_bf16.h>
15-
#include <cuda_fp16.h>
16-
#endif
175

186
namespace op::hardswish::cuda {
197

src/infiniop/ops/hardtanh/cuda/kernel.cuh

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,6 @@
11
#ifndef __HARDTANH_CUDA_H__
22
#define __HARDTANH_CUDA_H__
33

4-
#if ENABLE_METAX_API
5-
#if defined(ENABLE_METAX_MC_API)
6-
#include <maca_bfloat16.h>
7-
#include <maca_fp16.h>
8-
#else
9-
#include <hpcc_bfloat16.h>
10-
#include <hpcc_fp16.h>
11-
#endif
12-
#elif defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
13-
#include <cuda_bf16.h>
14-
#include <cuda_fp16.h>
15-
#endif
164
#include <type_traits>
175

186
namespace op::hardtanh::cuda {

src/infiniop/ops/hypot/cuda/kernel.cuh

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,6 @@
33

44
#include <cmath>
55
#include <type_traits>
6-
#if ENABLE_METAX_API
7-
#if defined(ENABLE_METAX_MC_API)
8-
#include <maca_bfloat16.h>
9-
#include <maca_fp16.h>
10-
#else
11-
#include <hpcc_bfloat16.h>
12-
#include <hpcc_fp16.h>
13-
#endif
14-
#elif defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
15-
#include <cuda_bf16.h>
16-
#include <cuda_fp16.h>
17-
#endif
186

197
namespace op::hypot::cuda {
208

0 commit comments

Comments
 (0)