diff --git a/csrc/apis/runtime.hpp b/csrc/apis/runtime.hpp index 9ef42078..c1511ec2 100644 --- a/csrc/apis/runtime.hpp +++ b/csrc/apis/runtime.hpp @@ -6,11 +6,17 @@ namespace deep_gemm::runtime { static void register_apis(pybind11::module_& m) { + m.def("set_compile_mode", [&](const int& new_compile_mode) { + device_runtime->set_compile_mode(new_compile_mode); + }); + m.def("get_compile_mode", [&]() { + return device_runtime->get_compile_mode(); + }); m.def("set_num_sms", [&](const int& new_num_sms) { device_runtime->set_num_sms(new_num_sms); }); m.def("get_num_sms", [&]() { - return device_runtime->get_num_sms(); + return device_runtime->get_num_sms(); }); m.def("set_tc_util", [&](const int& new_tc_util) { device_runtime->set_tc_util(new_tc_util); diff --git a/csrc/jit/device_runtime.hpp b/csrc/jit/device_runtime.hpp index 79139d6c..cf78dec8 100644 --- a/csrc/jit/device_runtime.hpp +++ b/csrc/jit/device_runtime.hpp @@ -10,6 +10,7 @@ namespace deep_gemm { class DeviceRuntime { int num_sms = 0, tc_util = 0; std::shared_ptr cached_prop; + int compile_mode = 0; public: explicit DeviceRuntime() = default; @@ -51,6 +52,15 @@ class DeviceRuntime { return num_sms; } + void set_compile_mode(const int& new_compile_mode) { + DG_HOST_ASSERT(0 <= new_compile_mode and new_compile_mode <= 1); + compile_mode = new_compile_mode; + } + + int get_compile_mode() { + return compile_mode; + } + void set_tc_util(const int& new_tc_util) { DG_HOST_ASSERT(0 <= new_tc_util and new_tc_util <= 100); tc_util = new_tc_util; diff --git a/csrc/jit_kernels/impls/runtime_utils.hpp b/csrc/jit_kernels/impls/runtime_utils.hpp index ed9c5305..1578312e 100644 --- a/csrc/jit_kernels/impls/runtime_utils.hpp +++ b/csrc/jit_kernels/impls/runtime_utils.hpp @@ -170,4 +170,10 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major, swizzle_mode); } +#define MAYBE_LAUNCH(EXPR) do { \ + if (device_runtime->get_compile_mode() == 0) { \ + (EXPR); \ + } \ +} while (0) + } // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp b/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp index 033a7b75..cc2d14c1 100644 --- a/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp +++ b/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp @@ -137,7 +137,7 @@ static void sm100_bf16_gemm(const torch::Tensor& a, }; const auto& code = SM100BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm100_bf16_gemm", code); - SM100BF16GemmRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM100BF16GemmRuntime::launch(runtime, args)); } } // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp index 67272d9c..99e319a9 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp @@ -143,7 +143,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa }; const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code); - SM100FP8Gemm1D1DRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM100FP8Gemm1D1DRuntime::launch(runtime, args)); } static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, @@ -200,7 +200,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con }; const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d1d", code); - SM100FP8Gemm1D1DRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM100FP8Gemm1D1DRuntime::launch(runtime, args)); } static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, @@ -258,7 +258,7 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t }; const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d1d", code); - SM100FP8Gemm1D1DRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM100FP8Gemm1D1DRuntime::launch(runtime, args)); } static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, @@ -338,7 +338,7 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& }; const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_fp8_k_grouped_gemm_1d1d", code); - SM100FP8Gemm1D1DRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM100FP8Gemm1D1DRuntime::launch(runtime, args)); } } // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp index 727d1b74..028068ec 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp @@ -124,7 +124,7 @@ static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa }; const auto& code = SM100FP8Gemm1D2DRuntime::generate(args); const auto& runtime = compiler->build("sm100_fp8_gemm_1d2d", code); - SM100FP8Gemm1D2DRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM100FP8Gemm1D2DRuntime::launch(runtime, args)); } static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, @@ -177,7 +177,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, con }; const auto& code = SM100FP8Gemm1D2DRuntime::generate(args); const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d2d", code); - SM100FP8Gemm1D2DRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM100FP8Gemm1D2DRuntime::launch(runtime, args)); } static void sm100_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, @@ -231,7 +231,7 @@ static void sm100_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const t }; const auto& code = SM100FP8Gemm1D2DRuntime::generate(args); const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d2d", code); - SM100FP8Gemm1D2DRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM100FP8Gemm1D2DRuntime::launch(runtime, args)); } } // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp b/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp index ea29883c..c010f669 100644 --- a/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp +++ b/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp @@ -116,7 +116,7 @@ static void sm90_bf16_gemm(const torch::Tensor& a, }; const auto& code = SM90BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm90_bf16_gemm", code); - SM90BF16GemmRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM90BF16GemmRuntime::launch(runtime, args)); } static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, @@ -169,7 +169,7 @@ static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, }; const auto& code = SM90BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm90_m_grouped_bf16_gemm_contiguous", code); - SM90BF16GemmRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM90BF16GemmRuntime::launch(runtime, args)); } static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a, @@ -223,7 +223,7 @@ static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a, }; const auto& code = SM90BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm90_bf16_m_grouped_gemm_masked", code); - SM90BF16GemmRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM90BF16GemmRuntime::launch(runtime, args)); } } // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index 3afc2d33..f3e91f38 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -124,7 +124,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, }; const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code); - SM90FP8Gemm1D2DRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM90FP8Gemm1D2DRuntime::launch(runtime, args)); } static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, @@ -183,7 +183,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons }; const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); const auto& runtime = compiler->build("sm90_m_grouped_fp8_gemm_contiguous_1d2d", code); - SM90FP8Gemm1D2DRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM90FP8Gemm1D2DRuntime::launch(runtime, args)); } static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, @@ -243,7 +243,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to }; const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code); - SM90FP8Gemm1D2DRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM90FP8Gemm1D2DRuntime::launch(runtime, args)); } } // namespace deep_gemm diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index 169e2e6b..4f5420ef 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -14,6 +14,8 @@ # Configs import deep_gemm_cpp from deep_gemm_cpp import ( + set_compile_mode, + get_compile_mode, set_num_sms, get_num_sms, set_tc_util,