Skip to content

Commit 2fb745a

Browse files
committed
feat: add flashinfer as kernel backend for cuda device.
1 parent 2e96809 commit 2fb745a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+1748
-119
lines changed

.gitmodules

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,9 @@
2828
[submodule "third_party/Mooncake"]
2929
path = third_party/Mooncake
3030
url = https://gitcode.com/xLLM-AI/Mooncake.git
31+
[submodule "third_party/tvm-ffi"]
32+
path = third_party/tvm-ffi
33+
url = https://gitcode.com/xLLM-AI/tvm-ffi.git
34+
[submodule "third_party/dlpack"]
35+
path = third_party/dlpack
36+
url = https://gitcode.com/xLLM-AI/dlpack.git

CMakeLists.txt

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
cmake_minimum_required(VERSION 3.26)
22
set_property(GLOBAL PROPERTY USE_FOLDERS ON)
3+
set(CMAKE_CUDA_COMPILER "/usr/local/cuda/bin/nvcc")
34

45
option(USE_NPU "Enable NPU support" OFF)
56
option(USE_MLU "Enable MLU support" OFF)
7+
option(USE_CUDA "Enable CUDA support" OFF)
68

79
if(DEVICE_ARCH STREQUAL "ARM")
810
set(CMAKE_SYSTEM_PROCESSOR aarch64)
@@ -101,7 +103,7 @@ set(CMAKE_CXX_STANDARD 20)
101103
set(CMAKE_CXX_STANDARD_REQUIRED ON)
102104
set(CMAKE_CXX_EXTENSIONS ON)
103105

104-
if(USE_NPU)
106+
if(USE_NPU OR USE_CUDA)
105107
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
106108
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
107109
elseif(USE_MLU)
@@ -178,6 +180,32 @@ if (DEFINED ENV{DEPENDENCES_ROOT})
178180
message(STATUS "Using DEPENDENCES_ROOT: $ENV{DEPENDENCES_ROOT}")
179181
endif()
180182

183+
# set architecture for CUDA
184+
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES AND USE_CUDA)
185+
set(CMAKE_CUDA_ARCHITECTURES 80)
186+
endif()
187+
188+
# Build TORCH_CUDA_ARCH_LIST
189+
if(USE_CUDA)
190+
# Build TORCH_CUDA_ARCH_LIST
191+
set(TORCH_CUDA_ARCH_LIST "")
192+
foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES)
193+
if(CUDA_ARCH MATCHES "^([0-9])([0-9])a$")
194+
set(TORCH_ARCH "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}a")
195+
elseif(CUDA_ARCH MATCHES "^([0-9])([0-9])*$")
196+
set(TORCH_ARCH "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}")
197+
elseif(CUDA_ARCH STREQUAL "native")
198+
set(TORCH_ARCH "Auto")
199+
else()
200+
message(FATAL_ERROR "${CUDA_ARCH} is not supported")
201+
endif()
202+
list(APPEND TORCH_CUDA_ARCH_LIST ${TORCH_ARCH})
203+
endforeach()
204+
205+
message(STATUS "CMAKE_CUDA_ARCHITECTURES: ${CMAKE_CUDA_ARCHITECTURES}")
206+
message(STATUS "TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST}")
207+
endif()
208+
181209
# configure vcpkg
182210
# have to set CMAKE_TOOLCHAIN_FILE before first project call.
183211
# if (DEFINED ENV{VCPKG_ROOT} AND NOT DEFINED CMAKE_TOOLCHAIN_FILE)
@@ -217,7 +245,12 @@ endif()
217245
set(CPPREST_EXCLUDE_WEBSOCKETS ON CACHE BOOL "Exclude websockets functionality." FORCE)
218246
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-format-truncation")
219247

220-
project("xllm" LANGUAGES C CXX)
248+
if(USE_CUDA)
249+
project("xllm" LANGUAGES C CXX CUDA)
250+
find_package(CUDAToolkit REQUIRED)
251+
else()
252+
project("xllm" LANGUAGES C CXX)
253+
endif()
221254

222255
# find_package(CUDAToolkit REQUIRED)
223256

@@ -352,6 +385,43 @@ if(USE_MLU)
352385
)
353386
endif()
354387

