From 4edeaba57f0749a169e0b0e99cbc4c32f52f779a Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Thu, 12 Oct 2023 13:34:58 -0700 Subject: [PATCH] Don't lift loop vars outside of their loops in sliding window Sliding window, when operating in the mode that shifts the consumer's loop min backwards a few iterations to cover the warmup, was capable of inappropriately lifting for loop vars inside that loop but outside the produce node of the slid Func. Fixes #7891 --- src/SlidingWindow.cpp | 8 +++++++- test/correctness/fuzz_schedule.cpp | 24 ++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/SlidingWindow.cpp b/src/SlidingWindow.cpp index 479d71ce6fac..8101d66d3fff 100644 --- a/src/SlidingWindow.cpp +++ b/src/SlidingWindow.cpp @@ -225,6 +225,9 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator { set &slid_dimensions; Scope scope; + // Loops between the loop being slid over and the produce node + Scope<> enclosing_loops; + map replacements; using IRMutator::visit; @@ -433,7 +436,9 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator { new_loop_min_eq = simplify(new_loop_min_eq); Interval solve_result = solve_for_inner_interval(new_loop_min_eq, new_loop_min_name); internal_assert(!new_loop_min.defined()); - if (solve_result.has_upper_bound() && !equal(solve_result.max, loop_min)) { + if (solve_result.has_upper_bound() && + !equal(solve_result.max, loop_min) && + !expr_uses_vars(solve_result.max, enclosing_loops)) { new_loop_min = simplify(solve_result.max); // We have a new loop min, so we an assume every iteration has @@ -558,6 +563,7 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator { // the var we're sliding over. Expr min = expand_expr(op->min, scope); Expr extent = expand_expr(op->extent, scope); + ScopedBinding<> bind(enclosing_loops, op->name); if (is_const_one(extent)) { // Just treat it like a let Stmt s = LetStmt::make(op->name, min, op->body); diff --git a/test/correctness/fuzz_schedule.cpp b/test/correctness/fuzz_schedule.cpp index d5a2a664fec5..c4042aec22a8 100644 --- a/test/correctness/fuzz_schedule.cpp +++ b/test/correctness/fuzz_schedule.cpp @@ -96,6 +96,30 @@ int main(int argc, char **argv) { check_blur_output(buf, correct); } + // https://github.com/halide/Halide/issues/7891 + { + Func input("input"); + Func local_sum("local_sum"); + Func blurry("blurry"); + Var x("x"), y("y"); + input(x, y) = 2 * x + 5 * y; + RDom r(-2, 5, -2, 5); + local_sum(x, y) = 0; + local_sum(x, y) += input(x + r.x, y + r.y); + blurry(x, y) = cast(local_sum(x, y) / 25); + Var yo, yi, xo, xi, xio, xii, xiio, xiii; + blurry.split(y, yo, yi, 4, TailStrategy::Auto) + .split(x, xo, xi, 1, TailStrategy::Auto) + .split(xi, xio, xii, 4, TailStrategy::GuardWithIf) + .split(xii, xiio, xiii, 1, TailStrategy::RoundUp); + local_sum.compute_at(blurry, xiio); + input.compute_at(blurry, xiio); + input.store_root(); + Pipeline p({blurry}); + Buffer buf = p.realize({32, 32}); + check_blur_output(buf, correct); + } + printf("Success!\n"); return 0;