Skip to content

Commit

Permalink
Implement basic TensorFlow meta objects
Browse files Browse the repository at this point in the history
This commit provides basic TensorFlow graph interop.  There are still a few
design questions and choices to iterate on (e.g. the TF namespace/graph used
during intermediate base-object derived meta object construction steps), but the
basic class wrappers, helper functions, term representation and unification are
working.

Closes pymc-devs#3.
  • Loading branch information
brandonwillard committed Mar 24, 2019
1 parent 0323e0e commit 5b5a6dd
Show file tree
Hide file tree
Showing 15 changed files with 1,147 additions and 212 deletions.
195 changes: 99 additions & 96 deletions symbolic_pymc/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,54 @@

from itertools import chain
from functools import partial
from collections.abc import Iterator
from collections.abc import Iterator, Mapping

from unification import isvar, Var

from .utils import _check_eq


# TODO: Replace `from_obj` with a generic function?
# from multipledispatch import dispatch

meta_repr = reprlib.Repr()
meta_repr.maxstring = 100
meta_repr.maxother = 100


def _make_hashable(x):
if isinstance(x, list):
return tuple(x)
elif isinstance(x, np.ndarray):
return x.data.tobytes()
else:
return x


def _meta_reify_iter(rands):
"""Recursively reify an iterable object and return a boolean indicating the presence of un-reifiable objects, if any."""
# We want as many of the rands reified as possible,
any_unreified = False
reified_rands = []
for s in rands:
if isinstance(rands, Mapping):
_rands = rands.items()
else:
_rands = rands

for s in _rands:
if isinstance(s, MetaSymbol):
rrand = s.reify()
reified_rands += [rrand]
reified_rands.append(rrand)
any_unreified |= isinstance(rrand, MetaSymbol)
any_unreified |= isvar(rrand)
elif MetaSymbol.is_meta(s):
reified_rands += [s]
reified_rands.append(s)
any_unreified |= True
elif isinstance(s, (list, tuple)):
_reified_rands, _any_unreified = _meta_reify_iter(s)
reified_rands += [type(s)(_reified_rands)]
reified_rands.append(type(s)(_reified_rands))
any_unreified |= _any_unreified
else:
reified_rands += [s]

return reified_rands, any_unreified
return type(rands)(reified_rands), any_unreified


class MetaSymbolType(abc.ABCMeta):
Expand All @@ -58,17 +70,17 @@ def __new__(cls, name, bases, clsdict):