388+
if(USE_CUDA)
389+
add_definitions(-DUSE_CUDA)
390+
add_compile_definitions(TORCH_CUDA=1)
391+
set(CMAKE_VERBOSE_MAKEFILE ON)
392+
include_directories(
393+
$ENV{PYTHON_INCLUDE_PATH}
394+
$ENV{PYTORCH_INSTALL_PATH}/include
395+
$ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/api/include
396+
)
397+
398+
link_directories(
399+
$ENV{PYTHON_LIB_PATH}
400+
$ENV{PYTORCH_INSTALL_PATH}/lib
401+
$ENV{CUDA_TOOLKIT_ROOT_DIR}/lib64
402+
)
403+
404+
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -O3)
405+
# The following definitions must be undefined since half-precision operation is required.
406+
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS}
407+
-U__CUDA_NO_HALF_OPERATORS__
408+
-U__CUDA_NO_HALF_CONVERSIONS__
409+
-U__CUDA_NO_HALF2_OPERATORS__
410+
-U__CUDA_NO_BFLOAT16_CONVERSIONS__)
411+
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} --use_fast_math -Xfatbin -compress-all)
412+
message(STATUS "CUDA_NVCC_FLAGS: ${CUDA_NVCC_FLAGS}")
413+
414+
# find_package(NCCL REQUIRED)
415+
416+
# find cudnn
417+
execute_process(COMMAND python -c "import nvidia.cudnn; print(nvidia.cudnn.__file__)" OUTPUT_VARIABLE CUDNN_PYTHON_PATH)
418+
get_filename_component(CUDNN_ROOT_DIR "${CUDNN_PYTHON_PATH}" DIRECTORY)
419+
link_directories(
420+
${CUDNN_ROOT_DIR}/lib64
421+
${CUDNN_ROOT_DIR}/lib
422+
)
423+
endif()
424+
355425
# check if USE_CXX11_ABI is set correctly
356426
# if (DEFINED USE_CXX11_ABI)
357427
# parse_make_options(${TORCH_CXX_FLAGS} "TORCH_CXX_FLAGS")

setup.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def get_python_include_path():
9898
return None
9999

100100

101-
# PYTORCH_INSTALL_PATH and LIBTORCH_ROOT
102101
def get_torch_root_path():
103102
try:
104103
import torch
@@ -115,6 +114,12 @@ def get_torch_mlu_root_path():
115114
except ImportError:
116115
return None
117116

117+
def get_nccl_root_path():
118+
try:
119+
from nvidia import nccl
120+
return str(Path(nccl.__file__).parent)
121+
except ImportError:
122+
return None
118123

119124
def set_npu_envs():
120125
PYTORCH_NPU_INSTALL_PATH = os.getenv("PYTORCH_NPU_INSTALL_PATH")
@@ -212,7 +217,16 @@ def set_mlu_envs():
212217
os.environ["LIBTORCH_ROOT"] = get_torch_root_path()
213218
os.environ["PYTORCH_INSTALL_PATH"] = get_torch_root_path()
214219
os.environ["PYTORCH_MLU_INSTALL_PATH"] = get_torch_mlu_root_path()
215-
220+
221+
def set_cuda_envs():
222+
os.environ["PYTHON_INCLUDE_PATH"] = get_python_include_path()
223+
os.environ["PYTHON_LIB_PATH"] = get_torch_root_path()
224+
os.environ["LIBTORCH_ROOT"] = get_torch_root_path()
225+
os.environ["PYTORCH_INSTALL_PATH"] = get_torch_root_path()
226+
os.environ["CUDA_TOOLKIT_ROOT_DIR"] = "/usr/local/cuda"
227+
os.environ["NCCL_ROOT"] = get_nccl_root_path()
228+
os.environ["NCCL_VERSION"] = "2"
229+
216230
class CMakeExtension(Extension):
217231
def __init__(self, name: str, path: str, sourcedir: str = "") -> None:
218232
super().__init__(name, sources=[])
@@ -223,7 +237,7 @@ def __init__(self, name: str, path: str, sourcedir: str = "") -> None:
223237
class ExtBuild(build_ext):
224238
user_options = build_ext.user_options + [
225239
("base-dir=", None, "base directory of xLLM project"),
226-
("device=", None, "target device type (a3 or a2 or mlu)"),
240+
("device=", None, "target device type (a3 or a2 or mlu or cuda)"),
227241
("arch=", None, "target arch type (x86 or arm)"),
228242
("install-xllm-kernels=", None, "install xllm_kernels RPM package (true/false)"),
229243
]
@@ -302,8 +316,14 @@ def build_extension(self, ext: CMakeExtension):
302316
cmake_args += ["-DUSE_MLU=ON"]
303317
# set mlu environment variables
304318
set_mlu_envs()
319+
elif self.device == "cuda":
320+
cuda_architectures = "80;89;90"
321+
cmake_args += ["-DUSE_CUDA=ON",
322+
f"-DCMAKE_CUDA_ARCHITECTURES={cuda_architectures}"]
323+
# set cuda environment variables
324+
set_cuda_envs()
305325
else:
306-
raise ValueError("Please set --device to a2 or a3 or mlu.")
326+
raise ValueError("Please set --device to a2 or a3 or mlu or cuda.")
307327

308328

309329
# Adding CMake arguments set as environment variable
@@ -353,7 +373,7 @@ def build_extension(self, ext: CMakeExtension):
353373

354374
class BuildDistWheel(bdist_wheel):
355375
user_options = bdist_wheel.user_options + [
356-
("device=", None, "target device type (a3 or a2 or mlu)"),
376+
("device=", None, "target device type (a3 or a2 or mlu or cuda)"),
357377
("arch=", None, "target arch type (x86 or arm)"),
358378
]
359379

