Skip to content

Commit

Permalink
Implement basic TensorFlow meta objects
Browse files Browse the repository at this point in the history
This commit provides basic TensorFlow graph interop.  There are still a few
design questions and choices to iterate on (e.g. the TF namespace/graph used
during intermediate base-object derived meta object construction steps), but the
basic class wrappers, helper functions, term representation and unification are
working.

Closes pymc-devs#3.
  • Loading branch information
brandonwillard committed Mar 20, 2019
1 parent 336eb83 commit 8e5a43f
Show file tree
Hide file tree
Showing 16 changed files with 1,173 additions and 251 deletions.
218 changes: 116 additions & 102 deletions symbolic_pymc/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,56 @@

from itertools import chain
from functools import partial
from collections.abc import Iterator
from collections.abc import Iterator, Mapping

from unification import isvar, Var

from .utils import _check_eq


# TODO: Replace `from_obj` with a generic function?
# from multipledispatch import dispatch

meta_repr = reprlib.Repr()
meta_repr.maxstring = 100
meta_repr.maxother = 100


def _make_hashable(x):
if isinstance(x, list):
return tuple(x)
elif isinstance(x, np.ndarray):
return x.data.tobytes()
else:
return x


def _meta_reify_iter(rands):
"""Recursively reify an iterable object and return a boolean indicating the
presence of un-reifiable objects, if any.
"""
# We want as many of the rands reified as possible,
any_unreified = False
reified_rands = []
for s in rands:
if isinstance(rands, Mapping):
_rands = rands.items()
else:
_rands = rands

for s in _rands:
if isinstance(s, MetaSymbol):
rrand = s.reify()
reified_rands += [rrand]
reified_rands.append(rrand)
any_unreified |= isinstance(rrand, MetaSymbol)
any_unreified |= isvar(rrand)
elif MetaSymbol.is_meta(s):
reified_rands += [s]
reified_rands.append(s)
any_unreified |= True
elif isinstance(s, (list, tuple)):
_reified_rands, _any_unreified = _meta_reify_iter(s)
reified_rands += [type(s)(_reified_rands)]
reified_rands.append(type(s)(_reified_rands))
any_unreified |= _any_unreified
else:
reified_rands += [s]

return reified_rands, any_unreified
return type(rands)(reified_rands), any_unreified


class MetaSymbolType(abc.ABCMeta):
Expand All @@ -58,16 +72,17 @@ def __new__(cls, name, bases, clsdict):

