Skip to content

Commit

Permalink
Compute preserved rdims set earlier to drop find_rvar
Browse files Browse the repository at this point in the history
  • Loading branch information
alexreinking committed Nov 25, 2024
1 parent 1355d3e commit a201afe
Showing 1 changed file with 7 additions and 16 deletions.
23 changes: 7 additions & 16 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,13 +624,6 @@ Func Stage::rfactor(const RVar &r, const Var &v) {
// Helpers for rfactor implementation
namespace {

optional<RVar> find_rvar(const vector<RVar> &items, const Dim &dim) {
const auto has_v = std::find_if(items.begin(), items.end(), [&](auto &x) {
return dim_match(dim, x);
});
return has_v == items.end() ? std::nullopt : std::make_optional(*has_v);
}

optional<Dim> find_dim(const vector<Dim> &items, const VarOrRVar &v) {
const auto has_v = std::find_if(items.begin(), items.end(), [&](auto &x) {
return dim_match(x, v);
Expand Down Expand Up @@ -841,32 +834,35 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
vector<RVar> preserved_rvars;
vector<Var> preserved_vars;
vector<Dim> preserved_rdims;
unordered_set<string> preserved_rdims_set;
vector<Dim> intermediate_rdims;
{
unordered_map<string, int> dim_ordering;
for (size_t i = 0; i < definition.schedule().dims().size(); i++) {
dim_ordering.emplace(definition.schedule().dims()[i].var, i);
}

vector<tuple<RVar, Var, Dim>> preserved_with_dims;
using PreservedData = tuple<RVar, Var, Dim>;
vector<PreservedData> preserved_with_dims;
for (const auto &[rv, v] : preserved) {
const optional<Dim> rdim = find_dim(definition.schedule().dims(), rv);
internal_assert(rdim);
preserved_with_dims.emplace_back(rv, v, *rdim);
}

std::sort(preserved_with_dims.begin(), preserved_with_dims.end(), [&](const auto &lhs, const auto &rhs) {
std::sort(preserved_with_dims.begin(), preserved_with_dims.end(), [&](const PreservedData &lhs, const PreservedData &rhs) {
return dim_ordering.at(std::get<2>(lhs).var) < dim_ordering.at(std::get<2>(rhs).var);
});

for (const auto &[rv, v, dim] : preserved_with_dims) {
preserved_rvars.push_back(rv);
preserved_vars.push_back(v);
preserved_rdims.push_back(dim);
preserved_rdims_set.insert(dim.var);
}

for (const Dim &dim : definition.schedule().dims()) {
if (dim.is_rvar() && !find_rvar(preserved_rvars, dim)) {
if (dim.is_rvar() && !preserved_rdims_set.count(dim.var)) {
intermediate_rdims.push_back(dim);
}
}
Expand Down Expand Up @@ -971,14 +967,9 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {

vector<Dim> reducing_dims;
{
unordered_set<string> preserved_rdim_set;
for (const auto &dim : preserved_rdims) {
preserved_rdim_set.insert(dim.var);
}

// Remove rvar dims NOT IN the preserved list from the REDUCING Func
for (const auto &dim : definition.schedule().dims()) {
if (!dim.is_rvar() || preserved_rdim_set.count(dim.var)) {
if (!dim.is_rvar() || preserved_rdims_set.count(dim.var)) {
reducing_dims.push_back(dim);
}
}
Expand Down

0 comments on commit a201afe

Please sign in to comment.