@@ -530,7 +550,7 @@ def apply_patch():
530550
idx = sys.argv.index('--device')
531551
if idx + 1 < len(sys.argv):
532552
device = sys.argv[idx+1].lower()
533-
if device not in ('a2', 'a3', 'mlu'):
553+
if device not in ('a2', 'a3', 'mlu', 'cuda'):
534554
print("Error: --device must be a2 or a3 or mlu (case-insensitive)")
535555
sys.exit(1)
536556
# Remove the arguments so setup() doesn't see them

third_party/dlpack

Submodule dlpack added at 93c8f2a

third_party/tvm-ffi

Submodule tvm-ffi added at af898a2

xllm/core/common/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ cc_library(
1515
rate_limiter.h
1616
types.h
1717
device_monitor.h
18+
flashinfer_workspace.h
1819
SRCS
1920
etcd_client.cpp
2021
global_flags.cpp
@@ -23,6 +24,7 @@ cc_library(
2324
options.cpp
2425
rate_limiter.cpp
2526
device_monitor.cpp
27+
flashinfer_workspace.cpp
2628
DEPS
2729
util
2830
absl::random_random
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "flashinfer_workspace.h"
17+
18+
#include <glog/logging.h>
19+
20+
#include "global_flags.h"
21+
22+
namespace xllm {
23+
24+
void FlashinferWorkspace::initialize(const torch::Device& device) {
25+
LOG(INFO) << "FlashinferWorkspace initialize on device: " << device;
26+
float_workspace_buffer_ =
27+
torch::empty({FLAGS_workspace_buffer_size},
28+
torch::dtype(torch::kUInt8).device(device));
29+
int_workspace_buffer_ = torch::empty(
30+
{128 * 1024 * 1024}, torch::dtype(torch::kUInt8).device(device));
31+
page_locked_int_workspace_buffer_ = torch::empty(
32+
{int_workspace_buffer_.size(0)},
33+
torch::dtype(torch::kUInt8).device(torch::kCPU).pinned_memory(true));
34+
LOG(INFO) << "FlashinferWorkspace initialize end";
35+
}
36+
37+
torch::Tensor FlashinferWorkspace::get_float_workspace_buffer() {
38+
return float_workspace_buffer_;
39+
}
40+
41+
torch::Tensor FlashinferWorkspace::get_int_workspace_buffer() {
42+
return int_workspace_buffer_;
43+
}
44+
45+
torch::Tensor FlashinferWorkspace::get_page_locked_int_workspace_buffer() {
46+
return page_locked_int_workspace_buffer_;
47+
}
48+
49+
} // namespace xllm
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
18+
#include <torch/torch.h>
19+
20+
#include <cstdint>
21+
22+
#include "macros.h"
23+
24+
namespace xllm {
25+
26+
class FlashinferWorkspace {
27+
public:
28+
static FlashinferWorkspace& get_instance() {
29+
static FlashinferWorkspace instance;
30+
return instance;
31+
};
32+
33+
void initialize(const torch::Device& device);
34+
35+
torch::Tensor get_float_workspace_buffer();
36+
torch::Tensor get_int_workspace_buffer();
37+
torch::Tensor get_page_locked_int_workspace_buffer();
38+
39+
private:
40+
FlashinferWorkspace() = default;
41+
~FlashinferWorkspace() = default;
42+
DISALLOW_COPY_AND_ASSIGN(FlashinferWorkspace);
43+
44+
torch::Tensor float_workspace_buffer_;
45+
torch::Tensor int_workspace_buffer_;
46+
torch::Tensor page_locked_int_workspace_buffer_;
47+
};
48+
49+
} // namespace xllm

xllm/core/common/global_flags.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ DEFINE_int32(micro_batch_num,
353353
"Default use two micro batches for multi-stream parallel.");
354354

355355
// --- dit config ---
356+
356357
DEFINE_int32(max_requests_per_batch, 1, "Max number of request per batch.");
357358

358359
// --- continuous kv cache config ---
@@ -377,15 +378,25 @@ DEFINE_int64(buffer_size_per_seq,
377378
"Buffer size per sequence in bytes, default 0.");
378379

379380
// --- beam search config ---
381+
380382
DEFINE_bool(enable_beam_search_kernel,
381383
false,
382384
"Whether to enable beam search kernel.");
383385

384386
// --- reasoning parser config ---
387+
385388
DEFINE_string(reasoning_parser,
386389
"",
387390
"Specify the reasoning parser for handling reasoning "
388391
"interactions(e.g. glm45, qwen3, deepseek-r1).");
389392

390393
// --- qwen3 reranker config ---
394+
391395
DEFINE_bool(enable_qwen3_reranker, false, "Whether to enable qwen3 reranker.");
396+
397+
// --- flashinfer config ---
398+
399+
DEFINE_int32(workspace_buffer_size,
400+
128 * 1024 * 1024,
401+
"The user reserved workspace buffer used to store intermediate "
402+
"attention results in split-k algorithm for flashinfer.");

xllm/core/common/global_flags.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,5 @@ DECLARE_bool(enable_qwen3_reranker);
202202
DECLARE_string(reasoning_parser);
203203

204204
DECLARE_bool(enable_shm);
205+
206+
DECLARE_int32(workspace_buffer_size);

0 commit comments

Comments
 (0)