Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New scheduling directive to disallow partitioning. #7882

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/AsyncProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class NoOpCollapsingMutator : public IRMutator {
if (is_no_op(body)) {
return body;
} else {
return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body);
return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body, op->allow_partitioning);
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3725,7 +3725,8 @@ void bounds_test() {
{Add::make(Call::make(in, input_site_1),
Call::make(in, input_site_2))},
output_site,
const_true()));
const_true()),
true);

map<string, Box> r;
r = boxes_required(loop);
Expand Down
4 changes: 2 additions & 2 deletions src/BoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,7 @@ class BoundsInference : public IRMutator {
}
}

return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body);
return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body, op->allow_partitioning);
}

Scope<> let_vars_in_scope;
Expand Down Expand Up @@ -1389,7 +1389,7 @@ Stmt bounds_inference(Stmt s,
s = Block::make(Evaluate::make(marker), s);

// Add a synthetic outermost loop to act as 'root'.
s = For::make("<outermost>", 0, 1, ForType::Serial, DeviceAPI::None, s);
s = For::make("<outermost>", 0, 1, ForType::Serial, DeviceAPI::None, s, false);

s = BoundsInference(funcs, fused_func_groups, fused_pairs_in_groups,
outputs, func_bounds, target)
Expand Down
2 changes: 1 addition & 1 deletion src/CanonicalizeGPUVars.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class CanonicalizeGPUVars : public IRMutator {
body.same_as(op->body)) {
return op;
} else {
return For::make(name, min, extent, op->for_type, op->device_api, body);
return For::make(name, min, extent, op->for_type, op->device_api, body, op->allow_partitioning);
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/CodeGen_Hexagon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ class InjectHVXLocks : public IRMutator {
body = acquire_hvx_context(body, target);
body = substitute("uses_hvx", true, body);
Stmt new_for = For::make(op->name, op->min, op->extent, op->for_type,
op->device_api, body);
op->device_api, body, op->allow_partitioning);
Stmt prolog =
IfThenElse::make(uses_hvx_var, call_halide_qurt_hvx_unlock());
Stmt epilog =
Expand Down Expand Up @@ -407,7 +407,7 @@ class InjectHVXLocks : public IRMutator {
// halide_qurt_unlock
// }
s = For::make(op->name, op->min, op->extent, op->for_type,
op->device_api, body);
op->device_api, body, op->allow_partitioning);
}

uses_hvx = old_uses_hvx;
Expand Down
18 changes: 18 additions & 0 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1924,6 +1924,24 @@ Stage &Stage::gpu_tile(const VarOrRVar &x, const VarOrRVar &y, const VarOrRVar &
return gpu_tile(x, y, z, x, y, z, tx, ty, tz, x_size, y_size, z_size, tail, device_api);
}

Stage &Stage::disallow_partitioning(const VarOrRVar &var) {
definition.schedule().touched() = true;
bool found = false;
vector<Dim> &dims = definition.schedule().dims();
for (auto &dim : dims) {
if (var_name_match(dim.var, var.name())) {
found = true;
dim.allow_partitioning = false;
}
}
user_assert(found)
<< "In schedule for " << name()
<< ", could not find var " << var.name()
<< " to mark as disallow partitioning.\n"
<< dump_argument_list();
return *this;
}

Stage &Stage::hexagon(const VarOrRVar &x) {
set_dim_device_api(x, DeviceAPI::Hexagon);
return *this;
Expand Down
1 change: 1 addition & 0 deletions src/Func.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ class Stage {
const std::vector<Expr> &factors,
TailStrategy tail = TailStrategy::Auto);
Stage &reorder(const std::vector<VarOrRVar> &vars);
Stage &disallow_partitioning(const VarOrRVar &var);

template<typename... Args>
HALIDE_NO_USER_CODE_INLINE typename std::enable_if<Internal::all_are_convertible<VarOrRVar, Args...>::value, Stage &>::type
Expand Down
20 changes: 10 additions & 10 deletions src/FuseGPUThreadLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class NormalizeDimensionality : public IRMutator {
}
while (max_depth < block_size.threads_dimensions()) {
string name = thread_names[max_depth];
s = For::make("." + name, 0, 1, ForType::GPUThread, device_api, s);
s = For::make("." + name, 0, 1, ForType::GPUThread, device_api, s, false);
max_depth++;
}
return s;
Expand Down Expand Up @@ -398,7 +398,7 @@ class ExtractSharedAndHeapAllocations : public IRMutator {
Expr v = Variable::make(Int(32), loop_name);
host_side_preamble = substitute(op->name, v, host_side_preamble);
host_side_preamble = For::make(loop_name, new_min, new_extent,
ForType::Serial, DeviceAPI::None, host_side_preamble);
ForType::Serial, DeviceAPI::None, host_side_preamble, op->allow_partitioning);
if (old_preamble.defined()) {
host_side_preamble = Block::make(old_preamble, host_side_preamble);
}
Expand All @@ -407,7 +407,7 @@ class ExtractSharedAndHeapAllocations : public IRMutator {
}

return For::make(op->name, new_min, new_extent,
op->for_type, op->device_api, body);
op->for_type, op->device_api, body, op->allow_partitioning);
}

