Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better sum factorisation for coefficients #81

Merged
merged 18 commits into from
Nov 30, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 47 additions & 22 deletions gem/gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,15 @@ def __new__(cls, a, b):
assert not a.shape
assert not b.shape

# Zero folding
# Constant folding
if isinstance(a, Zero):
return b
elif isinstance(b, Zero):
return a

if isinstance(a, Constant) and isinstance(b, Constant):
return Literal(a.value + b.value)

self = super(Sum, cls).__new__(cls)
self.children = a, b
return self
Expand All @@ -218,10 +221,18 @@ def __new__(cls, a, b):
assert not a.shape
assert not b.shape

# Zero folding
# Constant folding
if isinstance(a, Zero) or isinstance(b, Zero):
return Zero()

if a == one:
return b
if b == one:
return a

if isinstance(a, Constant) and isinstance(b, Constant):
return Literal(a.value * b.value)

self = super(Product, cls).__new__(cls)
self.children = a, b
return self
Expand All @@ -234,12 +245,18 @@ def __new__(cls, a, b):
assert not a.shape
assert not b.shape

# Zero folding
# Constant folding
if isinstance(b, Zero):
raise ValueError("division by zero")
if isinstance(a, Zero):
return Zero()

if b == one:
return a

if isinstance(a, Constant) and isinstance(b, Constant):
return Literal(a.value / b.value)

self = super(Division, cls).__new__(cls)
self.children = a, b
return self
Expand All @@ -258,7 +275,7 @@ def __new__(cls, base, exponent):
raise ValueError("cannot solve 0^0")
return Zero()
elif isinstance(exponent, Zero):
return Literal(1)
return one

self = super(Power, cls).__new__(cls)
self.children = base, exponent
Expand Down Expand Up @@ -448,6 +465,10 @@ def __new__(cls, aggregate, multiindex):
if isinstance(index, Index):
index.set_extent(extent)

# Empty multiindex
if not multiindex:
return aggregate

# Zero folding
if isinstance(aggregate, Zero):
return Zero()
Expand Down Expand Up @@ -553,26 +574,27 @@ def __new__(cls, expression, multiindex):


class IndexSum(Scalar):
__slots__ = ('children', 'index')
__back__ = ('index',)
__slots__ = ('children', 'multiindex')
__back__ = ('multiindex',)

def __new__(cls, summand, index):
def __new__(cls, summand, multiindex):
# Sum zeros
assert not summand.shape
if isinstance(summand, Zero):
return summand

# Sum a single expression
if index.extent == 1:
return Indexed(ComponentTensor(summand, (index,)), (0,))
# No indices case
multiindex = tuple(multiindex)
if not multiindex:
return summand

self = super(IndexSum, cls).__new__(cls)
self.children = (summand,)
self.index = index
self.multiindex = multiindex

# Collect shape and free indices
assert index in summand.free_indices
self.free_indices = unique(set(summand.free_indices) - {index})
assert set(multiindex) <= set(summand.free_indices)
self.free_indices = unique(set(summand.free_indices) - set(multiindex))

return self

Expand Down Expand Up @@ -641,7 +663,7 @@ def __new__(cls, i, j):

# \delta_{i,i} = 1
if i == j:
return Literal(1)
return one

# Fixed indices
if isinstance(i, int) and isinstance(j, int):
Expand Down Expand Up @@ -714,14 +736,13 @@ def unique(indices):
return tuple(sorted(set(indices), key=id))


def index_sum(expression, index):
"""Eliminates an index from the free indices of an expression by
summing over it. Returns the expression unchanged if the index is
not a free index of the expression."""
if index in expression.free_indices:
return IndexSum(expression, index)
else:
return expression
def index_sum(expression, indices):
"""Eliminates indices from the free indices of an expression by
summing over them. Skips any index that is not a free index of
the expression."""
multiindex = tuple(index for index in indices
if index in expression.free_indices)
return IndexSum(expression, multiindex)


