Skip to content
Merged
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/scheduler/resize.cpp
${NVFUSER_SRCS_DIR}/scheduler/runtime_info.cpp
${NVFUSER_SRCS_DIR}/scheduler/scheduler_types.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/cub_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/domain_map.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/inlining.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/loop_domain_scheduler.cpp
Expand Down
149 changes: 115 additions & 34 deletions csrc/scheduler/greedy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <scheduler/greedy.h>
#include <scheduler/mark_aliases.h>
#include <scheduler/runtime_info.h>
#include <scheduler/tools/cub_utils.h>
#include <scheduler/tools/inlining.h>
#include <scheduler/tools/loop_domain_scheduler.h>
#include <scheduler/tools/maxinfo_propagator.h>
Expand Down Expand Up @@ -634,6 +635,8 @@ class RunTimeChecker : private IterVisitor {
max_threads_per_block_(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock) {
traverse(fusion);

checkSharedMemoryBufferUsage();
}

void dispatch(Expr* expr) override {
Expand All @@ -644,46 +647,88 @@ class RunTimeChecker : private IterVisitor {
}

void handle(ArgsortOp* argsort) override {
checkDomainConstraints(
int64_t size_of_constrained_ids = checkDomainConstraints(
ir_utils::getTvOutput(argsort)->getLogicalDomain(),
{argsort->dim()},
dataTypeSizeByte(ir_utils::getTvOutput(argsort)->dtype()),
/*support_batching=*/true);

int64_t batch_size =
ceilDiv(size_of_constrained_ids, max_threads_per_block_);
int64_t bdimx = std::min(size_of_constrained_ids, max_threads_per_block_);
cub_shmem_buffer_.registerArgsort(
bdimx, batch_size, ir_utils::getTvInput(argsort)->dtype());
}

void handle(PadOp* pad) override {
checkDomainConstraints(
ir_utils::getTvOutput(pad)->getLogicalDomain(), pad->getPaddedAxes());
ir_utils::getTvOutput(pad)->getLogicalDomain(),
pad->getPaddedAxes(),
dataTypeSizeByte(ir_utils::getTvOutput(pad)->dtype()));
}

void handle(ScanOp* scan) override {
checkDomainConstraints(
ir_utils::getTvOutput(scan)->getLogicalDomain(),
{scan->dim()},
dataTypeSizeByte(ir_utils::getTvOutput(scan)->dtype()),
/*support_batching=*/true);
}

void handle(TopKOp* topk) override {
checkDomainConstraints(
ir_utils::getTvOutput(topk)->getLogicalDomain(), {topk->dim()});
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use of the output here was wrong. Need to check the input as it has a larger extent.

// TopKOp produces two outputs: one has the same type as the input
// and another is an integer index tensor
int64_t size_of_constrained_ids = checkDomainConstraints(
TensorDomain::noReductions(
ir_utils::getTvInput(topk)->getLogicalDomain()),
{topk->dim()},
dataTypeSizeByte(ir_utils::getTvInput(topk)->dtype()) +
dataTypeSizeByte(DataType::Int));

int64_t batch_size =
ceilDiv(size_of_constrained_ids, max_threads_per_block_);
int64_t bdimx = std::min(size_of_constrained_ids, max_threads_per_block_);
cub_shmem_buffer_.registerTopK(
bdimx, batch_size, ir_utils::getTvInput(topk)->dtype());
}

void handle(ScatterOp* scatter) override {
auto out = ir_utils::getTvOutput(scatter);
auto index = scatter->index()->as<TensorView>();

// TODO: If the input and output is a fusion input and output,
// there will be no computation for the shape of the logical
// domain, so this check is not necessary.
checkDomainConstraints(
out->getLogicalDomain(),
{scatter->dim()},
dataTypeSizeByte(out->dtype()),
/*support_batching=*/true);

int64_t index_bytes = dataTypeSizeByte(index->dtype());
// If it's scalar, ignore the contribution
int64_t src_bytes = scatter->src()->isA<TensorView>()
? dataTypeSizeByte(scatter->src()->dtype())
: 0;

checkDomainConstraints(
TensorDomain::noReductions(index->getLogicalDomain()),
{scatter->dim()},
index_bytes + src_bytes,
/*support_batching=*/true);
}

void checkDomainConstraints(
// Check the constraints on the given domain. bytes_per_element
// indicates the size of data required to hold one work item, which
// may correspond to multiple tensor elements. For example, in the
// case of TopKOp, two outputs are produced, so the size should
// cover both of them.
//
// Returns the size of the constrained IDs in bytes
int64_t checkDomainConstraints(
const std::vector<IterDomain*>& domain,
const std::vector<int64_t>& constrained_id_offsets,
int64_t bytes_per_element,
bool support_batching = false) {
int64_t size_of_constrained_ids = 1;
for (const auto i : constrained_id_offsets) {
Expand All @@ -697,42 +742,77 @@ class RunTimeChecker : private IterVisitor {
size_of_constrained_ids *= extent_val.as<int64_t>();
}

// The maximum supported size depends on several factors. The hard
// limit is the shared memory capacity since the kernel launch
// would just fail if the shared memory usage exceeds the
// available size. The next important limit would be the register
// usage as we would not want to have excessive register spilling.
//
const int64_t threads_per_block = max_threads_per_block_;

// At this moment, not all constrained ops supports batching. If
// batching is not supported, the limit is simply set as the
// maximum number of threads per thread block. This is likely
// a sufficient condition even for shared memory, although not
// guaranteed.
if (!support_batching) {
if (size_of_constrained_ids > threads_per_block) {
reject(
"Extent of constrained logical IDs, ",
size_of_constrained_ids,
", exceeds the number of threads per thread block: ",
threads_per_block);
}
}

// The maximum supported size depends on several factors. The hard
// limit is the shared memory capacity since the kernel launch
// would just fail if the shared memory usage exceeds the
// available size. It is checked at the end of the RunTimeChecker
// constructor.
//
// When batching is supported, up to half of the shared memory
// capacity is allowed for now. This is a pretty rough estimate
// and does not guarantee the safety of kernel launches nor avoids
// register spilling but is used for now since more accurate
// estimation of shared memory usage remains to be done, and the
// register spilling is not a functional concern.
//
// TODO: More accurate estimation of resource requirements
int64_t max_supported_size = max_threads_per_block_;
if (support_batching) {
auto available_shmem_capacity =
at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock / 2;
// TODO: don't assume it's always float.
auto element_size = sizeof(float);
max_supported_size =
static_cast<int64_t>(available_shmem_capacity / element_size);
}

if (size_of_constrained_ids > max_supported_size) {
// The next important limit would be the register usage as we
// would not want to have excessive register spilling. The
// register usage would be linearly correlated with the batching
// factor. For now, just put a simple upper limit to avoid
// disastrous regressions. Fine tuning would be necessary.
const int64_t register_count_per_thread =
ceilDiv(size_of_constrained_ids, threads_per_block) *
bytes_per_element / 4;
const int64_t available_register_count_per_thread =
at::cuda::getCurrentDeviceProperties()->regsPerBlock /
threads_per_block;
// Make sure at least 20 registers are always available
const int64_t reserved_regiser_count_per_thread = 20;
if (register_count_per_thread + reserved_regiser_count_per_thread >
available_register_count_per_thread) {
reject(
"Expected register usage, ",
register_count_per_thread,
", exceeds the available count, ",
available_register_count_per_thread);
}

return size_of_constrained_ids;
}

void checkSharedMemoryBufferUsage() {
// TODO: Use the constant and util functions added in #5272
auto aligned_size = [](int64_t x) { return (x + 127) / 128 * 128; };

const int64_t cub_buffer_size =
aligned_size(cub_shmem_buffer_.getTotalSizeInBytes());

// TODO: Shared memory may be also used for resolving mismatched
// parallelization of constrained.

const auto total_required_size = cub_buffer_size;

const auto available_size = static_cast<int64_t>(
at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock);

if (total_required_size > available_size) {
reject(
"Extent of constrained logical IDs, ",
size_of_constrained_ids,
", exceeds the maxinum supported size: ",
max_supported_size);
"Not enough shared memory. Required size for CUB: ",
cub_buffer_size,
". Total required size: ",
total_required_size,
". Available: ",
available_size);
}
}

Expand All @@ -750,6 +830,7 @@ class RunTimeChecker : private IterVisitor {
private:
SchedulerRuntimeInfo& runtime_info_;
int64_t max_threads_per_block_ = 0;
scheduler_tools::CubSharedMemoryBuffer cub_shmem_buffer_;

bool can_schedule_ = true;
std::string reject_reason_;
Expand Down
Loading
Loading