diff --git a/coffee/cse.py b/coffee/cse.py deleted file mode 100644 index dd79c3f2..00000000 --- a/coffee/cse.py +++ /dev/null @@ -1,508 +0,0 @@ -# This file is part of COFFEE -# -# COFFEE is Copyright (c) 2016, Imperial College London. -# Please see the AUTHORS file in the main source directory for -# a full list of copyright holders. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * The name of Imperial College London or that of other -# contributors may not be used to endorse or promote products -# derived from this software without specific prior written -# permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS -# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS -# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -# OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import absolute_import, print_function, division -from six import iterkeys, iteritems, itervalues -from six.moves import zip - -import operator - -from .base import * -from .utils import * -from coffee.visitors import EstimateFlops -from .expression import MetaExpr -from .logger import log, COST_MODEL -from functools import reduce - - -class Temporary(object): - - """A Temporary stores useful information for a statement (e.g., an Assign - or an AugmentedAssig) that computes a temporary variable; that is, a variable - that is read in more than one place.""" - - def __init__(self, node, main_loop, nest, linear_reads_costs=None): - self.level = -1 - self.pushed = False - self.readby = [] - - self.node = node - self.main_loop = main_loop - self.nest = nest - self.linear_reads_costs = linear_reads_costs or OrderedDict() - self.flops = EstimateFlops().visit(node) - - @property - def name(self): - return self.symbol.symbol if self.symbol else None - - @property - def rank(self): - return self.symbol.rank if self.symbol else None - - @property - def linearity_degree(self): - return len(self.main_linear_loops) - - @property - def symbol(self): - if isinstance(self.node, Writer): - return self.node.lvalue - elif isinstance(self.node, Symbol): - return self.node - else: - return None - - @property - def expr(self): - if isinstance(self.node, Writer): - return self.node.rvalue - else: - return None - - @property - def urepr(self): - return self.symbol.urepr - - @property - def reads(self): - return Find(Symbol).visit(self.expr)[Symbol] if self.expr else [] - - @property - def linear_reads(self): - return list(iterkeys(self.linear_reads_costs)) if self.linear_reads_costs else [] - - @property - def loops(self): - return list(zip(*self.nest))[0] - - @property - def main_linear_loops(self): - return [l for l in self.main_loops if l.is_linear] - - @property - def main_linear_nest(self): - return [(l, p) for l, p in self.main_nest if l in self.linear_loops] - - @property - def main_loops(self): - index = self.loops.index(self.main_loop) - return [l for l in self.loops[:index + 1]] - - @property - def main_nest(self): - return [(l, p) for l, p in self.nest if l in self.main_loops] - - @property - def flops_projection(self): - # #muls + #sums - nmuls = len(self.linear_reads) - return (nmuls) + (nmuls - 1) - - @property - def is_ssa(self): - return self.symbol not in self.readby - - @property - def is_static_init(self): - return isinstance(self.expr, ArrayInit) - - @property - def is_increment(self): - return isinstance(self.node, Incr) - - @property - def reductions(self): - return [l for l in self.main_loops if l.dim not in self.rank] - - @property - def nreductions(self): - return len(self.reductions) - - def niters(self, mode='all', handle=None): - assert mode in ['all', 'outer', 'nonlinear', 'in', 'out'] - handle = handle or [] - limit = self.loops.index(self.main_loop) - loops = self.loops[:limit + 1] - if mode == 'all': - sizes = [l.size for l in loops] - elif mode == 'outer': - sizes = [l.size for l in loops if l is not self.main_loop] - elif mode == 'nonlinear': - sizes = [l.size for l in loops if not l.is_linear] - elif mode == 'in': - sizes = [l.size for l in loops if l.dim in handle] - else: - sizes = [l.size for l in loops if l.dim not in handle] - return reduce(operator.mul, sizes, 1) - - def depends(self, others): - """Return True if ``self`` reads a temporary or is read by a temporary - that appears in the iterator ``others``, False otherwise.""" - dependencies = self.linear_reads + self.reads - for t in others: - if any(s.urepr == t.urepr for s in dependencies): - return True - return False - - def reconstruct(self): - temporary = Temporary(self.node, self.main_loop, self.nest, - OrderedDict(self.linear_reads_costs)) - temporary.level = self.level - temporary.readby = list(self.readby) - return temporary - - def __str__(self): - return "%s: level=%d, flops/iter=%d, linear_reads=[%s], isread=[%s]" % \ - (self.symbol, self.level, self.flops, - ", ".join([str(i) for i in self.linear_reads]), - ", ".join([str(i) for i in self.readby])) - - -class CSEUnpicker(object): - - """Analyze loops in which some temporary variables are computed and, applying - a cost model, decides whether to leave a temporary intact or inline it for - creating factorization and code motion opportunities. - - The cost model exploits one particular property of loops, namely linearity in - symbols (further information concerning loop linearity is available in the module - ``expression.py``).""" - - def __init__(self, exprs, header, hoisted): - self.exprs = exprs - self.header = header - self.hoisted = hoisted - - @property - def type(self): - return list(itervalues(self.exprs))[0].type - - @property - def linear_dims(self): - return list(itervalues(self.exprs))[0].linear_dims - - def _push_temporaries(self, temporaries, trace, global_trace, ra, decls): - - def is_pushable(temporary, temporaries): - # To be pushable ... - if not temporary.is_ssa: - # ... must be written only once - return False - if not temporary.readby: - # ... must actually be read by some other temporaries (the output - # variables are not) - return False - if temporary.is_static_init: - # ... its rvalue must not be an array initializer - return False - if temporary.depends(temporaries): - # ... it cannot depend on other temporaries in the same level - return False - pushed_in = [global_trace.get(rb.urepr) for rb in temporary.readby] - pushed_in = set(rb.main_loop.children[0] for rb in pushed_in if rb) - reads = [s for s in temporary.reads if not s.is_number] - for s in reads: - # ... all the read temporaries must be accessible in the loops in which - # they will be pushed - if s.urepr in global_trace and global_trace[s.urepr].pushed: - continue - if s.symbol not in decls: - continue - if any(l not in ra[decls[s.symbol]] for l in pushed_in): - return False - return True - - to_replace, modified_temporaries = {}, OrderedDict() - for t in temporaries: - # Track temporaries to be pushed from /level-1/ into the later /level/s - if not is_pushable(t, temporaries): - continue - to_replace[t.symbol] = t.expr or t.symbol - for rb in t.readby: - modified_temporaries[rb.urepr] = trace.get(rb.urepr, - global_trace[rb.urepr]) - # The temporary is going to be pushed, so we can remove it as long as - # it is not needed somewhere else - if t.node in t.main_loop.body and\ - all(rb.urepr in global_trace for rb in t.readby): - global_trace[t.urepr].pushed = True - t.main_loop.body.remove(t.node) - - # Transform the AST (note: node replacement must happen in the order - # in which the temporaries have been encountered) - modified_temporaries = sorted(modified_temporaries.values(), - key=lambda t: list(iterkeys(global_trace)).index(t.urepr)) - for t in modified_temporaries: - ast_replace(t.node, to_replace, copy=True) - replaced = [t.urepr for t in to_replace.keys()] - - # Update the temporaries - for t in modified_temporaries: - for r, c in list(iteritems(t.linear_reads_costs)): - if r.urepr in replaced: - t.linear_reads_costs.pop(r) - r_linear_reads_costs = global_trace[r.urepr].linear_reads_costs - for p, p_c in r_linear_reads_costs.items() or [(r, 0)]: - t.linear_reads_costs[p] = c + p_c - - def _transform_temporaries(self, temporaries, decls): - from .rewriter import ExpressionRewriter - - # Never attempt to transform the main expression - temporaries = [t for t in temporaries if t.node not in self.exprs] - - lda = loops_analysis(self.header, key='symbol', value='dim') - - # Expand + Factorize - rewriters = OrderedDict() - for t in temporaries: - expr_info = MetaExpr(self.type, t.main_loop.block, t.main_nest) - ew = ExpressionRewriter(t.node, expr_info, self.header, self.hoisted) - ew.replacediv() - ew.expand(mode='all', lda=lda) - ew.reassociate(lambda i: all(r != t.main_loop.dim for r in lda[i.symbol])) - ew.factorize(mode='adhoc', adhoc={i.urepr: [] for i in t.linear_reads}, lda=lda) - rewriters[t] = ew - - lda = loops_analysis(self.header, value='dim') - - # Code motion - for t, ew in rewriters.items(): - ew.licm(mode='only_outlinear', lda=lda, global_cse=True) - if t.linearity_degree > 1: - ew.licm(mode='only_linear', lda=lda) - - # Keep track of new declarations (recomputation might otherwise be too expensive) - decls.update(OrderedDict([(k, v.decl) for k, v in self.hoisted.items()])) - - def _analyze_expr(self, expr, loop, lda, decls): - reads = Find(Symbol).visit(expr)[Symbol] - reads = [s for s in reads if s.symbol in decls] - syms = [s for s in reads if any(d in loop.dim for d in lda[s])] - - linear_reads_costs = OrderedDict() - - def wrapper(node, found=0): - if isinstance(node, Symbol): - if node in syms: - linear_reads_costs.setdefault(node, 0) - linear_reads_costs[node] += found - return - elif isinstance(node, (EmptyStatement, ArrayInit)): - return - elif isinstance(node, (Prod, Div)): - found += 1 - operands = list(zip(*explore_operator(node)))[0] - for o in operands: - wrapper(o, found) - wrapper(expr) - - return reads, linear_reads_costs - - def _analyze_loop(self, loop, nest, lda, global_trace, decls): - linear_dims = [l.dim for l, _ in nest if l.is_linear] - - trace = OrderedDict() - for node in loop.body: - if not isinstance(node, Writer): - not_ssa = [trace[w] for w in in_written(node, key='urepr') if w in trace] - for t in not_ssa: - t.readby.append(t.symbol) - continue - reads, linear_reads_costs = self._analyze_expr(node.rvalue, loop, lda, decls) - affected = [s for s in reads if any(i in linear_dims for i in lda[s])] - for s in affected: - if s.urepr in global_trace: - temporary = global_trace[s.urepr] - temporary.readby.append(node.lvalue) - temporary = temporary.reconstruct() - temporary.level = -1 - trace[s.urepr] = temporary - else: - temporary = trace.setdefault(s.urepr, Temporary(s, loop, nest)) - temporary.readby.append(node.lvalue) - new_temporary = Temporary(node, loop, nest, linear_reads_costs) - new_temporary.level = max([trace[s.urepr].level for s - in new_temporary.linear_reads] or [-2]) + 1 - trace[node.lvalue.urepr] = new_temporary - - return trace - - def _group_by_level(self, trace): - levels = defaultdict(list) - - for temporary in trace.values(): - levels[temporary.level].append(temporary) - return levels - - def _cost_cse(self, levels, bounds=None): - if bounds is not None: - lb, up = bounds[0], bounds[1] + 1 - levels = {i: levels[i] for i in range(lb, up)} - cost = 0 - for level, temporaries in levels.items(): - cost += sum(t.flops*t.niters('all') for t in temporaries) - return cost - - def _cost_fact(self, trace, levels, lda, bounds): - # Check parameters - assert len(bounds) == 2 and bounds[1] >= bounds[0] - assert bounds[0] in levels.keys() and bounds[1] in levels.keys() - - # Determine current costs of individual loop regions - input_cost = self._cost_cse(levels, (min(levels.keys()), max(levels.keys()))) - uptolevel_cost, post_cse_cost = input_cost, input_cost - level_inloop_cost, total_outloop_cost = 0, 0 - - # We are going to modify a copy of the temporaries dict - new_trace = OrderedDict() - for s, t in trace.items(): - new_trace[s] = t.reconstruct() - - # Cost induced by the untransformed temporaries - pre_cse_cost = self._cost_cse(levels, (min(levels.keys()), bounds[0])) - - best = (bounds[0], bounds[0], uptolevel_cost) - fact_levels = {k: v for k, v in levels.items() if k > bounds[0] and k <= bounds[1]} - for level, temporaries in sorted(fact_levels.items(), key=lambda i_j: i_j[0]): - level_inloop_cost = 0 - for t in temporaries: - # Compute the cost induced by /t/ in the outer loops after fact+licm - t_outloop_cost, linear_reads = 0, [] - for read, cost in t.linear_reads_costs.items(): - traced = new_trace.get(read.urepr) - if traced and traced.level >= bounds[0]: - handle = traced.linear_reads or [read] - if cost: - for i in handle: - # One prod in the closest linear loop - t_outloop_cost += t.niters('out', lda[i]) - # The rest falls outside of the linear loops - t_outloop_cost += (cost - 1)*t.niters('nonlinear') - else: - handle = [read] - linear_reads.extend(handle) - factors = list(itervalues({as_urepr(i): i for i in linear_reads})) - # Take into account the increased number of sums (due to fact) - hoist_region = set.union(*[lda[i] for i in factors]) - niters = t.niters('out', hoist_region) - t_outloop_cost += (len(linear_reads) - len(factors))*niters - total_outloop_cost += t_outloop_cost - - # Compute the cost induced by /t/ in the main loop after fact+licm - # We end up creating n prods and n -1 sums - t_inloop_cost = 2*len(factors) - 1 - level_inloop_cost += t_inloop_cost*t.niters('all') - - # Take into account any hoistable reductions - if t.is_increment: - for i in factors: - handle = [l.dim for l in t.reductions if l.dim not in i.rank] - level_inloop_cost -= t.niters('all') - t.niters('out', handle) - - # Keep the trace up-to-date - linear_reads_costs = {i: 1 for i in factors} - new_trace[t.urepr].linear_reads_costs = linear_reads_costs - - # Some temporaries within levels < /level/ might also appear in - # subsequent loops or levels beyond /level/, so they still contribute - # to the operation count - for t in list(flatten([levels[j] for j in range(level)])): - if any(rb.urepr not in new_trace for rb in t.readby) or \ - any(new_trace[rb.urepr].level > level for rb in t.readby): - # Note: condition 1) is basically saying "if I'm read by - # a temporary that is not in this loop's trace, then I must - # be read in some other loops". - level_inloop_cost += \ - new_trace[t.urepr].flops_projection*t.niters('all') - - post_cse_cost = self._cost_cse(fact_levels, (level + 1, bounds[1])) - - # Compute the total cost - total_inloop_cost = pre_cse_cost + level_inloop_cost + post_cse_cost - uptolevel_cost = total_outloop_cost + total_inloop_cost - - # Update the best alternative - if uptolevel_cost < best[2]: - best = (bounds[0], level, uptolevel_cost) - - log('[CSE]: unpicking between [%d, %d]:' % (bounds[0], level), COST_MODEL) - log(' flops: %d -> %d (hoist=%d, preCSE=%d, fact=%d, postCSE=%d)' % - (input_cost, uptolevel_cost, total_outloop_cost, pre_cse_cost, - level_inloop_cost, post_cse_cost), COST_MODEL) - - return best - - def unpick(self): - # Collect all necessary info - info = visit(self.header, info_items=['decls', 'fors']) - decls, fors = info['decls'], info['fors'] - lda = loops_analysis(self.header, value='dim') - - # Collect all loops to be analyzed - nests = OrderedDict() - for nest in fors: - for loop, parent in nest: - if loop.is_linear: - nests[loop] = nest - - # Analyze loops - global_trace = OrderedDict() - mapper = OrderedDict() - for loop, nest in nests.items(): - trace = self._analyze_loop(loop, nest, lda, global_trace, decls) - if trace: - mapper[loop] = trace - global_trace.update(trace) - - for loop, trace in mapper.items(): - # Compute the best cost alternative - levels = self._group_by_level(trace) - min_level, max_level = min(levels.keys()), max(levels.keys()) - current_cost = self._cost_cse(levels, (min_level, max_level)) - global_best = (min_level, min_level, current_cost) - for i in sorted(levels.keys()): - local_best = self._cost_fact(trace, levels, lda, (i, max_level)) - if local_best[2] < global_best[2]: - global_best = local_best - - log("-- Best: [%d, %d] (cost=%d) --" % global_best, COST_MODEL) - - # Transform the loop - for i in range(global_best[0] + 1, global_best[1] + 1): - ra = reachability_analysis(self.header) - self._push_temporaries(levels[i-1], trace, global_trace, ra, decls) - self._transform_temporaries(levels[i], decls) - - cleanup(self.header) diff --git a/coffee/expander.py b/coffee/expander.py deleted file mode 100644 index 6d2e312a..00000000 --- a/coffee/expander.py +++ /dev/null @@ -1,118 +0,0 @@ -# This file is part of COFFEE -# -# COFFEE is Copyright (c) 2016, Imperial College London. -# Please see the AUTHORS file in the main source directory for -# a full list of copyright holders. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * The name of Imperial College London or that of other -# contributors may not be used to endorse or promote products -# derived from this software without specific prior written -# permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS -# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS -# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -# OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import absolute_import, print_function, division - -import itertools - -from .base import * -from .utils import * -from .exceptions import UnexpectedNode - - -class Expander(object): - - """Expand the products in an expression according to a set of rules. For a - comprehensive list of possible rules, refer to the documentation of the - corresponding wrapper function ``expand`` in ``ExpressionRewriter``.""" - - # Constants used by the /expand/ method to charaterize sub-expressions: - GROUP = 0 # Expression /will/ not trigger expansion - EXPAND = 1 # Expression /could/ be expanded - - def __init__(self, stmt): - self.stmt = stmt - - def _build(self, exp, grp, expansions): - """Create a node for the expansion and keep track of it.""" - expansion = Prod(exp, dcopy(grp)) - # Track the new expansion - expansions.append(expansion) - # Untrack any expansions occured in children nodes - if grp in expansions: - expansions.remove(grp) - return expansion - - def _expand(self, node, parent, expansions): - if isinstance(node, Symbol): - return ([node], self.EXPAND) if self.should_expand(node) \ - else ([node], self.GROUP) - - elif isinstance(node, (Div, Ternary, FunCall)): - # Try to expand /within/ the children, but then return saying "I'm not - # expandable any further" - for n in node.children: - self._expand(n, node, expansions) - return ([node], self.GROUP) - - elif isinstance(node, Prod): - l_exps, l_type = self._expand(node.left, node, expansions) - r_exps, r_type = self._expand(node.right, node, expansions) - if l_type == self.GROUP and r_type == self.GROUP: - return ([node], self.GROUP) - # At least one child is expandable (marked as EXPAND), whereas the - # other could either be expandable as well or groupable (marked - # as GROUP): so we can perform the expansion - groupable = l_exps if l_type == self.GROUP else r_exps - expandable = r_exps if l_type == self.GROUP else l_exps - to_replace = OrderedDict() - for exp, grp in itertools.product(expandable, groupable): - expansion = self._build(exp, grp, expansions) - to_replace.setdefault(exp, []).append(expansion) - ast_replace(node, {k: ast_make_expr(Sum, v) for k, v in to_replace.items()}, - copy=False, mode='symbol') - # Update the parent node, since an expression has just been expanded - expanded = node.right if l_type == self.GROUP else node.left - parent.children[parent.children.index(node)] = expanded - return (list(flatten(to_replace.values())) or [expanded], self.EXPAND) - - elif isinstance(node, (Sum, Sub)): - l_exps, l_type = self._expand(node.left, node, expansions) - r_exps, r_type = self._expand(node.right, node, expansions) - if l_type == self.EXPAND and r_type == self.EXPAND and isinstance(node, Sum): - return (l_exps + r_exps, self.EXPAND) - elif l_type == self.EXPAND and r_type == self.EXPAND and isinstance(node, Sub): - return (l_exps + [Neg(r) for r in r_exps], self.EXPAND) - else: - return ([node], self.GROUP) - - else: - raise UnexpectedNode("Expansion: %s" % str(node)) - - def expand(self, should_expand, **kwargs): - expressions = kwargs.get('subexprs', [(self.stmt.rvalue, self.stmt)]) - - self.should_expand = should_expand - - for node, parent in expressions: - self._expand(node, parent, []) diff --git a/coffee/factorizer.py b/coffee/factorizer.py deleted file mode 100644 index 48c2d7a4..00000000 --- a/coffee/factorizer.py +++ /dev/null @@ -1,236 +0,0 @@ -# This file is part of COFFEE -# -# COFFEE is Copyright (c) 2016, Imperial College London. -# Please see the AUTHORS file in the main source directory for -# a full list of copyright holders. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * The name of Imperial College London or that of other -# contributors may not be used to endorse or promote products -# derived from this software without specific prior written -# permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS -# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS -# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -# OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import absolute_import, print_function, division - -import operator - -from .base import * -from .utils import * -from functools import reduce - - -class Term(object): - """A Term represents a product between 'operands' and 'factors'. In a - product /a*(b+c)/, /a/ is the 'operand', while /b/ and /c/ are the 'factors'. - The symbol /+/ is the 'op' of the Term. - """ - - def __init__(self, operands, factors=None, op=None): - self.operands = operands - self.factors = factors or [] - self.op = op - - @property - def operands_ast(self): - return ast_make_expr(Prod, self.operands) - - @property - def factors_ast(self): - return ast_make_expr(self.op, self.factors) - - @property - def generate_ast(self): - if len(self.factors) == 0: - return self.operands_ast - elif len(self.operands) == 0: - return self.factors_ast - elif len(self.factors) == 1 and \ - all(isinstance(i, Symbol) and i.symbol == 1.0 for i in self.factors): - return self.operands_ast - else: - return Prod(self.operands_ast, self.factors_ast) - - def add_operands(self, operands): - for o in operands: - if o not in self.operands: - self.operands.append(o) - - def remove_operands(self, operands): - for o in operands: - if o in self.operands: - self.operands.remove(o) - - def add_factors(self, factors): - for f in factors: - if f not in self.factors: - self.factors.append(f) - - def remove_factors(self, factors): - for f in factors: - if f in self.factors: - self.factors.remove(f) - - @staticmethod - def process(symbols, should_factorize, op=None): - operands = [s for s in symbols if should_factorize(s)] - factors = [s for s in symbols if not should_factorize(s)] - return Term(operands, factors, op) - - -class Factorizer(object): - - """Factorize terms in an expression according to a set of rules. For a - comprehensive list of possible rules, refer to the documentation of the - corresponding wrapper function ``factorize`` in ``ExpressionRewriter``.""" - - def __init__(self, stmt): - self.stmt = stmt - - def _simplify_sum(self, terms): - unique_terms = OrderedDict() - for t in terms: - unique_terms.setdefault(str(t.generate_ast), list()).append(t) - - for t_repr, t_list in unique_terms.items(): - occurrences = len(t_list) - unique_terms[t_repr] = t_list[0] - if occurrences > 1: - unique_terms[t_repr].add_factors([Symbol(occurrences)]) - - terms[:] = unique_terms.values() - - def _heuristic_collection(self, terms): - if not self.heuristic or any(t.operands for t in terms): - return - tracker = OrderedDict() - for t in terms: - symbols = [s for s in t.factors if isinstance(s, Symbol)] - for s in symbols: - tracker.setdefault(s.urepr, []).append(t) - reverse_tracker = OrderedDict() - for s, ts in tracker.items(): - reverse_tracker.setdefault(tuple(ts), []).append(s) - # 1) At least one symbol appearing in all terms: use that as operands ... - operands = [(ts, s) for ts, s in reverse_tracker.items() if ts == tuple(terms)] - # 2) ... Or simply pick operands greedily - if not operands: - handled = set() - for ts, s in reverse_tracker.items(): - if len(ts) > 1 and all(t not in handled for t in ts): - operands.append((ts, s)) - handled |= set(ts) - for ts, s in operands: - for t in ts: - new_operands = [i for i in t.factors if - isinstance(i, Symbol) and i.urepr in s] - t.remove_factors(new_operands) - t.add_operands(new_operands) - - def _premultiply_symbols(self, symbols): - floats = [s for s in symbols if isinstance(s.symbol, (int, float))] - if len(floats) > 1: - other_symbols = [s for s in symbols if s not in floats] - prem = reduce(operator.mul, [s.symbol for s in floats], 1.0) - prem = [Symbol(prem)] if prem not in [1, 1.0] else [] - return prem + other_symbols - else: - return symbols - - def _filter(self, factorizable_term): - o = factorizable_term.operands_ast - grp = self.adhoc.get(o.urepr, []) if isinstance(o, Symbol) else [] - if not grp: - return False - for f in factorizable_term.factors: - symbols = Find(Symbol).visit(f)[Symbol] - if any(s.urepr in grp for s in symbols): - return False - return True - - def _factorize(self, node, parent): - if isinstance(node, Symbol): - return Term.process([node], self.should_factorize) - - elif isinstance(node, (FunCall, Div)): - # Try to factorize /within/ the children, but then return saying - # "I'm not factorizable any further" - for n in node.children: - self._factorize(n, node) - return Term([], [node]) - - elif isinstance(node, Prod): - children = explore_operator(node) - symbols = [n for n, _ in children if isinstance(n, Symbol)] - other_nodes = [(n, p) for n, p in children if n not in symbols] - symbols = self._premultiply_symbols(symbols) - factorized = Term.process(symbols, self.should_factorize, Prod) - terms = [self._factorize(n, p) for n, p in other_nodes] - for t in terms: - factorized.add_operands(t.operands) - factorized.add_factors(t.factors) - return factorized - - # The fundamental case is when /node/ is a Sum (or Sub, equivalently). - # Here, we try to factorize the terms composing the operation - elif isinstance(node, (Sum, Sub)): - children = explore_operator(node) - # First try to factorize within /node/'s children - terms = [self._factorize(n, p) for n, p in children] - # Check if it's possible to aggregate operations - # Example: replace (a*b)+(a*b) with 2*(a*b) - self._simplify_sum(terms) - # No global factorization rule is used, so just try to maximize - # factorization within /this/ Sum/Sub - self._heuristic_collection(terms) - # Finally try to factorize some of the operands composing the operation - factorized = OrderedDict() - for t in terms: - operand = [t.operands_ast] if t.operands else [] - factor = [t.factors_ast] if t.factors else [Symbol(1.0)] - factorizable_term = Term(operand, factor, node.__class__) - if self._filter(factorizable_term): - # Skip - factorized[t] = t - else: - # Do factorize - _t = factorized.setdefault(str(t.operands_ast), factorizable_term) - _t.add_factors(factor) - factorized = [t.generate_ast for t in factorized.values()] - factorized = ast_make_expr(Sum, factorized) - parent.children[parent.children.index(node)] = factorized - return Term([], [factorized]) - - else: - return Term([], [node]) - - def factorize(self, should_factorize, **kwargs): - expressions = kwargs.get('subexprs', [(self.stmt.rvalue, self.stmt)]) - adhoc = kwargs.get('adhoc', {}) - - self.should_factorize = should_factorize - self.adhoc = adhoc if any(v for v in adhoc.values()) else {} - self.heuristic = kwargs.get('heuristic', False) - - for node, parent in expressions: - self._factorize(node, parent) diff --git a/coffee/hoister.py b/coffee/hoister.py deleted file mode 100644 index 3c15dfe3..00000000 --- a/coffee/hoister.py +++ /dev/null @@ -1,351 +0,0 @@ -# This file is part of COFFEE -# -# COFFEE is Copyright (c) 2016, Imperial College London. -# Please see the AUTHORS file in the main source directory for -# a full list of copyright holders. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * The name of Imperial College London or that of other -# contributors may not be used to endorse or promote products -# derived from this software without specific prior written -# permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS -# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS -# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -# OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import absolute_import, print_function, division -from six.moves import zip - -from .base import * -from .utils import * - - -class Extractor(object): - - EXT = 0 # expression marker: extract - STOP = 1 # expression marker: do not extract - - def __init__(self, stmt, expr_info, should_extract): - self.stmt = stmt - self.expr_info = expr_info - self.should_extract = should_extract - - def _apply_cse(self): - # Find common sub-expressions heuristically looking at binary terminal - # operations (i.e., a terminal has two Symbols as children). This may - # induce more sweeps of extraction to find all common sub-expressions, - # but at least it keeps the algorithm simple and probably more effective - finder = Find(Symbol, with_parent=True) - for dep, subexprs in self.extracted.items(): - cs = OrderedDict() - values = [finder.visit(e)[Symbol] for e in subexprs] - binexprs = list(zip(*flatten(values)))[1] - binexprs = [b for b in binexprs if binexprs.count(b) > 1] - for b in binexprs: - t = cs.setdefault(b.urepr, []) - if b not in t: - t.append(b) - cs = [v for k, v in cs.items() if len(v) > 1] - if cs: - self.extracted[dep] = list(flatten(cs)) - - def _try(self, node, dep): - if isinstance(node, Symbol): - return False - should_extract = self.should_extract(dep) - if should_extract or self._look_ahead: - dep = sorted(dep, key=lambda i: self.expr_info.dims.index(i)) - self.extracted.setdefault(tuple(dep), []).append(node) - return should_extract - - def _visit(self, node): - if isinstance(node, Symbol): - return (self._lda[node], self.EXT) - - elif isinstance(node, (FunCall, Ternary)): - arg_deps = [self._visit(n) for n in node.children] - dep = set(flatten([dep for dep, _ in arg_deps])) - info = self.EXT if all(i == self.EXT for _, i in arg_deps) else self.STOP - return (dep, info) - - else: - retval = [(n,) + self._visit(n) for n in node.children] - dep = set.union(*[d for _, d, _ in retval]) - dep = {d for d in dep if d in self.expr_info.dims} - if self.should_extract(dep) or self._look_ahead: - # Still a chance of finding a bigger expression - return (dep, self.EXT) - else: - for n, n_dep, n_info in retval: - if n_info == self.EXT and not isinstance(n, Symbol): - k = sorted(n_dep, key=lambda i: self.expr_info.dims.index(i)) - self.extracted.setdefault(tuple(k), []).append(n) - return (dep, self.STOP) - - def extract(self, look_ahead, lda, with_cse=False): - """Extract invariant subexpressions from /self.expr/.""" - self._lda = lda - self._look_ahead = look_ahead - self.extracted = OrderedDict() - - self._visit(self.stmt.rvalue) - if with_cse: - self._apply_cse() - - del self._lda - del self._look_ahead - - return self.extracted - - -class Hoister(object): - - # Temporary variables template - _template = "ct%d" - - def __init__(self, stmt, expr_info, header, hoisted): - """Initialize the Hoister.""" - self.stmt = stmt - self.expr_info = expr_info - self.header = header - self.hoisted = hoisted - - def _filter(self, dep, subexprs, make_unique=True, sharing=None): - """Filter hoistable subexpressions.""" - if make_unique: - # Uniquify expressions - subexprs = uniquify(subexprs) - - if sharing: - # Partition expressions such that expressions sharing the same - # set of symbols are in the same partition - if dep == self.expr_info.dims: - return [] - sharing = [str(s) for s in sharing] - partitions = defaultdict(list) - for e in subexprs: - symbols = tuple(set(str(s) for s in Find(Symbol).visit(e)[Symbol] - if str(s) in sharing)) - partitions[symbols].append(e) - for shared, partition in partitions.items(): - if len(partition) > len(shared): - subexprs = [e for e in subexprs if e not in partition] - - return subexprs - - def _is_hoistable(self, subexprs, loop): - """Return True if the sub-expressions provided in ``subexprs`` are - hoistable outside of ``loop``, False otherwise.""" - written = in_written(loop, 'symbol') - reads = Find.default_retval() - for e in subexprs: - Find(Symbol).visit(e, ret=reads) - reads = [s.symbol for s in reads[Symbol]] - return set.isdisjoint(set(reads), set(written)) - - def _locate(self, dep, subexprs, with_promotion=False): - # Start assuming no "real" hoisting can take place - # E.g.: for i {a[i]*(t1 + t2);} --> for i {t3 = t1 + t2; a[i]*t3;} - place, offset = self.expr_info.innermost_loop.block, self.stmt - - if with_promotion: - # Hoist outside a loop even though this doesn't result in any - # operation count reduction - should_jump = lambda l: True - else: - # "Standard" code motion case, i.e. moving /subexprs/ as far as - # possible in the loop nest such that dependencies are honored - should_jump = lambda l: l.dim not in dep - - loops = list(reversed(self.expr_info.loops)) - candidates = [l.block for l in loops[1:]] + [self.header] - - for loop, candidate in zip(loops, candidates): - if not self._is_hoistable(subexprs, loop): - break - if should_jump(loop): - place, offset = candidate, loop - - # Determine how much extra memory and whether clone loops are needed - jumped = loops[:candidates.index(place) + 1] - clone = tuple(l for l in reversed(jumped) if l.dim in dep) - - return place, offset, clone - - def extract(self, should_extract, **kwargs): - """Return a dictionary of hoistable subexpressions.""" - lda = kwargs.get('lda') or loops_analysis(self.header, value='dim') - extractor = Extractor(self.stmt, self.expr_info, should_extract) - return extractor.extract(True, lda) - - def licm(self, should_extract, **kwargs): - """Perform generalized loop-invariant code motion.""" - max_sharing = kwargs.get('max_sharing', False) - with_promotion = kwargs.get('with_promotion', False) - iterative = kwargs.get('iterative', True) - lda = kwargs.get('lda') or loops_analysis(self.header, value='dim') - global_cse = kwargs.get('global_cse', False) - - extractor = Extractor(self.stmt, self.expr_info, should_extract) - - extracted = True - while extracted: - extracted = extractor.extract(False, lda, global_cse) - for dep, subexprs in extracted.items(): - # 1) Filter subexpressions that will be hoisted - sharing = [] - if max_sharing: - sharing = uniquify([s for s, d in lda.items() if d == dep]) - subexprs = self._filter(dep, subexprs, sharing=sharing) - if not subexprs: - continue - - # 2) Determine the outermost loop where invariant expressions - # can be hoisted without breaking data dependencies. - place, offset, clone = self._locate(dep, subexprs, with_promotion) - - loop_size = tuple(l.size for l in clone) - loop_dim = tuple(l.dim for l in clone) - - # 3) Create the required new AST nodes - symbols, decls, stmts = [], [], [] - for e in subexprs: - already_hoisted = False - if global_cse and self.hoisted.get_symbol(e): - name = self.hoisted.get_symbol(e) - decl = self.hoisted[name].decl - if decl in place.children and \ - place.children.index(decl) < place.children.index(offset): - already_hoisted = True - if not already_hoisted: - name = self._template % (len(self.hoisted) + len(stmts)) - stmts.append(Assign(Symbol(name, loop_dim), dcopy(e))) - decls.append(Decl(self.expr_info.type, - Symbol(name, loop_size), - scope=LOCAL)) - symbols.append(Symbol(name, loop_dim)) - - # 4) Replace invariant sub-expressions with temporaries - replacements = ast_replace(self.stmt, dict(zip(subexprs, symbols))) - - # 5) Modify the AST adding the hoisted expressions - if clone: - outer_clone = ast_make_for(stmts, clone[-1]) - for l in reversed(clone[:-1]): - outer_clone = ast_make_for([outer_clone], l) - code = decls + [outer_clone] - clone = outer_clone - else: - code = decls + stmts - clone = None - offset = place.children.index(offset) - place.children[offset:offset] = code - - # 6) Track hoisted symbols and data dependencies - for i, j in zip(stmts, decls): - name = j.lvalue.symbol - self.hoisted[name] = (i, j, clone, place) - lda.update({s: set(dep) for s in replacements}) - - if not iterative: - break - - def trim(self, candidate, **kwargs): - """ - Remove unnecessary reduction loops from the expression loop nest. - Sometimes, reduction loops can be factored out in outer loops, thus - reducing the operation count, without breaking data dependencies. - """ - # Rule out unsafe cases - if not is_perfect_loop(self.expr_info.innermost_loop): - return - - # Find out all reducible symbols - lda = kwargs.get('lda') or loops_analysis(self.header) - reducible, other = [], [] - for i in summands(self.stmt.rvalue): - symbols = Find(Symbol).visit(i)[Symbol] - unavoidable = set.intersection(*[set(lda[s]) for s in symbols]) - if candidate in unavoidable: - return - reducible.extend([s.symbol for s in symbols if candidate in lda[s]]) - other.extend([s.symbol for s in symbols if candidate not in lda[s]]) - - # Make sure we do not break data dependencies - make_reduce = [] - writes = Find(Writer).visit(candidate) - for w in flatten(writes.values()): - if isinstance(w.rvalue, EmptyStatement): - continue - if any(s == w.lvalue.symbol for s in other): - return - if any(s == w.lvalue.symbol for s in reducible): - loop = lda[w.lvalue][-1] - make_reduce.append((w, loop)) - - assignments = [(w, p) for w, p in make_reduce if isinstance(w, Assign)] - loops, parents = zip(*self.expr_info.loops_info) - index = loops.index(candidate) - - # Perform a number of checks to ensure lifting reductions is safe - if not all(s in [w.lvalue.symbol for w, _ in make_reduce] for s in reducible): - return - if any(p != candidate and not is_perfect_loop(p) for w, p in make_reduce): - return - if any(candidate.dim in w.lvalue.rank for w, _ in assignments): - return - if any(set(loops[index + 1:]) & set(lda[w.lvalue]) for w, _ in make_reduce): - return - - # Inject the reductions into the AST - decls = visit(self.header, info_items=['decls'])['decls'] - for w, p in make_reduce: - name = self._template % len(self.hoisted) - reduction = Incr(Symbol(name, w.lvalue.rank, w.lvalue.offset), - ast_reconstruct(w.rvalue)) - insert_at_elem(p.body, w, reduction) - handle = decls[w.lvalue.symbol] - declaration = Decl(handle.typ, Symbol(name, handle.lvalue.rank), - ArrayInit(np.array([0.0])), handle.qual, handle.attr) - insert_at_elem(parents[index].children, candidate, declaration) - ast_replace(self.stmt, {w.lvalue: reduction.lvalue}, copy=True) - self.hoisted[name] = (reduction, declaration, p, p.body) - - # Pull out the candidate reduction loop - pulling = loops[index + 1:] - pulling = list(zip(*[((l.start, l.end), l.dim) for l in pulling])) - pulling = ItSpace().to_for(*pulling, stmts=[self.stmt]) - insert_at_elem(parents[index].children, candidate, pulling[0], ofs=1) - if len(self.expr_info.parent.children) == 1: - loops[index].body.remove(loops[index + 1]) - else: - self.expr_info.parent.children.remove(self.stmt) - - # Clean up removing any now unnecessary symbols - reads = in_read(candidate, key='symbol') - declarations = Find(Decl, with_parent=True).visit(self.header)[Decl] - declarations = dict(declarations) - for w, p in make_reduce: - if w.lvalue.symbol not in reads: - p.body.remove(w) - if not isinstance(w, Decl): - key = decls[w.lvalue.symbol] - declarations[key].children.remove(key) diff --git a/coffee/optimizer.py b/coffee/optimizer.py index fb4cd231..1a437e00 100644 --- a/coffee/optimizer.py +++ b/coffee/optimizer.py @@ -31,24 +31,7 @@ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED # OF THE POSSIBILITY OF SUCH DAMAGE. -from __future__ import absolute_import, print_function, division -from six.moves import zip - -import operator -import resource -from collections import OrderedDict -from itertools import combinations -from math import factorial as fact - -from . import system -from .base import * -from .utils import * -from .scheduler import ExpressionFissioner, ZeroRemover, SSALoopMerger -from .rewriter import ExpressionRewriter -from .cse import CSEUnpicker -from .logger import warn -from coffee.visitors import Find, ProjectExpansion -from functools import reduce +from .utils import StmtTracker class LoopOptimizer(object): @@ -69,436 +52,6 @@ def __init__(self, loop, header, exprs): # Track hoisted expressions self.hoisted = StmtTracker() - def rewrite(self, mode): - """Rewrite all compute-intensive expressions detected in the loop nest to - minimize the number of floating point operations performed. - - :param mode: Any value in (0, 1, 2, 3, 4). Each ``mode`` corresponds to a - different expression rewriting strategy. - - * mode == 0: no rewriting is performed. - * mode == 1: generalized loop-invariant code motion. - * mode == 2: apply four passes: generalized loop-invariant code motion; - expansion of inner-loop dependent expressions; factorization of - inner-loop dependent terms; generalized loop-invariant code motion. - * mode == 3: apply multiple passes; aims at pre-evaluating sub-expressions - that fully depend on reduction loops. - * mode == 4: rewrite an expression based on its sharing graph - """ - # Set a rewrite mode for each expression - for stmt, expr_info in self.exprs.items(): - expr_info.mode = mode - - # Analyze the individual expressions and try to select an optimal rewrite - # mode for each of them. A preliminary transformation of the loop nest may - # take place in this pass (e.g., injection) - if mode == 'auto': - self._dissect('greedy') - elif mode == 'auto-aggressive': - self._dissect('aggressive') - - # Search for factorization opportunities across temporaries in the kernel - if mode > 1 and self.exprs: - self._unpick_cse() - - # Expression rewriting, expressed as a sequence of AST transformation passes - for stmt, expr_info in self.exprs.items(): - ew = ExpressionRewriter(stmt, expr_info, self.header, self.hoisted) - - if expr_info.mode == 1: - if expr_info.dimension in [0, 1]: - ew.licm(mode='only_outlinear') - else: - ew.licm() - - elif expr_info.mode == 2: - if expr_info.dimension > 0: - ew.replacediv() - ew.sharing_graph_rewrite() - ew.licm(mode='reductions') - - elif expr_info.mode == 3: - ew.expand(mode='all') - ew.factorize(mode='all') - ew.licm(mode='only_const') - ew.factorize(mode='constants') - ew.licm(mode='aggressive') - ew.preevaluate() - ew.factorize(mode='linear') - ew.licm(mode='only_const') - - elif expr_info.mode == 4: - ew.replacediv() - ew.factorize() - ew.licm(mode='only_outlinear') - if expr_info.dimension > 0: - ew.licm(mode='only_linear', iterative=False, max_sharing=True) - ew.sharing_graph_rewrite() - ew.expand() - - # Try merging the loops created by expression rewriting - merged_loops = SSALoopMerger().merge(self.header) - # Update the trackers - for merged, merged_in in merged_loops: - for l in merged: - self.hoisted.update_loop(l, merged_in) - # Was /merged/ an expression loops? If so, need to update the - # corresponding MetaExpr - for stmt, expr_info in self.exprs.items(): - if expr_info.loops[-1] == l: - expr_info._loops_info[-1] = (merged_in, expr_info.loops_parents[-1]) - expr_info._parent = merged_in.children[0] - - # Reduce memory pressure by avoiding useless temporaries - self._min_temporaries() - - # Handle the effects, at the C-level, of the AST transformation - self._recoil() - - def eliminate_zeros(self): - """Restructure the iteration spaces nested in this LoopOptimizer to - avoid evaluation of arithmetic operations involving zero-valued blocks - in statically initialized arrays.""" - - zls = ZeroRemover(self.exprs, self.hoisted) - self.nz_syms = zls.reschedule(self.header) - - def _unpick_cse(self): - """Search for factorization opportunities across temporaries created by - common sub-expression elimination. If a gain in operation count is detected, - unpick CSE and apply factorization + code motion.""" - cse_unpicker = CSEUnpicker(self.exprs, self.header, self.hoisted) - cse_unpicker.unpick() - - def _min_temporaries(self): - """Remove unnecessary temporaries, thus relieving memory pressure. - A temporary is removed iff: - - * it is written once, AND - * it is read once OR it is read n times, but it hosts only a Symbol - """ - - occurrences = count(self.header, mode='symbol_id', read_only=True) - - for l in self.hoisted.all_loops: - info = visit(l, info_items=['symbol_refs', 'symbols_mode']) - to_replace, to_remove = {}, [] - for (temporary, _, _), c in count(l, read_only=True).items(): - if temporary not in self.hoisted: - continue - if self.hoisted[temporary].loop is not l: - continue - if occurrences.get(temporary) != c: - continue - decl = self.hoisted[temporary].decl - place = self.hoisted[temporary].place - expr = self.hoisted[temporary].stmt.rvalue - if c > 1 and explore_operator(expr): - continue - references = info['symbol_refs'][temporary] - syms_mode = info['symbols_mode'] - # Note: only one write is possible at this point - write = [(s, p) for s, p in references if syms_mode[s][0] == WRITE][0] - to_replace[write[0]] = expr - to_remove.append(write[1]) - place.children.remove(decl) - # Update trackers - self.hoisted.pop(temporary) - - # Replace temporary symbols and clean up - l_innermost_body = inner_loops(l)[-1].body - for stmt in l_innermost_body: - if stmt.lvalue in to_replace: - continue - while ast_replace(stmt, to_replace, copy=True): - pass - for stmt in to_remove: - l_innermost_body.remove(stmt) - - def _dissect(self, heuristics): - """Analyze the set of expressions in the LoopOptimizer and infer an - optimal rewrite mode for each of them. - - If an expression is embedded in a non-perfect loop nest, then injection - may be performed. Injection consists of unrolling any loops outside of - the expression iteration space into the expression itself. - For example: :: - - for i - for r - a += B[r]*C[i][r] - for j - for k - A[j][k] += ...f(a)... // the expression at hand - - gets transformed into: - - for i - for j - for k - A[j][k] += ...f(B[0]*C[i][0] + B[1]*C[i][1] + ...)... - - Injection could be necessary to maximize the impact of rewrite mode=3, - which tries to pre-evaluate subexpressions whose values are known at - code generation time. Injection is essential to factorize such subexprs. - - :arg heuristic: any value in ['greedy', 'aggressive']. With 'greedy', a greedy - approach is used to decide which of the expressions for which injection - looks beneficial should be dissected (e.g., injection increases the memory - footprint, and some memory constraints must always be preserved). - With 'aggressive', the whole space of possibilities is analyzed. - """ - # The memory threshold. The total size of temporaries will not have to - # be greated than this value. If we predict that injection will lead - # to too much temporary space, we have to partially drop it - threshold = system.architecture['cache_size'] * 1.2 - - expr_graph = ExpressionGraph(header) - - # 1) Find out and unroll injectable loops. For unrolling we create new - # expressions; that is, for now, we do not modify the AST in place. - analyzed, injectable = [], {} - for stmt, expr_info in self.exprs.items(): - # Get all loop nests, then discard the one enclosing the expression - nests = [n for n in visit(expr_info.loops_parents[0])['fors']] - injectable_nests = [n for n in nests if list(zip(*n))[0] != expr_info.loops] - - for nest in injectable_nests: - to_unroll = [(l, p) for l, p in nest if l not in expr_info.loops] - unroll_cost = reduce(operator.mul, (l.size for l, p in to_unroll)) - - nest_writers = Find(Writer).visit(to_unroll[0][0]) - for op, i_stmts in nest_writers.items(): - # Check safety of unrolling - if op in [Assign, IMul, IDiv]: - continue - assert op in [Incr, Decr] - - for i_stmt in i_stmts: - i_sym, i_expr = i_stmt.children - - # Avoid injecting twice the same loop - if i_stmt in analyzed + [l.incr for l, p in to_unroll]: - continue - analyzed.append(i_stmt) - - # Create unrolled, injectable expressions - for l, p in reversed(to_unroll): - i_expr = [dcopy(i_expr) for i in range(l.size)] - for i, e in enumerate(i_expr): - e_syms = Find(Symbol).visit(e)[Symbol] - for s in e_syms: - s.rank = tuple([r if r != l.dim else i for r in s.rank]) - i_expr = ast_make_expr(Sum, i_expr) - - # Track the unrolled, injectable expressions and their cost - if i_sym.symbol in injectable: - old_i_expr, old_cost = injectable[i_sym.symbol] - new_i_expr = ast_make_expr(Sum, [i_expr, old_i_expr]) - new_cost = unroll_cost + old_cost - injectable[i_sym.symbol] = (new_i_expr, new_cost) - else: - injectable[i_sym.symbol] = (i_expr, unroll_cost) - - # 2) Will rewrite mode=3 be cheaper than rewrite mode=2? - def find_save(target_expr, expr_info): - save_factor = [l.size for l in expr_info.out_linear_loops] or [1] - save_factor = reduce(operator.mul, save_factor) - # The save factor should be multiplied by the number of terms - # that will /not/ be pre-evaluated. To obtain this number, we - # can exploit the linearity of the expression in the terms - # depending on the linear loops. - syms = Find(Symbol).visit(target_expr)[Symbol] - inner = lambda s: any(r == expr_info.linear_dims[-1] for r in s.rank) - nterms = len(set(s.symbol for s in syms if inner(s))) - save = nterms * save_factor - return save_factor, save - - should_unroll = True - storage = 0 - i_syms, injected = injectable.keys(), defaultdict(list) - for stmt, expr_info in self.exprs.items(): - sym, expr = stmt.children - - # Divide /expr/ into subexpressions, each subexpression affected - # differently by injection - if i_syms: - dissected = find_expression(expr, Prod, expr_info.linear_dims, i_syms) - leftover = find_expression(expr, dims=expr_info.linear_dims, out_syms=i_syms) - leftover = {(): list(flatten(leftover.values()))} - dissected = dict(dissected.items() + leftover.items()) - else: - dissected = {(): [expr]} - if any(i not in flatten(dissected.keys()) for i in i_syms): - should_unroll = False - continue - - # Apply the profitability model - analysis = OrderedDict() - for i_syms, target_exprs in dissected.items(): - for target_expr in target_exprs: - - # *** Save *** - save_factor, save = find_save(target_expr, expr_info) - - # *** Cost *** - # The number of operations increases by a factor which - # corresponds to the number of possible /combinations with - # repetitions/ in the injected-values set. We consider - # combinations and not dispositions to take into account the - # (future) effect of factorization. - retval = ProjectExpansion.default_retval() - projection = ProjectExpansion(i_syms).visit(target_expr, ret=retval) - projection = [i for i in projection if i] - increase_factor = 0 - for i in projection: - partial = 1 - for j in expr_graph.shares(i): - # _n=number of unique elements, _k=group size - _n = injectable[j[0]][1] - _k = len(j) - partial *= fact(_n + _k - 1)//(fact(_k)*fact(_n - 1)) - increase_factor += partial - increase_factor = increase_factor or 1 - if increase_factor > save_factor: - # We immediately give up if this holds since it ensures - # that /cost > save/ (but not that cost <= save) - should_unroll = False - continue - # The increase factor should be multiplied by the number of - # terms that will be pre-evaluated. To obtain this number, - # we need to project the output of factorization. - fake_stmt = stmt.__class__(stmt.children[0], dcopy(target_expr)) - fake_parent = expr_info.parent.children - fake_parent[fake_parent.index(stmt)] = fake_stmt - ew = ExpressionRewriter(fake_stmt, expr_info) - ew.expand(mode='all').factorize(mode='all').factorize(mode='linear') - nterms = ew.licm(mode='aggressive', look_ahead=True) - nterms = len(uniquify(nterms[expr_info.dims])) or 1 - fake_parent[fake_parent.index(fake_stmt)] = stmt - cost = nterms * increase_factor - - # Pre-evaluation will also increase the working set size by - # /cost/ * /sizeof(term)/. - size = [l.size for l in expr_info.linear_loops] - size = reduce(operator.mul, size, 1) - storage_increase = cost * size * system.architecture[expr_info.type] - - # Track the injectable sub-expression and its cost/save. The - # final decision of whether to actually perform injection or not - # is postponed until all dissected expressions have been analyzed - analysis[target_expr] = (cost, save, storage_increase) - - # So what should we inject afterall ? Time to *use* the cost model - if heuristics == 'greedy': - for target_expr, (cost, save, storage_increase) in analysis.items(): - if cost > save or storage_increase + storage > threshold: - should_unroll = False - else: - # Update the available storage - storage += storage_increase - # At this point, we can happily inject - to_replace = {k: v[0] for k, v in injectable.items()} - ast_replace(target_expr, to_replace, copy=True) - injected[stmt].append(target_expr) - elif heuristics == 'aggressive': - # A) Remove expression that we already know should never be injected - not_injected = [] - for target_expr, (cost, save, storage_increase) in analysis.items(): - if cost > save: - should_unroll = False - analysis.pop(target_expr) - not_injected.append(target_expr) - # B) Find all possible bipartitions: each bipartition represents - # the set of expressions that will be pre-evaluated and the set - # of expressions that could also be pre-evaluated, but might not - # (e.g. because of memory constraints) - target_exprs = analysis.keys() - bipartitions = [] - for i in range(len(target_exprs)+1): - for e1 in combinations(target_exprs, i): - bipartitions.append((e1, tuple(e2 for e2 in target_exprs - if e2 not in e1))) - # C) Eliminate those bipartitions that would lead to exceeding - # the memory threshold - bipartitions = [(e1, e2) for e1, e2 in bipartitions if - sum(analysis[i][2] for i in e1) <= threshold] - # D) Find out what is best to pre-evaluate (and therefore - # what should be injected) - totals = OrderedDict() - for e1, e2 in bipartitions: - # Is there any value in actually not pre-evaluating the - # expressions in /e2/ ? - fake_expr = ast_make_expr(Sum, list(e2) + not_injected) - _, save = find_save(fake_expr, expr_info) if fake_expr else (0, 0) - cost = sum(analysis[i][0] for i in e1) - totals[(e1, e2)] = save + cost - best = min(totals, key=totals.get) - # At this point, we can happily inject - to_replace = {k: v[0] for k, v in injectable.items()} - for target_expr in best[0]: - ast_replace(target_expr, to_replace, copy=True) - injected[stmt].append(target_expr) - if best[1]: - # At least one non-injected expressions, let's be sure we - # don't unroll everything - should_unroll = False - - # 3) Purge the AST from now useless symbols/expressions - if should_unroll: - decls = visit(self.header, info_items=['decls'])['decls'] - for stmt, expr_info in self.exprs.items(): - nests = [n for n in visit(expr_info.loops_parents[0])['fors']] - injectable_nests = [n for n in nests if list(zip(*n))[0] != expr_info.loops] - for nest in injectable_nests: - unrolled = [(l, p) for l, p in nest if l not in expr_info.loops] - for l, p in unrolled: - p.children.remove(l) - for i_sym in injectable.keys(): - decl = decls.get(i_sym) - if decl and decl in p.children: - p.children.remove(decl) - - # 4) Split the expressions if injection has been performed - for stmt, expr_info in self.exprs.items(): - expr_info.mode = 4 - inj_exprs = injected.get(stmt) - if not inj_exprs: - continue - fissioner = ExpressionFissioner(match=inj_exprs, loops='all', perfect=True) - new_exprs = fissioner.fission(stmt, self.exprs.pop(stmt)) - self.exprs.update(new_exprs) - for stmt, expr_info in new_exprs.items(): - expr_info.mode = 3 if stmt in fissioner.matched else 4 - - def _recoil(self): - """Increase the stack size if the kernel arrays exceed the stack limit - threshold (at the C level).""" - decls = visit(self.header, info_items=['decls'])['decls'] - - # Assume the size of a C type double is 8 bytes - c_double_size = 8 - # Assume the stack size is 1.7 MB (2 MB is usually the limit) - stack_size = 1.7*1024*1024 - - decls = [d for d in decls.values() if d.size] - size = sum([reduce(operator.mul, d.sym.rank) for d in decls]) - - if size * c_double_size > stack_size: - # Increase the stack size if the kernel's stack size seems to outreach - # the space available - try: - resource.setrlimit(resource.RLIMIT_STACK, (resource.RLIM_INFINITY, - resource.RLIM_INFINITY)) - except resource.error: - warn("Stack may blow up, could not increase its size.") - - @property - def expr_loops(self): - """Return ``[(loop1, loop2, ...), ...]``, where each tuple contains all - loops enclosing expressions.""" - return [expr_info.loops for expr_info in self.exprs.values()] - @property def expr_linear_loops(self): """Return ``[(loop1, loop2, ...), ...]``, where each tuple contains all @@ -510,70 +63,7 @@ class CPULoopOptimizer(LoopOptimizer): """Loop optimizer for CPU architectures.""" - def split(self, cut=1): - """Split expressions into multiple chunks exploiting sum's associativity. - Each chunk will have ``cut`` summands. - - For example, consider the following piece of code: :: - - for i - for j - A[i][j] += X[i]*Y[j] + Z[i]*K[j] + B[i]*X[j] - - If ``cut=1`` the expression is cut into chunks of length 1: :: - - for i - for j - A[i][j] += X[i]*Y[j] - for i - for j - A[i][j] += Z[i]*K[j] - for i - for j - A[i][j] += B[i]*X[j] - - If ``cut=2`` the expression is cut into chunks of length 2, plus a - remainder chunk of size 1: :: - - for i - for j - A[i][j] += X[i]*Y[j] + Z[i]*K[j] - # Remainder: - for i - for j - A[i][j] += B[i]*X[j] - """ - - new_exprs = OrderedDict() - elf = ExpressionFissioner(cut=cut, loops='expr') - for stmt, expr_info in self.exprs.items(): - new_exprs.update(elf.fission(stmt, expr_info)) - self.exprs = new_exprs - class GPULoopOptimizer(LoopOptimizer): """Loop optimizer for GPU architectures.""" - - def extract(self): - """Remove the fully-parallel loops of the loop nest. No data dependency - analysis is performed; rather, these are the loops that are marked with - ``pragma coffee itspace``.""" - - info = visit(self.loop, self.header, info_items=['symbols_dep', 'fors']) - symbols = info['symbols_dep'] - - itspace_vrs = set() - for nest in info['fors']: - for loop, parent in reversed(nest): - if '#pragma coffee itspace' not in loop.pragma: - continue - parent = parent.children - for n in loop.body: - parent.insert(parent.index(loop), n) - parent.remove(loop) - itspace_vrs.add(loop.dim) - - accessed_vrs = [s for s in symbols if any_in(s.rank, itspace_vrs)] - - return (itspace_vrs, accessed_vrs) diff --git a/coffee/rewriter.py b/coffee/rewriter.py deleted file mode 100644 index 5a2621f2..00000000 --- a/coffee/rewriter.py +++ /dev/null @@ -1,639 +0,0 @@ -# This file is part of COFFEE -# -# COFFEE is Copyright (c) 2014, Imperial College London. -# Please see the AUTHORS file in the main source directory for -# a full list of copyright holders. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * The name of Imperial College London or that of other -# contributors may not be used to endorse or promote products -# derived from this software without specific prior written -# permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS -# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS -# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -# OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import absolute_import, print_function, division -from six.moves import zip - -from collections import Counter -from itertools import combinations -from operator import itemgetter - -from .base import * -from .utils import * -from coffee.visitors import * -from .hoister import Hoister -from .expander import Expander -from .factorizer import Factorizer -from .logger import warn - - -class ExpressionRewriter(object): - """Provide operations to re-write an expression: - - * Loop-invariant code motion: find and hoist sub-expressions which are - invariant with respect to a loop - * Expansion: transform an expression ``(a + b)*c`` into ``(a*c + b*c)`` - * Factorization: transform an expression ``a*b + a*c`` into ``a*(b+c)``""" - - def __init__(self, stmt, expr_info, header=None, hoisted=None): - """Initialize the ExpressionRewriter. - - :param stmt: the node whose rvalue is the expression for rewriting - :param expr_info: ``MetaExpr`` object describing the expression - :param header: the kernel's top node - :param hoisted: dictionary that tracks all hoisted expressions - """ - self.stmt = stmt - self.expr_info = expr_info - self.header = header or Root() - self.hoisted = hoisted if hoisted is not None else StmtTracker() - - self.codemotion = Hoister(self.stmt, self.expr_info, self.header, self.hoisted) - self.expander = Expander(self.stmt) - self.factorizer = Factorizer(self.stmt) - - def licm(self, mode='normal', **kwargs): - """Perform generalized loop-invariant code motion, a transformation - detailed in a paper available at: - - http://dl.acm.org/citation.cfm?id=2687415 - - :param mode: drive code motion by specifying what subexpressions should - be hoisted and where. - * normal: (default) all subexpressions that depend on one loop at most - * aggressive: all subexpressions, depending on any number of loops. - This may require introducing N-dimensional temporaries. - * incremental: apply, in sequence, only_const, only_outlinear, and - one sweep for each linear dimension - * only_const: only all constant subexpressions - * only_linear: only all subexpressions depending on linear loops - * only_outlinear: only all subexpressions independent of linear loops - * reductions: all sub-expressions that are redundantly computed within - a reduction loop; if possible, pull the reduction loop out of - the nest. - :param kwargs: - * look_ahead: (default: False) should be set to True if only a projection - of the hoistable subexpressions is needed (i.e., hoisting not performed) - * max_sharing: (default: False) should be set to True if hoisting should be - avoided in case the same set of symbols appears in different hoistable - sub-expressions. By not hoisting, factorization opportunities are preserved - * iterative: (default: True) should be set to False if interested in - hoisting only the smallest subexpressions matching /mode/ - * lda: an up-to-date loop dependence analysis, as returned by a call - to ``loops_analysis(node, 'dim'). By providing this information, loop - dependence analysis can be avoided, thus speeding up the transformation. - * global_cse: (default: False) search for common sub-expressions across - all previously hoisted terms. Note that no data dependency analysis is - performed, so this is at caller's risk. - * with_promotion: compute hoistable subexpressions within clone loops - even though this doesn't necessarily result in fewer operations. - - Examples - ======== - - 1) With mode='normal': :: - - for i - for j - for k - a[j][k] += (b[i][j] + c[i][j])*(d[i][k] + e[i][k]) - - Redundancies are spotted along both the i and j dimensions, resulting in: :: - - for i - for k - ct1[k] = d[i][k] + e[i][k] - for j - ct2 = b[i][j] + c[i][j] - for k - a[j][k] += ct2*ct1[k] - - 2) With mode='reductions'. - Consider the following loop nest: :: - - for i - for j - a[j] += b[j]*c[i] - - By unrolling the loops, one clearly sees that: :: - - a[0] += b[0]*c[0] + b[0]*c[1] + b[0]*c[2] + ... - a[1] += b[1]*c[0] + b[1]*c[1] + b[1]*c[2] + ... - - Which is identical to: :: - - ct = c[0] + c[1] + c[2] + ... - a[0] += b[0]*ct - a[1] += b[1]*ct - - Thus, the original loop nest is simplified as: :: - - for i - ct += c[i] - for j - a[j] += b[j]*ct - """ - - dimension = self.expr_info.dimension - dims = set(self.expr_info.dims) - linear_dims = set(self.expr_info.linear_dims) - out_linear_dims = set(self.expr_info.out_linear_dims) - - if kwargs.get('look_ahead'): - hoist = self.codemotion.extract - else: - hoist = self.codemotion.licm - - if mode == 'normal': - should_extract = lambda d: d != dims - hoist(should_extract, **kwargs) - elif mode == 'reductions': - should_extract = lambda d: d != dims - # Expansion and reassociation may create hoistable reduction loops - candidates = self.expr_info.reduction_loops - if not candidates: - return self - candidate = candidates[-1] - if candidate.size == 1: - # Otherwise the operation count will just end up increasing - return - self.expand(mode='all') - lda = loops_analysis(self.header, value='dim') - non_candidates = {l.dim for l in candidates[:-1]} - self.reassociate(lambda i: not lda[i].intersection(non_candidates)) - hoist(should_extract, with_promotion=True, lda=lda) - self.codemotion.trim(candidate) - elif mode == 'incremental': - lda = kwargs.get('lda') or loops_analysis(self.header, value='dim') - should_extract = lambda d: not (d and d.issubset(dims)) - hoist(should_extract, lda=lda) - should_extract = lambda d: d.issubset(out_linear_dims) - hoist(should_extract, lda=lda) - for i in range(1, dimension): - should_extract = lambda d: len(d.intersection(linear_dims)) <= i - hoist(should_extract, lda=lda, **kwargs) - elif mode == 'only_const': - should_extract = lambda d: not (d and d.issubset(dims)) - hoist(should_extract, **kwargs) - elif mode == 'only_outlinear': - should_extract = lambda d: d.issubset(out_linear_dims) - hoist(should_extract, **kwargs) - elif mode == 'only_linear': - should_extract = lambda d: not d.issubset(out_linear_dims) and d != linear_dims - hoist(should_extract, **kwargs) - elif mode == 'aggressive': - should_extract = lambda d: True - self.reassociate() - hoist(should_extract, with_promotion=True, **kwargs) - else: - warn('Skipping unknown licm strategy.') - return self - - return self - - def expand(self, mode='standard', **kwargs): - """Expand expressions based on different rules. For example: :: - - (X[i] + Y[j])*F + ... - - can be expanded into: :: - - (X[i]*F + Y[j]*F) + ... - - The expanded term could also be lifted. For example, if we have: :: - - Y[j] = f(...) - (X[i]*Y[j])*F + ... - - where ``Y`` was produced by code motion, expansion results in: :: - - Y[j] = f(...)*F - (X[i]*Y[j]) + ... - - Reasons for expanding expressions include: - - * Exposing factorization opportunities - * Exposing higher level operations (e.g., matrix multiplies) - * Relieving register pressure - - :param mode: multiple expansion strategies are possible - * mode == 'standard': expand along the loop dimension appearing most - often in different symbols - * mode == 'dimensions': expand along the loop dimensions provided in - /kwargs['dimensions']/ - * mode == 'all': expand when symbols depend on at least one of the - expression's dimensions - * mode == 'linear': expand when symbols depend on the expressions's - linear loops. - * mode == 'outlinear': expand when symbols are independent of the - expression's linear loops. - :param kwargs: - * subexprs: an iterator of subexpressions rooted in /self.stmt/. If - provided, expansion will be performed only within these trees, - rather than within the whole expression. - * lda: an up-to-date loop dependence analysis, as returned by a call - to ``loops_analysis(node, 'symbol', 'dim'). By providing this - information, loop dependence analysis can be avoided, thus - speeding up the transformation. - """ - - if mode == 'standard': - symbols = Find(Symbol).visit(self.stmt.rvalue)[Symbol] - # The heuristics privileges linear dimensions - dims = self.expr_info.out_linear_dims - if not dims or self.expr_info.dimension >= 2: - dims = self.expr_info.linear_dims - # Get the dimension occurring most often - occurrences = [tuple(r for r in s.rank if r in dims) for s in symbols] - occurrences = [i for i in occurrences if i] - if not occurrences: - return self - # Finally, establish the expansion dimension - dimension = Counter(occurrences).most_common(1)[0][0] - should_expand = lambda n: set(dimension).issubset(set(n.rank)) - elif mode == 'dimensions': - dimensions = kwargs.get('dimensions', ()) - should_expand = lambda n: set(dimensions).issubset(set(n.rank)) - elif mode in ['all', 'linear', 'outlinear']: - lda = kwargs.get('lda') or loops_analysis(self.expr_info.outermost_loop, - key='symbol', value='dim') - if mode == 'all': - should_expand = lambda n: lda.get(n.symbol) and \ - any(r in self.expr_info.dims for r in lda[n.symbol]) - elif mode == 'linear': - should_expand = lambda n: lda.get(n.symbol) and \ - any(r in self.expr_info.linear_dims for r in lda[n.symbol]) - elif mode == 'outlinear': - should_expand = lambda n: lda.get(n.symbol) and \ - not lda[n.symbol].issubset(set(self.expr_info.linear_dims)) - else: - warn('Skipping unknown expansion strategy.') - return - - self.expander.expand(should_expand, **kwargs) - return self - - def factorize(self, mode='standard', **kwargs): - """Factorize terms in the expression. For example: :: - - A[i]*B[j] + A[i]*C[j] - - becomes :: - - A[i]*(B[j] + C[j]). - - :param mode: multiple factorization strategies are possible. Note that - different strategies may expose different code motion opportunities - - * mode == 'standard': factorize symbols along the dimension that appears - most often in the expression. - * mode == 'dimensions': factorize symbols along the loop dimensions provided - in /kwargs['dimensions']/ - * mode == 'all': factorize symbols depending on at least one of the - expression's dimensions. - * mode == 'linear': factorize symbols depending on the expression's - linear loops. - * mode == 'outlinear': factorize symbols independent of the expression's - linear loops. - * mode == 'constants': factorize symbols independent of any loops enclosing - the expression. - * mode == 'adhoc': factorize only symbols in /kwargs['adhoc']/ (details below) - * mode == 'heuristic': no global factorization rule is used; rather, within - each Sum tree, factorize the symbols appearing most often in that tree - :param kwargs: - * subexprs: an iterator of subexpressions rooted in /self.stmt/. If - provided, factorization will be performed only within these trees, - rather than within the whole expression - * adhoc: a list of symbols that can be factorized and, for each symbol, - a list of symbols that can be grouped. For example, if we have - ``kwargs['adhoc'] = [(A, [B, C]), (D, [E, F, G])]``, and the - expression is ``A*B + D*E + A*C + A*F``, the result will be - ``A*(B+C) + A*F + D*E``. If the A's list were empty, all of the - three symbols B, C, and F would be factorized. Recall that this - option is ignored unless ``mode == 'adhoc'``. - * lda: an up-to-date loop dependence analysis, as returned by a call - to ``loops_analysis(node, 'symbol', 'dim'). By providing this - information, loop dependence analysis can be avoided, thus - speeding up the transformation. - """ - - if mode == 'standard': - symbols = Find(Symbol).visit(self.stmt.rvalue)[Symbol] - # The heuristics privileges linear dimensions - dims = self.expr_info.out_linear_dims - if not dims or self.expr_info.dimension >= 2: - dims = self.expr_info.linear_dims - # Get the dimension occurring most often - occurrences = [tuple(r for r in s.rank if r in dims) for s in symbols] - occurrences = [i for i in occurrences if i] - if not occurrences: - return self - # Finally, establish the factorization dimension - dimension = Counter(occurrences).most_common(1)[0][0] - should_factorize = lambda n: set(dimension).issubset(set(n.rank)) - elif mode == 'dimensions': - dimensions = kwargs.get('dimensions', ()) - should_factorize = lambda n: set(dimensions).issubset(set(n.rank)) - elif mode == 'adhoc': - adhoc = kwargs.get('adhoc') - if not adhoc: - return self - should_factorize = lambda n: n.urepr in adhoc - elif mode == 'heuristic': - kwargs['heuristic'] = True - should_factorize = lambda n: False - elif mode in ['all', 'linear', 'outlinear', 'constants']: - lda = kwargs.get('lda') or loops_analysis(self.expr_info.outermost_loop, - key='symbol', value='dim') - if mode == 'all': - should_factorize = lambda n: lda.get(n.symbol) and \ - any(r in self.expr_info.dims for r in lda[n.symbol]) - elif mode == 'linear': - should_factorize = lambda n: lda.get(n.symbol) and \ - any(r in self.expr_info.linear_dims for r in lda[n.symbol]) - elif mode == 'outlinear': - should_factorize = lambda n: lda.get(n.symbol) and \ - not lda[n.symbol].issubset(set(self.expr_info.linear_dims)) - elif mode == 'constants': - should_factorize = lambda n: not lda.get(n.symbol) - else: - warn('Skipping unknown factorization strategy.') - return - - # Perform the factorization - self.factorizer.factorize(should_factorize, **kwargs) - return self - - def reassociate(self, reorder=None): - """Reorder symbols in associative operations following a convention. - By default, the convention is to order the symbols based on their rank. - For example, the terms in the expression :: - - a*b[i]*c[i][j]*d - - are reordered as :: - - a*d*b[i]*c[i][j] - - This as achieved by reorganizing the AST of the expression. - """ - - def _reassociate(node, parent): - if isinstance(node, (Symbol, Div)): - return - - elif isinstance(node, (Sum, Sub, FunCall, Ternary)): - for n in node.children: - _reassociate(n, node) - - elif isinstance(node, Prod): - children = explore_operator(node) - # Reassociate symbols - symbols = [n for n, p in children if isinstance(n, Symbol)] - # Capture the other children and recur on them - other_nodes = [(n, p) for n, p in children if not isinstance(n, Symbol)] - for n, p in other_nodes: - _reassociate(n, p) - # Create the reassociated product and modify the original AST - children = list(zip(*other_nodes))[0] if other_nodes else () - children += tuple(sorted(symbols, key=reorder)) - reassociated_node = ast_make_expr(Prod, children, balance=False) - parent.children[parent.children.index(node)] = reassociated_node - - else: - warn('Unexpected node %s while reassociating' % typ(node)) - - reorder = reorder if reorder else lambda n: (n.rank, n.dim) - _reassociate(self.stmt.rvalue, self.stmt) - return self - - def replacediv(self): - """Replace divisions by a constant with multiplications.""" - divisions = Find(Div).visit(self.stmt.rvalue)[Div] - to_replace = {} - for i in divisions: - if isinstance(i.right, Symbol): - if isinstance(i.right.symbol, (int, float)): - to_replace[i] = Prod(i.left, 1 / i.right.symbol) - elif isinstance(i.right.symbol, str) and i.right.symbol.isdigit(): - to_replace[i] = Prod(i.left, 1 / float(i.right.symbol)) - else: - to_replace[i] = Prod(i.left, Div(1.0, i.right)) - ast_replace(self.stmt, to_replace, copy=True, mode='symbol') - return self - - def preevaluate(self): - """Preevaluates subexpressions which values are compile-time constants. - In this process, reduction loops might be removed if the reduction itself - could be pre-evaluated.""" - # Aliases - stmt, expr_info = self.stmt, self.expr_info - - # Simplify reduction loops - if not isinstance(stmt, (Incr, Decr, IMul, IDiv)): - # Not a reduction expression, give up - return - expr_syms = Find(Symbol).visit(stmt.rvalue)[Symbol] - reduction_loops = expr_info.out_linear_loops_info - if any([not is_perfect_loop(l) for l, p in reduction_loops]): - # Unsafe if not a perfect loop nest - return - # The following check is because it is unsafe to simplify if non-loop or - # non-constant dimensions are present - hoisted_stmts = self.hoisted.all_stmts - hoisted_syms = [Find(Symbol).visit(h)[Symbol] for h in hoisted_stmts] - hoisted_dims = [s.rank for s in flatten(hoisted_syms)] - hoisted_dims = set([r for r in flatten(hoisted_dims) if not is_const_dim(r)]) - if any(d not in expr_info.dims for d in hoisted_dims): - # Non-loop dimension or non-constant dimension found, e.g. A[i], with /i/ - # not being a loop iteration variable - return - for i, (l, p) in enumerate(reduction_loops): - syms_dep = SymbolDependencies().visit(l, **SymbolDependencies.default_args) - if not all([(tuple(syms_dep[s]) == expr_info.loops and s.dim == len(expr_info.loops)) - for s in expr_syms if syms_dep[s]]): - # A sufficient (although not necessary) condition for loop reduction to - # be safe is that all symbols in the expression are either constants or - # tensors assuming a distinct value in each point of the iteration space. - # So if this condition fails, we give up - return - # At this point, tensors can be reduced along the reducible dimensions - reducible_syms = [s for s in expr_syms if not s.is_const] - # All involved symbols must result from hoisting - if not all([s.symbol in self.hoisted for s in reducible_syms]): - return - # Replace hoisted assignments with reductions - finder = Find(Assign, stop_when_found=True, with_parent=True) - for hoisted_loop in self.hoisted.all_loops: - for assign, parent in finder.visit(hoisted_loop)[Assign]: - sym, expr = assign.children - decl = self.hoisted[sym.symbol].decl - if sym.symbol in [s.symbol for s in reducible_syms]: - parent.children[parent.children.index(assign)] = Incr(sym, expr) - sym.rank = self.expr_info.linear_dims - decl.sym.rank = decl.sym.rank[i+1:] - # Remove the reduction loop - p.children[p.children.index(l)] = l.body[0] - # Update symbols' ranks - for s in reducible_syms: - s.rank = self.expr_info.linear_dims - # Update expression metadata - self.expr_info._loops_info.remove((l, p)) - - # Precompute constant expressions - decls = visit(self.header, info_items=['decls'])['decls'] - evaluator = Evaluate(decls, any(d.nonzero for s, d in decls.items())) - for hoisted_loop in self.hoisted.all_loops: - evals = evaluator.visit(hoisted_loop, **Evaluate.default_args) - # First, find out identical tables - mapper = defaultdict(list) - for s, values in evals.items(): - mapper[str(values)].append(s) - # Then, map identical tables to a single symbol - for values, symbols in mapper.items(): - to_replace = {s: symbols[0] for s in symbols[1:]} - ast_replace(self.stmt, to_replace, copy=True) - # Clean up - for s in symbols[1:]: - s_decl = self.hoisted[s.symbol].decl - self.header.children.remove(s_decl) - self.hoisted.pop(s.symbol) - evals.pop(s) - # Finally, update the hoisted symbols - for s, values in evals.items(): - hoisted = self.hoisted[s.symbol] - hoisted.decl.init = values - hoisted.decl.qual = ['static', 'const'] - self.hoisted.pop(s.symbol) - # Move all decls at the top of the kernel - self.header.children.remove(hoisted.decl) - self.header.children.insert(0, hoisted.decl) - self.header.children.insert(0, FlatBlock("// Preevaluated tables")) - # Clean up - self.header.children.remove(hoisted_loop) - return self - - def sharing_graph_rewrite(self): - """Rewrite the expression based on its sharing graph. Details in the - paper: - - An algorithm for the optimization of finite element integration loops - (Luporini et. al.) - """ - linear_dims = self.expr_info.linear_dims - other_dims = self.expr_info.out_linear_dims - - # Maximize visibility of linear symbols - self.expand(mode='all') - - # Make sure that potential reductions are not hidden away - lda = loops_analysis(self.header, value='dim') - self.reassociate(lambda i: (not lda[i]) + lda[i].issubset(set(other_dims))) - - # Construct the sharing graph - nodes, edges = [], [] - for i in summands(self.stmt.rvalue): - symbols = [i] if isinstance(i, Symbol) else list(zip(*explore_operator(i)))[0] - lsymbols = [s for s in symbols if any(d in lda[s] for d in linear_dims)] - lsymbols = [s.urepr for s in lsymbols] - nodes.extend([j for j in lsymbols if j not in nodes]) - edges.extend(combinations(lsymbols, r=2)) - sgraph = nx.Graph(edges) - - # Transform everything outside the sharing graph (pure linear, no ambiguity) - isolated = [n for n in nodes if n not in sgraph.nodes()] - for n in isolated: - self.factorize(mode='adhoc', adhoc={n: [] for n in nodes}) - self.licm('only_const').licm('only_outlinear') - - # Transform the expression based on the sharing graph - nodes = [n for n in nodes if n in sgraph.nodes()] - if not (nodes and all(sgraph.degree(n) > 0 for n in nodes)): - self.factorize(mode='heuristic') - self.licm('only_const').licm('only_outlinear') - return - # Use short variable names otherwise Pulp might complain - nodes_vars = {i: n for i, n in enumerate(nodes)} - vars_nodes = {n: i for i, n in nodes_vars.items()} - edges = [(vars_nodes[i], vars_nodes[j]) for i, j in edges] - - import pulp as ilp - - def setup(): - # ... declare variables - x = ilp.LpVariable.dicts('x', nodes_vars.keys(), 0, 1, ilp.LpBinary) - y = ilp.LpVariable.dicts('y', - [(i, j) for i, j in edges] + [(j, i) for i, j in edges], - 0, 1, ilp.LpBinary) - limits = defaultdict(int) - for i, j in edges: - limits[i] += 1 - limits[j] += 1 - - # ... define the problem - prob = ilp.LpProblem("Factorizer", ilp.LpMinimize) - - # ... define the constraints - for i in nodes_vars: - prob += ilp.lpSum(y[(i, j)] for j in nodes_vars if (i, j) in y) <= limits[i]*x[i] - - for i, j in edges: - prob += y[(i, j)] + y[(j, i)] == 1 - - # ... define the objective function (min number of factorizations) - prob += ilp.lpSum(x[i] for i in nodes_vars) - - return x, prob - - # Solve the ILP problem to find out the minimal-cost factorization strategy - x, prob = setup() - prob.solve(ilp.GLPK(msg=0)) - - # Also attempt to find another optimal factorization, but with - # additional constraints on the reduction dimensions. This may help in - # later rewrite steps - if len(other_dims) > 1: - z, prob = setup() - for i, n in nodes_vars.items(): - if not set(n[1]).intersection(set(other_dims[:-1])): - prob += z[i] == 0 - prob.solve(ilp.GLPK(msg=0)) - if ilp.LpStatus[prob.status] == 'Optimal': - x = z - - # ... finally, apply the transformations. Observe that: - # 1) the order (first /nodes/, than /other_nodes/) in which - # the factorizations are carried out is crucial - # 2) sorting /nodes/ and /other_nodes/ locally ensures guarantees - # deterministic output code - # 3) precedence is given to outer reduction loops; this maximises the - # impact of later transformations, while not affecting this pass - # 4) with_promotion is set to true if there exist potential reductions - # to simplify - nodes = [nodes_vars[n] for n, v in x.items() if v.value() == 1] - other_nodes = [nodes_vars[n] for n, v in x.items() if nodes_vars[n] not in nodes] - for n in sorted(nodes, key=itemgetter(1)) + sorted(other_nodes): - self.factorize(mode='adhoc', adhoc={n: []}) - self.licm('incremental', with_promotion=len(other_dims) > 1) - - return self diff --git a/coffee/scheduler.py b/coffee/scheduler.py deleted file mode 100644 index 3da18964..00000000 --- a/coffee/scheduler.py +++ /dev/null @@ -1,856 +0,0 @@ -# This file is part of COFFEE -# -# COFFEE is Copyright (c) 2014, Imperial College London. -# Please see the AUTHORS file in the main source directory for -# a full list of copyright holders. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * The name of Imperial College London or that of other -# contributors may not be used to endorse or promote products -# derived from this software without specific prior written -# permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS -# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS -# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -# OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import absolute_import, print_function, division -from six import iteritems -from six.moves import range, zip - -from collections import OrderedDict, defaultdict -from itertools import product -from copy import deepcopy as dcopy - -from .base import * -from .utils import * -from .expression import copy_metaexpr -from .rewriter import ExpressionRewriter -from .exceptions import ControlFlowError, UnexpectedNode -from coffee.visitors import FindLoopNests - - -class LoopScheduler(object): - - """Base class for classes that handle loop scheduling; that is, loop fusion, - loop distribution, etc.""" - - -class SSALoopMerger(LoopScheduler): - - """Analyze data dependencies and iteration spaces, then merge fusible loops. - Statements must be in "soft" SSA form: they can be declared and initialized - at declaration time, then they can be assigned a value in only one place.""" - - def _merge_loops(self, root, loop_a, loop_b): - """Merge the body of ``loop_a`` into ``loop_b``.""" - root.children.remove(loop_a) - - dims_a, dims_b = [loop_a.dim], [loop_b.dim] - while isinstance(loop_b.body[0], For): - dims_b.append(loop_b.dim) - loop_b = loop_b.body[0] - while isinstance(loop_a.body[0], For): - dims_a.append(loop_a.dim) - loop_a = loop_a.body[0] - - loop_b.body = loop_a.body + loop_b.body - - ast_update_rank(loop_b, dict(zip(dims_a, dims_b))) - - def _simplify(self, merged_loops): - """Scan the list of merged loops and eliminate sub-expressions that became - duplicate as now iterating along the same iteration space. For example: :: - - for i = 0 to N - A[i] = B[i] + C[i] - for j = 0 to N - D[j] = B[j] + C[j] - - After merging this becomes: :: - - for i = 0 to N - A[i] = B[i] + C[i] - D[i] = B[i] + C[i] - - And finally, after simplification (i.e. after ``simplify`` is applied): :: - - for i = 0 to N - A[i] = B[i] + C[i] - D[i] = A[i] - """ - for loop in merged_loops: - to_replace = {} - for stmt in loop.body: - ast_replace(stmt, to_replace, copy=True) - if not isinstance(stmt, AugmentedAssign): - to_replace[stmt.rvalue] = stmt.lvalue - - def merge(self, root): - """Merge perfect loop nests in ``root``.""" - - # Make sure there are no empty loops within root, otherwise kill them - remove_empty_loops(root) - - expr_graph = ExpressionGraph(root) - - # Collect iteration spaces visiting the tree rooted in /root/ - found_nests = OrderedDict() - for n in root.children: - if isinstance(n, For): - retval = FindLoopNests.default_retval() - loops_infos = FindLoopNests().visit(n, parent=root, ret=retval) - for li in loops_infos: - loops, loops_parents = zip(*li) - # Note that only inner loops can be fused, and that they must - # share the same parent node - key = (tuple(l.header for l in loops), loops_parents[-1]) - found_nests.setdefault(key, []).append(loops[-1]) - - all_merged, merged_loops = [], [] - # A perfect loop nest L1 is mergeable in a loop nest L2 if - # 1 - their iteration space is identical; implicitly true because the keys, - # in the dictionary, are iteration spaces. - # 2 - between the two nests, there are no statements that read/write values - # computed in L1. This is checked later - # 3 - there are no read-after-write dependencies between variables written - # in L1 and read in L2. This is checked later - # In the following, convention is that L2 = /merging_in/, L1 = /l/ - for (itspace, parent), loop_nests in found_nests.items(): - if len(loop_nests) == 1: - # At least two loops are necessary for merging to be meaningful - continue - mergeable = [] - merging_in = loop_nests[-1] - retval = SymbolModes.default_retval() - merging_in_reads = SymbolModes().visit(merging_in.body, ret=retval) - merging_in_reads = [s for s, m in merging_in_reads.items() if m[0] == READ] - for l in loop_nests[:-1]: - is_mergeable = True - # Get the symbols written in /l/ - l_writes = SymbolModes().visit(l.body, ret=SymbolModes.default_retval()) - l_writes = [s for s, m in l_writes.items() if m[0] == WRITE] - - # Check condition 2 - # Get the symbols written between loop /l/ (excluded) and loop - # merging_in (excluded) - bound_left = parent.children.index(l)+1 - bound_right = parent.children.index(merging_in) - for n in parent.children[bound_left:bound_right]: - in_writes = SymbolModes().visit(n, ret=SymbolModes.default_retval()) - in_writes = [s for s, m in in_writes.items()] - for iw, lw in product(in_writes, l_writes): - if expr_graph.is_written(iw, lw): - is_mergeable = False - break - - # Check condition 3 - for lw, mir in product(l_writes, merging_in_reads): - if lw.symbol == mir.symbol and not lw.rank and not mir.rank: - is_mergeable = False - break - - # Track mergeable loops - if is_mergeable: - mergeable.append(l) - - # If there is at least one mergeable loops, do the merging - for l in reversed(mergeable): - self._merge_loops(parent, l, merging_in) - # Update the lists of merged loops - all_merged.append((mergeable, merging_in)) - merged_loops.append(merging_in) - - # Reuse temporaries in merged loops - self._simplify(merged_loops) - - return all_merged - - -class ExpressionFissioner(LoopScheduler): - - """Split expressions embedded in a loop nest.""" - - def __init__(self, **kwargs): - """Initialize the ExpressionFissioner. - - :arg kwargs: - * cut: the number of operands an expression should be fissioned into - * match: a list of subexpressions that should be cut from the input - expression. ``cut`` is ignored if ``match`` is provided. - * loops: a value in ['all', 'expr', 'none']. 'all' means that an - expression is split and its "chunks" are placed in separate loop - nests. 'expr' implies that the chunks are placed within the non - linear loops sorrounding the expression. 'none' means that all - chunks are simply placed within the orginal loop nest - * perfect: if True, create perfect loop nests. This means that any - new loop nest in which a chunk is placed is purged from any extra - statement (apart, obviously, from the chunk itself) - """ - self.cut = kwargs.get('cut', -1) - self.match = [str(i) for i in kwargs.get('match', [])] - self.loops = kwargs.get('loops', 'expr') - self.perfect = kwargs.get('perfect', False) - - if 'match' in kwargs: - self.cutter = self.CutterMatch(self) - elif self.cut > 0: - self.cutter = self.CutterSum(self) - else: - raise RuntimeError("Must specify a `cut` or a `match`.") - - class Cutter(object): - - def __init__(self, expr_fissioner): - self.expr_fissioner = expr_fissioner - - def cut(self, node): - """ - Split ``node`` into /two halves/, called /split/ and /remainder/ - - For example, consider the expression a*b + c*d; if the expression is cut - into chunks containing only one operand (i.e., self.cut=1), then we have - precisely two chunks, /split/ = a*b, /remainder/ = c*d - - If the input expression is a*b + c*d + e*f, and still self.cut=1, then we - have two chunks, /split/ = a*b, /remainder/ = c*d + e*f; that is, - /remainder/ always contains the subexpression after the fission point - """ - self._success = False - left = dcopy(node) - self._cut(left.children[1], left, 'split') - - self._success = False - right = dcopy(node) - self._cut(right.children[1], right, 'remainder') - - return left, right - - class CutterSum(Cutter): - - def _cut(self, node, parent, side, topsum=None): - if isinstance(node, (Symbol, FunCall, Ternary)): - return 0 - - elif isinstance(node, Div): - return self._cut(node.children[0], node, side, topsum) - - elif isinstance(node, Prod): - if topsum: - return 0 - if self._cut(node.left, node, side, topsum) == 0: - return self._cut(node.right, node, side, topsum) - # Prods zero the sum/sub counter - return 0 - - elif isinstance(node, (Sum, Sub)): - topsum = topsum or (parent, parent.children.index(node)) - counter = 1 - counter += self._cut(node.left, node, side, topsum) - counter += self._cut(node.right, node, side, topsum) - if not self._success and counter >= self.expr_fissioner.cut: - # We now are on the topleft sum of this sub-expression such - # that enough sum/sub have been encountered - if not parent: - return 0 - self._success = True - if side == 'split': - topsum[0].children[topsum[1]] = node.left - else: - right = Neg(node.right) if isinstance(node, Sub) else node.right - parent.children[parent.children.index(node)] = right - return counter - else: - return counter - - else: - raise UnexpectedNode("Fission: %s" % str(node)) - - def cut(self, node, expr_info): - left, right = ExpressionFissioner.Cutter.cut(self, node) - if self._success: - index = expr_info.parent.children.index(node) - - # Append /left/ to the original loop nest - expr_info.parent.children[index] = left - split = (left, copy_metaexpr(expr_info)) - - # Append /right/ ... - if self.expr_fissioner.loops in ['expr', 'all']: - # ... in a new loop nest ... - right_info = self.expr_fissioner._embedexpr(right, expr_info) - else: - # ... to the original loop nest - expr_info.parent.children.insert(index, right) - right_info = copy_metaexpr(expr_info) - splittable = (right, right_info) - - return (split, splittable) - return ((node, expr_info), ()) - - class CutterMatch(Cutter): - - def __init__(self, expr_fissioner): - ExpressionFissioner.Cutter.__init__(self, expr_fissioner) - self.matched = [] - - def _cut(self, node, parent, side, topsum=None): - if not self._success and str(node) in self.expr_fissioner.match: - # We initially assume that the found 'match' corresponds - # to the entire node provided as input to the /CutterMatch/. - # Recurring back, we might switch /_success/ to 'match_and_cut', - # if /node/ actually was a summand of a Sum/Sub - self._success = 'match' - return node - - elif isinstance(node, (Symbol, FunCall)): - return None - - elif isinstance(node, Div): - return self._cut(node.left, node, side) - - elif isinstance(node, Prod): - cutting = self._cut(node.left, node, side) - if cutting: - # Found a match /within/ /node.left/; for correctness, we - # need to be sure we will be cutting the whole Prod, so we - # return /node/ instead of /cutting/. - return node - cutting = self._cut(node.right, node, side) - if cutting: - # Same as above - return node - return None - - elif isinstance(node, (Sum, Sub)): - topsum = topsum or (parent, parent.children.index(node)) - # Find out if one of the two children is cuttable - cutting = self._cut(node.left, node, side, topsum) - if cutting and side == 'remainder': - # Need to swap - cutting = node.right - elif not cutting: - cutting = self._cut(node.right, node, side, topsum) - if cutting and side == 'remainder': - # Need to swap - cutting = node.left - if not cutting: - return None - # Adjust if a Sub - if isinstance(node, Sub) and cutting == node.right: - cutting = Neg(cutting) - self._success = 'match_and_cut' - if side == 'split': - # In a tree of Sum/Subs, only the /top/ Sum/Sub performs the - # actual cut, while the others just propagate upwards the - # notification "a cut point was found" - if parent == topsum[0]: - topsum[0].children[topsum[1]] = cutting - return parent - else: - return cutting - else: - parent.children[parent.children.index(node)] = cutting - return None - - else: - raise UnexpectedNode("Fission: %s" % str(node)) - - def cut(self, node, expr_info): - left, right = ExpressionFissioner.Cutter.cut(self, node) - - if self._success == 'match_and_cut': - # Append /left/ to a new loop nest - split = (left, self.expr_fissioner._embedexpr(left, expr_info)) - self.matched.append(left) - - # Append /right/ to the original loop nest - index = expr_info.parent.children.index(node) - expr_info.parent.children[index] = right - splittable = (right, copy_metaexpr(expr_info)) - return (split, splittable) - - elif self._success == 'match': - # A match was actualy found, but there's just nothing to cut - # (i.e., the /match/ is a direct child of /node/) - self.matched.append(node) - - return ((node, expr_info), ()) - - def _embedexpr(self, stmt, expr_info): - """Build a loop nest for ``stmt`` and return its :class:`MetaExpr` object.""" - if self.loops == 'none': - return copy_metaexpr(expr_info) - - # Handle the linear loops - linear_loops = ItSpace(mode=2).to_for(expr_info.linear_loops, stmts=[stmt]) - linear_outerloop = linear_loops[0] - - # Handle the out-linear loops - if self.loops == 'all' and expr_info.out_linear_loops_info: - out_linear_loop, out_linear_loop_parent = expr_info.out_linear_loops_info[0] - index = out_linear_loop.body.index(expr_info.linear_loops[0]) - out_linear_loop = dcopy(out_linear_loop) - if self.perfect: - out_linear_loop.body[:] = [linear_outerloop] - else: - out_linear_loop.body[index] = linear_outerloop - out_linear_loops_info = ((out_linear_loop, out_linear_loop_parent),) - linear_outerloop_parent = out_linear_loop.children[0] - else: - out_linear_loops_info = expr_info.out_linear_loops_info - linear_outerloop_parent = expr_info.linear_loops_parents[0] - - # Build new loops info - finder, env = FindLoopNests(), {'node_parent': linear_outerloop_parent} - loops_info = out_linear_loops_info - loops_info += tuple(finder.visit(linear_outerloop, env=env)[0]) - - # Append the newly created loop nest - if self.loops == 'all' and expr_info.out_linear_loops_info: - expr_info.outermost_parent.children.append(out_linear_loop) - else: - linear_outerloop_parent.children.append(linear_outerloop) - - # Finally, create and return the MetaExpr object - parent = loops_info[-1][0].children[0] - return copy_metaexpr(expr_info, parent=parent, loops_info=loops_info) - - @property - def matched(self): - return self.cutter.matched if self.match else [] - - def fission(self, stmt, expr_info): - """Split, or fission, an expression ``stmt``, whose metadata are provided - through ``expr_info``. - - Return a dictionary mapping expression chunks to :class:`MetaExpr` objects. - - :arg stmt: the expression to be fissioned - :arg expr_info: ``MetaExpr`` object describing ``stmt`` - """ - exprs = OrderedDict() - splittable = (stmt, expr_info) - while splittable: - split, splittable = self.cutter.cut(*splittable) - exprs[split[0]] = split[1] - return exprs - - -class ZeroRemover(LoopScheduler): - - """Analyze data dependencies and iteration spaces to remove arithmetic - operations in loops that iterate over zero-valued blocks. Consequently, - loop nests can be fissioned and/or merged. For example: :: - - for i = 0, N - A[i] = C[i]*D[i] - B[i] = E[i]*F[i] - - If the evaluation of A requires iterating over a block of zero (0.0) values, - because for instance C and D are block-sparse, then A is evaluated in a - different, smaller (i.e., with less iterations) loop nest: :: - - for i = 0 < (N-k) - A[i+k] = C[i+k][i+k] - for i = 0, N - B[i] = E[i]*F[i] - - The implementation is based on symbolic execution. Control flow is not - admitted. - """ - - THRESHOLD = 1 # Only skip if there more than THRESHOLD consecutive zeros - - def __init__(self, exprs, hoisted): - """Initialize the ZeroRemover. - - :param exprs: the expressions for which zero removal is performed. - :param hoisted: dictionary that tracks hoisted sub-expressions - """ - self.exprs = exprs - self.hoisted = hoisted - - def _track_nz_expr(self, node, nz_syms, nest): - """For the expression rooted in ``node``, return iteration space and - offset required to iterate over non zero-valued blocks. For example: :: - - for i = 0 to N - for j = 0 to N - A[i][j] = B[i]*C[j] - - If B along `i` is non-zero in ranges [0, k1] and [k2, k3], while C along - `j` is non-zero in range [N-k4, N], return the intersection of the non-zero - regions as: :: - - [(('i', k1, 0), ('j', N-(N-k4), N-k4))), - (('i', k3-k2, k2), ('j', N-(N-k4), N-k4))] - - That is, for each iteration space variable, return a list of 2-tuples, - in which the first entry represents the size of the iteration space, - and the second entry represents the offset in memory to access the - correct values. - """ - - if isinstance(node, Symbol): - itspace = [] - def_itspace = [tuple((l.dim, Region(l.size, 0)) for l, p in nest)] - nz_bounds = zip(*nz_syms.get(node.symbol, [])) - for i, (r, o, nz_bs) in enumerate(zip(node.rank, node.offset, nz_bounds)): - if o[0] != 1 or isinstance(o[1], str) or is_const_dim(r): - # Cannot handle jumps, non-integer offsets, or constant accesses - continue - try: - # Am I tracking the loop with iteration variable == /r/ ? - loop = [l for l, p in nest if l.dim == r][0] - except IndexError: - # No, so I just assume it covers the entire non zero-valued region - itspace.append([(r, nz_b) for nz_b in nz_bs]) - continue - # Now I can intersect the loop's iteration space with the non - # zero-valued regions - offset = o[1] - r_region = [] - for nz_b in nz_bs: - nz_b_size, nz_b_offset = nz_b - end = nz_b_size + nz_b_offset - start = max(offset, nz_b_offset) - r_offset = start - offset - r_size = max(min(offset + loop.size, end) - start, 0) - r_region.append((r, Region(r_size, r_offset))) - itspace.append(r_region) - itspace = list(zip(*itspace)) or def_itspace - return itspace - - elif isinstance(node, FunCall): - return self._track_nz_expr(node.children[0], nz_syms, nest) - - elif isinstance(node, Ternary): - raise ControlFlowError - - else: - itspace_l = self._track_nz_expr(node.left, nz_syms, nest) - itspace_r = self._track_nz_expr(node.right, nz_syms, nest) - itspace = OrderedDict() - for l in itspace_l: - for i, region in l: - itspace.setdefault(i, []).append(region) - asdict = OrderedDict() - for r in itspace_r: - for i, region in r: - asdict.setdefault(i, []).append(region) - itspace_r = asdict - for i, region in itspace_r.items(): - if i not in itspace: - itspace[i] = region - elif isinstance(node, (Prod, Div)): - result = [] - for j in product(itspace[i], region): - # Products over zero-valued regions are ininfluent - result += [ItSpace(mode=1).intersect(j)] - itspace[i] = result - elif isinstance(node, (Sum, Sub)): - # Sums over zeros remove the zero-valued region (in other words, - # the non zero-valued regions get /merged/) - itspace[i] = ItSpace(mode=1).merge(itspace[i] + region) - else: - raise UnexpectedNode("Zero-avoidance: %s", str(node)) - itspace = list(set(tuple(zip(itspace, i)) - for i in product(*itspace.values()))) - return itspace - - def _track_nz_blocks(self, node, nz_syms, nz_info, nest=None, parent=None, candidates=None): - """Track the propagation of zero-valued blocks in the AST rooted in ``node`` - - ``nz_syms`` contains, for each known identifier, the ranges of - its non zero-valued blocks. For example, assuming identifier A is an - array and has non-zero values in positions [0, k] and [N-k, N], then - ``nz_syms`` will contain an entry {"A": ((0, k), (N-k, N))}. - If A is modified by some statements rooted in ``node``, then - ``nz_syms["A"]`` will be modified accordingly. - - This method also populates ``nz_info``, which maps loop nests to the - enclosed symbols' non-zero blocks. For example, given the following - code: :: - - { // root - ... - for i - for j - A = ... - B = ... - } - - After the traversal of the AST, the ``nz_info`` dictionary will look like: :: - - ((, ), root) -> {A: (i, (nz_along_i)), (j, (nz_along_j))} - - """ - if isinstance(node, Writer): - sym, expr = node.children - - # Outer, non-perfect loops are discarded for transformation safety - # as splitting (a consequence of zero-removal) non-perfect nests is unsafe - nest = tuple([(l, p) for l, p in (nest or []) if is_perfect_loop(l)]) - if not nest: - return - - if nest[-1][0] not in candidates: - return - - # Track the propagation of non zero-valued blocks: ... - # ... within the rvalue - itspaces = self._track_nz_expr(expr, nz_syms, nest) - for i in itspaces: - # ... and then through the lvalue (merging overlaps) - nz_expr = tuple(dict(i).get(r) for r in sym.rank if not is_const_dim(r)) - if any(j is None for j in nz_expr): - break - nz_node = list(nz_syms.setdefault(sym.symbol, [nz_expr])) - if not nz_expr: - continue - merged = False - for e, j in enumerate(nz_node): - # Merging condition: complete overlap in all dimensions but - # the innermost one, for which partial overlap is accepted - inner_merge = ItSpace(mode=1).merge([nz_expr[-1], j[-1]]) - if len(inner_merge) == 1 and \ - all(ItSpace(mode=1).intersect([m, n]) == m for m, n in - zip(nz_expr[:-1], j[:-1])): - nz_syms[sym.symbol][e] = j[:-1] + tuple(inner_merge) - merged = True - break - if not merged: - nz_syms[sym.symbol].append(nz_expr) - - # Record loop nest bounds and memory offsets for /node/ - dims = [l.dim for l, p in nest] - itspaces = [tuple(j for j in i if j[0] in dims) for i in itspaces] - nz_info.setdefault(nest, []).append((node, itspaces)) - - elif isinstance(node, For): - new_nest = (nest or []) + [(node, parent)] - self._track_nz_blocks(node.children[0], nz_syms, nz_info, new_nest, - node, candidates) - - elif isinstance(node, (Root, Block)): - for n in node.children: - self._track_nz_blocks(n, nz_syms, nz_info, nest, node, candidates) - - else: - raise ControlFlowError - - def _reschedule_itspace(self, root, candidates, decls): - """Consider two statements A and B, and their iteration space. If the - two iteration spaces have - - * Same size and same bounds, then put A and B in the same loop nest: :: - - for i, for j - W1[i][j] = W2[i][j] - Z1[i][j] = Z2[i][j] - - * Same size but different bounds, then put A and B in the same loop - nest, but add suitable offsets to all of the involved iteration - variables: :: - - for i, for j - W1[i][j] = W2[i][j] - Z1[i+k][j+k] = Z2[i+k][j+k] - - * Different size, then put A and B in two different loop nests: :: - - for i, for j - W1[i][j] = W2[i][j] - for i, for j // Different loop bounds - Z1[i][j] = Z2[i][j] - - A dictionary describing the structure of the new iteration spaces is - returned. - """ - nz_info = OrderedDict() - - # Compute the initial sparsity pattern of the symbols in /root/ - nz_syms = defaultdict(list) - for s, d in decls.items(): - if not d.nonzero: - continue - for nz_b in product(*d.nonzero): - entries = [list(range(i.ofs, i.ofs + i.size)) for i in nz_b] - if not np.all(d.init.values[np.ix_(*entries)] == 0.0): - nz_syms[s].append(nz_b) - - # Track the propagation of non zero-valued blocks through symbolic - # execution. This populates /nz_info/ and updates /nz_syms/ - try: - self._track_nz_blocks(root, nz_syms, nz_info, candidates=candidates) - except ControlFlowError: - # Couldn't perform symbolic execution due to runtime-dependent data - return nz_syms, OrderedDict() - - # At this point we know where non-zero blocks are located, so we have - # to create proper loop nests to access them - new_exprs, new_nz_info = OrderedDict(), OrderedDict() - for nest, stmt_itspaces in nz_info.items(): - loops, loops_parents = zip(*nest) - fissioned_nests = defaultdict(list) - # Fission the nest to get rid of computation over zero-valued blocks - for stmt, itspaces in stmt_itspaces: - sym, expr = stmt.children - # For each non zero-valued region iterated over... - for i in itspaces: - dim_offset = {d: o for d, (sz, o) in i} - dim_size = tuple(((0, dict(i)[l.dim][0]), l.dim) for l in loops) - # ...add an offset to /stmt/ to access the correct values - new_stmt = ast_update_ofs(dcopy(stmt), dim_offset, increase=True) - # ...add /stmt/ to a new, shorter loop nest - fissioned_nests[dim_size].append((new_stmt, dim_offset)) - # ...initialize arrays to 0.0 for correctness - if sym.symbol in self.hoisted: - self.hoisted[sym.symbol].decl.init = ArrayInit(np.array([0.0])) - # ...track fissioned expressions - if stmt in self.exprs: - new_exprs[new_stmt] = self.exprs[stmt] - # Generate the fissioned loop nests - # Note: the dictionary is sorted because smaller loop nests should - # be executed first, since larger ones depend on them - for dim_size, stmt_dim_offsets in sorted(fissioned_nests.items()): - if all([sz == (0, 0) for sz, dim in dim_size]): - # Discard empty loop nests - continue - # Create the new loop nest ... - new_loops = ItSpace(mode=0).to_for(*zip(*dim_size)) - for stmt, _ in stmt_dim_offsets: - # ... populate it - new_loops[-1].body.append(stmt) - # ... and update tracked data - if stmt in new_exprs: - new_nest = list(zip(new_loops, loops_parents)) - new_exprs[stmt] = copy_metaexpr(new_exprs[stmt], - parent=new_loops[-1].body, - loops_info=new_nest) - self.hoisted.update_stmt(stmt.children[0].symbol, - loop=new_loops[0], - place=loops_parents[0]) - new_nz_info[tuple(new_loops)] = stmt_dim_offsets - # Append the new loops to the root - insert_at_elem(loops_parents[0].children, loops[0], new_loops[0]) - loops_parents[0].children.remove(loops[0]) - - self.exprs.clear() - self.exprs.update(new_exprs) - return nz_syms, new_nz_info - - def _recombine(self, nz_info): - """Recombine expressions writing to the same lvalue.""" - new_exprs = OrderedDict() - ops = {Incr: Sum, Decr: Sub, IMul: Prod} - - for nest, stmt_dim_offsets in nz_info.items(): - mapper = OrderedDict() - for stmt, dim_offsets in stmt_dim_offsets: - sym, expr = stmt.children - if type(stmt) in ops: - # The /key/ means: I'm in the same loop nest, I'm writing to - # the same symbol, and in particular to the same symbol - # locations, and I'm doing an associative AugmentedAssignment. - key = (str(sym), type(stmt)) - mapper.setdefault(key, []).append(stmt) - - for (_, op), stmts in mapper.items(): - exprs = [i.children[1] for i in stmts] - for i in stmts: - nest[-1].body.remove(i) - stmt = op(i.children[0], ast_make_expr(ops[op], exprs)) - nest[-1].body.append(stmt) - # Update the tracked expressions, if necessary - if all(i in self.exprs for i in stmts): - new_exprs[stmt] = self.exprs[i] - - for stmt, expr_info in new_exprs.items(): - ew = ExpressionRewriter(stmt, expr_info) - ew.factorize('heuristic') - - if new_exprs: - self.exprs.clear() - self.exprs.update(new_exprs) - - def _should_skip(self, zero_decls): - """Return False if, based on heuristics, it seems worth skipping the - computation over zeros, True otherwise. True is returned if it - is thought that the implications on low-level performance would be - worse than the gain in operation count (e.g., because spatial locality - within loop would go lost).""" - - if not zero_decls: - return True - - for d in zero_decls: - for d_dim in d.nonzero: - if all(size < ZeroRemover.THRESHOLD for size, offset in d_dim): - return True - - return False - - def reschedule(self, root): - """Restructure the loop nests in ``root`` to avoid computation over - zero-valued data spaces. This is achieved through symbolic execution - starting from ``root``. Control flow, in the form of If, Switch, etc., - is forbidden.""" - decls = visit(root, info_items=['decls'])['decls'] - - # Avoid rescheduling if zero-valued blocks are too small - zero_decls = [d for d in decls.values() if d.nonzero] - if self._should_skip(zero_decls): - return {} - - # Determine the analyzable loops (inner loops in which statements have no - # read-after-write dependencies) - linear_expr_loops = [(l for l in ei.linear_loops) for ei in self.exprs.values()] - linear_expr_loops = set(flatten(linear_expr_loops)) - candidates = [l for l in inner_loops(root) - if not l.is_linear or l in linear_expr_loops] - candidates = [l for l in candidates - if not ExpressionGraph(l.body).has_dependency()] - if not candidates: - return {} - - if linear_expr_loops & set(candidates): - # Split the main expressions to maximize the impact of the rescheduling (this - # helps if different summands have zero-valued blocks at different offsets) - elf = ExpressionFissioner(cut=1, loops='none') - new_exprs = {} - for stmt, expr_info in iteritems(self.exprs): - if expr_info.is_scalar: - new_exprs[stmt] = expr_info - else: - new_exprs.update(elf.fission(stmt, expr_info)) - self.exprs = new_exprs - - # Apply the rescheduling - nz_syms, nz_info = self._reschedule_itspace(root, candidates, decls) - - # Finally, "inline" the expressions that were originally split, if possible - self._recombine(nz_info) - else: - # Apply the rescheduling - nz_syms, nz_info = self._reschedule_itspace(root, candidates, decls) - - return nz_syms diff --git a/coffee/version.py b/coffee/version.py deleted file mode 100644 index 1ca68113..00000000 --- a/coffee/version.py +++ /dev/null @@ -1,4 +0,0 @@ -from __future__ import absolute_import, print_function, division - -__version_info__ = (0, 1, 0) -__version__ = '.'.join(map(str, __version_info__))