Stmt visit(const Block *op) override {
Expand Down Expand Up @@ -1093,7 +1093,7 @@ class ExtractRegisterAllocations : public IRMutator {
allocations.swap(old);
}

return For::make(op->name, mutate(op->min), mutate(op->extent), op->for_type, op->device_api, body);
return For::make(op->name, mutate(op->min), mutate(op->extent), op->for_type, op->device_api, body, op->allow_partitioning);
}
}

Expand Down Expand Up @@ -1254,7 +1254,7 @@ class InjectThreadBarriers : public IRMutator {
body = Block::make(body, make_barrier(0));
}
return For::make(op->name, op->min, op->extent,
op->for_type, op->device_api, body);
op->for_type, op->device_api, body, op->allow_partitioning);
} else {
return IRMutator::visit(op);
}
Expand Down Expand Up @@ -1405,14 +1405,14 @@ class FuseGPUThreadLoopsSingleKernel : public IRMutator {
string thread_id = "." + thread_names[0];
// Add back in any register-level allocations
body = register_allocs.rewrap(body, thread_id);
body = For::make(thread_id, 0, block_size_x, innermost_loop_type, op->device_api, body);
body = For::make(thread_id, 0, block_size_x, innermost_loop_type, op->device_api, body, op->allow_partitioning);

// Rewrap the whole thing in other loops over threads
for (int i = 1; i < block_size.threads_dimensions(); i++) {
thread_id = "." + thread_names[i];
body = register_allocs.rewrap(body, thread_id);
body = For::make("." + thread_names[i], 0, block_size.num_threads(i),
ForType::GPUThread, op->device_api, body);
ForType::GPUThread, op->device_api, body, op->allow_partitioning);
}
thread_id.clear();
body = register_allocs.rewrap(body, thread_id);
Expand All @@ -1428,7 +1428,7 @@ class FuseGPUThreadLoopsSingleKernel : public IRMutator {
if (body.same_as(op->body)) {
return op;
} else {
return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body);
return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body, op->allow_partitioning);
}
} else {
return IRMutator::visit(op);
Expand Down Expand Up @@ -1497,7 +1497,7 @@ class ZeroGPULoopMins : public IRMutator {
internal_assert(op);
Expr adjusted = Variable::make(Int(32), op->name) + op->min;
Stmt body = substitute(op->name, adjusted, op->body);
stmt = For::make(op->name, 0, op->extent, op->for_type, op->device_api, body);
stmt = For::make(op->name, 0, op->extent, op->for_type, op->device_api, body, op->allow_partitioning);
}
return stmt;
}
Expand Down Expand Up @@ -1580,7 +1580,7 @@ class AddConditionToALoop : public IRMutator {
}

return For::make(op->name, op->min, op->extent, op->for_type, op->device_api,
IfThenElse::make(condition, op->body, Stmt()));
IfThenElse::make(condition, op->body, Stmt()), op->allow_partitioning);
}

public:
Expand Down
2 changes: 1 addition & 1 deletion src/HexagonOffload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ class InjectHexagonRpc : public IRMutator {
body = LetStmt::make(loop->name, loop->min, loop->body);
} else {
body = For::make(loop->name, loop->min, loop->extent, loop->for_type,
DeviceAPI::None, loop->body);
DeviceAPI::None, loop->body, loop->allow_partitioning);
}

// Build a closure for the device code.
Expand Down
3 changes: 2 additions & 1 deletion src/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ Stmt ProducerConsumer::make_consume(const std::string &name, Stmt body) {
return ProducerConsumer::make(name, false, std::move(body));
}

