Skip to content

Commit

Permalink
Remove type inspection helpers from ApplySplitResult and Split (#8489)
Browse files Browse the repository at this point in the history
They hide bugs where a case hasn't been considered.
  • Loading branch information
alexreinking authored Nov 23, 2024
1 parent 922e469 commit 166cd92
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 102 deletions.
27 changes: 18 additions & 9 deletions src/ApplySplit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ vector<ApplySplitResult> apply_split(const Split &split, const string &prefix,

Expr outer = Variable::make(Int(32), prefix + split.outer);
Expr outer_max = Variable::make(Int(32), prefix + split.outer + ".loop_max");
if (split.is_split()) {
switch (split.split_type) {
case Split::SplitVar: {
Expr inner = Variable::make(Int(32), prefix + split.inner);
Expr old_max = Variable::make(Int(32), prefix + split.old_var + ".loop_max");
Expr old_min = Variable::make(Int(32), prefix + split.old_var + ".loop_min");
Expand Down Expand Up @@ -129,8 +130,8 @@ vector<ApplySplitResult> apply_split(const Split &split, const string &prefix,
// Define the original variable as the base value computed above plus the inner loop variable.
result.emplace_back(old_var_name, base_var + inner, ApplySplitResult::LetStmt);
result.emplace_back(base_name, base, ApplySplitResult::LetStmt);

} else if (split.is_fuse()) {
} break;
case Split::FuseVars: {
// Define the inner and outer in terms of the fused var
Expr fused = Variable::make(Int(32), prefix + split.old_var);
Expr inner_min = Variable::make(Int(32), prefix + split.inner + ".loop_min");
Expand All @@ -154,10 +155,12 @@ vector<ApplySplitResult> apply_split(const Split &split, const string &prefix,
outer_dim != dim_extent_alignment.end()) {
dim_extent_alignment[split.old_var] = inner_dim->second * outer_dim->second;
}
} else {
// rename or purify
} break;
case Split::RenameVar:
case Split::PurifyRVar:
result.emplace_back(prefix + split.old_var, outer, ApplySplitResult::Substitution);
result.emplace_back(prefix + split.old_var, outer, ApplySplitResult::LetStmt);
break;
}

return result;
Expand All @@ -173,7 +176,8 @@ vector<std::pair<string, Expr>> compute_loop_bounds_after_split(const Split &spl
Expr old_var_extent = Variable::make(Int(32), prefix + split.old_var + ".loop_extent");
Expr old_var_max = Variable::make(Int(32), prefix + split.old_var + ".loop_max");
Expr old_var_min = Variable::make(Int(32), prefix + split.old_var + ".loop_min");
if (split.is_split()) {
switch (split.split_type) {
case Split::SplitVar: {
Expr inner_extent = split.factor;
Expr outer_extent = (old_var_max - old_var_min + split.factor) / split.factor;
let_stmts.emplace_back(prefix + split.inner + ".loop_min", 0);
Expand All @@ -182,20 +186,25 @@ vector<std::pair<string, Expr>> compute_loop_bounds_after_split(const Split &spl
let_stmts.emplace_back(prefix + split.outer + ".loop_min", 0);
let_stmts.emplace_back(prefix + split.outer + ".loop_max", outer_extent - 1);
let_stmts.emplace_back(prefix + split.outer + ".loop_extent", outer_extent);
} else if (split.is_fuse()) {
} break;
case Split::FuseVars: {
// Define bounds on the fused var using the bounds on the inner and outer
Expr inner_extent = Variable::make(Int(32), prefix + split.inner + ".loop_extent");
Expr outer_extent = Variable::make(Int(32), prefix + split.outer + ".loop_extent");
Expr fused_extent = inner_extent * outer_extent;
let_stmts.emplace_back(prefix + split.old_var + ".loop_min", 0);
let_stmts.emplace_back(prefix + split.old_var + ".loop_max", fused_extent - 1);
let_stmts.emplace_back(prefix + split.old_var + ".loop_extent", fused_extent);
} else if (split.is_rename()) {
} break;
case Split::RenameVar:
let_stmts.emplace_back(prefix + split.outer + ".loop_min", old_var_min);
let_stmts.emplace_back(prefix + split.outer + ".loop_max", old_var_max);
let_stmts.emplace_back(prefix + split.outer + ".loop_extent", old_var_extent);
break;
case Split::PurifyRVar:
// Do nothing for purify
break;
}
// Do nothing for purify

return let_stmts;
}
Expand Down
25 changes: 0 additions & 25 deletions src/ApplySplit.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,31 +46,6 @@ struct ApplySplitResult {
ApplySplitResult(Expr val, Type t = Predicate)
: name(""), value(std::move(val)), type(t) {
}

bool is_substitution() const {
return (type == Substitution);
}
bool is_substitution_in_calls() const {
return (type == SubstitutionInCalls);
}
bool is_substitution_in_provides() const {
return (type == SubstitutionInProvides);
}
bool is_let() const {
return (type == LetStmt);
}
bool is_predicate() const {
return (type == Predicate);
}
bool is_predicate_calls() const {
return (type == PredicateCalls);
}
bool is_predicate_provides() const {
return (type == PredicateProvides);
}
bool is_blend_provides() const {
return (type == BlendProvides);
}
};

/** Given a Split schedule on a definition (init or update), return a list of
Expand Down
Loading

0 comments on commit 166cd92

Please sign in to comment.