diff --git a/src/ApplySplit.cpp b/src/ApplySplit.cpp index 7bcdc521ba4e..b6491f063fba 100644 --- a/src/ApplySplit.cpp +++ b/src/ApplySplit.cpp @@ -17,7 +17,8 @@ vector apply_split(const Split &split, const string &prefix, Expr outer = Variable::make(Int(32), prefix + split.outer); Expr outer_max = Variable::make(Int(32), prefix + split.outer + ".loop_max"); - if (split.is_split()) { + switch (split.split_type) { + case Split::SplitVar: { Expr inner = Variable::make(Int(32), prefix + split.inner); Expr old_max = Variable::make(Int(32), prefix + split.old_var + ".loop_max"); Expr old_min = Variable::make(Int(32), prefix + split.old_var + ".loop_min"); @@ -129,8 +130,8 @@ vector apply_split(const Split &split, const string &prefix, // Define the original variable as the base value computed above plus the inner loop variable. result.emplace_back(old_var_name, base_var + inner, ApplySplitResult::LetStmt); result.emplace_back(base_name, base, ApplySplitResult::LetStmt); - - } else if (split.is_fuse()) { + } break; + case Split::FuseVars: { // Define the inner and outer in terms of the fused var Expr fused = Variable::make(Int(32), prefix + split.old_var); Expr inner_min = Variable::make(Int(32), prefix + split.inner + ".loop_min"); @@ -154,10 +155,12 @@ vector apply_split(const Split &split, const string &prefix, outer_dim != dim_extent_alignment.end()) { dim_extent_alignment[split.old_var] = inner_dim->second * outer_dim->second; } - } else { - // rename or purify + } break; + case Split::RenameVar: + case Split::PurifyRVar: result.emplace_back(prefix + split.old_var, outer, ApplySplitResult::Substitution); result.emplace_back(prefix + split.old_var, outer, ApplySplitResult::LetStmt); + break; } return result; @@ -173,7 +176,8 @@ vector> compute_loop_bounds_after_split(const Split &spl Expr old_var_extent = Variable::make(Int(32), prefix + split.old_var + ".loop_extent"); Expr old_var_max = Variable::make(Int(32), prefix + split.old_var + ".loop_max"); Expr old_var_min = Variable::make(Int(32), prefix + split.old_var + ".loop_min"); - if (split.is_split()) { + switch (split.split_type) { + case Split::SplitVar: { Expr inner_extent = split.factor; Expr outer_extent = (old_var_max - old_var_min + split.factor) / split.factor; let_stmts.emplace_back(prefix + split.inner + ".loop_min", 0); @@ -182,7 +186,8 @@ vector> compute_loop_bounds_after_split(const Split &spl let_stmts.emplace_back(prefix + split.outer + ".loop_min", 0); let_stmts.emplace_back(prefix + split.outer + ".loop_max", outer_extent - 1); let_stmts.emplace_back(prefix + split.outer + ".loop_extent", outer_extent); - } else if (split.is_fuse()) { + } break; + case Split::FuseVars: { // Define bounds on the fused var using the bounds on the inner and outer Expr inner_extent = Variable::make(Int(32), prefix + split.inner + ".loop_extent"); Expr outer_extent = Variable::make(Int(32), prefix + split.outer + ".loop_extent"); @@ -190,12 +195,16 @@ vector> compute_loop_bounds_after_split(const Split &spl let_stmts.emplace_back(prefix + split.old_var + ".loop_min", 0); let_stmts.emplace_back(prefix + split.old_var + ".loop_max", fused_extent - 1); let_stmts.emplace_back(prefix + split.old_var + ".loop_extent", fused_extent); - } else if (split.is_rename()) { + } break; + case Split::RenameVar: let_stmts.emplace_back(prefix + split.outer + ".loop_min", old_var_min); let_stmts.emplace_back(prefix + split.outer + ".loop_max", old_var_max); let_stmts.emplace_back(prefix + split.outer + ".loop_extent", old_var_extent); + break; + case Split::PurifyRVar: + // Do nothing for purify + break; } - // Do nothing for purify return let_stmts; } diff --git a/src/ApplySplit.h b/src/ApplySplit.h index 4f74aea9ec62..b7a81f508ffe 100644 --- a/src/ApplySplit.h +++ b/src/ApplySplit.h @@ -46,31 +46,6 @@ struct ApplySplitResult { ApplySplitResult(Expr val, Type t = Predicate) : name(""), value(std::move(val)), type(t) { } - - bool is_substitution() const { - return (type == Substitution); - } - bool is_substitution_in_calls() const { - return (type == SubstitutionInCalls); - } - bool is_substitution_in_provides() const { - return (type == SubstitutionInProvides); - } - bool is_let() const { - return (type == LetStmt); - } - bool is_predicate() const { - return (type == Predicate); - } - bool is_predicate_calls() const { - return (type == PredicateCalls); - } - bool is_predicate_provides() const { - return (type == PredicateProvides); - } - bool is_blend_provides() const { - return (type == BlendProvides); - } }; /** Given a Split schedule on a definition (init or update), return a list of diff --git a/src/Func.cpp b/src/Func.cpp index ae49cf1e7485..c243e6950f3f 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -417,34 +417,43 @@ void check_for_race_conditions_in_split_with_blend(const StageSchedule &sched) { // Process the splits in reverse order to figure out which root vars have a // parallel child. for (const auto &split : reverse_view(sched.splits())) { - if (split.is_fuse()) { + switch (split.split_type) { + case Split::FuseVars: if (parallel.count(split.old_var)) { parallel.insert(split.inner); parallel.insert(split.old_var); } - } else if (split.is_rename() || split.is_purify()) { + break; + case Split::RenameVar: + case Split::PurifyRVar: if (parallel.count(split.outer)) { parallel.insert(split.old_var); } - } else { + break; + case Split::SplitVar: if (parallel.count(split.inner) || parallel.count(split.outer)) { parallel.insert(split.old_var); } + break; } } // Now propagate back to all children of the identified root vars, to assert // that none of them use a blending tail strategy. for (const auto &split : sched.splits()) { - if (split.is_fuse()) { + switch (split.split_type) { + case Split::FuseVars: if (parallel.count(split.inner) || parallel.count(split.outer)) { parallel.insert(split.old_var); } - } else if (split.is_rename() || split.is_purify()) { + break; + case Split::RenameVar: + case Split::PurifyRVar: if (parallel.count(split.old_var)) { parallel.insert(split.outer); } - } else { + break; + case Split::SplitVar: if (parallel.count(split.old_var)) { parallel.insert(split.inner); parallel.insert(split.old_var); @@ -457,6 +466,7 @@ void check_for_race_conditions_in_split_with_blend(const StageSchedule &sched) { << "This could cause a race condition.\n"; } } + break; } } } @@ -612,16 +622,20 @@ void apply_split_result(const vector> &bounds_let_stmts, vector &values) { for (const auto &res : splits_result) { - if (res.is_substitution() || res.is_let()) { + switch (res.type) { + case ApplySplitResult::Substitution: + case ApplySplitResult::LetStmt: // Apply substitutions to the list of predicates, args, and values. // Make sure we substitute in all the let stmts as well since we are // not going to add them to the exprs. substitute_var_in_exprs(res.name, res.value, predicates); substitute_var_in_exprs(res.name, res.value, args); substitute_var_in_exprs(res.name, res.value, values); - } else { - internal_assert(res.is_predicate()); + break; + default: + internal_assert(res.type == ApplySplitResult::Predicate); predicates.push_back(res.value); + break; } } @@ -640,7 +654,7 @@ void apply_split_result(const vector> &bounds_let_stmts, bool apply_split(const Split &s, vector &rvars, vector &predicates, vector &args, vector &values, map &dim_extent_alignment) { - internal_assert(s.is_split()); + internal_assert(s.split_type == Split::SplitVar); const auto it = std::find_if(rvars.begin(), rvars.end(), [&s](const ReductionVariable &rv) { return (s.old_var == rv.var); }); @@ -673,7 +687,7 @@ bool apply_split(const Split &s, vector &rvars, bool apply_fuse(const Split &s, vector &rvars, vector &predicates, vector &args, vector &values, map &dim_extent_alignment) { - internal_assert(s.is_fuse()); + internal_assert(s.split_type == Split::FuseVars); const auto &iter_outer = std::find_if(rvars.begin(), rvars.end(), [&s](const ReductionVariable &rv) { return (s.outer == rv.var); }); const auto &iter_inner = std::find_if(rvars.begin(), rvars.end(), @@ -710,7 +724,7 @@ bool apply_fuse(const Split &s, vector &rvars, bool apply_purify(const Split &s, vector &rvars, vector &predicates, vector &args, vector &values, map &dim_extent_alignment) { - internal_assert(s.is_purify()); + internal_assert(s.split_type == Split::PurifyRVar); const auto &iter = std::find_if(rvars.begin(), rvars.end(), [&s](const ReductionVariable &rv) { return (s.old_var == rv.var); }); if (iter != rvars.end()) { @@ -731,7 +745,7 @@ bool apply_purify(const Split &s, vector &rvars, bool apply_rename(const Split &s, vector &rvars, vector &predicates, vector &args, vector &values, map &dim_extent_alignment) { - internal_assert(s.is_rename()); + internal_assert(s.split_type == Split::RenameVar); const auto &iter = std::find_if(rvars.begin(), rvars.end(), [&s](const ReductionVariable &rv) { return (s.old_var == rv.var); }); if (iter != rvars.end()) { @@ -765,14 +779,19 @@ bool apply_split_directive(const Split &s, vector &rvars, } bool found = false; - if (s.is_split()) { + switch (s.split_type) { + case Split::SplitVar: found = apply_split(s, rvars, predicates, args, values, dim_extent_alignment); - } else if (s.is_fuse()) { + break; + case Split::FuseVars: found = apply_fuse(s, rvars, predicates, args, values, dim_extent_alignment); - } else if (s.is_purify()) { + break; + case Split::PurifyRVar: found = apply_purify(s, rvars, predicates, args, values, dim_extent_alignment); - } else { + break; + case Split::RenameVar: found = apply_rename(s, rvars, predicates, args, values, dim_extent_alignment); + break; } if (found) { @@ -1173,19 +1192,24 @@ void Stage::split(const string &old, const string &outer, const string &inner, c // factor does not divide the outer split factor. std::set inner_vars; for (const Split &s : definition.schedule().splits()) { - if (s.is_split()) { + switch (s.split_type) { + case Split::SplitVar: inner_vars.insert(s.inner); if (inner_vars.count(s.old_var)) { inner_vars.insert(s.outer); } - } else if (s.is_rename() || s.is_purify()) { + break; + case Split::RenameVar: + case Split::PurifyRVar: if (inner_vars.count(s.old_var)) { inner_vars.insert(s.outer); } - } else if (s.is_fuse()) { + break; + case Split::FuseVars: if (inner_vars.count(s.inner) || inner_vars.count(s.outer)) { inner_vars.insert(s.old_var); } + break; } } round_up_ok = !inner_vars.count(old_name); @@ -1203,19 +1227,24 @@ void Stage::split(const string &old, const string &outer, const string &inner, c // is OK. Otherwise we can't prove it's safe. std::set inner_vars; for (const Split &s : definition.schedule().splits()) { - if (s.is_split()) { + switch (s.split_type) { + case Split::SplitVar: inner_vars.insert(s.inner); if (inner_vars.count(s.old_var)) { inner_vars.insert(s.outer); } - } else if (s.is_rename() || s.is_purify()) { + break; + case Split::RenameVar: + case Split::PurifyRVar: if (inner_vars.count(s.old_var)) { inner_vars.insert(s.outer); } - } else if (s.is_fuse()) { + break; + case Split::FuseVars: if (inner_vars.count(s.inner) || inner_vars.count(s.outer)) { inner_vars.insert(s.old_var); } + break; } } predicate_loads_ok = !inner_vars.count(old_name); @@ -1258,14 +1287,24 @@ void Stage::split(const string &old, const string &outer, const string &inner, c std::map descends_from_shiftinwards_outer; for (const Split &s : definition.schedule().splits()) { auto it = descends_from_shiftinwards_outer.find(s.old_var); - if (s.is_split() && s.tail == TailStrategy::ShiftInwards) { - descends_from_shiftinwards_outer[s.outer] = s.factor; - } else if (s.is_split() && it != descends_from_shiftinwards_outer.end()) { - descends_from_shiftinwards_outer[s.inner] = it->second; - descends_from_shiftinwards_outer[s.outer] = it->second; - } else if ((s.is_rename() || s.is_purify()) && - it != descends_from_shiftinwards_outer.end()) { - descends_from_shiftinwards_outer[s.outer] = it->second; + switch (s.split_type) { + case Split::SplitVar: + if (s.tail == TailStrategy::ShiftInwards) { + descends_from_shiftinwards_outer[s.outer] = s.factor; + } else if (it != descends_from_shiftinwards_outer.end()) { + descends_from_shiftinwards_outer[s.inner] = it->second; + descends_from_shiftinwards_outer[s.outer] = it->second; + } + break; + case Split::RenameVar: + case Split::PurifyRVar: + if (it != descends_from_shiftinwards_outer.end()) { + descends_from_shiftinwards_outer[s.outer] = it->second; + } + break; + case Split::FuseVars: + // Do nothing + break; } } auto it = descends_from_shiftinwards_outer.find(old_name); @@ -1524,7 +1563,8 @@ void Stage::remove(const string &var) { vector temp; for (const auto &split : reverse_view(splits)) { bool is_removed = false; - if (split.is_fuse()) { + switch (split.split_type) { + case Split::FuseVars: debug(4) << " checking fuse " << split.inner << " and " << split.inner << " into " << split.old_var << "\n"; if (split.inner == old_name || @@ -1541,7 +1581,8 @@ void Stage::remove(const string &var) { removed_vars.insert(split.outer); removed_vars.insert(split.inner); } - } else if (split.is_split()) { + break; + case Split::SplitVar: debug(4) << " splitting " << split.old_var << " into " << split.outer << " and " << split.inner << "\n"; if (should_remove(split.inner)) { @@ -1558,7 +1599,9 @@ void Stage::remove(const string &var) { << " because it has already been renamed or split.\n" << dump_argument_list(); } - } else { + break; + case Split::RenameVar: + case Split::PurifyRVar: debug(4) << " replace/rename " << split.old_var << " into " << split.outer << "\n"; if (should_remove(split.outer)) { @@ -1572,6 +1615,7 @@ void Stage::remove(const string &var) { << " because it has already been renamed or split.\n" << dump_argument_list(); } + break; } if (!is_removed) { temp.insert(temp.begin(), split); @@ -1626,7 +1670,8 @@ Stage &Stage::rename(const VarOrRVar &old_var, const VarOrRVar &new_var) { // If possible, rewrite the split or rename that defines it. found = false; for (auto &split : reverse_view(schedule.splits())) { - if (split.is_fuse()) { + switch (split.split_type) { + case Split::FuseVars: if (split.inner == old_name || split.outer == old_name) { user_error @@ -1641,7 +1686,11 @@ Stage &Stage::rename(const VarOrRVar &old_var, const VarOrRVar &new_var) { found = true; break; } - } else { + + break; + case Split::SplitVar: + case Split::RenameVar: + case Split::PurifyRVar: if (split.inner == old_name) { split.inner = new_name; found = true; @@ -1659,6 +1708,7 @@ Stage &Stage::rename(const VarOrRVar &old_var, const VarOrRVar &new_var) { << " because it has already been renamed or split.\n" << dump_argument_list(); } + break; } } diff --git a/src/Inline.cpp b/src/Inline.cpp index 5b5d21462c9e..54399cf77b76 100644 --- a/src/Inline.cpp +++ b/src/Inline.cpp @@ -55,16 +55,19 @@ void validate_schedule_inlined_function(Function f) { } for (const auto &split : stage_s.splits()) { - if (split.is_rename()) { + switch (split.split_type) { + case Split::RenameVar: user_warning << "It is meaningless to rename variable " << split.old_var << " of function " << f.name() << " to " << split.outer << " because " << f.name() << " is scheduled inline.\n"; - } else if (split.is_fuse()) { + break; + case Split::FuseVars: user_warning << "It is meaningless to fuse variables " << split.inner << " and " << split.outer << " because " << f.name() << " is scheduled inline.\n"; - } else { + break; + case Split::SplitVar: user_warning << "It is meaningless to split variable " << split.old_var << " of function " << f.name() << " into " @@ -72,6 +75,10 @@ void validate_schedule_inlined_function(Function f) { << split.factor << " + " << split.inner << " because " << f.name() << " is scheduled inline.\n"; + + break; + case Split::PurifyRVar: + break; } } diff --git a/src/Schedule.h b/src/Schedule.h index f32ce2265a0f..ea2692752a9e 100644 --- a/src/Schedule.h +++ b/src/Schedule.h @@ -348,19 +348,6 @@ struct Split { // If split_type is Fuse, then this does the opposite of a // split, it joins the outer and inner into the old_var. SplitType split_type; - - bool is_rename() const { - return split_type == RenameVar; - } - bool is_split() const { - return split_type == SplitVar; - } - bool is_fuse() const { - return split_type == FuseVars; - } - bool is_purify() const { - return split_type == PurifyRVar; - } }; /** Each Dim below has a dim_type, which tells you what diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index e7d603adc10a..c7a257dd085e 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -226,21 +226,27 @@ Stmt build_loop_nest( // an example like let a = 2*x in a + f[a]. stmt = substitute_in_all_lets(stmt); for (const auto &res : splits_result) { - if (res.is_substitution()) { + switch (res.type) { + case ApplySplitResult::Substitution: stmt = graph_substitute(res.name, res.value, stmt); - } else if (res.is_substitution_in_calls()) { + break; + case ApplySplitResult::SubstitutionInCalls: stmt = substitute_in(res.name, res.value, true, false, stmt); - } else if (res.is_substitution_in_provides()) { + break; + case ApplySplitResult::SubstitutionInProvides: stmt = substitute_in(res.name, res.value, false, true, stmt); - } else if (res.is_blend_provides() || - res.is_predicate_calls() || - res.is_predicate_provides()) { + break; + case ApplySplitResult::BlendProvides: + case ApplySplitResult::PredicateCalls: + case ApplySplitResult::PredicateProvides: stmt = add_predicates(res.value, func, res.type, stmt); - } else if (res.is_let()) { + break; + case ApplySplitResult::LetStmt: stmt = LetStmt::make(res.name, res.value, stmt); - } else { - internal_assert(res.is_predicate()); + break; + case ApplySplitResult::Predicate: stmt = IfThenElse::make(res.value, stmt, Stmt()); + break; } } stmt = common_subexpression_elimination(stmt); @@ -2230,13 +2236,24 @@ bool validate_schedule(Function f, const Stmt &s, const Target &target, bool is_ // (Note that the splits are ordered, so a single reverse-pass catches all these cases.) for (const auto &split : reverse_view(s.splits())) { - if (split.is_split() && (parallel_vars.count(split.outer) || parallel_vars.count(split.inner))) { - parallel_vars.insert(split.old_var); - } else if (split.is_fuse() && parallel_vars.count(split.old_var)) { - parallel_vars.insert(split.inner); - parallel_vars.insert(split.outer); - } else if ((split.is_rename() || split.is_purify()) && parallel_vars.count(split.outer)) { - parallel_vars.insert(split.old_var); + switch (split.split_type) { + case Split::SplitVar: + if (parallel_vars.count(split.outer) || parallel_vars.count(split.inner)) { + parallel_vars.insert(split.old_var); + } + break; + case Split::FuseVars: + if (parallel_vars.count(split.old_var)) { + parallel_vars.insert(split.inner); + parallel_vars.insert(split.outer); + } + break; + case Split::RenameVar: + case Split::PurifyRVar: + if (parallel_vars.count(split.outer)) { + parallel_vars.insert(split.old_var); + } + break; } }