Stmt For::make(const std::string &name, Expr min, Expr extent, ForType for_type, DeviceAPI device_api, Stmt body) {
Stmt For::make(const std::string &name, Expr min, Expr extent, ForType for_type, DeviceAPI device_api, Stmt body, bool allow_partitioning) {
internal_assert(min.defined()) << "For of undefined\n";
internal_assert(extent.defined()) << "For of undefined\n";
internal_assert(min.type() == Int(32)) << "For with non-integer min\n";
Expand All @@ -354,6 +354,7 @@ Stmt For::make(const std::string &name, Expr min, Expr extent, ForType for_type,
node->min = std::move(min);
node->extent = std::move(extent);
node->for_type = for_type;
node->allow_partitioning = allow_partitioning;
node->device_api = device_api;
node->body = std::move(body);
return node;
Expand Down
3 changes: 2 additions & 1 deletion src/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -798,8 +798,9 @@ struct For : public StmtNode<For> {
ForType for_type;
DeviceAPI device_api;
Stmt body;
bool allow_partitioning;

static Stmt make(const std::string &name, Expr min, Expr extent, ForType for_type, DeviceAPI device_api, Stmt body);
static Stmt make(const std::string &name, Expr min, Expr extent, ForType for_type, DeviceAPI device_api, Stmt body, bool allow_partitioning);

bool is_unordered_parallel() const {
return Halide::Internal::is_unordered_parallel(for_type);
Expand Down
2 changes: 1 addition & 1 deletion src/IRMutator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ Stmt IRMutator::visit(const For *op) {
return op;
}
return For::make(op->name, std::move(min), std::move(extent),
op->for_type, op->device_api, std::move(body));
op->for_type, op->device_api, std::move(body), op->allow_partitioning);
}

Stmt IRMutator::visit(const Store *op) {
Expand Down
4 changes: 2 additions & 2 deletions src/IRPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,12 @@ void IRPrinter::test() {
internal_assert(expr_source.str() == "((x + 3)*((y/2) + 17))");

Stmt store = Store::make("buf", (x * 17) / (x - 3), y - 1, Parameter(), const_true(), ModulusRemainder());
Stmt for_loop = For::make("x", -2, y + 2, ForType::Parallel, DeviceAPI::Host, store);
Stmt for_loop = For::make("x", -2, y + 2, ForType::Parallel, DeviceAPI::Host, store, true);
vector<Expr> args(1);
args[0] = x % 3;
Expr call = Call::make(i32, "buf", args, Call::Extern);
Stmt store2 = Store::make("out", call + 1, x, Parameter(), const_true(), ModulusRemainder(3, 5));
Stmt for_loop2 = For::make("x", 0, y, ForType::Vectorized, DeviceAPI::Host, store2);
Stmt for_loop2 = For::make("x", 0, y, ForType::Vectorized, DeviceAPI::Host, store2, true);

Stmt producer = ProducerConsumer::make_produce("buf", for_loop);
Stmt consumer = ProducerConsumer::make_consume("buf", for_loop2);
Expand Down
6 changes: 3 additions & 3 deletions src/LICM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ class LICM : public IRMutator {
internal_assert(loop);

new_stmt = For::make(loop->name, loop->min, loop->extent,
loop->for_type, loop->device_api, mutate(loop->body));
loop->for_type, loop->device_api, mutate(loop->body), loop->allow_partitioning);

// Wrap lets for the lifted invariants
for (size_t i = 0; i < exprs.size(); i++) {
Expand Down Expand Up @@ -564,15 +564,15 @@ class HoistIfStatements : public IRMutator {
is_pure(i->condition) &&
!expr_uses_var(i->condition, op->name)) {
Stmt s = For::make(op->name, op->min, op->extent,
op->for_type, op->device_api, i->then_case);
op->for_type, op->device_api, i->then_case, op->allow_partitioning);
return IfThenElse::make(i->condition, s);
}
}
if (body.same_as(op->body)) {
return op;
} else {
return For::make(op->name, op->min, op->extent,
op->for_type, op->device_api, body);
op->for_type, op->device_api, body, op->allow_partitioning);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/LoopCarry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ class LoopCarry : public IRMutator {
if (body.same_as(op->body)) {
stmt = op;
} else {
stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body);
stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body, op->allow_partitioning);
}

// Inject the scratch buffer allocations.
Expand Down
3 changes: 2 additions & 1 deletion src/LowerParallelTasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ struct LowerParallelTasks : public IRMutator {
Variable::make(Int(32), loop_extent_name),
ForType::Serial,
DeviceAPI::None,
t.body);
t.body,
true);
} else {
internal_assert(is_const_one(t.extent));
}
Expand Down
4 changes: 2 additions & 2 deletions src/LowerWarpShuffles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ class LowerWarpShuffles : public IRMutator {
allocations.clear();

return For::make(op->name, op->min, warp_size,
op->for_type, op->device_api, body);
op->for_type, op->device_api, body, op->allow_partitioning);
} else {
return IRMutator::visit(op);
}
Expand Down Expand Up @@ -731,7 +731,7 @@ class HoistWarpShufflesFromSingleIfStmt : public IRMutator {
} else {
debug(3) << "Successfully hoisted shuffle out of for loop\n";
}
return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body);
return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body, op->allow_partitioning);
}