def partial_indexed(tensor, indices):
Expand Down Expand Up @@ -764,3 +785,7 @@ def reshape(variable, *shapes):
dim2idxs.append((0, tuple(idxs)))
expr = FlexiblyIndexed(variable, tuple(dim2idxs))
return ComponentTensor(expr, tuple(indices))


# Static one object for quicker constant folding
one = Literal(1)
11 changes: 11 additions & 0 deletions gem/impero.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,17 @@ class For(Node):
__slots__ = ('index', 'children')
__front__ = ('index',)

def __new__(cls, index, statement):
# In case of an empty loop, create a Noop instead.
# Related: https://github.com/coneoproject/COFFEE/issues/98
assert isinstance(statement, Block)
if not statement.children:
# This "works" because the loop_shape of this node is not
# asked any more.
return Noop(None)
else:
return super(For, cls).__new__(cls)

def __init__(self, index, statement):
self.index = index
self.children = (statement,)
160 changes: 150 additions & 10 deletions gem/optimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,22 @@
expressions."""

from __future__ import absolute_import, print_function, division
from six.moves import map
from six import itervalues
from six.moves import map, zip

from collections import OrderedDict, deque
from functools import reduce
from itertools import permutations

import numpy
from singledispatch import singledispatch

from gem.node import Memoizer, MemoizerArg, reuse_if_untouched, reuse_if_untouched_arg
from gem.gem import (Node, Terminal, Failure, Identity, Literal, Zero,
Sum, Comparison, Conditional, Index,
Product, Sum, Comparison, Conditional, Index,
VariableIndex, Indexed, FlexiblyIndexed,
IndexSum, ComponentTensor, ListTensor, Delta,
partial_indexed)
partial_indexed, one)


@singledispatch
Expand Down Expand Up @@ -190,6 +195,121 @@ def select_expression(expressions, index):
return ComponentTensor(selected, alpha)


def delta_elimination(sum_indices, factors):
"""IndexSum-Delta cancellation.

:arg sum_indices: free indices for contractions
:arg factors: product factors
:returns: optimised (sum_indices, factors)
"""
sum_indices = list(sum_indices) # copy for modification

delta_queue = [(f, index)
for f in factors if isinstance(f, Delta)
for index in (f.i, f.j) if index in sum_indices]
while delta_queue:
delta, from_ = delta_queue[0]
to_, = list({delta.i, delta.j} - {from_})

sum_indices.remove(from_)

mapper = MemoizerArg(filtered_replace_indices)
factors = [mapper(e, ((from_, to_),)) for e in factors]

delta_queue = [(f, index)
for f in factors if isinstance(f, Delta)
for index in (f.i, f.j) if index in sum_indices]

# Drop ones
return sum_indices, [e for e in factors if e != one]


def sum_factorise(sum_indices, factors):
"""Optimise a tensor product throw sum factorisation.

:arg sum_indices: free indices for contractions
:arg factors: product factors
:returns: optimised GEM expression
"""
if len(sum_indices) > 5:
raise NotImplementedError("Too many indices for sum factorisation!")

# Form groups by free indices
groups = OrderedDict()
for factor in factors:
groups.setdefault(factor.free_indices, []).append(factor)
groups = [reduce(Product, terms) for terms in itervalues(groups)]

# Sum factorisation
expression = None
best_flops = numpy.inf

# Consider all orderings of contraction indices
for ordering in permutations(sum_indices):
terms = groups[:]
flops = 0
# Apply contraction index by index
for sum_index in ordering:
# Select terms that need to be part of the contraction
contract = [t for t in terms if sum_index in t.free_indices]
deferred = [t for t in terms if sum_index not in t.free_indices]

# A further optimisation opportunity is to consider
# various ways of building the product tree.
product = reduce(Product, contract)
term = IndexSum(product, (sum_index,))
# For the operation count estimation we assume that no
# operations were saved with the particular product tree
# that we built above.
flops += len(contract) * numpy.prod([i.extent for i in product.free_indices], dtype=int)

# Replace the contracted terms with the result of the
# contraction.
terms = deferred + [term]

# If some contraction indices were independent, then we may
# still have several terms at this point.
expr = reduce(Product, terms)
flops += (len(terms) - 1) * numpy.prod([i.extent for i in expr.free_indices], dtype=int)

if flops < best_flops:
expression = expr
best_flops = flops

return expression


def contraction(expression):
"""Optimise the contractions of the tensor product at the root of
the expression, including:

