diff --git a/coffee/__init__.py b/coffee/__init__.py index 27de2c6c..6fc464df 100644 --- a/coffee/__init__.py +++ b/coffee/__init__.py @@ -36,7 +36,6 @@ import sys from coffee.citations import update_citations -from coffee.vectorizer import VectStrategy from coffee.logger import LOG_DEFAULT, set_log_level, warn from coffee.system import set_architecture, set_compiler, set_isa @@ -169,8 +168,6 @@ def set_opt_level(optlevel): O1 = OptimizationLevel('O1', rewrite=1) O2 = OptimizationLevel('O2', rewrite=2, dead_ops_elimination=True) O3 = OptimizationLevel('O3', align_pad=True, **O2) -Ofast = OptimizationLevel('Ofast', vectorize=(VectStrategy.SPEC_UAJ_PADD, 2), - precompute='noloops', **O3) initialized = False diff --git a/coffee/base.py b/coffee/base.py index dc4ffd73..bc519546 100644 --- a/coffee/base.py +++ b/coffee/base.py @@ -447,11 +447,11 @@ class Symbol(Expr): a[i][3*j + 2]. """ - def __init__(self, symbol, rank=(), offset=()): + def __init__(self, symbol, rank=None, offset=None): super(Symbol, self).__init__([]) self.symbol = symbol - self.rank = rank - self.offset = offset or tuple([(1, 0) for r in rank]) + self.rank = Rank(rank or ()) + self.offset = offset or tuple([(1, 0) for r in self.rank]) def operands(self): return [self.symbol, self.rank, self.offset], {} @@ -465,6 +465,14 @@ def is_const(self): from .utils import is_const_dim return not self.rank or all(is_const_dim(r) for r in self.rank) + @property + def is_number(self): + try: + float(self.symbol) + return True + except ValueError: + return False + @property def is_const_offset(self): from .utils import is_const_dim, flatten @@ -715,7 +723,8 @@ def __init__(self, typ, sym, init=None, qualifiers=None, attributes=None, self._scope = scope or UNKNOWN def operands(self): - return [self.typ, self.sym, self.init, self.qual, self.attr], {} + return [self.typ, self.sym, self.init, self.qual, self.attr, + self.pointers], {} def pad(self, new_rank): self.sym.rank = new_rank @@ -893,6 +902,10 @@ def increment(self, value): def header(self): return (self.start, self.size, self.increment) + @property + def block(self): + return self.children[0] + @property def body(self): return self.children[0].children @@ -1216,6 +1229,26 @@ def gencode(self, not_scope=False): return self.children[0].gencode() +class Rank(tuple): + + def __contains__(self, val): + from coffee.visitors import FindInstances + if isinstance(val, Node): + val, search = str(val), type(Node) + elif isinstance(val, str): + val, search = val, Symbol + else: + return False + for i in self: + if isinstance(i, Node): + items = FindInstances(search).visit(i) + if any(val == str(i) for i in items[search]): + return True + elif isinstance(i, str) and val == i: + return True + return False + + # Utility functions ### diff --git a/coffee/cse.py b/coffee/cse.py index 6347a79a..fb9867de 100644 --- a/coffee/cse.py +++ b/coffee/cse.py @@ -34,7 +34,6 @@ from __future__ import absolute_import, print_function, division from six.moves import zip -from sys import maxsize import operator from .base import * @@ -51,7 +50,7 @@ class Temporary(object): or an AugmentedAssig) that computes a temporary variable; that is, a variable that is read in more than one place.""" - def __init__(self, node, main_loop, nest, reads=None, linear_reads_costs=None): + def __init__(self, node, main_loop, nest, linear_reads_costs=None): self.level = -1 self.pushed = False self.readby = [] @@ -59,7 +58,6 @@ def __init__(self, node, main_loop, nest, reads=None, linear_reads_costs=None): self.node = node self.main_loop = main_loop self.nest = nest - self.reads = reads or [] self.linear_reads_costs = linear_reads_costs or OrderedDict() self.flops = EstimateFlops().visit(node) @@ -72,9 +70,8 @@ def rank(self): return self.symbol.rank if self.symbol else None @property - def is_bilinear(self): - linear_loops = [l for l in self.loops if l.dim in self.rank] - return len(linear_loops) == 2 + def linearity_degree(self): + return len(self.main_linear_loops) @property def symbol(self): @@ -96,6 +93,10 @@ def expr(self): def urepr(self): return self.symbol.urepr + @property + def reads(self): + return FindInstances(Symbol).visit(self.expr)[Symbol] if self.expr else [] + @property def linear_reads(self): return self.linear_reads_costs.keys() if self.linear_reads_costs else [] @@ -105,17 +106,27 @@ def loops(self): return list(zip(*self.nest))[0] @property - def niters(self): - return reduce(operator.mul, [l.size for l in self.loops], 1) + def main_linear_loops(self): + return [l for l in self.main_loops if l.is_linear] + + @property + def main_linear_nest(self): + return [(l, p) for l, p in self.main_nest if l in self.linear_loops] + + @property + def main_loops(self): + index = self.loops.index(self.main_loop) + return [l for l in self.loops[:index + 1]] @property - def niters_after_licm(self): - return reduce(operator.mul, - [l.size for l in self.loops if l is not self.main_loop], 1) + def main_nest(self): + return [(l, p) for l, p in self.nest if l in self.main_loops] @property - def project(self): - return len(self.linear_reads) + def flops_projection(self): + # #muls + #sums + nmuls = len(self.linear_reads) + return (nmuls) + (nmuls - 1) @property def is_ssa(self): @@ -125,6 +136,35 @@ def is_ssa(self): def is_static_init(self): return isinstance(self.expr, ArrayInit) + @property + def is_increment(self): + return isinstance(self.node, Incr) + + @property + def reductions(self): + return [l for l in self.main_loops if l.dim not in self.rank] + + @property + def nreductions(self): + return len(self.reductions) + + def niters(self, mode='all', handle=None): + assert mode in ['all', 'outer', 'nonlinear', 'in', 'out'] + handle = handle or [] + limit = self.loops.index(self.main_loop) + loops = self.loops[:limit + 1] + if mode == 'all': + sizes = [l.size for l in loops] + elif mode == 'outer': + sizes = [l.size for l in loops if l is not self.main_loop] + elif mode == 'nonlinear': + sizes = [l.size for l in loops if not l.is_linear] + elif mode == 'in': + sizes = [l.size for l in loops if l.dim in handle] + else: + sizes = [l.size for l in loops if l.dim not in handle] + return reduce(operator.mul, sizes, 1) + def depends(self, others): """Return True if ``self`` reads a temporary or is read by a temporary that appears in the iterator ``others``, False otherwise.""" @@ -135,7 +175,7 @@ def depends(self, others): return False def reconstruct(self): - temporary = Temporary(self.node, self.main_loop, self.nest, list(self.reads), + temporary = Temporary(self.node, self.main_loop, self.nest, OrderedDict(self.linear_reads_costs)) temporary.level = self.level temporary.readby = list(self.readby) @@ -158,12 +198,11 @@ class CSEUnpicker(object): symbols (further information concerning loop linearity is available in the module ``expression.py``).""" - def __init__(self, exprs, header, hoisted, decls, expr_graph): + def __init__(self, exprs, header, hoisted, decls): self.exprs = exprs self.header = header self.hoisted = hoisted self.decls = decls - self.expr_graph = expr_graph @property def type(self): @@ -192,7 +231,8 @@ def is_pushable(temporary, temporaries): return False pushed_in = [global_trace.get(rb.urepr) for rb in temporary.readby] pushed_in = set(rb.main_loop.children[0] for rb in pushed_in if rb) - for s in temporary.reads: + reads = [s for s in temporary.reads if not s.is_number] + for s in reads: # ... all the read temporaries must be accessible in the loops in which # they will be pushed if s.urepr in global_trace and global_trace[s.urepr].pushed: @@ -212,13 +252,14 @@ def is_pushable(temporary, temporaries): global_trace[rb.urepr]) # The temporary is going to be pushed, so we can remove it as long as # it is not needed somewhere else - if t.node in t.main_loop.body and all(rb.urepr in trace for rb in t.readby): + if t.node in t.main_loop.body and\ + all(rb.urepr in global_trace for rb in t.readby): global_trace[t.urepr].pushed = True t.main_loop.body.remove(t.node) self.decls.pop(t.name, None) - # Transform the AST (note: node replacement must happend in the order - # in which temporaries have been encountered) + # Transform the AST (note: node replacement must happen in the order + # in which the temporaries have been encountered) modified_temporaries = sorted(modified_temporaries.values(), key=lambda t: global_trace.keys().index(t.urepr)) for t in modified_temporaries: @@ -245,14 +286,13 @@ def _transform_temporaries(self, temporaries): # Expand + Factorize rewriters = OrderedDict() for t in temporaries: - expr_info = MetaExpr(self.type, t.main_loop.children[0], t.nest, - tuple(l.dim for l in t.loops if l.is_linear)) + expr_info = MetaExpr(self.type, t.main_loop.block, t.main_nest) ew = ExpressionRewriter(t.node, expr_info, self.decls, self.header, - self.hoisted, self.expr_graph) + self.hoisted) ew.replacediv() ew.expand(mode='all', lda=lda) + ew.reassociate(lambda i: all(r != t.main_loop.dim for r in lda[i.symbol])) ew.factorize(mode='adhoc', adhoc={i.urepr: [] for i in t.linear_reads}, lda=lda) - ew.factorize(mode='heuristic') rewriters[t] = ew lda = loops_analysis(self.header, value='dim') @@ -260,14 +300,14 @@ def _transform_temporaries(self, temporaries): # Code motion for t, ew in rewriters.items(): ew.licm(mode='only_outlinear', lda=lda, global_cse=True) - if t.is_bilinear: - ew.licm(mode='only_linear') + if t.linearity_degree > 1: + ew.licm(mode='only_linear', lda=lda) - def _analyze_expr(self, expr, lda): + def _analyze_expr(self, expr, loop, lda): finder = FindInstances(Symbol) reads = finder.visit(expr, ret=FindInstances.default_retval())[Symbol] reads = [s for s in reads if s.symbol in self.decls] - syms = [s for s in reads if any(l in self.linear_dims for l in lda[s])] + syms = [s for s in reads if any(d in loop.dim for d in lda[s])] linear_reads_costs = OrderedDict() @@ -289,16 +329,18 @@ def wrapper(node, found=0): return reads, linear_reads_costs def _analyze_loop(self, loop, nest, lda, global_trace): - trace = OrderedDict() + linear_dims = [l.dim for l, _ in nest if l.is_linear] + trace = OrderedDict() for node in loop.body: if not isinstance(node, Writer): not_ssa = [trace[w] for w in in_written(node, key='urepr') if w in trace] for t in not_ssa: t.readby.append(t.symbol) continue - reads, linear_reads_costs = self._analyze_expr(node.rvalue, lda) - for s in linear_reads_costs.keys(): + reads, linear_reads_costs = self._analyze_expr(node.rvalue, loop, lda) + affected = [s for s in reads if any(i in linear_dims for i in lda[s])] + for s in affected: if s.urepr in global_trace: temporary = global_trace[s.urepr] temporary.readby.append(node.lvalue) @@ -308,7 +350,7 @@ def _analyze_loop(self, loop, nest, lda, global_trace): else: temporary = trace.setdefault(s.urepr, Temporary(s, loop, nest)) temporary.readby.append(node.lvalue) - new_temporary = Temporary(node, loop, nest, reads, linear_reads_costs) + new_temporary = Temporary(node, loop, nest, linear_reads_costs) new_temporary.level = max([trace[s.urepr].level for s in new_temporary.linear_reads] or [-2]) + 1 trace[node.lvalue.urepr] = new_temporary @@ -328,86 +370,95 @@ def _cost_cse(self, levels, bounds=None): levels = {i: levels[i] for i in range(lb, up)} cost = 0 for level, temporaries in levels.items(): - cost += sum(t.flops*t.niters for t in temporaries) + cost += sum(t.flops*t.niters('all') for t in temporaries) return cost - def _cost_fact(self, trace, levels, bounds): + def _cost_fact(self, trace, levels, lda, bounds): # Check parameters - bounds = bounds or (min(levels.keys()), max(levels.keys())) assert len(bounds) == 2 and bounds[1] >= bounds[0] assert bounds[0] in levels.keys() and bounds[1] in levels.keys() # Determine current costs of individual loop regions - cse_cost = self._cost_cse(levels, (min(levels.keys()), bounds[0])) - uptolevel_cost = cse_cost - level_inloop_cost, total_outloop_cost, cse = 0, 0, 0 + input_cost = self._cost_cse(levels, (min(levels.keys()), max(levels.keys()))) + uptolevel_cost, post_cse_cost = input_cost, input_cost + level_inloop_cost, total_outloop_cost = 0, 0 # We are going to modify a copy of the temporaries dict new_trace = OrderedDict() for s, t in trace.items(): new_trace[s] = t.reconstruct() - best = (bounds[0], bounds[0], maxsize) + # Cost induced by the untransformed temporaries + pre_cse_cost = self._cost_cse(levels, (min(levels.keys()), bounds[0])) + + best = (bounds[0], bounds[0], uptolevel_cost) fact_levels = {k: v for k, v in levels.items() if k > bounds[0] and k <= bounds[1]} for level, temporaries in sorted(fact_levels.items(), key=lambda i_j: i_j[0]): level_inloop_cost = 0 for t in temporaries: - # The operation count, after fact+licm, outside /loop/, induced by /t/ - t_outloop_cost = 0 - # The operation count, after fact+licm, within /loop/, induced by /t/ - t_inloop_cost = 0 - - # Calculate the operation count for /t/ if we applied expansion + fact - linear_reads = [] + # Compute the cost induced by /t/ in the outer loops after fact+licm + t_outloop_cost, linear_reads = 0, [] for read, cost in t.linear_reads_costs.items(): - if read.urepr in new_trace: - linear_reads.extend(new_trace[read.urepr].linear_reads or - [read.urepr]) - t_outloop_cost += new_trace[read.urepr].project*cost + traced = new_trace.get(read.urepr) + if traced and traced.level >= bounds[0]: + handle = traced.linear_reads or [read] + if cost: + for i in handle: + # One prod in the closest linear loop + t_outloop_cost += t.niters('out', lda[i]) + # The rest falls outside of the linear loops + t_outloop_cost += (cost - 1)*t.niters('nonlinear') else: - linear_reads.extend([read.urepr]) - - # Factorization will kill duplicates and increase the number of sums - # in the outer loop - fact_syms = {s.urepr if isinstance(s, Symbol) else s for s in linear_reads} - t_outloop_cost += len(linear_reads) - len(fact_syms) - - # Note: if n=len(fact_syms), then we'll have n prods, n-1 sums - t_inloop_cost += 2*len(fact_syms) - 1 - - # Add to the total and scale up by the corresponding number of iterations - total_outloop_cost += t_outloop_cost*t.niters_after_licm - level_inloop_cost += t_inloop_cost*t.niters - - # Update the trace because we want to track the cost after "pushing" the - # temporaries on which /t/ depends into /t/ itself - new_trace[t.urepr].linear_reads_costs = {s: 1 for s in fact_syms} - - # Some temporaries at levels < /i/ may also appear in: - # 1) subsequent loops - # 2) levels beyond /i/ + handle = [read] + linear_reads.extend(handle) + factors = {as_urepr(i): i for i in linear_reads}.values() + # Take into account the increased number of sums (due to fact) + hoist_region = set.union(*[lda[i] for i in factors]) + niters = t.niters('out', hoist_region) + t_outloop_cost += (len(linear_reads) - len(factors))*niters + total_outloop_cost += t_outloop_cost + + # Compute the cost induced by /t/ in the main loop after fact+licm + # We end up creating n prods and n -1 sums + t_inloop_cost = 2*len(factors) - 1 + level_inloop_cost += t_inloop_cost*t.niters('all') + + # Take into account any hoistable reductions + if t.is_increment: + for i in factors: + handle = [l.dim for l in t.reductions if l.dim not in i.rank] + level_inloop_cost -= t.niters('all') - t.niters('out', handle) + + # Keep the trace up-to-date + linear_reads_costs = {i: 1 for i in factors} + new_trace[t.urepr].linear_reads_costs = linear_reads_costs + + # Some temporaries within levels < /level/ might also appear in + # subsequent loops or levels beyond /level/, so they still contribute + # to the operation count for t in list(flatten([levels[j] for j in range(level)])): if any(rb.urepr not in new_trace for rb in t.readby) or \ any(new_trace[rb.urepr].level > level for rb in t.readby): - # Note: condition 1) is basically saying "if I'm read from + # Note: condition 1) is basically saying "if I'm read by # a temporary that is not in this loop's trace, then I must # be read in some other loops". - level_inloop_cost += t.flops*t.niters + level_inloop_cost += \ + new_trace[t.urepr].flops_projection*t.niters('all') + + post_cse_cost = self._cost_cse(fact_levels, (level + 1, bounds[1])) - # Total cost = cost_after_fact_up_to_level + cost_inloop_cse - # = cost_hoisted_subexprs + cost_inloop_fact + cost_inloop_cse - uptolevel_cost = cse_cost + total_outloop_cost + level_inloop_cost - uptolevel_cost += self._cost_cse(fact_levels, (level + 1, bounds[1])) + # Compute the total cost + total_inloop_cost = pre_cse_cost + level_inloop_cost + post_cse_cost + uptolevel_cost = total_outloop_cost + total_inloop_cost # Update the best alternative if uptolevel_cost < best[2]: best = (bounds[0], level, uptolevel_cost) - cse = self._cost_cse(fact_levels, (level + 1, bounds[1])) - - log('CSE: unpicking between levels [%d, %d]:' % bounds, COST_MODEL) - log('CSE: cost=%d (cse=%d, outloop=%d, inloop_fact=%d, inloop_cse=%d)' % - (uptolevel_cost, cse_cost, total_outloop_cost, level_inloop_cost, cse), COST_MODEL) + log('[CSE]: unpicking between [%d, %d]:' % (bounds[0], level), COST_MODEL) + log(' flops: %d -> %d (hoist=%d, preCSE=%d, fact=%d, postCSE=%d)' % + (input_cost, uptolevel_cost, total_outloop_cost, pre_cse_cost, + level_inloop_cost, post_cse_cost), COST_MODEL) return best @@ -416,7 +467,6 @@ def unpick(self): external_decls = [d for d in self.decls.values() if d.scope == EXTERNAL] fors = visit(self.header, info_items=['fors'])['fors'] lda = loops_analysis(self.header, value='dim') - ra = reachability_analysis(self.header, external_decls) # Collect all loops to be analyzed nests = OrderedDict() @@ -441,7 +491,7 @@ def unpick(self): current_cost = self._cost_cse(levels, (min_level, max_level)) global_best = (min_level, min_level, current_cost) for i in sorted(levels.keys()): - local_best = self._cost_fact(trace, levels, (i, max_level)) + local_best = self._cost_fact(trace, levels, lda, (i, max_level)) if local_best[2] < global_best[2]: global_best = local_best @@ -449,6 +499,7 @@ def unpick(self): # Transform the loop for i in range(global_best[0] + 1, global_best[1] + 1): + ra = reachability_analysis(self.header, external_decls) self._push_temporaries(levels[i-1], trace, global_trace, ra) self._transform_temporaries(levels[i]) diff --git a/coffee/expander.py b/coffee/expander.py index 74663f0c..9ece54cc 100644 --- a/coffee/expander.py +++ b/coffee/expander.py @@ -50,12 +50,11 @@ class Expander(object): GROUP = 0 # Expression /will/ not trigger expansion EXPAND = 1 # Expression /could/ be expanded - def __init__(self, stmt, expr_info=None, decls=None, hoisted=None, expr_graph=None): + def __init__(self, stmt, expr_info=None, decls=None, hoisted=None): self.stmt = stmt self.expr_info = expr_info self.decls = decls self.hoisted = hoisted - self.expr_graph = expr_graph self.local_decls = {} @@ -74,7 +73,7 @@ def _expand(self, node, parent): return ([node], self.EXPAND) if self.should_expand(node) \ else ([node], self.GROUP) - elif isinstance(node, (Div, FunCall)): + elif isinstance(node, (Div, Ternary, FunCall)): # Try to expand /within/ the children, but then return saying "I'm not # expandable any further" for n in node.children: @@ -96,7 +95,7 @@ def _expand(self, node, parent): expansion = self._build(exp, grp) to_replace.setdefault(exp, []).append(expansion) ast_replace(node, {k: ast_make_expr(Sum, v) for k, v in to_replace.items()}, - mode='symbol') + copy=False, mode='symbol') # Update the parent node, since an expression has just been expanded expanded = node.right if l_type == self.GROUP else node.left parent.children[parent.children.index(node)] = expanded diff --git a/coffee/expression.py b/coffee/expression.py index f082bfd0..96d8e784 100644 --- a/coffee/expression.py +++ b/coffee/expression.py @@ -42,24 +42,18 @@ class MetaExpr(object): """Metadata container for a compute-intensive expression.""" - def __init__(self, type, parent, loops_info, linear_dims, mode=0): + def __init__(self, type, parent, loops_info, mode=0): """Initialize the MetaExpr. :param type: the C type of the expression. :param parent: the node in which the expression is embedded. :param loops_info: an iterator of 2-tuples; each tuple represents a loop enclosing the expression (first entry) and its parent (second entry). - :param linear_dims: the dimensions of the linear loops enclosing the - expression, as an n-tuple. The expression is affine in the symbols - varying along a linear loop. For a formal definition of a linear loop, - please refer to the paper ``An algorithm for the optimization of - finite element integration loop nests``. :param mode: the suggested rewrite mode. """ self._type = type self._parent = parent self._loops_info = list(loops_info) - self._linear_dims = linear_dims self._mode = mode @property @@ -72,12 +66,16 @@ def dims(self): @property def linear_dims(self): - return self._linear_dims + return tuple(l.dim for l in self.linear_loops) @property def out_linear_dims(self): return tuple(d for d in self.dims if d not in self.linear_dims) + @property + def reduction_dims(self): + return tuple(l.dim for l in self.reduction_loops) + @property def loops(self): return list(zip(*self._loops_info))[0] @@ -96,15 +94,15 @@ def loops_info(self): @property def linear_loops(self): - return tuple([l for l in self.loops if l.dim in self.linear_dims]) + return tuple([l for l in self.loops if l.is_linear]) @property def linear_loops_parents(self): - return tuple([p for l, p in self._loops_info if l.dim in self.linear_dims]) + return tuple([p for l, p in self._loops_info if l.is_linear]) @property def linear_loops_info(self): - return tuple([(l, p) for l, p in self._loops_info if l.dim in self.linear_dims]) + return tuple([(l, p) for l, p in self._loops_info if l.is_linear]) @property def out_linear_loops(self): @@ -118,6 +116,25 @@ def out_linear_loops_parents(self): def out_linear_loops_info(self): return tuple([i for i in self.loops_info if i not in self.linear_loops_info]) + @property + def reduction_loops(self): + stmts = FindInstances((Writer, Incr)).visit(self.parent) + if stmts[Incr]: + writers = flatten(stmts.values()) + return tuple(l for l in self.loops + if all(l.dim not in i.lvalue.rank for i in writers)) + else: + return () + + @property + def reduction_loops_parents(self): + retval = self.reduction_loops_info + return zip(*retval)[1] if retval else () + + @property + def reduction_loops_info(self): + return tuple((l, p) for l, p in self.loops_info if l in self.reduction_loops) + @property def perfect_loops(self): """Return the loops in a perfect loop nest for the expression.""" @@ -143,6 +160,22 @@ def outermost_linear_loop(self): def outermost_linear_loop_parent(self): return self.linear_loops_parents[0] if len(self.linear_loops_parents) > 0 else None + @property + def innermost_loop(self): + return self.loops[-1] if len(self.loops) > 0 else None + + @property + def innermost_parent(self): + return self.loops_parents[-1] if len(self.loops_parents) > 0 else None + + @property + def innermost_linear_loop(self): + return self.linear_loops[-1] if len(self.linear_loops) > 0 else None + + @property + def innermost_linear_loop_parent(self): + return self.linear_loops_parents[-1] if len(self.linear_loops_parents) > 0 else None + @property def dimension(self): return len(self.linear_dims) if not self.is_scalar else 0 @@ -176,10 +209,9 @@ def copy_metaexpr(expr_info, **kwargs): """Given a ``MetaExpr``, return a plain new ``MetaExpr`` starting from a copy of ``expr_info``, and replaces some attributes as specified in ``kwargs``. ``kwargs`` accepts the following keys: parent, loops_info, - linear_dims, mode.""" + mode.""" parent = kwargs.get('parent', expr_info.parent) - linear_dims = kwargs.get('linear_dims', expr_info.linear_dims) mode = kwargs.get('mode', expr_info.mode) new_loops_info, old_loops_info = [], expr_info.loops_info @@ -193,4 +225,4 @@ def copy_metaexpr(expr_info, **kwargs): else: new_loops_info.append(loop_info) - return MetaExpr(expr_info.type, parent, new_loops_info, linear_dims, mode) + return MetaExpr(expr_info.type, parent, new_loops_info, mode) diff --git a/coffee/hoister.py b/coffee/hoister.py index 7b082e52..81cfbaac 100644 --- a/coffee/hoister.py +++ b/coffee/hoister.py @@ -36,7 +36,6 @@ from .base import * from .utils import * -from .logger import warn class Extractor(object): @@ -44,35 +43,11 @@ class Extractor(object): EXT = 0 # expression marker: extract STOP = 1 # expression marker: do not extract - @staticmethod - def factory(mode, stmt, expr_info): - if mode == 'normal': - should_extract = lambda d: True - return MainExtractor(stmt, expr_info, should_extract) - elif mode == 'only_const': - # Do not extract unless constant in all loops - should_extract = lambda d: not (d and d.issubset(set(expr_info.dims))) - return MainExtractor(stmt, expr_info, should_extract) - elif mode == 'only_outlinear': - should_extract = lambda d: d.issubset(set(expr_info.out_linear_dims)) - return MainExtractor(stmt, expr_info, should_extract) - elif mode == 'only_linear': - should_extract = lambda d: not (d.issubset(set(expr_info.out_linear_dims))) - return SoftExtractor(stmt, expr_info, should_extract) - elif mode == 'aggressive': - should_extract = lambda d: True - return AggressiveExtractor(stmt, expr_info, should_extract) - else: - raise RuntimeError("Requested an invalid Extractor (%s)" % mode) - def __init__(self, stmt, expr_info, should_extract): self.stmt = stmt self.expr_info = expr_info self.should_extract = should_extract - def _handle_expr(*args): - raise NotImplementedError("Extractor is an abstract class") - def _apply_cse(self): # Find common sub-expressions heuristically looking at binary terminal # operations (i.e., a terminal has two Symbols as children). This may @@ -108,20 +83,23 @@ def _visit(self, node): elif isinstance(node, (FunCall, Ternary)): arg_deps = [self._visit(n) for n in node.children] - dep = tuple(set(flatten([dep for dep, _ in arg_deps]))) + dep = set(flatten([dep for dep, _ in arg_deps])) info = self.EXT if all(i == self.EXT for _, i in arg_deps) else self.STOP return (dep, info) else: - left, right = node.children - dep_l, info_l = self._visit(left) - dep_r, info_r = self._visit(right) - - dep_l = {d for d in dep_l if d in self.expr_info.dims} - dep_r = {d for d in dep_r if d in self.expr_info.dims} - dep_n = dep_l | dep_r - - return self._handle_expr(left, right, dep_l, dep_r, dep_n, info_l, info_r) + retval = [(n,) + self._visit(n) for n in node.children] + dep = set.union(*[d for _, d, _ in retval]) + dep = {d for d in dep if d in self.expr_info.dims} + if self.should_extract(dep) or self._look_ahead: + # Still a chance of finding a bigger expression + return (dep, self.EXT) + else: + for n, n_dep, n_info in retval: + if n_info == self.EXT and not isinstance(n, Symbol): + k = sorted(n_dep, key=lambda i: self.expr_info.dims.index(i)) + self.extracted.setdefault(tuple(k), []).append(n) + return (dep, self.STOP) def extract(self, look_ahead, lda, with_cse=False): """Extract invariant subexpressions from /self.expr/.""" @@ -139,119 +117,18 @@ def extract(self, look_ahead, lda, with_cse=False): return self.extracted -class MainExtractor(Extractor): - - def _handle_expr(self, left, right, dep_l, dep_r, dep_n, info_l, info_r): - if info_l == self.EXT and info_r == self.EXT: - if dep_l == dep_r: - # E.g. alpha*beta, A[i] + B[i] - return (dep_l, self.EXT) - elif not dep_l: - # E.g. alpha*A[i,j] - self._try(left, dep_l) - if not (set(self.expr_info.linear_dims) & dep_r and self._try(left, dep_l)): - return (dep_r, self.EXT) - elif not dep_r: - # E.g. A[i,j]*alpha - self._try(right, dep_r) - if not (set(self.expr_info.linear_dims) & dep_l and self._try(left, dep_l)): - return (dep_l, self.EXT) - elif dep_l.issubset(dep_r): - # E.g. A[i]*B[i,j] - if not self._try(left, dep_l): - return (dep_n, self.EXT) - elif dep_r.issubset(dep_l): - # E.g. A[i,j]*B[i] - if not self._try(right, dep_r): - return (dep_n, self.EXT) - else: - # E.g. A[i]*B[j] - self._try(left, dep_l) - self._try(right, dep_r) - elif info_r == self.EXT: - self._try(right, dep_r) - elif info_l == self.EXT: - self._try(left, dep_l) - return (dep_n, self.STOP) - - -class SoftExtractor(Extractor): - - def _handle_expr(self, left, right, dep_l, dep_r, dep_n, info_l, info_r): - if info_l == self.EXT and info_r == self.EXT: - if dep_l == dep_r: - # E.g. alpha*beta, A[i] + B[i] - return (dep_l, self.EXT) - elif dep_l.issubset(dep_r): - # E.g. A[i]*B[i,j] - if not self._try(right, dep_r): - return (dep_n, self.EXT) - elif dep_r.issubset(dep_l): - # E.g. A[i,j]*B[i] - if not self._try(left, dep_l): - return (dep_n, self.EXT) - else: - # E.g. A[i]*B[j] - self._try(left, dep_l) - self._try(right, dep_r) - elif info_r == self.EXT: - self._try(right, dep_r) - elif info_l == self.EXT: - self._try(left, dep_l) - return (dep_n, self.STOP) - - -class AggressiveExtractor(Extractor): - - def _handle_expr(self, left, right, dep_l, dep_r, dep_n, info_l, info_r): - if info_l == self.EXT and info_r == self.EXT: - if dep_l == dep_r: - # E.g. alpha*beta, A[i] + B[i] - return (dep_l, self.EXT) - elif not dep_l: - # E.g. alpha*A[i,j], not hoistable anymore - self._try(right, dep_r) - elif not dep_r: - # E.g. A[i,j]*alpha, not hoistable anymore - self._try(left, dep_l) - elif dep_l.issubset(dep_r): - # E.g. A[i]*B[i,j] - if not self._try(left, dep_l): - return (dep_n, self.EXT) - elif dep_r.issubset(dep_l): - # E.g. A[i,j]*B[i] - if not self._try(right, dep_r): - return (dep_n, self.EXT) - else: - # E.g. A[i]*B[j], hoistable in TMP[i,j] - return (dep_n, self.EXT) - elif info_r == self.EXT: - self._try(right, dep_r) - elif info_l == self.EXT: - self._try(left, dep_l) - return (dep_n, self.STOP) - - class Hoister(object): - # How many times the hoister was invoked - _handled = 0 # Temporary variables template - _hoisted_sym = "%(loop_dep)s_%(expr_id)d_%(round)d_%(i)d" + _template = "ct%d" - def __init__(self, stmt, expr_info, header, decls, hoisted, expr_graph): + def __init__(self, stmt, expr_info, header, decls, hoisted): """Initialize the Hoister.""" self.stmt = stmt self.expr_info = expr_info self.header = header self.decls = decls self.hoisted = hoisted - self.expr_graph = expr_graph - - # Increment counters for unique variable names - self.nextracted = 0 - self.expr_id = Hoister._handled - Hoister._handled += 1 def _filter(self, dep, subexprs, make_unique=True, sharing=None): """Filter hoistable subexpressions.""" @@ -268,8 +145,7 @@ def _filter(self, dep, subexprs, make_unique=True, sharing=None): finder = FindInstances(Symbol) partitions = defaultdict(list) for e in subexprs: - retval = FindInstances.default_retval() - symbols = tuple(set(str(s) for s in finder.visit(e, ret=retval)[Symbol] + symbols = tuple(set(str(s) for s in finder.visit(e)[Symbol] if str(s) in sharing)) partitions[symbols].append(e) for shared, partition in partitions.items(): @@ -288,27 +164,52 @@ def _is_hoistable(self, subexprs, loop): reads = [s.symbol for s in reads[Symbol]] return set.isdisjoint(set(reads), set(written)) - def extract(self, mode, **kwargs): + def _locate(self, dep, subexprs, with_promotion=False): + # Start assuming no "real" hoisting can take place + # E.g.: for i {a[i]*(t1 + t2);} --> for i {t3 = t1 + t2; a[i]*t3;} + place, offset = self.expr_info.innermost_loop.block, self.stmt + + if with_promotion: + # Hoist outside a loop even though this doesn't result in any + # operation count reduction + should_jump = lambda l: True + else: + # "Standard" code motion case, i.e. moving /subexprs/ as far as + # possible in the loop nest such that dependencies are honored + should_jump = lambda l: l.dim not in dep + + loops = list(reversed(self.expr_info.loops)) + candidates = [l.block for l in loops[1:]] + [self.header] + + for loop, candidate in zip(loops, candidates): + if not self._is_hoistable(subexprs, loop): + break + if should_jump(loop): + place, offset = candidate, loop + + # Determine how much extra memory and whether clone loops are needed + jumped = loops[:candidates.index(place) + 1] + clone = tuple(l for l in reversed(jumped) if l.dim in dep) + + return place, offset, clone + + def extract(self, should_extract, **kwargs): """Return a dictionary of hoistable subexpressions.""" lda = kwargs.get('lda') or loops_analysis(self.header, value='dim') - extractor = Extractor.factory(mode, self.stmt, self.expr_info) + extractor = Extractor(self.stmt, self.expr_info, should_extract) return extractor.extract(True, lda) - def licm(self, mode, **kwargs): + def licm(self, should_extract, **kwargs): """Perform generalized loop-invariant code motion.""" max_sharing = kwargs.get('max_sharing', False) + with_promotion = kwargs.get('with_promotion', False) iterative = kwargs.get('iterative', True) lda = kwargs.get('lda') or loops_analysis(self.header, value='dim') global_cse = kwargs.get('global_cse', False) - expr_dims_loops = self.expr_info.loops_from_dims - expr_outermost_loop = self.expr_info.outermost_loop - expr_outermost_linear_loop = self.expr_info.outermost_linear_loop - is_bilinear = self.expr_info.is_bilinear + extractor = Extractor(self.stmt, self.expr_info, should_extract) - extractor = Extractor.factory(mode, self.stmt, self.expr_info) extracted = True - while extracted: extracted = extractor.extract(False, lda, global_cse) for dep, subexprs in extracted.items(): @@ -320,64 +221,16 @@ def licm(self, mode, **kwargs): if not subexprs: continue - # 2) Determine the loop nest level where invariant expressions - # should be hoisted. The goal is to hoist them as far as possible - # in the loop nest, while minimising temporary storage. - # We distinguish several cases: - depth = len(dep) - if depth == 0: - # As scalar, outside of the loop nest; - place = self.header - wrap_loop = () - offset = expr_outermost_loop - elif depth == 1 and len(expr_dims_loops) == 1: - # As scalar, within the only loop present - place = expr_outermost_loop.children[0] - wrap_loop = () - offset = place.children[place.children.index(self.stmt)] - elif depth == 1 and len(expr_dims_loops) > 1: - if expr_dims_loops[dep[0]] == expr_outermost_loop: - # As scalar, within the outermost loop - place = expr_outermost_loop.children[0] - wrap_loop = () - offset = od_find_next(expr_dims_loops, dep[0]) - else: - # As vector, outside of the loop nest; - place = self.header - wrap_loop = (expr_dims_loops[dep[0]],) - offset = expr_outermost_loop - elif mode == 'aggressive' and set(dep) == set(self.expr_info.dims) and \ - not any([self.expr_graph.is_written(e) for e in subexprs]): - # As n-dimensional vector (n == depth), outside of the loop nest - place = self.header - wrap_loop = tuple(expr_dims_loops.values()) - offset = expr_outermost_loop - elif depth == 2: - if self._is_hoistable(subexprs, expr_outermost_linear_loop): - # As vector, within the outermost loop imposing the dependency - place = expr_dims_loops[dep[0]].children[0] - wrap_loop = tuple(expr_dims_loops[dep[i]] for i in range(1, depth)) - offset = od_find_next(expr_dims_loops, dep[0]) - elif expr_outermost_linear_loop.dim == dep[-1] and is_bilinear: - # As scalar, within the closest loop imposing the dependency - place = expr_dims_loops[dep[-1]].children[0] - wrap_loop = () - offset = od_find_next(expr_dims_loops, dep[-1]) - else: - # As scalar, within the closest loop imposing the dependency - place = expr_dims_loops[dep[-1]].children[0] - wrap_loop = () - offset = place.children[place.children.index(self.stmt)] - else: - warn("Skipping unexpected code motion case.") - return + # 2) Determine the outermost loop where invariant expressions + # can be hoisted without breaking data dependencies. + place, offset, clone = self._locate(dep, subexprs, with_promotion) - loop_size = tuple([l.size for l in wrap_loop]) - loop_dim = tuple([l.dim for l in wrap_loop]) + loop_size = tuple(l.size for l in clone) + loop_dim = tuple(l.dim for l in clone) # 3) Create the required new AST nodes symbols, decls, stmts = [], [], [] - for i, e in enumerate(subexprs): + for e in subexprs: already_hoisted = False if global_cse and self.hoisted.get_symbol(e): name = self.hoisted.get_symbol(e) @@ -386,50 +239,116 @@ def licm(self, mode, **kwargs): place.children.index(decl) < place.children.index(offset): already_hoisted = True if not already_hoisted: - name = self._hoisted_sym % { - 'loop_dep': '_'.join(dep) if dep else 'c', - 'expr_id': self.expr_id, - 'round': self.nextracted, - 'i': i - } + name = self._template % (len(self.hoisted) + len(stmts)) stmts.append(Assign(Symbol(name, loop_dim), dcopy(e))) - decl = Decl(self.expr_info.type, Symbol(name, loop_size), - scope=LOCAL) - decls.append(decl) - self.decls[name] = decl + decls.append(Decl(self.expr_info.type, + Symbol(name, loop_size), + scope=LOCAL)) symbols.append(Symbol(name, loop_dim)) # 4) Replace invariant sub-expressions with temporaries - to_replace = dict(zip(subexprs, symbols)) - n_replaced = ast_replace(self.stmt.rvalue, to_replace) - - # 5) Update data dependencies - for s, e in zip(symbols, subexprs): - self.expr_graph.add_dependency(s, e) - if n_replaced[str(s)] > 1: - self.expr_graph.add_dependency(s, s) - lda[s] = dep - - # 6) Modify the AST adding the hoisted expressions - if wrap_loop: - outer_wrap_loop = ast_make_for(stmts, wrap_loop[-1]) - for l in reversed(wrap_loop[:-1]): - outer_wrap_loop = ast_make_for([outer_wrap_loop], l) - code = decls + [outer_wrap_loop] - wrap_loop = outer_wrap_loop + replacements = ast_replace(self.stmt, dict(zip(subexprs, symbols))) + + # 5) Modify the AST adding the hoisted expressions + if clone: + outer_clone = ast_make_for(stmts, clone[-1]) + for l in reversed(clone[:-1]): + outer_clone = ast_make_for([outer_clone], l) + code = decls + [outer_clone] + clone = outer_clone else: code = decls + stmts - wrap_loop = None - # Insert the new nodes at the right level in the loop nest + clone = None offset = place.children.index(offset) place.children[offset:offset] = code - # Track hoisted symbols + + # 6) Track hoisted symbols and data dependencies for i, j in zip(stmts, decls): - self.hoisted[j.sym.symbol] = (i, j, wrap_loop, place) + name = j.lvalue.symbol + self.hoisted[name] = (i, j, clone, place) + self.decls[name] = j + lda.update({s: set(dep) for s in replacements}) - self.nextracted += 1 if not iterative: break - # Finally, make sure symbols are unique in the AST - self.stmt.rvalue = dcopy(self.stmt.rvalue) + def trim(self, candidate, **kwargs): + """ + Remove unnecessary reduction loops from the expression loop nest. + Sometimes, reduction loops can be factored out in outer loops, thus + reducing the operation count, without breaking data dependencies. + """ + # Rule out unsafe cases + if not is_perfect_loop(self.expr_info.innermost_loop): + return + + # Find out all reducible symbols + lda = kwargs.get('lda') or loops_analysis(self.header) + reducible, other = [], [] + for i in summands(self.stmt.rvalue): + symbols = FindInstances(Symbol).visit(i)[Symbol] + unavoidable = set.intersection(*[set(lda[s]) for s in symbols]) + if candidate in unavoidable: + return + reducible.extend([s.symbol for s in symbols if candidate in lda[s]]) + other.extend([s.symbol for s in symbols if candidate not in lda[s]]) + + # Make sure we do not break data dependencies + make_reduce = [] + writes = FindInstances(Writer).visit(candidate) + for w in flatten(writes.values()): + if isinstance(w.rvalue, EmptyStatement): + continue + if any(s == w.lvalue.symbol for s in other): + return + if any(s == w.lvalue.symbol for s in reducible): + loop = lda[w.lvalue][-1] + make_reduce.append((w, loop)) + + assignments = [(w, p) for w, p in make_reduce if isinstance(w, Assign)] + loops, parents = zip(*self.expr_info.loops_info) + index = loops.index(candidate) + + # Perform a number of checks to ensure lifting reductions is safe + if not all(s in [w.lvalue.symbol for w, _ in make_reduce] for s in reducible): + return + if any(p != candidate and not is_perfect_loop(p) for w, p in make_reduce): + return + if any(candidate.dim in w.lvalue.rank for w, _ in assignments): + return + if any(set(loops[index + 1:]) & set(lda[w.lvalue]) for w, _ in make_reduce): + return + + # Inject the reductions into the AST + for w, p in make_reduce: + name = self._template % len(self.hoisted) + reduction = Incr(Symbol(name, w.lvalue.rank, w.lvalue.offset), + ast_reconstruct(w.rvalue)) + insert_at_elem(p.body, w, reduction) + handle = self.decls[w.lvalue.symbol] + declaration = Decl(handle.typ, Symbol(name, handle.lvalue.rank), + ArrayInit(np.array([0.0])), handle.qual, handle.attr) + insert_at_elem(parents[index].children, candidate, declaration) + ast_replace(self.stmt, {w.lvalue: reduction.lvalue}, copy=True) + self.hoisted[name] = (reduction, declaration, p, p.body) + + # Pull out the candidate reduction loop + pulling = loops[index + 1:] + pulling = list(zip(*[((l.start, l.end), l.dim) for l in pulling])) + pulling = ItSpace().to_for(*pulling, stmts=[self.stmt]) + insert_at_elem(parents[index].children, candidate, pulling[0], ofs=1) + if len(self.expr_info.parent.children) == 1: + loops[index].body.remove(loops[index + 1]) + else: + self.expr_info.parent.children.remove(self.stmt) + + # Clean up removing any now unnecessary symbols + reads = in_read(candidate, key='symbol') + declarations = FindInstances(Decl, with_parent=True).visit(self.header)[Decl] + declarations = dict(declarations) + for w, p in make_reduce: + if w.lvalue.symbol not in reads: + p.body.remove(w) + if not isinstance(w, Decl): + key = self.decls.pop(w.lvalue.symbol) + declarations[key].children.remove(key) diff --git a/coffee/optimizer.py b/coffee/optimizer.py index 11faed00..a9a9ceb0 100644 --- a/coffee/optimizer.py +++ b/coffee/optimizer.py @@ -68,8 +68,6 @@ def __init__(self, loop, header, decls, exprs): # Track nonzero regions accessed in each symbol self.nz_syms = {} - # Track data dependencies - self.expr_graph = ExpressionGraph(header) # Track hoisted expressions self.hoisted = StmtTracker() @@ -89,8 +87,6 @@ def rewrite(self, mode): that fully depend on reduction loops. * mode == 4: rewrite an expression based on its sharing graph """ - ExpressionRewriter.reset() - # Set a rewrite mode for each expression for stmt, expr_info in self.exprs.items(): expr_info.mode = mode @@ -110,7 +106,7 @@ def rewrite(self, mode): # Expression rewriting, expressed as a sequence of AST transformation passes for stmt, expr_info in self.exprs.items(): ew = ExpressionRewriter(stmt, expr_info, self.decls, self.header, - self.hoisted, self.expr_graph) + self.hoisted) if expr_info.mode == 1: if expr_info.dimension in [0, 1]: @@ -121,7 +117,8 @@ def rewrite(self, mode): elif expr_info.mode == 2: if expr_info.dimension > 0: ew.replacediv() - ew.SGrewrite() + ew.sharing_graph_rewrite() + ew.licm(mode='reductions') elif expr_info.mode == 3: ew.expand(mode='all') @@ -139,11 +136,11 @@ def rewrite(self, mode): ew.licm(mode='only_outlinear') if expr_info.dimension > 0: ew.licm(mode='only_linear', iterative=False, max_sharing=True) - ew.SGrewrite() + ew.sharing_graph_rewrite() ew.expand() # Try merging the loops created by expression rewriting - merged_loops = SSALoopMerger(self.expr_graph).merge(self.header) + merged_loops = SSALoopMerger().merge(self.header) # Update the trackers for merged, merged_in in merged_loops: for l in merged: @@ -157,7 +154,6 @@ def rewrite(self, mode): # Reduce memory pressure by avoiding useless temporaries self._min_temporaries() - self.expr_graph = ExpressionGraph(self.header) # Handle the effects, at the C-level, of the AST transformation self._recoil() @@ -167,131 +163,14 @@ def eliminate_zeros(self): avoid evaluation of arithmetic operations involving zero-valued blocks in statically initialized arrays.""" - zls = ZeroRemover(self.exprs, self.decls, self.hoisted, self.expr_graph) + zls = ZeroRemover(self.exprs, self.decls, self.hoisted) self.nz_syms = zls.reschedule(self.header) - def precompute(self, mode='perfect'): - """Precompute statements out of ``self.loop``. This is achieved through - scalar code hoisting. - - :arg mode: drives the precomputation. Two values are possible: ['perfect', - 'noloops']. The 'perfect' mode attempts to hoist everything, making the loop - nest perfect. The 'noloops' mode excludes inner loops from the precomputation. - - Example: :: - - for i - for r - A[r] += f(i, ...) - for j - for k - B[j][k] += g(A[r], ...) - - with mode='perfect', becomes: :: - - for i - for r - A[i][r] += f(...) - for i - for j - for k - B[j][k] += g(A[i][r], ...) - """ - - precomputed_block = [] - precomputed_syms = {} - - def _precompute(node, outer_block): - - if isinstance(node, Symbol): - if node.symbol in precomputed_syms: - node.rank = precomputed_syms[node.symbol] + node.rank - - elif isinstance(node, FlatBlock): - outer_block.append(node) - - elif isinstance(node, Expr): - for n in node.children: - _precompute(n, outer_block) - - elif isinstance(node, Writer): - sym, expr = node.children - precomputed_syms[sym.symbol] = (self.loop.dim,) - _precompute(sym, outer_block) - _precompute(expr, outer_block) - outer_block.append(node) - - elif isinstance(node, Decl): - outer_block.append(node) - if isinstance(node.init, Symbol): - node.init.symbol = "{%s}" % node.init.symbol - elif isinstance(node.init, Expr): - _precompute(Assign(dcopy(node.sym), node.init), outer_block) - node.init = EmptyStatement() - node.sym.rank = (self.loop.size,) + node.sym.rank - - elif isinstance(node, For): - new_children = [] - for n in node.body: - _precompute(n, new_children) - node.body = new_children - outer_block.append(node) - - else: - raise RuntimeError("Precompute error: unexpteced node: %s" % str(node)) - - # If the outermost loop is already perfect, there is nothing to precompute - if is_perfect_loop(self.loop): - return - - # Get the nodes that should not be precomputed - no_precompute = set() - if mode == 'noloops': - for l in self.hoisted.values(): - if l.loop: - no_precompute.add(l.decl) - no_precompute.add(l.loop) - - # Visit the AST and perform the precomputation - to_remove = [] - for n in self.loop.body: - if n in flatten(self.expr_linear_loops): - break - elif n not in no_precompute: - _precompute(n, precomputed_block) - to_remove.append(n) - - # Clean up - for n in to_remove: - self.loop.body.remove(n) - - # Wrap precomputed statements within a loop - searching, outer_block = [], [] - for n in precomputed_block: - if searching and not isinstance(n, Writer): - outer_block.append(ast_make_for(searching, self.loop)) - searching = [] - if isinstance(n, For): - outer_block.append(ast_make_for([n], self.loop)) - elif isinstance(n, Writer): - searching.append(n) - else: - outer_block.append(n) - if searching: - outer_block.append(ast_make_for(searching, self.loop)) - - # Update the AST ... - # ... adding the newly precomputed blocks - insert_at_elem(self.header.children, self.loop, outer_block) - # ... scalar-expanding the precomputed symbols - ast_update_rank(self.loop, precomputed_syms) - def _unpick_cse(self): """Search for factorization opportunities across temporaries created by common sub-expression elimination. If a gain in operation count is detected, unpick CSE and apply factorization + code motion.""" - cse_unpicker = CSEUnpicker(self.exprs, self.header, self.hoisted, - self.decls, self.expr_graph) + cse_unpicker = CSEUnpicker(self.exprs, self.header, self.hoisted, self.decls) cse_unpicker.unpick() def _min_temporaries(self): @@ -378,6 +257,8 @@ def _dissect(self, heuristics): # to too much temporary space, we have to partially drop it threshold = system.architecture['cache_size'] * 1.2 + expr_graph = ExpressionGraph(header) + # 1) Find out and unroll injectable loops. For unrolling we create new # expressions; that is, for now, we do not modify the AST in place. analyzed, injectable = [], {} @@ -476,7 +357,7 @@ def find_save(target_expr, expr_info): increase_factor = 0 for i in projection: partial = 1 - for j in self.expr_graph.shares(i): + for j in expr_graph.shares(i): # _n=number of unique elements, _k=group size _n = injectable[j[0]][1] _k = len(j) diff --git a/coffee/plan.py b/coffee/plan.py index a1287ad7..fbff92c8 100644 --- a/coffee/plan.py +++ b/coffee/plan.py @@ -85,7 +85,6 @@ def plan_cpu(self, opts): vectorize = opts.get('vectorize', (None, None)) align_pad = opts.get('align_pad') split = opts.get('split') - precompute = opts.get('precompute') dead_ops_elimination = opts.get('dead_ops_elimination') info = visit(kernel) @@ -93,10 +92,10 @@ def plan_cpu(self, opts): # Collect expressions and related metadata nests = defaultdict(OrderedDict) for stmt, expr_info in info['exprs'].items(): - parent, nest, linear_dims = expr_info + parent, nest = expr_info if not nest: continue - metaexpr = MetaExpr(check_type(stmt, decls), parent, nest, linear_dims) + metaexpr = MetaExpr(check_type(stmt, decls), parent, nest) nests[nest[0]].update({stmt: metaexpr}) loop_opts = [CPULoopOptimizer(loop, header, decls, exprs) for (loop, header), exprs in nests.items()] @@ -126,8 +125,6 @@ def plan_cpu(self, opts): # 2) Code specialization if split: loop_opt.split(split) - if precompute: - loop_opt.precompute(precompute) if coffee.initialized and flatten(loop_opt.expr_linear_loops): vect = LoopVectorizer(loop_opt, kernel) if align_pad: @@ -190,10 +187,10 @@ def plan_gpu(self): # Structure up expressions and related metadata nests = defaultdict(OrderedDict) for stmt, expr_info in info['exprs'].items(): - parent, nest, linear_dims = expr_info + parent, nest = expr_info if not nest: continue - metaexpr = MetaExpr(check_type(stmt, decls), parent, nest, linear_dims) + metaexpr = MetaExpr(check_type(stmt, decls), parent, nest) nests[nest[0]].update({stmt: metaexpr}) loop_opts = [GPULoopOptimizer(l, header, decls) for l, header in nests] diff --git a/coffee/rewriter.py b/coffee/rewriter.py index 4a3cd026..b20e8fa1 100644 --- a/coffee/rewriter.py +++ b/coffee/rewriter.py @@ -35,6 +35,8 @@ from six.moves import zip from collections import Counter +from itertools import combinations +from operator import itemgetter import pulp as ilp from .base import * @@ -54,7 +56,7 @@ class ExpressionRewriter(object): * Expansion: transform an expression ``(a + b)*c`` into ``(a*c + b*c)`` * Factorization: transform an expression ``a*b + a*c`` into ``a*(b+c)``""" - def __init__(self, stmt, expr_info, decls, header=None, hoisted=None, expr_graph=None): + def __init__(self, stmt, expr_info, decls, header=None, hoisted=None): """Initialize the ExpressionRewriter. :param stmt: the node whose rvalue is the expression for rewriting @@ -62,19 +64,17 @@ def __init__(self, stmt, expr_info, decls, header=None, hoisted=None, expr_graph :param decls: all declarations for the symbols in ``stmt``. :param header: the kernel's top node :param hoisted: dictionary that tracks all hoisted expressions - :param expr_graph: a graph for data dependence analysis """ self.stmt = stmt self.expr_info = expr_info self.decls = decls self.header = header or Root() self.hoisted = hoisted if hoisted is not None else StmtTracker() - self.expr_graph = expr_graph or ExpressionGraph(self.header) self.expr_hoister = Hoister(self.stmt, self.expr_info, self.header, - self.decls, self.hoisted, self.expr_graph) + self.decls, self.hoisted) self.expr_expander = Expander(self.stmt, self.expr_info, self.decls, - self.hoisted, self.expr_graph) + self.hoisted) self.expr_factorizer = Factorizer(self.stmt) def licm(self, mode='normal', **kwargs): @@ -84,13 +84,18 @@ def licm(self, mode='normal', **kwargs): http://dl.acm.org/citation.cfm?id=2687415 :param mode: drive code motion by specifying what subexpressions should - be hoisted + be hoisted and where. * normal: (default) all subexpressions that depend on one loop at most * aggressive: all subexpressions, depending on any number of loops. This may require introducing N-dimensional temporaries. + * incremental: apply, in sequence, only_const, only_outlinear, and + one sweep for each linear dimension * only_const: only all constant subexpressions * only_linear: only all subexpressions depending on linear loops * only_outlinear: only all subexpressions independent of linear loops + * reductions: all sub-expressions that are redundantly computed within + a reduction loop; if possible, pull the reduction loop out of + the nest. :param kwargs: * look_ahead: (default: False) should be set to True if only a projection of the hoistable subexpressions is needed (i.e., hoisting not performed) @@ -105,14 +110,110 @@ def licm(self, mode='normal', **kwargs): * global_cse: (default: False) search for common sub-expressions across all previously hoisted terms. Note that no data dependency analysis is performed, so this is at caller's risk. + * with_promotion: compute hoistable subexpressions within clone loops + even though this doesn't necessarily result in fewer operations. + + Examples + ======== + + 1) With mode='normal': :: + + for i + for j + for k + a[j][k] += (b[i][j] + c[i][j])*(d[i][k] + e[i][k]) + + Redundancies are spotted along both the i and j dimensions, resulting in: :: + + for i + for k + ct1[k] = d[i][k] + e[i][k] + for j + ct2 = b[i][j] + c[i][j] + for k + a[j][k] += ct2*ct1[k] + + 2) With mode='reductions'. + Consider the following loop nest: :: + + for i + for j + a[j] += b[j]*c[i] + + By unrolling the loops, one clearly sees that: :: + + a[0] += b[0]*c[0] + b[0]*c[1] + b[0]*c[2] + ... + a[1] += b[1]*c[0] + b[1]*c[1] + b[1]*c[2] + ... + + Which is identical to: :: + + ct = c[0] + c[1] + c[2] + ... + a[0] += b[0]*ct + a[1] += b[1]*ct + + Thus, the original loop nest is simplified as: :: + + for i + ct += c[i] + for j + a[j] += b[j]*ct """ + dimension = self.expr_info.dimension + dims = set(self.expr_info.dims) + linear_dims = set(self.expr_info.linear_dims) + out_linear_dims = set(self.expr_info.out_linear_dims) + if kwargs.get('look_ahead'): - return self.expr_hoister.extract(mode, **kwargs) - if mode == 'aggressive': - # Reassociation may promote more hoisting in /aggressive/ mode + hoist = self.expr_hoister.extract + else: + hoist = self.expr_hoister.licm + + if mode == 'normal': + should_extract = lambda d: d != dims + hoist(should_extract, **kwargs) + elif mode == 'reductions': + should_extract = lambda d: d != dims + # Expansion and reassociation may create hoistable reduction loops + candidates = self.expr_info.reduction_loops + if not candidates: + return self + candidate = candidates[-1] + if candidate.size == 1: + # Otherwise the operation count will just end up increasing + return + self.expand(mode='all') + lda = loops_analysis(self.header, value='dim') + non_candidates = {l.dim for l in candidates[:-1]} + self.reassociate(lambda i: not lda[i].intersection(non_candidates)) + hoist(should_extract, with_promotion=True, lda=lda) + self.expr_hoister.trim(candidate) + elif mode == 'incremental': + lda = kwargs.get('lda') or loops_analysis(self.header, value='dim') + should_extract = lambda d: not (d and d.issubset(dims)) + hoist(should_extract, lda=lda) + should_extract = lambda d: d.issubset(out_linear_dims) + hoist(should_extract, lda=lda) + for i in range(1, dimension): + should_extract = lambda d: len(d.intersection(linear_dims)) <= i + hoist(should_extract, lda=lda, **kwargs) + elif mode == 'only_const': + should_extract = lambda d: not (d and d.issubset(dims)) + hoist(should_extract, **kwargs) + elif mode == 'only_outlinear': + should_extract = lambda d: d.issubset(out_linear_dims) + hoist(should_extract, **kwargs) + elif mode == 'only_linear': + should_extract = lambda d: not d.issubset(out_linear_dims) and d != linear_dims + hoist(should_extract, **kwargs) + elif mode == 'aggressive': + should_extract = lambda d: True self.reassociate() - self.expr_hoister.licm(mode, **kwargs) + hoist(should_extract, with_promotion=True, **kwargs) + else: + warn('Skipping unknown licm strategy.') + return self + return self def expand(self, mode='standard', **kwargs): @@ -308,7 +409,7 @@ def _reassociate(node, parent): if isinstance(node, (Symbol, Div)): return - elif isinstance(node, (Sum, Sub, FunCall)): + elif isinstance(node, (Sum, Sub, FunCall, Ternary)): for n in node.children: _reassociate(n, node) @@ -443,40 +544,55 @@ def preevaluate(self): self.header.children.remove(hoisted_loop) return self - def SGrewrite(self): + def sharing_graph_rewrite(self): """Rewrite the expression based on its sharing graph. Details in the paper: - On Optimality of Finite Element Integration, Luporini et. al. + An algorithm for the optimization of finite element integration loops + (Luporini et. al.) """ - lda = loops_analysis(self.expr_info.linear_loops[0], key='symbol', value='dim') - sg_visitor = SharingGraph(self.expr_info, lda) - - # Maximize the visibility of linear symbols - sgraph, mapper = sg_visitor.visit(self.stmt.rvalue) - if 'topsum' in mapper: - self.expand(mode='linear', subexprs=[mapper['topsum']]) - sgraph, mapper = sg_visitor.visit(self.stmt.rvalue) + linear_dims = self.expr_info.linear_dims + other_dims = self.expr_info.out_linear_dims + + # Maximize visibility of linear symbols + self.expand(mode='all') + + # Make sure that potential reductions are not hidden away + lda = loops_analysis(self.header, value='dim') + self.reassociate(lambda i: (not lda[i]) + lda[i].issubset(set(other_dims))) + + # Construct the sharing graph + nodes, edges = [], [] + for i in summands(self.stmt.rvalue): + symbols = [i] if isinstance(i, Symbol) else list(zip(*explore_operator(i)))[0] + lsymbols = [s for s in symbols if any(d in lda[s] for d in linear_dims)] + lsymbols = [s.urepr for s in lsymbols] + nodes.extend([j for j in lsymbols if j not in nodes]) + edges.extend(combinations(lsymbols, r=2)) + sgraph = nx.Graph(edges) + + # Transform everything outside the sharing graph (pure linear, no ambiguity) + isolated = [n for n in nodes if n not in sgraph.nodes()] + for n in isolated: + self.factorize(mode='adhoc', adhoc={n: [] for n in nodes}) + self.licm('only_const').licm('only_outlinear') + # Transform the expression based on the sharing graph nodes, edges = sgraph.nodes(), sgraph.edges() + if not (nodes and all(sgraph.degree(n) > 0 for n in nodes)): + self.factorize(mode='heuristic') + self.licm('only_const').licm('only_outlinear') + return + # Use short variable names otherwise Pulp might complain + nodes_vars = {i: n for i, n in enumerate(nodes)} + vars_nodes = {n: i for i, n in nodes_vars.items()} + edges = [(vars_nodes[i], vars_nodes[j]) for i, j in edges] - if self.expr_info.is_linear: - self.factorize(mode='adhoc', adhoc={n: [] for n in nodes}) - self.licm('only_outlinear') - elif self.expr_info.is_bilinear: - # Resort to an ILP formulation to find out the best factorization candidates - if not (nodes and all(sgraph.degree(n) > 0 for n in nodes)): - self.factorize(mode='heuristic') - self.licm(mode='only_outlinear') - return - # Note: need to use short variable names otherwise Pulp might complain - nodes_vars = {i: n for i, n in enumerate(nodes)} - vars_nodes = {n: i for i, n in nodes_vars.items()} - edges = [(vars_nodes[i], vars_nodes[j]) for i, j in edges] - + def setup(): # ... declare variables x = ilp.LpVariable.dicts('x', nodes_vars.keys(), 0, 1, ilp.LpBinary) - y = ilp.LpVariable.dicts('y', [(i, j) for i, j in edges] + [(j, i) for i, j in edges], + y = ilp.LpVariable.dicts('y', + [(i, j) for i, j in edges] + [(j, i) for i, j in edges], 0, 1, ilp.LpBinary) limits = defaultdict(int) for i, j in edges: @@ -496,20 +612,37 @@ def SGrewrite(self): # ... define the objective function (min number of factorizations) prob += ilp.lpSum(x[i] for i in nodes_vars) - # ... solve the problem - prob.solve(ilp.GLPK(msg=0)) + return x, prob - # Finally, factorize and hoist (note: the order in which factorizations are carried - # out is crucial) - nodes = [nodes_vars[n] for n, v in x.items() if v.value() == 1] - other_nodes = [nodes_vars[n] for n, v in x.items() if nodes_vars[n] not in nodes] - for n in nodes + other_nodes: - self.factorize(mode='adhoc', adhoc={n: []}) - self.licm() + # Solve the ILP problem to find out the minimal-cost factorization strategy + x, prob = setup() + prob.solve(ilp.GLPK(msg=0)) - return self + # Also attempt to find another optimal factorization, but with + # additional constraints on the reduction dimensions. This may help in + # later rewrite steps + if len(other_dims) > 1: + z, prob = setup() + for i, n in nodes_vars.items(): + if not set(n[1]).intersection(set(other_dims[:-1])): + prob += z[i] == 0 + prob.solve(ilp.GLPK(msg=0)) + if ilp.LpStatus[prob.status] == 'Optimal': + x = z + + # ... finally, apply the transformations. Observe that: + # 1) the order (first /nodes/, than /other_nodes/) in which + # the factorizations are carried out is crucial + # 2) sorting /nodes/ and /other_nodes/ locally ensures guarantees + # deterministic output code + # 3) precedence is given to outer reduction loops; this maximises the + # impact of later transformations, while not affecting this pass + # 4) with_promotion is set to true if there exist potential reductions + # to simplify + nodes = [nodes_vars[n] for n, v in x.items() if v.value() == 1] + other_nodes = [nodes_vars[n] for n, v in x.items() if nodes_vars[n] not in nodes] + for n in sorted(nodes, key=itemgetter(1)) + sorted(other_nodes): + self.factorize(mode='adhoc', adhoc={n: []}) + self.licm('incremental', with_promotion=len(other_dims) > 1) - @staticmethod - def reset(): - Hoister._handled = 0 - Expander._handled = 0 + return self diff --git a/coffee/scheduler.py b/coffee/scheduler.py index d2dc9687..383f5068 100644 --- a/coffee/scheduler.py +++ b/coffee/scheduler.py @@ -58,37 +58,21 @@ class SSALoopMerger(LoopScheduler): Statements must be in "soft" SSA form: they can be declared and initialized at declaration time, then they can be assigned a value in only one place.""" - def __init__(self, expr_graph): - """Initialize the SSALoopMerger. + def _merge_loops(self, root, loop_a, loop_b): + """Merge the body of ``loop_a`` into ``loop_b``.""" + root.children.remove(loop_a) - :param expr_graph: the ExpressionGraph tracking all data dependencies - involving identifiers that appear in ``header``. - """ - self.expr_graph = expr_graph - self.merged_loops = [] + dims_a, dims_b = [loop_a.dim], [loop_b.dim] + while isinstance(loop_b.body[0], For): + dims_b.append(loop_b.dim) + loop_b = loop_b.body[0] + while isinstance(loop_a.body[0], For): + dims_a.append(loop_a.dim) + loop_a = loop_a.body[0] - def _merge_loops(self, root, loop_a, loop_b): - """Merge the body of ``loop_a`` in ``loop_b`` and eliminate ``loop_a`` - from the tree rooted in ``root``. Return a reference to the block - containing the merged loop as well as the iteration variables used - in the respective iteration spaces.""" - # Find the first statement in the perfect loop nest loop_b - dims_a, dims_b = [], [] - while isinstance(loop_b.children[0], (Block, For)): - if isinstance(loop_b, For): - dims_b.append(loop_b.dim) - loop_b = loop_b.children[0] - # Find the first statement in the perfect loop nest loop_a - root_loop_a = loop_a - while isinstance(loop_a.children[0], (Block, For)): - if isinstance(loop_a, For): - dims_a.append(loop_a.dim) - loop_a = loop_a.children[0] - # Merge body of loop_a in loop_b - loop_b.children[0:0] = loop_a.children - # Remove loop_a from root - root.children.remove(root_loop_a) - return (loop_b, tuple(dims_a), tuple(dims_b)) + loop_b.body = loop_a.body + loop_b.body + + ast_update_rank(loop_b, dict(zip(dims_a, dims_b))) def _simplify(self, merged_loops): """Scan the list of merged loops and eliminate sub-expressions that became @@ -111,16 +95,23 @@ def _simplify(self, merged_loops): A[i] = B[i] + C[i] D[i] = A[i] """ - to_replace = {} for loop in merged_loops: + to_replace = {} for stmt in loop.body: ast_replace(stmt, to_replace, copy=True) - to_replace[stmt.rvalue] = stmt.lvalue + if not isinstance(stmt, AugmentedAssign): + to_replace[stmt.rvalue] = stmt.lvalue def merge(self, root): """Merge perfect loop nests in ``root``.""" - found_nests = OrderedDict() + + # Make sure there are no empty loops within root, otherwise kill them + remove_empty_loops(root) + + expr_graph = ExpressionGraph(root) + # Collect iteration spaces visiting the tree rooted in /root/ + found_nests = OrderedDict() for n in root.children: if isinstance(n, For): retval = FindLoopNests.default_retval() @@ -165,7 +156,7 @@ def merge(self, root): in_writes = SymbolModes().visit(n, ret=SymbolModes.default_retval()) in_writes = [s for s, m in in_writes.items()] for iw, lw in product(in_writes, l_writes): - if self.expr_graph.is_written(iw, lw): + if expr_graph.is_written(iw, lw): is_mergeable = False break @@ -181,8 +172,7 @@ def merge(self, root): # If there is at least one mergeable loops, do the merging for l in reversed(mergeable): - merged, l_dims, m_dims = self._merge_loops(parent, l, merging_in) - ast_update_rank(merged, dict(zip(l_dims, m_dims))) + self._merge_loops(parent, l, merging_in) # Update the lists of merged loops all_merged.append((mergeable, merging_in)) merged_loops.append(merging_in) @@ -488,18 +478,16 @@ class ZeroRemover(LoopScheduler): THRESHOLD = 1 # Only skip if there more than THRESHOLD consecutive zeros - def __init__(self, exprs, decls, hoisted, expr_graph): + def __init__(self, exprs, decls, hoisted): """Initialize the ZeroRemover. :param exprs: the expressions for which zero removal is performed. :param decls: lists of declarations visible to ``exprs``. :param hoisted: dictionary that tracks hoisted sub-expressions - :param expr_graph: expression graph that tracks symbol dependencies """ self.exprs = exprs self.decls = decls self.hoisted = hoisted - self.expr_graph = expr_graph def _track_nz_expr(self, node, nz_syms, nest): """For the expression rooted in ``node``, return iteration space and @@ -799,8 +787,7 @@ def _recombine(self, nz_info): for stmt, expr_info in new_exprs.items(): ew = ExpressionRewriter(stmt, expr_info, self.decls, - expr_info.outermost_parent, - self.hoisted, self.expr_graph) + expr_info.outermost_parent, self.hoisted) ew.factorize('heuristic') if new_exprs: diff --git a/coffee/utils.py b/coffee/utils.py index eb252cb6..e08a47ee 100644 --- a/coffee/utils.py +++ b/coffee/utils.py @@ -44,6 +44,7 @@ from coffee.base import * from coffee.visitors.inspectors import * +from coffee.visitors.utilities import Reconstructor ##################################### @@ -51,40 +52,48 @@ ##################################### -def ast_replace(node, to_replace, copy=False, mode='all'): - """Given a dictionary ``to_replace`` s.t. ``{sym: new_sym}``, replace the - various ``syms`` rooted in ``node`` with ``new_sym``. +def ast_replace(node, to_replace, copy=True, mode='all'): + """ + Given the ``to_replace`` dictionary ``{k: v}``, replace each + ``k`` rooted in ``node`` with ``v``. - :param copy: if True, a deep copy of the replacing symbol is created. + :param copy: pass False to avoid reconstructing ``v`` each time ``k`` + is encountered. :param mode: either ``all``, in which case ``to_replace``'s keys are turned into strings, and all of the occurrences are removed from the AST; or ``symbol``, in which case only (all of) the references to the symbols given in ``to_replace`` are replaced. """ + assert mode in ['all', 'symbol'] if mode == 'all': - to_replace = dict(zip([str(s) for s in to_replace.keys()], to_replace.values())) - __ast_replace = lambda n: to_replace.get(str(n)) - elif mode == 'symbol': - __ast_replace = lambda n: to_replace.get(n) + to_replace = {str(k): v for k, v in to_replace.items()} + should_replace = lambda n: to_replace.get(str(n)) else: - raise ValueError + should_replace = lambda n: to_replace.get(n) + + replacements = [] - def _ast_replace(node, to_replace, n_replaced): - replaced = {} + def _ast_replace(node, to_replace): + replaced_children = {} for i, n in enumerate(node.children): - replacing = __ast_replace(n) - if replacing: - replaced[i] = replacing if not copy else dcopy(replacing) - n_replaced[str(replacing)] += 1 + v = should_replace(n) + if v: + replaced_children[i] = ast_reconstruct(v) if copy else v + replacements.append(replaced_children[i]) else: - _ast_replace(n, to_replace, n_replaced) - for i, r in replaced.items(): + _ast_replace(n, to_replace) + for i, r in replaced_children.items(): node.children[i] = r - n_replaced = defaultdict(int) - _ast_replace(node, to_replace, n_replaced) - return n_replaced + _ast_replace(node, to_replace) + + return replacements + + +def ast_reconstruct(node): + """Recursively reconstruct ``node``.""" + return Reconstructor().visit(node) def ast_update_ofs(node, ofs, **kwargs): @@ -121,24 +130,15 @@ def ast_update_rank(node, mapper): ``rank``. :arg node: Root AST node - :arg mapper: Describe how to change the rank of a symbol. - :type mapper: a dictionary. Keys can either be Symbols -- in which case - values are interpreted as dimensions to be added to the rank -- or - actual ranks (strings, integers) -- which means rank dimensions are - replaced; for example, if mapper={'i': 'j'} and node='A[i] = B[i]', - node will be transformed into 'A[j] = B[j]' + :arg mapper: Describe how to change the rank of a symbol. For example, + if mapper={'i': 'j'} and node='A[i] = B[i]', then node will be + transformed into 'A[j] = B[j]' """ - symbols = FindInstances(Symbol).visit(node, ret=FindInstances.default_retval())[Symbol] - for s in symbols: - if mapper.get(s.symbol): - # Add a dimension - s.rank = mapper[s.symbol] + s.rank - else: - # Try to replace dimensions - s.rank = tuple([r if r not in mapper else mapper[r] for r in s.rank]) - - return node + retval = FindInstances.default_retval() + FindInstances(Symbol).visit(node, ret=retval) + for s in retval[Symbol]: + s.rank = tuple([r if r not in mapper else mapper[r] for r in s.rank]) def ast_update_id(symbol, name, id): @@ -285,6 +285,7 @@ def loops_analysis(node, key='default', value='default'): :arg value: any value in ['default', 'dim']. If 'dim' is specified, then loop iteration dimensions are used in place of the actual object. """ + symbols_dep = visit(node, info_items=['symbols_dep'])['symbols_dep'] if key == 'default': gen_key = lambda s: s @@ -296,16 +297,16 @@ def loops_analysis(node, key='default', value='default'): raise RuntimeError("Illegal key=%s for loop dependence analysis" % key) if value == 'default': - gen_value = lambda d: set(d) + lda = defaultdict(list) + update = lambda i, dep: i.extend(list(dep)) elif value == 'dim': - gen_value = lambda d: {l.dim for l in d} + lda = defaultdict(set) + update = lambda i, dep: i.update({j.dim for j in dep}) else: raise RuntimeError("Illegal value=%s for loop dependence analysis" % value) - symbols_dep = visit(node, info_items=['symbols_dep'])['symbols_dep'] - lda = defaultdict(set) for s, dep in symbols_dep.items(): - lda[gen_key(s)] |= gen_value(dep) + update(lda[gen_key(s)], dep) return lda @@ -366,10 +367,10 @@ def in_written(node, key='default'): elif key == 'symbol': gen_key = lambda s: s.symbol else: - raise RuntimeError("Illegal key=%s for loop dependence analysis" % key) + raise RuntimeError("Illegal key=%s for in_written" % key) found = [] - writers = FindInstances(Writer).visit(node, ret=FindInstances.default_retval()) + writers = FindInstances(Writer).visit(node) for type, stmts in writers.items(): for stmt in stmts: found.append(gen_key(stmt.lvalue)) @@ -377,6 +378,34 @@ def in_written(node, key='default'): return found +def in_read(node, key='default'): + """ + Return a list of symbols read in ``node``. + + :arg key: any value in ['default', 'urepr', 'symbol']. With 'urepr' and + 'symbol' different instances of the same Symbol are represented by + a single entry in the returned dictionary. + """ + + if key == 'default': + gen_key = lambda s: s + elif key == 'urepr': + gen_key = lambda s: s.urepr + elif key == 'symbol': + gen_key = lambda s: s.symbol + else: + raise RuntimeError("Illegal key=%s for in_read" % key) + + found = [] + writers = FindInstances(Writer).visit(node) + for type, stmts in writers.items(): + for stmt in stmts: + reads = FindInstances(Symbol).visit(stmt.rvalue)[Symbol] + found.extend([gen_key(s) for s in reads]) + + return found + + def count(node, mode='urepr', read_only=False): """Count the occurrences of all variables appearing in ``node``. For example, for the expression: :: @@ -451,6 +480,33 @@ def find_expression(node, type=None, dims=None, in_syms=None, out_syms=None): return exprs +def summands(node): + """ + Return the top-level summands in /node/. + + Examples + ======== + + a + b --> [a, b] + a*b*c --> [a*b*c] + a*b*c + c*d --> [a*b*c, c*d] + (a+b)*c + d --> [(a+b)*c, d] + foo(a) --> [] + """ + + handle = list(zip(*explore_operator(node))) + if not handle: + return [] + operands, parents = handle + if all(isinstance(p, Sum) for p in parents): + return operands + elif all(isinstance(p, Prod) for p in parents): + # Single top-level summand + return [node] + else: + return [] + + ####################################################################### # Functions to manipulate iteration spaces in various representations # ####################################################################### @@ -762,6 +818,15 @@ def reads(self, sym): od_find_next = lambda a, b: a.values()[a.keys().index(b)+1] +def as_urepr(l): + convert = lambda i: i.urepr if isinstance(i, Symbol) else i + try: + converted = [convert(i) for i in l] + except TypeError: + converted = convert(l) + return tuple(converted) + + def is_const_dim(d): return isinstance(d, int) or (isinstance(d, str) and d.isdigit()) @@ -780,6 +845,19 @@ def uniquify(exprs): return OrderedDict([(e.urepr, e) for e in exprs]).values() +def remove_empty_loops(node): + """Remove all empty loops within node.""" + + for nest in visit(node, info_items=['fors'])['fors']: + to_remove = (None, None) + for loop, parent in reversed(nest): + if not loop.body or all(i == to_remove[0] for i in loop.body): + to_remove = (loop, parent) + if all(to_remove): + loop, parent = to_remove + parent.children.remove(loop) + + def postprocess(node): """Rearrange the Nodes in the AST rooted in ``node`` to improve the code quality when unparsing the tree.""" diff --git a/coffee/vectorizer.py b/coffee/vectorizer.py index 0109c5de..2659f1a5 100644 --- a/coffee/vectorizer.py +++ b/coffee/vectorizer.py @@ -77,7 +77,6 @@ def __init__(self, loop_opt, kernel=None): self.loop = loop_opt.loop self.decls = loop_opt.decls self.exprs = loop_opt.exprs - self.expr_graph = loop_opt.expr_graph self.nz_syms = loop_opt.nz_syms def autovectorize(self, p_dim=-1): diff --git a/coffee/visitors/inspectors.py b/coffee/visitors/inspectors.py index 90f8792d..90948295 100644 --- a/coffee/visitors/inspectors.py +++ b/coffee/visitors/inspectors.py @@ -1,6 +1,6 @@ from __future__ import absolute_import, print_function, division from coffee.visitor import Visitor -from coffee.base import READ, WRITE, LOCAL, EXTERNAL, Symbol, EmptyStatement +from coffee.base import READ, WRITE, LOCAL, EXTERNAL, Symbol, EmptyStatement, Writer from collections import defaultdict, OrderedDict, Counter import itertools @@ -195,9 +195,6 @@ def default_retval(cls): """ - def extract_linear_dimensions(self, symbol): - return tuple(i for i in symbol.rank if isinstance(i, str) and not i.isdigit()) - def visit_object(self, o, ret=None, *args, **kwargs): return ret @@ -216,8 +213,8 @@ def visit_Writer(self, o, ret=None, parent=None, *args, **kwargs): if len(opts) < 3: continue if opts[1] == "coffee" and opts[2] == "expression": - # (parent, loop-nest, rank) - ret[o] = (parent, None, self.extract_linear_dimensions(o.lvalue)) + # (parent, loop-nest) + ret[o] = (parent, None) return ret return ret @@ -238,7 +235,7 @@ def visit_For(self, o, ret=None, parent=None, *args, **kwargs): # Add nest structure to new items keys = list(ret.keys())[nval:] for k in keys: - p, nest, rank = ret[k] + p, nest = ret[k] if nest is None: # Statement is directly underneath this loop, so the # loop nest structure is just the current loop @@ -247,7 +244,7 @@ def visit_For(self, o, ret=None, parent=None, *args, **kwargs): # Inside a nested set of loops, so prepend current # loop info to nest structure nest = [me] + nest - ret[k] = p, nest, rank + ret[k] = p, nest return ret @@ -566,6 +563,17 @@ def __init__(self, types, stop_when_found=False, with_parent=False): self.with_parent = with_parent super(FindInstances, self).__init__() + def useless_traversal(self, o): + """ + Return True if the traversal of the sub-tree rooted in o + is useless given that we are searching for nodes of type /t/ + + E.g., Writers cannot be nested. + """ + if isinstance(o, Writer) and self.types == Writer: + return True + return False + def visit_object(self, o, ret=None, *args, **kwargs): return ret @@ -583,6 +591,8 @@ def visit_Node(self, o, ret=None, parent=None, *args, **kwargs): # Don't traverse children if stop-on-found if self.stop_when_found: return ret + if self.useless_traversal(o): + return ret # Not found, or traversing children anyway ops, _ = o.operands() for op in ops: diff --git a/coffee/visitors/utilities.py b/coffee/visitors/utilities.py index 0fe41eac..4018ad3a 100644 --- a/coffee/visitors/utilities.py +++ b/coffee/visitors/utilities.py @@ -6,15 +6,13 @@ from copy import deepcopy from collections import OrderedDict, defaultdict import numpy as np -import networkx as nx from coffee.visitor import Visitor from coffee.base import Sum, Sub, Prod, Div, ArrayInit, SparseArrayInit -import coffee.utils __all__ = ["ReplaceSymbols", "CheckUniqueness", "Uniquify", "Evaluate", - "EstimateFlops", "ProjectExpansion", "SharingGraph"] + "EstimateFlops", "ProjectExpansion", "Reconstructor"] class ReplaceSymbols(Visitor): @@ -124,7 +122,8 @@ def __init__(self, decls, track_zeros): import coffee.vectorizer self.up = coffee.vectorizer.vect_roundup self.down = coffee.vectorizer.vect_rounddown - self.make_itspace = coffee.utils.ItSpace + from coffee.utils import ItSpace + self.make_itspace = ItSpace super(Evaluate, self).__init__() def visit_object(self, o, *args, **kwargs): @@ -302,17 +301,18 @@ def visit_Expr(self, o, parent=None, *args, **kwargs): return ret def visit_Prod(self, o, parent=None, *args, **kwargs): + from coffee.utils import flatten if isinstance(parent, Prod): projection = self.default_retval() for n in o.children: projection.extend(self.visit(n, parent=o, *args, **kwargs)) - return [list(coffee.utils.flatten(projection))] + return [list(flatten(projection))] else: # Only the top level Prod, in a chain of Prods, should do the # tensor product projection = [self.visit(n, parent=o, *args, **kwargs) for n in o.children] product = itertools.product(*projection) - ret = [list(coffee.utils.flatten(i)) for i in product] or projection + ret = [list(flatten(i)) for i in product] or projection return ret def visit_Symbol(self, o, *args, **kwargs): @@ -376,104 +376,16 @@ def visit_Determinant3x3(self, o, *args, **kwargs): return 14 -class SharingGraph(Visitor): - - @classmethod - def default_retval(cls): - return (nx.Graph(), OrderedDict()) +class Reconstructor(Visitor): """ - A sharing graph is a particular graph in which vertices represent symbols - iterating along the expression's linear loops, while an edge between /v1/ - and /v2/ indicates that both /v1/ and /v2/ appear in the same sub-expression, - or would appear in the same sub-expression if expansion were performed. - - Simultaneously, build a mapper from symbols to nodes in the expression. - A symbol /s/ (a vertex in the sharing graph) is mapped to a list of nodes - /[n]/, with /n/ in /[n]/ being the root of a Sum in which /s/ appears in - both children (i.e., the Sum induces sharing). - - :arg expr_info: A :class:`~.MetaExpr` object describing the expression for - which the sharing graph is built. + Recursively reconstruct abstract syntax trees. """ - def __init__(self, expr_info, lda): - self.expr_info = expr_info - self.lda = lda - super(SharingGraph, self).__init__() - - def _update_mapper(self, mapper, loc_syms, pointer=None): - if pointer: - old_pointer = None - for s in set.intersection(*loc_syms): - v = mapper.setdefault(s, [None]) - old_pointer = v[-1] - v[-1] = pointer - for s in set.union(*loc_syms): - if s in mapper and mapper[s][-1] == old_pointer: - mapper[s][-1] = pointer - else: - for s in set.union(*loc_syms): - if s in mapper: - mapper[s].append(None) - - def visit_object(self, o, ret=None, *args, **kwargs): - return self.default_retval() - - def visit_Node(self, o, ret=None, parent=None, *args, **kwargs): - ops, _ = o.operands() - for op in ops: - ret = self.visit(op, ret=ret, parent=o, *args, **kwargs) - return ret - - def visit_Prod(self, o, ret=None, syms=None, parent=None, *args, **kwargs): - if ret is None: - ret = self.default_retval() - if syms is None: - syms = set() - G, mapper = ret - ops, _ = o.operands() - loc_syms = [set() for i in ops] - for i, op in enumerate(ops): - ret = self.visit(op, ret=ret, syms=loc_syms[i], parent=o) - if all(i for i in loc_syms): - self._update_mapper(mapper, loc_syms) - loc_syms = itertools.product(*loc_syms) - loc_syms = [tuple(coffee.utils.flatten(i)) for i in loc_syms] - syms |= set(loc_syms) - G.add_edges_from(loc_syms) - else: - for i in loc_syms: - syms |= i - return ret + def visit_object(self, o): + return o - def visit_Sum(self, o, ret=None, syms=None, parent=None, *args, **kwargs): - if ret is None: - ret = self.default_retval() - if syms is None: - syms = set() - pointer = (o, parent) - _, mapper = ret + def visit_Node(self, o): ops, _ = o.operands() - loc_syms = [set() for i in ops] - for i, op in enumerate(ops): - ret = self.visit(op, ret=ret, syms=loc_syms[i], parent=o) - syms |= loc_syms[i] - self._update_mapper(mapper, loc_syms, pointer) - mapper['topsum'] = pointer - return ret - - visit_Sub = visit_Sum - - def visit_Symbol(self, o, ret=None, syms=None, *args, **kwargs): - if ret is None: - ret = self.default_retval() - G, _ = ret - deps = [d for d in self.lda[o.symbol]] - if syms is not None and any(i in self.expr_info.linear_dims for i in deps): - syms.add((o.urepr,)) - try: - G.node[o.urepr]['occs'] += 1 - except: - G.add_node(o.urepr, occs=1) - return ret + reconstructed_operands = [self.visit(op) for op in ops] + return o.reconstruct(*reconstructed_operands) diff --git a/setup.cfg b/setup.cfg index 796c344b..9478c194 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [flake8] ignore = - E501,F403,F405,E226,E265,E731,E402,E266, + E501,F403,F405,E226,E265,E731,E402,E266,F999, FI14,FI54, FI50,FI51,FI53 exclude = .git,,__pycache__,build,dist,doc/source/conf.py diff --git a/tests/test_visitors.py b/tests/test_visitors.py index 5c615f0f..d1e30400 100644 --- a/tests/test_visitors.py +++ b/tests/test_visitors.py @@ -375,9 +375,8 @@ def test_find_coffee_expressions_single(): val = ret[assign] - assert len(val) == 3 + assert len(val) == 2 - assert val[2] == () assert val[1] == [(tree.children[0], tree)] assert val[0] == tree.children[0].children[0] @@ -405,13 +404,11 @@ def test_find_coffee_expressions_nested(): assert val1[0] == tree.children[0].children[0] assert val1[1] == [(tree.children[0], tree)] - assert val1[2] == () assert val2[0] == tree.children[0].body[0].children[0].children[0] assert val2[1] == [(tree.children[0], tree), (tree.children[0].body[0].children[0], tree.children[0].body[0])] - assert val2[2] == ("i", ) def test_symbol_modes_simple():