Stmt visit(const Store *op) override {
Expand Down
16 changes: 10 additions & 6 deletions src/PartitionLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,10 @@ class PartitionLoops : public IRMutator {
bool in_gpu_loop = false;

Stmt visit(const For *op) override {
if (!op->allow_partitioning) {
return IRMutator::visit(op);
}

Stmt body = op->body;

ScopedValue<bool> old_in_gpu_loop(in_gpu_loop, in_gpu_loop ||
Expand Down Expand Up @@ -706,16 +710,16 @@ class PartitionLoops : public IRMutator {
// Bust simple serial for loops up into three.
if (op->for_type == ForType::Serial && !op->body.as<Acquire>()) {
stmt = For::make(op->name, min_steady, max_steady - min_steady,
op->for_type, op->device_api, simpler_body);
op->for_type, op->device_api, simpler_body, true);

if (make_prologue) {
prologue = For::make(op->name, op->min, min_steady - op->min,
op->for_type, op->device_api, prologue);
op->for_type, op->device_api, prologue, true);
stmt = Block::make(prologue, stmt);
}
if (make_epilogue) {
epilogue = For::make(op->name, max_steady, op->min + op->extent - max_steady,
op->for_type, op->device_api, epilogue);
op->for_type, op->device_api, epilogue, true);
stmt = Block::make(stmt, epilogue);
}
} else {
Expand Down Expand Up @@ -743,7 +747,7 @@ class PartitionLoops : public IRMutator {
stmt = IfThenElse::make(loop_var < min_steady, prologue, stmt);
}
}
stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, stmt);
stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, stmt, true);
}

if (make_epilogue) {
Expand Down Expand Up @@ -866,7 +870,7 @@ class RenormalizeGPULoops : public IRMutator {
internal_assert(!expr_uses_var(f->min, op->name) &&
!expr_uses_var(f->extent, op->name));
Stmt inner = LetStmt::make(op->name, op->value, f->body);
inner = For::make(f->name, f->min, f->extent, f->for_type, f->device_api, inner);
inner = For::make(f->name, f->min, f->extent, f->for_type, f->device_api, inner, f->allow_partitioning);
return mutate(inner);
} else if (a && in_gpu_loop && !in_thread_loop) {
internal_assert(a->extents.size() == 1);
Expand Down Expand Up @@ -944,7 +948,7 @@ class RenormalizeGPULoops : public IRMutator {
for_a->min.same_as(for_b->min) &&
for_a->extent.same_as(for_b->extent)) {
Stmt inner = IfThenElse::make(op->condition, for_a->body, for_b->body);
inner = For::make(for_a->name, for_a->min, for_a->extent, for_a->for_type, for_a->device_api, inner);
inner = For::make(for_a->name, for_a->min, for_a->extent, for_a->for_type, for_a->device_api, inner, for_a->allow_partitioning);
return mutate(inner);
} else {
internal_error << "Unexpected construct inside if statement: " << Stmt(op) << "\n";
Expand Down
6 changes: 3 additions & 3 deletions src/Prefetch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ class InjectPlaceholderPrefetch : public IRMutator {

Stmt stmt;
if (!body.same_as(op->body)) {
stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, std::move(body));
stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, std::move(body), op->allow_partitioning);
} else {
stmt = op;
}
Expand Down Expand Up @@ -303,7 +303,7 @@ class ReducePrefetchDimension : public IRMutator {
stmt = Evaluate::make(Call::make(prefetch->type, Call::prefetch, args, Call::Intrinsic));
for (size_t i = 0; i < index_names.size(); ++i) {
stmt = For::make(index_names[i], 0, prefetch->args[(i + max_dim) * 2 + 2],
ForType::Serial, DeviceAPI::None, stmt);
ForType::Serial, DeviceAPI::None, stmt, true);
}
debug(5) << "\nReduce prefetch to " << max_dim << " dim:\n"
<< "Before:\n"
Expand Down Expand Up @@ -374,7 +374,7 @@ class SplitPrefetch : public IRMutator {
stmt = Evaluate::make(Call::make(prefetch->type, Call::prefetch, args, Call::Intrinsic));
for (size_t i = 0; i < index_names.size(); ++i) {
stmt = For::make(index_names[i], 0, extents[i],
ForType::Serial, DeviceAPI::None, stmt);
ForType::Serial, DeviceAPI::None, stmt, true);
}
debug(5) << "\nSplit prefetch to max of " << max_byte_size << " bytes:\n"
<< "Before:\n"
Expand Down
Loading
Loading