Skip to content

Commit

Permalink
misc: Finer-grained control over fp16/fp8 builds (#722)
Browse files Browse the repository at this point in the history
Flags can be used to disable fp16 and either of the fp8 variants in
order to speed up AOT builds.

By default, the configuration remains unchanged and the
`FLASHINFER_ENABLE_FP8` flag will enable both fp8 modes.
  • Loading branch information
nandor authored Jan 8, 2025
1 parent 06309c4 commit 13de896
Show file tree
Hide file tree
Showing 14 changed files with 177 additions and 164 deletions.
24 changes: 21 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ endif()
# options. Alernatively, use cmake -DOPTION=VALUE through command-line.
flashinfer_option(FLASHINFER_ENABLE_FP8
"Whether to compile fp8 kernels or not." ON)
flashinfer_option(FLASHINFER_ENABLE_FP8_E4M3
"Whether to compile fp8_e4m3 kernels or not." ON)
flashinfer_option(FLASHINFER_ENABLE_FP8_E5M2
"Whether to compile fp8_e5m2 kernels or not." ON)
flashinfer_option(FLASHINFER_ENABLE_F16
"Whether to compile f16 kernels or not." ON)
flashinfer_option(FLASHINFER_ENABLE_BF16
"Whether to compile bf16 kernels or not." ON)
flashinfer_option(
Expand Down Expand Up @@ -98,10 +104,20 @@ find_package(Thrust REQUIRED)
set(FLASHINFER_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/include)

if(FLASHINFER_ENABLE_FP8)
message(STATUS "Compile fp8 kernels.")
add_definitions(-DFLASHINFER_ENABLE_FP8)
set(FLASHINFER_ENABLE_FP8_E4M3 ON)
set(FLASHINFER_ENABLE_FP8_E5M2 ON)
endif(FLASHINFER_ENABLE_FP8)

if(FLASHINFER_ENABLE_FP8_E4M3)
message(STATUS "Compile fp8_e4m3 kernels.")
add_definitions(-DFLASHINFER_ENABLE_FP8_E4M3)
endif(FLASHINFER_ENABLE_FP8_E4M3)

if(FLASHINFER_ENABLE_FP8_E5M2)
message(STATUS "Compile fp8_e5m2 kernels.")
add_definitions(-DFLASHINFER_ENABLE_FP8_E5M2)
endif(FLASHINFER_ENABLE_FP8_E5M2)

if(FLASHINFER_ENABLE_BF16)
message(STATUS "Compile bf16 kernels.")
add_definitions(-DFLASHINFER_ENABLE_BF16)
Expand Down Expand Up @@ -130,8 +146,10 @@ set(AOT_GENERATE_COMMAND
--pos_encoding_modes ${POS_ENCODING_MODES}
--allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS}
--mask_modes ${MASK_MODES}
--enable_f16 ${FLASHINFER_ENABLE_F16}
--enable_bf16 ${FLASHINFER_ENABLE_BF16}
--enable_fp8 ${FLASHINFER_ENABLE_FP8})
--enable_fp8_e4m3 ${FLASHINFER_ENABLE_FP8_E4M3}
--enable_fp8_e5m2 ${FLASHINFER_ENABLE_FP8_E5M2})

execute_process(
COMMAND ${AOT_GENERATE_COMMAND}
Expand Down
42 changes: 33 additions & 9 deletions aot_build_utils/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ def write_if_different(path: Path, content: str) -> None:
pos_encoding_modes: List[int] = args.pos_encoding_modes
allow_fp16_qk_reductions: List[int] = args.allow_fp16_qk_reductions
mask_modes: List[int] = args.mask_modes
enable_f16: bool = args.enable_f16
enable_bf16: bool = args.enable_bf16
enable_fp8: bool = args.enable_fp8
enable_fp8_e4m3: bool = args.enable_fp8_e4m3
enable_fp8_e5m2: bool = args.enable_fp8_e5m2

path.mkdir(parents=True, exist_ok=True)

Expand All @@ -59,16 +61,24 @@ def write_if_different(path: Path, content: str) -> None:
)

idtypes = ["i32"]
prefill_dtypes = ["f16"]
decode_dtypes = ["f16"]
fp16_dtypes = ["f16"]
fp8_dtypes = ["e4m3", "e5m2"]
prefill_dtypes = []
decode_dtypes = []
fp16_dtypes = []
fp8_dtypes = []
if enable_f16:
prefill_dtypes.append("f16")
decode_dtypes.append("f16")
fp16_dtypes.append("f16")
if enable_bf16:
prefill_dtypes.append("bf16")
decode_dtypes.append("bf16")
fp16_dtypes.append("bf16")
if enable_fp8:
decode_dtypes.extend(fp8_dtypes)
if enable_fp8_e4m3:
fp8_dtypes.extend(["e4m3"])
decode_dtypes.extend(["e4m3"])
if enable_fp8_e5m2:
fp8_dtypes.extend(["e5m2"])
decode_dtypes.extend(["e5m2"])