def __setattr__(self, attr, obj):
"""If a slot value is changed, discard any associated non-meta/base objects."""
if (
if attr == "obj":
if isinstance(obj, MetaSymbol):
raise ValueError("base object cannot be a meta object!")
elif (
getattr(self, "obj", None) is not None
and not isinstance(self.obj, Var)
and attr in getattr(self, "__all_slots__", {})
and hasattr(self, attr)
and getattr(self, attr) != obj
):
self.obj = None
elif attr == "obj":
if isinstance(obj, MetaSymbol):
raise ValueError("base object cannot be a meta object!")

object.__setattr__(self, attr, obj)

Expand All @@ -82,7 +94,10 @@ def __setattr__(self, attr, obj):


class MetaSymbol(metaclass=MetaSymbolType):
"""Meta objects for unification and such."""
"""Meta objects for unification and such.
TODO: Should `MetaSymbol.obj` be an abstract property and a `weakref`?
"""

@classmethod
def base_classes(cls, mro_order=True):
Expand Down Expand Up @@ -113,26 +128,38 @@ def to_base_obj(cls, obj):
def from_obj(cls, obj):
"""Create a meta object for a given base object.
XXX: Be careful when overriding this: `isvar` checks are necessary!
Generic `MetaSymbol` classes (i.e. those not modeling base
classes/types) handle the conversion of standard Python objects and
delegation to other meta object-modeling sub-classes. Otherwise, if
`from_obj` is called from a subclass of `MetaSymbol` modeling a `base`
class/type, conversion is strict (i.e. the object must be convertible
to an object of the modeled type).
XXX: Be careful when overriding this without using `super`: `isvar`
checks are necessary!
TODO: Consider replacing this with a generic function.
"""
if (
cls.is_meta(obj)
or obj is None
or isinstance(obj, (types.FunctionType, partial, str, dict))
):
if cls.is_meta(obj):
return obj

if isinstance(obj, (set, list, tuple, Iterator)):
# Convert elements of the iterable
return type(obj)([cls.from_obj(o) for o in obj])
if not hasattr(cls, "base"):
if obj is None or isinstance(obj, (types.FunctionType, partial, str, dict)):
return obj

if isinstance(obj, (set, list, tuple, Iterator)):
# Convert elements of an iterable
return type(obj)([cls.from_obj(o) for o in obj])

if inspect.isclass(obj) and issubclass(obj, cls.base_classes()):
# This is a class/type covered by a meta class/type.
# This object is a class/type object. Let's see if there's a
# an existing meta class related to it. If there is, we return
# the corresponding meta class; otherwise, if the object is a
# subclass, we'll create a new meta subclass for it.
try:
obj_cls = next(filter(lambda t: issubclass(obj, t.base), cls.__subclasses__()))
except StopIteration:
# The current class is the best fit.
# The current class is the best match.
if cls.base == obj:
return cls

Expand All @@ -144,11 +171,9 @@ def from_obj(cls, obj):

if not isinstance(obj, cls.base_classes()):
# We might've been given something convertible to a type with a
# meta type, so let's try that
# meta type, so let's try that.
try:
obj = cls.to_base_obj(obj)
# tt.as_tensor_variable(obj)
# tt.AsTensorError
except ValueError:
pass

Expand All @@ -175,7 +200,7 @@ def rands(self):

def reify(self):
"""Create a concrete base object from this meta object (and its rands)."""
if self.obj and not isinstance(self.obj, Var):
if self.obj is not None and not isinstance(self.obj, Var):
return self.obj
else:
reified_rands, any_unreified = _meta_reify_iter(self.rands())
Expand All @@ -191,63 +216,42 @@ def reify(self):
return res

def __eq__(self, other):
"""Syntactic equality between meta objects and their bases."""
# TODO: Allow a sort of cross-inheritance equivalence (e.g. a
# `tt.Variable` or `tt.TensorVariable`)?
# a_sub_b = isinstance(self, type(other))
# b_sub_a = isinstance(other, type(self))
# if not (a_sub_b or b_sub_a):
# return False
"""Implement an equivalence between meta objects and their bases."""
if self is other:
return True

if not (type(self) == type(other)):
return False

# TODO: ?
# Same for base objects
# a_sub_b = isinstance(self.base, type(other.base))
# b_sub_a = isinstance(other.base, type(self.base))
# if not (a_sub_b or b_sub_a):
# return False
if not (self.base == other.base):
return False

# TODO: ?
# # `self` is the super class, that might be generalizing
# # `other`
a_slots = getattr(self, "__slots__", [])
# b_slots = getattr(other, '__slots__', [])
# if (b_sub_a and not a_sub_b and
# not all(getattr(self, attr) == getattr(other, attr)
# for attr in a_slots)):
# return False
# # `other` is the super class, that might be generalizing
# # `self`
# elif (a_sub_b and not b_sub_a and
# not all(getattr(self, attr) == getattr(other, attr)
# for attr in b_slots)):
# return False
if not all(_check_eq(getattr(self, attr), getattr(other, attr)) for attr in a_slots):
a_slots = getattr(self, "__slots__", None)
if a_slots is not None:
if not all(_check_eq(getattr(self, attr), getattr(other, attr)) for attr in a_slots):
return False
elif getattr(other, "__slots__", None) is not None:
# The other object has slots, but this one doesn't.
return False

# if (self.obj and not isvar(self.obj) and
# other.obj and not isvar(other.objj)):
# assert self.obj == other.obj
else:
# Neither have slots, so best we can do is compare
# base objects (if any).
# If there aren't base objects, we say they're not equal.
# (Maybe we should *require* base objects in this case
# and raise an exception?)
return getattr(self, "obj", None) == getattr(other, "obj", None) is not None

return True

def __ne__(self, other):
return not self.__eq__(other)

def __hash__(self):
def _make_hashable(x):
if isinstance(x, list):
return tuple(x)
elif isinstance(x, np.ndarray):
return x.data.tobytes()
else:
return x

rands = tuple(_make_hashable(p) for p in self.rands())
return hash(rands + (self.base,))
if getattr(self, "__slots__", None) is not None:
rands = tuple(_make_hashable(p) for p in self.rands())
return hash(rands + (self.base,))
else:
return hash((self.base, self.obj))

def __str__(self):
obj = getattr(self, "obj", None)
Expand Down Expand Up @@ -280,17 +284,19 @@ def _repr_pretty_(self, p, cycle):
p.text(name)
p.text("=")
p.pretty(item)

obj = getattr(self, "obj", None)
if obj:
if idx:

if obj is not None:
if idx is not None:
p.text(",")
p.breakable()
p.text("obj=")
p.pretty(obj)


class MetaOp(MetaSymbol):
"""A meta object that represents a `MetaVaribale`-producing operator.
"""A meta object that represents a `MetaVariable`-producing operator.
Also, make sure to override `Op.out_meta_type` and make it return the
expected meta variable type, if it isn't the default: `MetaTensorVariable`.
Expand All @@ -303,43 +309,40 @@ class MetaOp(MetaSymbol):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.op_sig = inspect.signature(self.obj.make_node)

@property
def obj(self):
return self._obj
return object.__getattribute__(self, "_obj")

@obj.setter
def obj(self, x):
if hasattr(self, "_obj"):
raise ValueError("Cannot reset obj in an `Op`")
self._obj = x
object.__setattr__(self, "_obj", x)

@abc.abstractmethod
def out_meta_type(self, inputs=None):
"""Return the type of meta variable this `Op` is expected to produce given the inputs."""
def out_meta_types(self, inputs=None):
"""Return the types of meta variables this `Op` is expected to produce given the inputs."""
raise NotImplementedError()

@abc.abstractmethod
def __call__(self, *args, ttype=None, index=None, **kwargs):
raise NotImplementedError()

def __eq__(self, other):
# Since these have no rands/slots, we can only really compare against
# the underlying base objects (which should be there!).
if not super().__eq__(other):
return False

assert self.obj

if self.obj != other.obj:
return False

return True

def __hash__(self):
return hash((self.base, self.obj))
class MetaVariable(MetaSymbol):
@property
@abc.abstractmethod
def operator(self):
"""Return a meta object representing an operator, if any, capable of producing this variable.
It should be callable with all inputs necessary to reproduce this
tensor given by `MetaVariable.inputs`.
"""
raise NotImplementedError()

class MetaVariable(MetaSymbol):
pass
@property
@abc.abstractmethod
def inputs(self):
"""Return the inputs necessary for `MetaVariable.operator` to produced this variable, if any."""
raise NotImplementedError()
2 changes: 2 additions & 0 deletions symbolic_pymc/tensorflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Needed to register generic functions
from .unify import *
Loading

0 comments on commit 5b5a6dd

Please sign in to comment.