diff --git a/src/Func.cpp b/src/Func.cpp index a4a9363494dd..a17bc3ee4057 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -624,13 +624,6 @@ Func Stage::rfactor(const RVar &r, const Var &v) { // Helpers for rfactor implementation namespace { -optional find_rvar(const vector &items, const Dim &dim) { - const auto has_v = std::find_if(items.begin(), items.end(), [&](auto &x) { - return dim_match(dim, x); - }); - return has_v == items.end() ? std::nullopt : std::make_optional(*has_v); -} - optional find_dim(const vector &items, const VarOrRVar &v) { const auto has_v = std::find_if(items.begin(), items.end(), [&](auto &x) { return dim_match(x, v); @@ -841,6 +834,7 @@ Func Stage::rfactor(const vector> &preserved) { vector preserved_rvars; vector preserved_vars; vector preserved_rdims; + unordered_set preserved_rdims_set; vector intermediate_rdims; { unordered_map dim_ordering; @@ -848,14 +842,15 @@ Func Stage::rfactor(const vector> &preserved) { dim_ordering.emplace(definition.schedule().dims()[i].var, i); } - vector> preserved_with_dims; + using PreservedData = tuple; + vector preserved_with_dims; for (const auto &[rv, v] : preserved) { const optional rdim = find_dim(definition.schedule().dims(), rv); internal_assert(rdim); preserved_with_dims.emplace_back(rv, v, *rdim); } - std::sort(preserved_with_dims.begin(), preserved_with_dims.end(), [&](const auto &lhs, const auto &rhs) { + std::sort(preserved_with_dims.begin(), preserved_with_dims.end(), [&](const PreservedData &lhs, const PreservedData &rhs) { return dim_ordering.at(std::get<2>(lhs).var) < dim_ordering.at(std::get<2>(rhs).var); }); @@ -863,10 +858,11 @@ Func Stage::rfactor(const vector> &preserved) { preserved_rvars.push_back(rv); preserved_vars.push_back(v); preserved_rdims.push_back(dim); + preserved_rdims_set.insert(dim.var); } for (const Dim &dim : definition.schedule().dims()) { - if (dim.is_rvar() && !find_rvar(preserved_rvars, dim)) { + if (dim.is_rvar() && !preserved_rdims_set.count(dim.var)) { intermediate_rdims.push_back(dim); } } @@ -971,14 +967,9 @@ Func Stage::rfactor(const vector> &preserved) { vector reducing_dims; { - unordered_set preserved_rdim_set; - for (const auto &dim : preserved_rdims) { - preserved_rdim_set.insert(dim.var); - } - // Remove rvar dims NOT IN the preserved list from the REDUCING Func for (const auto &dim : definition.schedule().dims()) { - if (!dim.is_rvar() || preserved_rdim_set.count(dim.var)) { + if (!dim.is_rvar() || preserved_rdims_set.count(dim.var)) { reducing_dims.push_back(dim); } }