- IndexSum-Delta cancellation
- Sum factorisation

This routine was designed with finite element coefficient
evaluation in mind.
"""
# Eliminate annoying ComponentTensors
expression, = remove_componenttensors([expression])

# Flatten a product tree
sum_indices = []
factors = []

queue = deque([expression])
while queue:
expr = queue.popleft()
if isinstance(expr, IndexSum):
queue.append(expr.children[0])
sum_indices.extend(expr.multiindex)
elif isinstance(expr, Product):
queue.extend(expr.children)
else:
factors.append(expr)

return sum_factorise(*delta_elimination(sum_indices, factors))


@singledispatch
def _replace_delta(node, self):
raise AssertionError("cannot handle type %s" % type(node))
Expand Down Expand Up @@ -222,13 +342,13 @@ def expression(index):
raise ValueError("Cannot convert running index to expression.")
e_i = expression(i)
e_j = expression(j)
return Conditional(Comparison("==", e_i, e_j), Literal(1), Zero())
return Conditional(Comparison("==", e_i, e_j), one, Zero())


def replace_delta(expressions):
"""Lowers all Deltas in a multi-root expression DAG."""
mapper = Memoizer(_replace_delta)
return map(mapper, expressions)
return list(map(mapper, expressions))


@singledispatch
Expand All @@ -246,13 +366,18 @@ def _unroll_indexsum(node, self):

@_unroll_indexsum.register(IndexSum) # noqa
def _(node, self):
if node.index.extent <= self.max_extent:
unroll = tuple(index for index in node.multiindex
if index.extent <= self.max_extent)
if unroll:
# Unrolling
summand = self(node.children[0])
return reduce(Sum,
(Indexed(ComponentTensor(summand, (node.index,)), (i,))
for i in range(node.index.extent)),
Zero())
shape = tuple(index.extent for index in unroll)
unrolled = reduce(Sum,
(Indexed(ComponentTensor(summand, unroll), alpha)
for alpha in numpy.ndindex(shape)),
Zero())
return IndexSum(unrolled, tuple(index for index in node.multiindex
if index not in unroll))
else:
return reuse_if_untouched(node, self)

Expand All @@ -267,3 +392,18 @@ def unroll_indexsum(expressions, max_extent):
mapper = Memoizer(_unroll_indexsum)
mapper.max_extent = max_extent
return list(map(mapper, expressions))


def aggressive_unroll(expression):
"""Aggressively unrolls all loop structures."""
# Unroll expression shape
if expression.shape:
tensor = numpy.empty(expression.shape, dtype=object)
for alpha in numpy.ndindex(expression.shape):
tensor[alpha] = Indexed(expression, alpha)
expression, = remove_componenttensors((ListTensor(tensor),))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This converts a numpy array to a GEM object.

As well as this line...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see...


# Unroll summation
expression, = unroll_indexsum((expression,), max_extent=numpy.inf)
expression, = remove_componenttensors((expression,))
return expression
2 changes: 1 addition & 1 deletion tests/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_loop_fusion():

def make_expression(i, j):
A = Variable('A', (6,))
s = IndexSum(Indexed(A, (j,)), j)
s = IndexSum(Indexed(A, (j,)), (j,))
return Product(Indexed(A, (i,)), s)

e1 = make_expression(i, j)
Expand Down
Loading