Skip to content

Commit 7655800

Browse files
bitzyzroot
authored andcommitted
feat: add cambricon bf16 & fp16 data type
1 parent 6bf2479 commit 7655800

10 files changed

Lines changed: 192 additions & 65 deletions

File tree

src/CMakeLists.txt

Lines changed: 10 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,19 @@
11
add_library(infiniops SHARED)
2-
32
file(GLOB BASE_SRCS CONFIGURE_DEPENDS "*.cc")
43
target_sources(infiniops PRIVATE ${BASE_SRCS})
5-
64
set(DEVICE_LIST "")
7-
85
if(WITH_CPU)
96
set(CPU_PATTERNS
107
"cpu/*.cc"
118
"cpu/*.cpp"
129
)
13-
1410
file(GLOB_RECURSE CPU_SOURCES CONFIGURE_DEPENDS ${CPU_PATTERNS})
1511
list(APPEND CORE_SOURCES ${CPU_SOURCES})
16-
1712
target_compile_definitions(infiniops PUBLIC WITH_CPU=1)
18-
1913
find_package(OpenMP REQUIRED)
2014
target_link_libraries(infiniops PRIVATE OpenMP::OpenMP_CXX)
21-
2215
list(APPEND DEVICE_LIST "cpu")
2316
endif()
24-
2517
if(WITH_NVIDIA)
2618
set(NVIDIA_PATTERNS
2719
"cuda/*.cc"
@@ -31,24 +23,18 @@ if(WITH_NVIDIA)
3123
"nvidia/*.cpp"
3224
"nvidia/*.cu"
3325
)
34-
3526
file(GLOB_RECURSE NVIDIA_SOURCES CONFIGURE_DEPENDS ${NVIDIA_PATTERNS})
36-
3727
enable_language(CUDA)
38-
3928
target_compile_definitions(infiniops PUBLIC WITH_NVIDIA=1)
4029
target_sources(infiniops PRIVATE ${NVIDIA_SOURCES})
41-
4230
find_package(CUDAToolkit REQUIRED)
4331
target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas CUDA::cuda_driver)
44-
4532
list(APPEND DEVICE_LIST "nvidia")
4633
set_target_properties(infiniops PROPERTIES
4734
CUDA_STANDARD 17
4835
CUDA_STANDARD_REQUIRED ON
4936
)
5037
endif()
51-
5238
if(WITH_ILUVATAR)
5339
set(ILUVATAR_PATTERNS
5440
"cuda/*.cc"
@@ -58,135 +44,112 @@ if(WITH_ILUVATAR)
5844
"iluvatar/*.cpp"
5945
"iluvatar/*.cu"
6046
)
61-
6247
file(GLOB_RECURSE ILUVATAR_SOURCES CONFIGURE_DEPENDS ${ILUVATAR_PATTERNS})
63-
6448
enable_language(CUDA)
65-
6649
target_compile_definitions(infiniops PUBLIC WITH_ILUVATAR=1)
6750
target_sources(infiniops PRIVATE ${ILUVATAR_SOURCES})
68-
6951
find_package(CUDAToolkit REQUIRED)
7052
target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas CUDA::cuda_driver)
71-
7253
set_target_properties(infiniops PROPERTIES
7354
CUDA_STANDARD 17
7455
CUDA_STANDARD_REQUIRED ON
7556
)
76-
7757
list(APPEND DEVICE_LIST "iluvatar")
7858
endif()
79-
8059
if(WITH_METAX)
8160
set(METAX_PATTERNS
8261
"cuda/*.cc"
8362
"cuda/*.cpp"
8463
"metax/*.cc"
8564
"metax/*.maca"
8665
)
87-
8866
file(GLOB_RECURSE METAX_SOURCES CONFIGURE_DEPENDS ${METAX_PATTERNS})
89-
9067
set_source_files_properties(${METAX_SOURCES} PROPERTIES LANGUAGE CXX)
91-
9268
target_compile_definitions(infiniops PRIVATE WITH_METAX=1)
9369
target_compile_options(infiniops PUBLIC "-x" "maca")
9470
target_sources(infiniops PRIVATE ${METAX_SOURCES})
95-
9671
target_include_directories(infiniops PUBLIC "${MACA_PATH}/include")
9772
target_link_libraries(infiniops PUBLIC
9873
${MACA_RUNTIME_LIB}
9974
${MACA_DNN_LIB}
10075
${MACA_BLAS_LIB}
10176
)
102-
10377
list(APPEND DEVICE_LIST "metax")
10478
endif()
105-
10679
if(WITH_CAMBRICON)
10780
file(GLOB_RECURSE CAMBRICON_MLU_SOURCES CONFIGURE_DEPENDS "cambricon/*/*.mlu")
10881
find_program(CNCC_COMPILER cncc HINTS "${NEUWARE_HOME}/bin" "$ENV{NEUWARE_HOME}/bin" /usr/local/neuware/bin)
109-
11082
if(CNCC_COMPILER)
11183
message(STATUS "Found cncc: ${CNCC_COMPILER}")
112-
11384
set(MLU_COMPILE_OPTS
11485
-c --bang-mlu-arch=mtp_592 -O3 -fPIC -Wall -Werror -std=c++17 -pthread
11586
-I${CMAKE_CURRENT_SOURCE_DIR} -I${NEUWARE_HOME}/include
87+
# FIX 1: Explicitly pass the fallback directory to the custom cncc command
88+
-idirafter /usr/local/neuware/lib/clang/11.1.0/include
11689
)
117-
11890
function(compile_mlu_file src_file)
11991
get_filename_component(name ${src_file} NAME_WE)
12092
get_filename_component(path ${src_file} DIRECTORY)
12193
set(out_file "${CMAKE_CURRENT_BINARY_DIR}/${path}/${name}.o")
12294
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${path}")
123-
12495
add_custom_command(OUTPUT ${out_file}
12596
COMMAND ${CNCC_COMPILER} ${MLU_COMPILE_OPTS} -c ${src_file} -o ${out_file}
12697
DEPENDS ${src_file}
12798
COMMENT "Building MLU kernel: ${src_file}"
12899
)
129100
set_property(DIRECTORY APPEND PROPERTY CAMBRICON_OBJECTS ${out_file})
130101
endfunction()
131-
132102
foreach(src ${CAMBRICON_MLU_SOURCES})
133103
compile_mlu_file(${src})
134104
endforeach()
135-
136105
get_directory_property(CAMBRICON_OBJECT_FILES CAMBRICON_OBJECTS)
137106
if(CAMBRICON_OBJECT_FILES)
138107
target_sources(infiniops PRIVATE ${CAMBRICON_OBJECT_FILES})
139108
endif()
140109
else()
141110
message(WARNING "cncc compiler not found. MLU kernels will not be compiled.")
142111
endif()
143-
144112
target_compile_definitions(infiniops PRIVATE WITH_CAMBRICON=1)
145113
target_include_directories(infiniops PUBLIC "${NEUWARE_HOME}/include")
146114
target_link_libraries(infiniops PUBLIC ${CAMBRICON_RUNTIME_LIB} ${CAMBRICON_CNNL_LIB} ${CAMBRICON_CNNL_EXTRA_LIB} ${CAMBRICON_PAPI_LIB})
147-
115+
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
116+
# FIX 2 & 3: Make it PUBLIC so the Pybind ops target sees it.
117+
# Use SHELL: so CMake doesn't split it, and constrain it to CXX files.
118+
target_compile_options(infiniops PUBLIC
119+
"$<$<COMPILE_LANGUAGE:CXX>:SHELL:-idirafter /usr/local/neuware/lib/clang/11.1.0/include>"
120+
)
121+
endif()
148122
list(APPEND DEVICE_LIST "cambricon")
149123
endif()
150-
151124
target_include_directories(infiniops PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
152-
153125
if(GENERATE_PYTHON_BINDINGS)
154126
execute_process(
155127
COMMAND python ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py --devices ${DEVICE_LIST}
156128
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
157129
RESULT_VARIABLE script_result
158130
)
159-
160131
if(NOT script_result EQUAL 0)
161132
message(FATAL_ERROR "Generating wrappers - failed")
162133
else()
163134
message(STATUS "Generating wrappers - done")
164135
endif()
165-
166136
set(PYBIND11_SOURCES "${PROJECT_SOURCE_DIR}/generated/bindings/ops.cc")
167-
168137
# TODO: There might be a better solution.
169138
if(WITH_NVIDIA OR WITH_ILUVATAR)
170139
set_source_files_properties(${PYBIND11_SOURCES} PROPERTIES LANGUAGE CUDA)
171140
endif()
172-
173141
find_package(Python COMPONENTS Interpreter Development)
174142
find_package(pybind11 CONFIG)
175-
176143
if(PYBIND11_ENABLE_EXTRAS)
177144
pybind11_add_module(ops ${PYBIND11_SOURCES})
178145
else()
179146
pybind11_add_module(ops NO_EXTRAS ${PYBIND11_SOURCES})
180147
endif()
181-
182148
target_include_directories(ops PRIVATE ${PROJECT_SOURCE_DIR})
183149
target_link_libraries(ops PRIVATE infiniops)
184-
185150
set_target_properties(infiniops PROPERTIES INSTALL_RPATH "$ORIGIN")
186151
set_target_properties(ops PROPERTIES INSTALL_RPATH "$ORIGIN")
187-
188152
install(TARGETS infiniops ops DESTINATION .)
189-
190153
file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/__init__.py" "")
191154
install(FILES "${CMAKE_CURRENT_BINARY_DIR}/__init__.py" DESTINATION .)
192-
endif()
155+
endif()

src/cambricon/rms_norm/kernel.mlu

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "../../data_type.h"
12
#include "rms_norm.h"
23

34
__nram__ char nram_buffer[NRAM_MAX_SIZE];
@@ -255,3 +256,39 @@ void rmsnormUnion(void *workspace, int core_per_cluster, int cluster_count, cnrt
255256

256257
cnrtQueueSync(queue);
257258
}
259+
260+
template void rmsnormUnion<half, half>(
261+
void *, int, int, cnrtQueue_t, void *, const void *, const void *,
262+
const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int);
263+
264+
template void rmsnormUnion<half, bfloat16_t>(
265+
void *, int, int, cnrtQueue_t, void *, const void *, const void *,
266+
const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int);
267+
268+
template void rmsnormUnion<half, float>(
269+
void *, int, int, cnrtQueue_t, void *, const void *, const void *,
270+
const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int);
271+
272+
template void rmsnormUnion<bfloat16_t, half>(
273+
void *, int, int, cnrtQueue_t, void *, const void *, const void *,
274+
const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int);
275+
276+
template void rmsnormUnion<bfloat16_t, bfloat16_t>(
277+
void *, int, int, cnrtQueue_t, void *, const void *, const void *,
278+
const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int);
279+
280+
template void rmsnormUnion<bfloat16_t, float>(
281+
void *, int, int, cnrtQueue_t, void *, const void *, const void *,
282+
const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int);
283+
284+
template void rmsnormUnion<float, half>(
285+
void *, int, int, cnrtQueue_t, void *, const void *, const void *,
286+
const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int);
287+
288+
template void rmsnormUnion<float, bfloat16_t>(
289+
void *, int, int, cnrtQueue_t, void *, const void *, const void *,
290+
const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int);
291+
292+
template void rmsnormUnion<float, float>(
293+
void *, int, int, cnrtQueue_t, void *, const void *, const void *,
294+
const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int);

src/common/cambricon/cast.h

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#ifndef INFINI_OPS_COMMON_CAMBRICON_CAST_H_
2+
#define INFINI_OPS_COMMON_CAMBRICON_CAST_H_
3+
4+
#include "bang_fp16.h"
5+
#include "bang_bf16.h"
6+
7+
#include "data_type.h"
8+
9+
namespace infini::ops {
10+
11+
namespace detail {
12+
13+
template <typename T>
14+
using PureType = std::remove_cv_t<std::remove_reference_t<T>>;
15+
16+
template <typename T>
17+
__host__ __device__ constexpr float ToFloatHelper(T&& x) {
18+
using PureSrc = PureType<T>;
19+
if constexpr (IsBFloat16<PureSrc>) {
20+
return __bfloat162float__(x);
21+
} else if constexpr (IsFP16<PureSrc>) {
22+
return __half2float(x);
23+
} else {
24+
return static_cast<float>(std::forward<T>(x));
25+
}
26+
}
27+
28+
template <typename Dst>
29+
__host__ __device__ constexpr Dst FromFloatHelper(float f) {
30+
using PureDst = PureType<Dst>;
31+
if constexpr (IsBFloat16<PureDst>) {
32+
return __float2bfloat16__(f);
33+
} else if constexpr (IsFP16<PureDst>) {
34+
return __float2half__(f);
35+
} else {
36+
return static_cast<Dst>(f);
37+
}
38+
}
39+
40+
// Priority tags for overload resolution.
41+
struct PriorityLow {};
42+
43+
struct PriorityHigh : PriorityLow {};
44+
45+
// Fallback: lowest priority. This always matches if nothing else does.
46+
template <typename Dst, typename Src>
47+
__host__ __device__ constexpr Dst HardwareCast(Src&& x, PriorityLow) {
48+
return FromFloatHelper<Dst>(ToFloatHelper(std::forward<Src>(x)));
49+
}
50+
51+
// Usage: `DEFINE_DIRECT_CAST(INTRINSIC, CONDITION)`.
52+
#define DEFINE_DIRECT_CAST(INTRINSIC, ...) \
53+
template <typename Dst, typename Src> \
54+
__host__ __device__ auto HardwareCast(Src x, PriorityHigh) \
55+
->std::enable_if_t<(__VA_ARGS__), \
56+
decltype(INTRINSIC(std::declval<Src>()))> { \
57+
return INTRINSIC(x); \
58+
}
59+
60+
DEFINE_DIRECT_CAST(
61+
__bfloat162int_rz__,
62+
std::is_same_v<PureType<Dst>, int>&& IsBFloat16<PureType<Src>>)
63+
DEFINE_DIRECT_CAST(
64+
__bfloat162short_rz__,
65+
std::is_same_v<PureType<Dst>, short>&& IsBFloat16<PureType<Src>>)
66+
DEFINE_DIRECT_CAST(
67+
__int2bfloat16_rn__,
68+
IsBFloat16<PureType<Dst>>&& std::is_same_v<PureType<Src>, int>)
69+
DEFINE_DIRECT_CAST(__int2half_rn__,
70+
IsFP16<PureType<Dst>>&& std::is_same_v<PureType<Src>, int>)
71+
DEFINE_DIRECT_CAST(
72+
__float2bfloat16__,
73+
IsBFloat16<PureType<Dst>>&& std::is_same_v<PureType<Src>, double>)
74+
DEFINE_DIRECT_CAST(
75+
__float2half__,
76+
IsFP16<PureType<Dst>>&& std::is_same_v<PureType<Src>, double>)
77+
DEFINE_DIRECT_CAST(__half, IsFP16<PureType<Dst>>&& IsBFloat16<PureType<Src>>)
78+
#undef DEFINE_DIRECT_CAST
79+
80+
} // namespace detail
81+
82+
template <typename Dst, typename Src>
83+
__host__ __device__ Dst Cast(Src&& x) {
84+
static_assert(!std::is_reference_v<Dst>,
85+
"`Cast` cannot return reference types");
86+
87+
using PureSrc = std::remove_cv_t<std::remove_reference_t<Src>>;
88+
using PureDst = std::remove_cv_t<std::remove_reference_t<Dst>>;
89+
90+
if constexpr (std::is_same_v<PureSrc, PureDst>) {
91+
return std::forward<Src>(x);
92+
} else {
93+
return detail::HardwareCast<PureDst>(std::forward<Src>(x),
94+
detail::PriorityHigh{});
95+
}
96+
}
97+
98+
} // namespace infini::ops
99+
100+
#endif

src/common/cast.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR) || defined(WITH_METAX)
55
#include "common/cuda/cast.h"
6+
#elif defined(WITH_CAMBRICON)
7+
#include "common/cambricon/cast.h"
68
#else
79
#include "common/cpu/cast.h"
810
#endif

