-
-
Notifications
You must be signed in to change notification settings - Fork 816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[venom]: add loop invariant hoisting pass #4175
base: master
Are you sure you want to change the base?
Changes from 8 commits
13944eb
f626b68
942c731
399a0a4
2a8dd4a
302aa21
013840c
1ee0068
f703408
c5dbf05
cd655fc
ecb272a
18c7610
5583cc5
bdb2896
788bd0d
715b128
8edce11
cf6b25e
77f97b6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import pytest | ||
|
||
from vyper.venom.analysis.analysis import IRAnalysesCache | ||
from vyper.venom.analysis.loop_detection import LoopDetectionAnalysis | ||
from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRVariable | ||
from vyper.venom.context import IRContext | ||
from vyper.venom.function import IRFunction | ||
from vyper.venom.passes.loop_invariant_hosting import LoopInvariantHoisting | ||
|
||
|
||
def _create_loops(fn, depth, loop_id, body_fn=lambda _: (), top=True): | ||
bb = fn.get_basic_block() | ||
cond = IRBasicBlock(IRLabel(f"cond{loop_id}{depth}"), fn) | ||
body = IRBasicBlock(IRLabel(f"body{loop_id}{depth}"), fn) | ||
if top: | ||
exit_block = IRBasicBlock(IRLabel(f"exit_top{loop_id}{depth}"), fn) | ||
else: | ||
exit_block = IRBasicBlock(IRLabel(f"exit{loop_id}{depth}"), fn) | ||
fn.append_basic_block(cond) | ||
fn.append_basic_block(body) | ||
|
||
bb.append_instruction("jmp", cond.label) | ||
|
||
cond_var = IRVariable(f"cond_var{loop_id}{depth}") | ||
cond.append_instruction("iszero", 0, ret=cond_var) | ||
assert isinstance(cond_var, IRVariable) | ||
cond.append_instruction("jnz", cond_var, body.label, exit_block.label) | ||
body_fn(fn, loop_id, depth) | ||
if depth > 1: | ||
_create_loops(fn, depth - 1, loop_id, body_fn, top=False) | ||
bb = fn.get_basic_block() | ||
bb.append_instruction("jmp", cond.label) | ||
fn.append_basic_block(exit_block) | ||
|
||
|
||
def _simple_body(fn, loop_id, depth): | ||
assert isinstance(fn, IRFunction) | ||
bb = fn.get_basic_block() | ||
add_var = IRVariable(f"add_var{loop_id}{depth}") | ||
bb.append_instruction("add", 1, 2, ret=add_var) | ||
|
||
|
||
def _hoistable_body(fn, loop_id, depth): | ||
assert isinstance(fn, IRFunction) | ||
bb = fn.get_basic_block() | ||
add_var_a = IRVariable(f"add_var_a{loop_id}{depth}") | ||
bb.append_instruction("add", 1, 2, ret=add_var_a) | ||
add_var_b = IRVariable(f"add_var_b{loop_id}{depth}") | ||
bb.append_instruction("add", add_var_a, 2, ret=add_var_b) | ||
|
||
|
||
@pytest.mark.parametrize("depth", range(1, 4)) | ||
@pytest.mark.parametrize("count", range(1, 4)) | ||
def test_loop_detection_analysis(depth, count): | ||
ctx = IRContext() | ||
fn = ctx.create_function("_global") | ||
|
||
for c in range(count): | ||
_create_loops(fn, depth, c, _simple_body) | ||
|
||
bb = fn.get_basic_block() | ||
bb.append_instruction("ret") | ||
|
||
ac = IRAnalysesCache(fn) | ||
analysis = ac.request_analysis(LoopDetectionAnalysis) | ||
assert len(analysis.loops) == depth * count | ||
|
||
|
||
@pytest.mark.parametrize("depth", range(1, 4)) | ||
@pytest.mark.parametrize("count", range(1, 4)) | ||
def test_loop_invariant_hoisting_simple(depth, count): | ||
ctx = IRContext() | ||
fn = ctx.create_function("_global") | ||
|
||
for c in range(count): | ||
_create_loops(fn, depth, c, _simple_body) | ||
|
||
bb = fn.get_basic_block() | ||
bb.append_instruction("ret") | ||
|
||
ac = IRAnalysesCache(fn) | ||
LoopInvariantHoisting(ac, fn).run_pass() | ||
|
||
entry = fn.entry | ||
assignments = list(map(lambda x: x.value, entry.get_assignments())) | ||
for bb in filter(lambda bb: bb.label.name.startswith("exit_top"), fn.get_basic_blocks()): | ||
assignments.extend(map(lambda x: x.value, bb.get_assignments())) | ||
|
||
assert len(assignments) == depth * count * 2 | ||
for loop_id in range(count): | ||
for d in range(1, depth + 1): | ||
assert f"%add_var{loop_id}{d}" in assignments, repr(fn) | ||
assert f"%cond_var{loop_id}{d}" in assignments, repr(fn) | ||
|
||
|
||
@pytest.mark.parametrize("depth", range(1, 4)) | ||
@pytest.mark.parametrize("count", range(1, 4)) | ||
def test_loop_invariant_hoisting_dependant(depth, count): | ||
ctx = IRContext() | ||
fn = ctx.create_function("_global") | ||
|
||
for c in range(count): | ||
_create_loops(fn, depth, c, _hoistable_body) | ||
|
||
bb = fn.get_basic_block() | ||
bb.append_instruction("ret") | ||
|
||
ac = IRAnalysesCache(fn) | ||
LoopInvariantHoisting(ac, fn).run_pass() | ||
|
||
entry = fn.entry | ||
assignments = list(map(lambda x: x.value, entry.get_assignments())) | ||
for bb in filter(lambda bb: bb.label.name.startswith("exit_top"), fn.get_basic_blocks()): | ||
assignments.extend(map(lambda x: x.value, bb.get_assignments())) | ||
|
||
assert len(assignments) == depth * count * 3 | ||
for loop_id in range(count): | ||
for d in range(1, depth + 1): | ||
assert f"%add_var_a{loop_id}{d}" in assignments, repr(fn) | ||
assert f"%add_var_b{loop_id}{d}" in assignments, repr(fn) | ||
assert f"%cond_var{loop_id}{d}" in assignments, repr(fn) | ||
|
||
def _unhoistable_body(fn, loop_id, depth): | ||
assert isinstance(fn, IRFunction) | ||
bb = fn.get_basic_block() | ||
add_var_a = IRVariable(f"add_var_a{loop_id}{depth}") | ||
bb.append_instruction("mload", 64, ret=add_var_a) | ||
add_var_b = IRVariable(f"add_var_b{loop_id}{depth}") | ||
bb.append_instruction("add", add_var_a, 2, ret=add_var_b) | ||
|
||
@pytest.mark.parametrize("depth", range(1, 4)) | ||
@pytest.mark.parametrize("count", range(1, 4)) | ||
def test_loop_invariant_hoisting_unhoistable(depth, count): | ||
ctx = IRContext() | ||
fn = ctx.create_function("_global") | ||
|
||
for c in range(count): | ||
_create_loops(fn, depth, c, _unhoistable_body) | ||
|
||
bb = fn.get_basic_block() | ||
bb.append_instruction("ret") | ||
|
||
ac = IRAnalysesCache(fn) | ||
LoopInvariantHoisting(ac, fn).run_pass() | ||
|
||
entry = fn.entry | ||
assignments = list(map(lambda x: x.value, entry.get_assignments())) | ||
for bb in filter(lambda bb: bb.label.name.startswith("exit_top"), fn.get_basic_blocks()): | ||
assignments.extend(map(lambda x: x.value, bb.get_assignments())) | ||
|
||
assert len(assignments) == depth * count | ||
for loop_id in range(count): | ||
for d in range(1, depth + 1): | ||
assert f"%cond_var{loop_id}{d}" in assignments, repr(fn) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from vyper.utils import OrderedSet | ||
from vyper.venom.analysis.analysis import IRAnalysis | ||
from vyper.venom.analysis.cfg import CFGAnalysis | ||
from vyper.venom.basicblock import IRBasicBlock | ||
|
||
|
||
class LoopDetectionAnalysis(IRAnalysis): | ||
""" | ||
Detects loops and computes basic blocks | ||
and the block which is before the loop | ||
""" | ||
|
||
# key = start of the loop (last bb not in the loop) | ||
# value = all the block that loop contains | ||
loops: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] | ||
|
||
done: OrderedSet[IRBasicBlock] | ||
visited: OrderedSet[IRBasicBlock] | ||
|
||
def analyze(self): | ||
self.analyses_cache.request_analysis(CFGAnalysis) | ||
self.loops: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] = dict() | ||
self.done = OrderedSet() | ||
self.visited = OrderedSet() | ||
entry = self.function.entry | ||
self.dfs(entry) | ||
|
||
def invalidate(self): | ||
return super().invalidate() | ||
|
||
def dfs(self, bb: IRBasicBlock, before: IRBasicBlock = None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this function does loop detection, why name it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: we use the convention |
||
if bb in self.visited: | ||
assert before is not None, "Loop must have one basic block before it" | ||
loop = self.collect_path(before, bb) | ||
in_bb = bb.cfg_in.difference({before}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think this might be clearer as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is just because it checks for natural loops, so I check if it has only one input into the loop. |
||
assert len(in_bb) == 1, "Loop must have one input basic block" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
input_bb = in_bb.first() | ||
self.loops[input_bb] = loop | ||
self.done.add(bb) | ||
return | ||
|
||
self.visited.add(bb) | ||
|
||
for neighbour in bb.cfg_out: | ||
if neighbour not in self.done: | ||
self.dfs(neighbour, bb) | ||
|
||
self.done.add(bb) | ||
return | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. redundant |
||
|
||
def collect_path(self, bb_from: IRBasicBlock, bb_to: IRBasicBlock) -> OrderedSet[IRBasicBlock]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: |
||
loop = OrderedSet() | ||
collect_visit = OrderedSet() | ||
self.collect_path_inner(bb_from, bb_to, loop, collect_visit) | ||
return loop | ||
|
||
def collect_path_inner( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
self, | ||
act_bb: IRBasicBlock, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does "act_bb" stand for "active bb"? maybe |
||
bb_to: IRBasicBlock, | ||
loop: OrderedSet[IRBasicBlock], | ||
collect_visit: OrderedSet[IRBasicBlock], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe just |
||
): | ||
if act_bb in collect_visit: | ||
return | ||
collect_visit.add(act_bb) | ||
loop.add(act_bb) | ||
if act_bb == bb_to: | ||
return | ||
|
||
for before in act_bb.cfg_in: | ||
self.collect_path_inner(before, bb_to, loop, collect_visit) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
from vyper.utils import OrderedSet | ||
from vyper.venom.analysis.cfg import CFGAnalysis | ||
from vyper.venom.analysis.dfg import DFGAnalysis | ||
from vyper.venom.analysis.liveness import LivenessAnalysis | ||
from vyper.venom.analysis.loop_detection import LoopDetectionAnalysis | ||
from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLiteral, IRVariable | ||
from vyper.venom.function import IRFunction | ||
from vyper.venom.passes.base_pass import IRPass | ||
|
||
def _ignore_instruction(instruction: IRInstruction) -> bool: | ||
return ( | ||
instruction.is_volatile | ||
or instruction.is_bb_terminator | ||
or instruction.opcode == "returndatasize" | ||
or instruction.opcode == "phi" | ||
) | ||
|
||
|
||
def _is_correct_store(instruction: IRInstruction) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is basically |
||
return ( | ||
instruction.opcode == "store" | ||
and len(instruction.operands) == 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this condition is always true, no? |
||
and isinstance(instruction.operands[0], IRLiteral) | ||
) | ||
|
||
|
||
class LoopInvariantHoisting(IRPass): | ||
""" | ||
This pass detects invariants in loops and hoists them above the loop body. | ||
Any VOLATILE_INSTRUCTIONS, BB_TERMINATORS CFG_ALTERING_INSTRUCTIONS are ignored | ||
""" | ||
|
||
from typing import Iterator | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need for local import, please hoist this to top of file |
||
|
||
function: IRFunction | ||
loops: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] | ||
dfg: DFGAnalysis | ||
|
||
def run_pass(self): | ||
self.analyses_cache.request_analysis(CFGAnalysis) | ||
self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) | ||
loops = self.analyses_cache.request_analysis(LoopDetectionAnalysis) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
self.loops = loops.loops | ||
while True: | ||
change = False | ||
for from_bb, loop in self.loops.items(): | ||
hoistable: list[ | ||
tuple[IRBasicBlock, IRBasicBlock, IRInstruction] | ||
] = self._get_hoistable_loop(from_bb, loop) | ||
if len(hoistable) == 0: | ||
continue | ||
change |= True | ||
self._hoist(hoistable) | ||
if not change: | ||
break | ||
# I have this inside the loop because I dont need to | ||
# invalidate if you dont hoist anything | ||
self.analyses_cache.invalidate_analysis(LivenessAnalysis) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's bring this outside the loop, and add a flag inside the loop to detect if |
||
|
||
def _hoist(self, hoistable: list[tuple[IRBasicBlock, IRBasicBlock, IRInstruction]]): | ||
for loop_idx, bb, inst in hoistable: | ||
bb.remove_instruction(inst) | ||
bb_before: IRBasicBlock = loop_idx | ||
bb_before.insert_instruction(inst, index=len(bb_before.instructions) - 1) | ||
|
||
def _get_hoistable_loop( | ||
self, from_bb: IRBasicBlock, loop: OrderedSet[IRBasicBlock] | ||
) -> list[tuple[IRBasicBlock, IRBasicBlock, IRInstruction]]: | ||
result: list[tuple[IRBasicBlock, IRBasicBlock, IRInstruction]] = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these type signatures are hard to read -- either make them less specific (e.g. |
||
for bb in loop: | ||
result.extend(self._get_hoistable_bb(bb, from_bb)) | ||
return result | ||
|
||
def _get_hoistable_bb( | ||
self, bb: IRBasicBlock, loop_idx: IRBasicBlock | ||
) -> list[tuple[IRBasicBlock, IRBasicBlock, IRInstruction]]: | ||
result: list[tuple[IRBasicBlock, IRBasicBlock, IRInstruction]] = [] | ||
for instruction in bb.instructions: | ||
if self._can_hoist_instruction(instruction, self.loops[loop_idx]): | ||
result.append((loop_idx, bb, instruction)) | ||
|
||
return result | ||
|
||
def _can_hoist_instruction( | ||
self, instruction: IRInstruction, loop: OrderedSet[IRBasicBlock] | ||
) -> bool: | ||
if _ignore_instruction(instruction): | ||
return False | ||
for bb in loop: | ||
if self._in_bb(instruction, bb): | ||
return False | ||
|
||
if _is_correct_store(instruction): | ||
for used_instruction in self.dfg.get_uses(instruction.output): | ||
if not self._can_hoist_instruction_ignore_stores(used_instruction, loop): | ||
return False | ||
|
||
return True | ||
|
||
def _in_bb(self, instruction: IRInstruction, bb: IRBasicBlock): | ||
for in_var in instruction.get_input_variables(): | ||
assert isinstance(in_var, IRVariable) | ||
source_ins = self.dfg._dfg_outputs[in_var] | ||
if source_ins in bb.instructions: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think we can just check |
||
return True | ||
return False | ||
|
||
def _can_hoist_instruction_ignore_stores( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. leave a comment explaining why this is necessary? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe just call this |
||
self, instruction: IRInstruction, loop: OrderedSet[IRBasicBlock] | ||
) -> bool: | ||
if _ignore_instruction(instruction): | ||
return False | ||
for bb in loop: | ||
if self._in_bb_ignore_store(instruction, bb): | ||
return False | ||
return True | ||
|
||
def _in_bb_ignore_store(self, instruction: IRInstruction, bb: IRBasicBlock): | ||
for in_var in instruction.get_input_variables(): | ||
assert isinstance(in_var, IRVariable) | ||
source_ins = self.dfg._dfg_outputs[in_var] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. prefer instruction to always be abbreviated as prev_inst = self.dfg.get_producing_instruction(in_var)
# can add `assert prev_inst is not None` if desired |
||
if _is_correct_store(source_ins): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how can this condition ever hold true? |
||
continue | ||
|
||
if source_ins in bb.instructions: | ||
return True | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need for this implementation