Skip to content
Merged
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/scheduler/resize.cpp
${NVFUSER_SRCS_DIR}/scheduler/runtime_info.cpp
${NVFUSER_SRCS_DIR}/scheduler/scheduler_types.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/cub_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/domain_map.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/inlining.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/loop_domain_scheduler.cpp
Expand Down
74 changes: 70 additions & 4 deletions csrc/scheduler/greedy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <scheduler/greedy.h>
#include <scheduler/mark_aliases.h>
#include <scheduler/runtime_info.h>
#include <scheduler/tools/cub_utils.h>
#include <scheduler/tools/inlining.h>
#include <scheduler/tools/loop_domain_scheduler.h>
#include <scheduler/tools/maxinfo_propagator.h>
Expand Down Expand Up @@ -617,20 +618,40 @@ class RunTimeChecker : private IterVisitor {
max_threads_per_block_(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock) {
traverse(fusion);

checkSharedMemoryBufferUsage();
}

void dispatch(Expr* expr) override {
if (!can_schedule_) {
return;
}
IterVisitor::dispatch(expr);

// Accumulate the largest type size for estimating the buffer size
// required for resolving scheduling conflicts. This is a
// conservative, gross estimate that should be improved as needed.
for (auto inp_tv : ir_utils::filterByType<TensorView>(expr->inputs())) {
largest_data_type_size_ =
std::max(largest_data_type_size_, dataTypeSizeByte(inp_tv->dtype()));
}
for (auto out_tv : ir_utils::filterByType<TensorView>(expr->outputs())) {
largest_data_type_size_ =
std::max(largest_data_type_size_, dataTypeSizeByte(out_tv->dtype()));
}
}

void handle(ArgsortOp* argsort) override {
checkDomainConstraints(
int64_t size_of_constrained_ids = checkDomainConstraints(
ir_utils::getTvOutput(argsort)->getLogicalDomain(),
{argsort->dim()},
/*support_batching=*/true);

int64_t batch_size =
ceilDiv(size_of_constrained_ids, max_threads_per_block_);
int64_t bdimx = std::min(size_of_constrained_ids, max_threads_per_block_);
cub_shmem_buffer_.registerArgsort(
bdimx, batch_size, ir_utils::getTvInput(argsort)->dtype());
}

void handle(PadOp* pad) override {
Expand All @@ -646,8 +667,16 @@ class RunTimeChecker : private IterVisitor {
}

void handle(TopKOp* topk) override {
checkDomainConstraints(
ir_utils::getTvOutput(topk)->getLogicalDomain(), {topk->dim()});
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use of the output here was wrong. Need to check the input as it has a larger extent.

int64_t size_of_constrained_ids = checkDomainConstraints(
TensorDomain::noReductions(
ir_utils::getTvInput(topk)->getLogicalDomain()),
{topk->dim()});

int64_t batch_size =
ceilDiv(size_of_constrained_ids, max_threads_per_block_);
int64_t bdimx = std::min(size_of_constrained_ids, max_threads_per_block_);
cub_shmem_buffer_.registerTopK(
bdimx, batch_size, ir_utils::getTvInput(topk)->dtype());
}

void handle(ScatterOp* scatter) override {
Expand All @@ -664,7 +693,8 @@ class RunTimeChecker : private IterVisitor {
/*support_batching=*/true);
}

void checkDomainConstraints(
// Returns batch size
int64_t checkDomainConstraints(
const std::vector<IterDomain*>& domain,
const std::vector<int64_t>& constrained_id_offsets,
bool support_batching = false) {
Expand Down Expand Up @@ -717,6 +747,39 @@ class RunTimeChecker : private IterVisitor {
", exceeds the maxinum supported size: ",
max_supported_size);
}

max_constraint_size_ =
std::max(max_constraint_size_, size_of_constrained_ids);

return size_of_constrained_ids;
}

void checkSharedMemoryBufferUsage() {
// TODO: Use the constant and util functions added in #5272
auto aligned_size = [](int64_t x) { return (x + 127) / 128 * 128; };

const int64_t cub_buffer_size =
aligned_size(cub_shmem_buffer_.getTotalSizeInBytes());

// Shared memory may be also used for resolving mismatched
// parallelization of constrained IDs
const auto resolution_size =
aligned_size(max_constraint_size_ * largest_data_type_size_);

const auto total_required_size = cub_buffer_size + resolution_size;

const auto available_size =
at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock;

if (total_required_size > available_size) {
reject(
"Not enough shared memory. Required size for CUB: ",
cub_buffer_size,
". Total required size: ",
total_required_size,
". Available: ",
available_size);
}
}

template <typename... Args>
Expand All @@ -733,6 +796,9 @@ class RunTimeChecker : private IterVisitor {
private:
SchedulerRuntimeInfo& runtime_info_;
int64_t max_threads_per_block_ = 0;
int64_t max_constraint_size_ = 0;
int64_t largest_data_type_size_ = 0;
scheduler_tools::CubSharedMemoryBuffer cub_shmem_buffer_;

bool can_schedule_ = true;
std::string reject_reason_;
Expand Down
Loading
Loading