def __setattr__(self, attr, obj):
"""If a slot value is changed, discard any associated non-meta/base
objects.
objects. Also, guarantee that underlying base objects cannot be
meta objects.
"""
if (getattr(self, 'obj', None) is not None and
if attr == 'obj':
if isinstance(obj, MetaSymbol):
raise ValueError('base object cannot be a meta object!')
elif (getattr(self, 'obj', None) is not None and
not isinstance(self.obj, Var) and
attr in getattr(self, '__all_slots__', {}) and
hasattr(self, attr) and getattr(self, attr) != obj):
self.obj = None
elif attr == 'obj':
if isinstance(obj, MetaSymbol):
raise ValueError('base object cannot be a meta object!')

object.__setattr__(self, attr, obj)

Expand All @@ -79,18 +94,12 @@ def __setattr__(self, attr, obj):
# E.g. cls.register(bases)
return res

# @property
# @abc.abstractmethod
# def base(self):
# """The base type/rator for this meta object.
# """
# raise NotImplementedError()


class MetaSymbol(metaclass=MetaSymbolType):
"""Meta objects for unification and such.
"""
TODO: Should `MetaSymbol.obj` be an abstract property and a `weakref`?
"""
@classmethod
def base_classes(cls, mro_order=True):
res = tuple(c.base for c in cls.__subclasses__())
Expand Down Expand Up @@ -120,24 +129,40 @@ def to_base_obj(cls, obj):
def from_obj(cls, obj):
"""Create a meta object for a given base object.
XXX: Be careful when overriding this: `isvar` checks are necessary!
Generic `MetaSymbol` classes (i.e. those not modeling base
classes/types) handle the conversion of standard Python objects and
delegation to other meta object-modeling sub-classes. Otherwise, if
`from_obj` is called from a subclass of `MetaSymbol` modeling a `base`
class/type, conversion is strict (i.e. the object must be convertible
to an object of the modeled type).
XXX: Be careful when overriding this without using `super`: `isvar`
checks are necessary!
TODO: Consider replacing this with a generic function.
"""
if (cls.is_meta(obj) or obj is None or
isinstance(obj, (types.FunctionType, partial,
str, dict))):
if cls.is_meta(obj):
return obj

if isinstance(obj, (set, list, tuple, Iterator)):
# Convert elements of the iterable
return type(obj)([cls.from_obj(o) for o in obj])
if not hasattr(cls, 'base'):
if (obj is None or isinstance(
obj, (types.FunctionType, partial, str, dict))):
return obj

if isinstance(obj, (set, list, tuple, Iterator)):
# Convert elements of an iterable
return type(obj)([cls.from_obj(o) for o in obj])

if inspect.isclass(obj) and issubclass(obj, cls.base_classes()):
# This is a class/type covered by a meta class/type.
# This object is a class/type object. Let's see if there's a
# an existing meta class related to it. If there is, we return
# the corresponding meta class; otherwise, if the object is a
# subclass, we'll create a new meta subclass for it.
try:
obj_cls = next(filter(lambda t: issubclass(obj, t.base),
cls.__subclasses__()))
except StopIteration:
# The current class is the best fit.
# The current class is the best match.
if cls.base == obj:
return cls

Expand All @@ -149,11 +174,9 @@ def from_obj(cls, obj):

if not isinstance(obj, cls.base_classes()):
# We might've been given something convertible to a type with a
# meta type, so let's try that
# meta type, so let's try that.
try:
obj = cls.to_base_obj(obj)
# tt.as_tensor_variable(obj)
# tt.AsTensorError
except ValueError:
pass

Expand Down Expand Up @@ -188,7 +211,7 @@ def reify(self):
"""Create a concrete base object from this meta object (and its
rands).
"""
if self.obj and not isinstance(self.obj, Var):
if self.obj is not None and not isinstance(self.obj, Var):
return self.obj
else:
reified_rands, any_unreified = _meta_reify_iter(self.rands())
Expand All @@ -203,65 +226,54 @@ def reify(self):

return res

# def __getattr__(self, attr_name):
# """Fallback to the underlying base object, if any, when unknown
# attributes are accessed.
# """
# if attr_name != 'obj':
# return getattr(object.__getattribute__(self, 'obj'), attr_name)
# raise AttributeError(f'{attr_name} not found')
def __eq__(self, other):
"""Syntactic equality between meta objects and their bases.
"""Implements an equivalence between meta objects and their bases.
"""
# TODO: Allow a sort of cross-inheritance equivalence (e.g. a
# `tt.Variable` or `tt.TensorVariable`)?
# a_sub_b = isinstance(self, type(other))
# b_sub_a = isinstance(other, type(self))
# if not (a_sub_b or b_sub_a):
# return False
if self is other:
return True

if not (type(self) == type(other)):
return False

# TODO: ?
# Same for base objects
# a_sub_b = isinstance(self.base, type(other.base))
# b_sub_a = isinstance(other.base, type(self.base))
# if not (a_sub_b or b_sub_a):
# return False
if not (self.base == other.base):
return False

# TODO: ?
# # `self` is the super class, that might be generalizing
# # `other`
a_slots = getattr(self, '__slots__', [])
# b_slots = getattr(other, '__slots__', [])
# if (b_sub_a and not a_sub_b and
# not all(getattr(self, attr) == getattr(other, attr)
# for attr in a_slots)):
# return False
# # `other` is the super class, that might be generalizing
# # `self`
# elif (a_sub_b and not b_sub_a and
# not all(getattr(self, attr) == getattr(other, attr)
# for attr in b_slots)):
# return False
if not all(_check_eq(getattr(self, attr), getattr(other, attr))
for attr in a_slots):
a_slots = getattr(self, '__slots__', None)
if a_slots is not None:
if not all(_check_eq(getattr(self, attr),
getattr(other, attr))
for attr in a_slots):
return False
elif getattr(other, '__slots__', None) is not None:
# The other object has slots, but this one doesn't.
return False

# if (self.obj and not isvar(self.obj) and
# other.obj and not isvar(other.objj)):
# assert self.obj == other.obj
else:
# Neither have slots, so best we can do is compare
# base objects (if any).
# If there aren't base objects, we say they're not equal.
# (Maybe we should *require* base objects in this case
# and raise an exception?)
return (getattr(self, 'obj', None) == getattr(other, 'obj', None)
is not None)

return True

def __ne__(self, other):
return not self.__eq__(other)

def __hash__(self):
def _make_hashable(x):
if isinstance(x, list):
return tuple(x)
elif isinstance(x, np.ndarray):
return x.data.tobytes()
else:
return x
rands = tuple(_make_hashable(p) for p in self.rands())
return hash(rands + (self.base,))
if getattr(self, '__slots__', None) is not None:
rands = tuple(_make_hashable(p) for p in self.rands())
return hash(rands + (self.base,))
else:
return hash((self.base, self.obj))

def __str__(self):
obj = getattr(self, 'obj', None)
Expand Down Expand Up @@ -296,17 +308,19 @@ def _repr_pretty_(self, p, cycle):
p.text(name)
p.text('=')
p.pretty(item)

obj = getattr(self, 'obj', None)
if obj:
if idx:

if obj is not None:
if idx is not None:
p.text(',')
p.breakable()
p.text('obj=')
p.pretty(obj)


class MetaOp(MetaSymbol):
"""A meta object that represents a `MetaVaribale`-producing operator.
"""A meta object that represents a `MetaVariable`-producing operator.
In some cases, operators hold their own inputs and outputs
(e.g. TensorFlow), and, in others, an intermediary "application" node holds
Expand All @@ -316,21 +330,20 @@ class MetaOp(MetaSymbol):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.op_sig = inspect.signature(self.obj.make_node)

@property
def obj(self):
return self._obj
return object.__getattribute__(self, '_obj')

@obj.setter
def obj(self, x):
if hasattr(self, '_obj'):
raise ValueError('Cannot reset obj in an `Op`')
self._obj = x
object.__setattr__(self, '_obj', x)

@abc.abstractmethod
def out_meta_type(self, inputs=None):
"""Return the type of meta variable this `Op` is expected to produce
def out_meta_types(self, inputs=None):
"""Return the types of meta variables this `Op` is expected to produce
given the inputs.
"""
raise NotImplementedError()
Expand All @@ -339,22 +352,23 @@ def out_meta_type(self, inputs=None):
def __call__(self, *args, ttype=None, index=None, **kwargs):
raise NotImplementedError()

def __eq__(self, other):
# Since these have no rands/slots, we can only really compare against
# the underlying base objects (which should be there!).
if not super().__eq__(other):
return False

assert self.obj

if self.obj != other.obj:
return False

return True

def __hash__(self):
return hash((self.base, self.obj))
class MetaVariable(MetaSymbol):
@property
@abc.abstractmethod
def operator(self):
"""Return a meta object representing an operator, if any, capable of
producing this variable.
It should be callable with all inputs necessary to reproduce this
tensor given by `MetaVariable.inputs`.
"""
raise NotImplementedError()

class MetaVariable(MetaSymbol):
pass
@property
@abc.abstractmethod
def inputs(self):
"""Return the inputs necessary for `MetaVariable.operator` to produced
this variable, if any.
"""
raise NotImplementedError()
Loading

0 comments on commit 8e5a43f

Please sign in to comment.