single_decode_uris = []
# single decode files
Expand Down Expand Up @@ -276,6 +286,13 @@ def write_if_different(path: Path, content: str) -> None:
nargs="+",
help="Mask modes",
)
parser.add_argument(
"--enable_fp16",
type=lambda x: x if isinstance(x, int) else x.lower() == "true",
required=True,
nargs="+",
help="Enable fp16",
)
parser.add_argument(
"--enable_bf16",
type=lambda x: x if isinstance(x, int) else x.lower() == "true",
Expand All @@ -284,11 +301,18 @@ def write_if_different(path: Path, content: str) -> None:
help="Enable bf16",
)
parser.add_argument(
"--enable_fp8",
"--enable_fp8_e4m3",
type=lambda x: x if isinstance(x, int) else x.lower() == "true",
default=True,
nargs="+",
help="Enable fp8_e4m3",
)
parser.add_argument(
"--enable_fp8_e5m2",
type=lambda x: x if isinstance(x, int) else x.lower() == "true",
default=True,
nargs="+",
help="Enable fp8",
help="Enable fp8_e5m2",
)
args = parser.parse_args()
get_instantiation_cu(args)
21 changes: 13 additions & 8 deletions aot_build_utils/generate_sm90.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,19 @@ def write_if_different(path: Path, content: str) -> None:
pos_encoding_modes: List[int] = args.pos_encoding_modes
allow_fp16_qk_reductions: List[int] = args.allow_fp16_qk_reductions
mask_modes: List[int] = args.mask_modes
enable_f16: bool = args.enable_f16
enable_bf16: bool = args.enable_bf16

path.mkdir(parents=True, exist_ok=True)

idtypes = ["i32"]
prefill_dtypes = ["f16"]
decode_dtypes = ["f16"]
fp16_dtypes = ["f16"]
prefill_dtypes = []
decode_dtypes = []
fp16_dtypes = []
if enable_f16:
prefill_dtypes.append("f16")
decode_dtypes.append("f16")
fp16_dtypes.append("f16")
if enable_bf16:
prefill_dtypes.append("bf16")
decode_dtypes.append("bf16")
Expand Down Expand Up @@ -183,18 +188,18 @@ def write_if_different(path: Path, content: str) -> None:
help="Mask modes",
)
parser.add_argument(
"--enable_bf16",
"--enable_f16",
type=lambda x: x if isinstance(x, int) else x.lower() == "true",
required=True,
nargs="+",
help="Enable bf16",
help="Enable f16",
)
parser.add_argument(
"--enable_fp8",
"--enable_bf16",
type=lambda x: x if isinstance(x, int) else x.lower() == "true",
default=True,
required=True,
nargs="+",
help="Enable fp8",
help="Enable bf16",
)
args = parser.parse_args()
get_sm90_instantiation_cu(args)
3 changes: 2 additions & 1 deletion cmake/config.cmake
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Whether to compile fp8 kernels or not.
set(FLASHINFER_ENABLE_FP8 ON)
set(FLASHINFER_ENABLE_FP8_E4M3 ON)
set(FLASHINFER_ENABLE_FP8_E5M2 ON)
# Whether to compile bf16 kernels or not.
set(FLASHINFER_ENABLE_BF16 ON)
# Whether to compile tvm bindings or not.
Expand Down
15 changes: 8 additions & 7 deletions csrc/bmm_fp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,20 @@ void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::T
TORCH_CHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2),
"Result tensor has incorrect shape");

auto batch_size = A.size(0);
auto m = A.size(1);
auto k = A.size(2);
auto n = B.size(2);

auto lt_handle = reinterpret_cast<cublasLtHandle_t>(cublas_handle);
auto stream = reinterpret_cast<cudaStream_t>(cuda_stream);
// PyTorch is row major by default. cuBLASLt is column major by default.
// We need row major D as expected.
// A ^ T * B = D, so D ^ T = B ^ T * A
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(B.scalar_type(), b_type, [&] {
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(A.scalar_type(), a_type, [&] {
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(D.scalar_type(), d_type, [&] {
auto batch_size = A.size(0);
auto m = A.size(1);
auto k = A.size(2);
auto n = B.size(2);

auto lt_handle = reinterpret_cast<cublasLtHandle_t>(cublas_handle);
auto stream = reinterpret_cast<cudaStream_t>(cuda_stream);

auto status = flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt(
workspace_buffer.data_ptr(), workspace_buffer.numel(),
static_cast<b_type*>(B.data_ptr()), static_cast<a_type*>(A.data_ptr()),
Expand Down
Loading

0 comments on commit 13de896

Please sign in to comment.