src/cpu/add/add.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <utility>
55

66
#include "base/add.h"
7-
#include "common/cast.h"
7+
#include "common/cpu/cast.h"
88
#include "common/generic_utils.h"
99

1010
namespace infini::ops {

src/cpu/causal_softmax/causal_softmax.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <cmath>
55

66
#include "base/causal_softmax.h"
7-
#include "common/cast.h"
7+
#include "common/cpu/cast.h"
88
#include "common/generic_utils.h"
99
#include "data_type.h"
1010
#include "tensor.h"

src/cpu/gemm/gemm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <utility>
55

66
#include "base/gemm.h"
7-
#include "common/cast.h"
7+
#include "common/cpu/cast.h"
88
#include "common/generic_utils.h"
99

1010
namespace infini::ops {

src/cpu/rms_norm/rms_norm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <cmath>
55

66
#include "base/rms_norm.h"
7-
#include "common/cast.h"
7+
#include "common/cpu/cast.h"
88
#include "common/generic_utils.h"
99
#include "data_type.h"
1010
#include "tensor.h"

src/cpu/swiglu/swiglu.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <cmath>
55

66
#include "base/swiglu.h"
7-
#include "common/cast.h"
7+
#include "common/cpu/cast.h"
88
#include "common/generic_utils.h"
99

1010
namespace infini::ops {

0 commit comments

Comments
 (0)