Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion csrc/apis/runtime.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@
namespace deep_gemm::runtime {

static void register_apis(pybind11::module_& m) {
m.def("set_compile_mode", [&](const int& new_compile_mode) {
device_runtime->set_compile_mode(new_compile_mode);
});
m.def("get_compile_mode", [&]() {
return device_runtime->get_compile_mode();
});
m.def("set_num_sms", [&](const int& new_num_sms) {
device_runtime->set_num_sms(new_num_sms);
});
m.def("get_num_sms", [&]() {
return device_runtime->get_num_sms();
return device_runtime->get_num_sms();
});
m.def("set_tc_util", [&](const int& new_tc_util) {
device_runtime->set_tc_util(new_tc_util);
Expand Down
10 changes: 10 additions & 0 deletions csrc/jit/device_runtime.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace deep_gemm {
class DeviceRuntime {
int num_sms = 0, tc_util = 0;
std::shared_ptr<cudaDeviceProp> cached_prop;
int compile_mode = 0;

public:
explicit DeviceRuntime() = default;
Expand Down Expand Up @@ -51,6 +52,15 @@ class DeviceRuntime {
return num_sms;
}

void set_compile_mode(const int& new_compile_mode) {
DG_HOST_ASSERT(0 <= new_compile_mode and new_compile_mode <= 1);
compile_mode = new_compile_mode;
}

int get_compile_mode() {
return compile_mode;
}

void set_tc_util(const int& new_tc_util) {
DG_HOST_ASSERT(0 <= new_tc_util and new_tc_util <= 100);
tc_util = new_tc_util;
Expand Down
6 changes: 6 additions & 0 deletions csrc/jit_kernels/impls/runtime_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,10 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
swizzle_mode);
}

#define MAYBE_LAUNCH(EXPR) do { \
if (device_runtime->get_compile_mode() == 0) { \
(EXPR); \
} \
} while (0)

} // namespace deep_gemm
2 changes: 1 addition & 1 deletion csrc/jit_kernels/impls/sm100_bf16_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ static void sm100_bf16_gemm(const torch::Tensor& a,
};
const auto& code = SM100BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm100_bf16_gemm", code);
SM100BF16GemmRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM100BF16GemmRuntime::launch(runtime, args));
}

} // namespace deep_gemm
8 changes: 4 additions & 4 deletions csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code);
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM100FP8Gemm1D1DRuntime::launch(runtime, args));
}

static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
Expand Down Expand Up @@ -200,7 +200,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d1d", code);
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM100FP8Gemm1D1DRuntime::launch(runtime, args));
}

static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
Expand Down Expand Up @@ -258,7 +258,7 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d1d", code);
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM100FP8Gemm1D1DRuntime::launch(runtime, args));
}

static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
Expand Down Expand Up @@ -338,7 +338,7 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor&
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_k_grouped_gemm_1d1d", code);
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM100FP8Gemm1D1DRuntime::launch(runtime, args));
}

} // namespace deep_gemm
6 changes: 3 additions & 3 deletions csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa
};
const auto& code = SM100FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_gemm_1d2d", code);
SM100FP8Gemm1D2DRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM100FP8Gemm1D2DRuntime::launch(runtime, args));
}

static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
Expand Down Expand Up @@ -177,7 +177,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, con
};
const auto& code = SM100FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d2d", code);
SM100FP8Gemm1D2DRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM100FP8Gemm1D2DRuntime::launch(runtime, args));
}

static void sm100_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
Expand Down Expand Up @@ -231,7 +231,7 @@ static void sm100_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const t
};
const auto& code = SM100FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d2d", code);
SM100FP8Gemm1D2DRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM100FP8Gemm1D2DRuntime::launch(runtime, args));
}

} // namespace deep_gemm
6 changes: 3 additions & 3 deletions csrc/jit_kernels/impls/sm90_bf16_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ static void sm90_bf16_gemm(const torch::Tensor& a,
};
const auto& code = SM90BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm90_bf16_gemm", code);
SM90BF16GemmRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM90BF16GemmRuntime::launch(runtime, args));
}

static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a,
Expand Down Expand Up @@ -169,7 +169,7 @@ static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a,
};
const auto& code = SM90BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm90_m_grouped_bf16_gemm_contiguous", code);
SM90BF16GemmRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM90BF16GemmRuntime::launch(runtime, args));
}

static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a,
Expand Down Expand Up @@ -223,7 +223,7 @@ static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a,
};
const auto& code = SM90BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm90_bf16_m_grouped_gemm_masked", code);
SM90BF16GemmRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM90BF16GemmRuntime::launch(runtime, args));
}

} // namespace deep_gemm
6 changes: 3 additions & 3 deletions csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
};
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code);
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM90FP8Gemm1D2DRuntime::launch(runtime, args));
}

static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
Expand Down Expand Up @@ -183,7 +183,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
};
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_m_grouped_fp8_gemm_contiguous_1d2d", code);
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM90FP8Gemm1D2DRuntime::launch(runtime, args));
}

static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
Expand Down Expand Up @@ -243,7 +243,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
};
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code);
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM90FP8Gemm1D2DRuntime::launch(runtime, args));
}

} // namespace deep_gemm
2 changes: 2 additions & 0 deletions deep_gemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# Configs
import deep_gemm_cpp
from deep_gemm_cpp import (
set_compile_mode,
get_compile_mode,
set_num_sms,
get_num_sms,
set_tc_util,
Expand Down