Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed May 4, 2024
1 parent 7386128 commit 4d744f6
Show file tree
Hide file tree
Showing 23 changed files with 123 additions and 419 deletions.
35 changes: 15 additions & 20 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,9 @@ def xla_computation(fun: Callable,
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
wrapped function returns a pair where the first element is the XLA
computation and the second element is a pytree with the same structure as
the output of ``fun`` and where the leaves are objects with ``shape``,
``dtype``, and ``named_shape`` attributes representing the corresponding
types of the output leaves.
the output of ``fun`` and where the leaves are objects with ``shape`` and
``dtype`` attributes representing the corresponding types of the output
leaves.
donate_argnums: Specify which arguments are "donated" to the computation.
It is safe to donate arguments if you no longer need them once the
computation has finished. In some cases XLA can make use of donated
Expand Down Expand Up @@ -552,8 +552,8 @@ def computation_maker(*args, **kwargs):

built = xc._xla.mlir.mlir_module_to_xla_computation(
m, use_tuple_args=tuple_args, return_tuple=True)
out_shapes_flat = [
ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
out_shape = tree_unflatten(out_tree(), out_shapes_flat)
for out_aval in out_avals:
if not isinstance(out_aval, ShapedArray):
Expand Down Expand Up @@ -2331,8 +2331,8 @@ def make_jaxpr(fun: Callable,
wrapped function returns a pair where the first element is the
``ClosedJaxpr`` representation of ``fun`` and the second element is a
pytree with the same structure as the output of ``fun`` and where the
leaves are objects with ``shape``, ``dtype``, and ``named_shape``
attributes representing the corresponding types of the output leaves.
leaves are objects with ``shape`` and ``dtype`` attributes representing
the corresponding types of the output leaves.
Returns:
A wrapped version of ``fun`` that when applied to example arguments returns
Expand Down Expand Up @@ -2385,7 +2385,7 @@ def make_jaxpr_f(*args, **kwargs):
specialized = jit(fun, static_argnums=static_argnums,
abstracted_axes=abstracted_axes).specialize(*args, **kwargs)
if return_shape:
out = [ShapeDtypeStruct(o.shape, o.dtype, getattr(o, 'named_shape', None))
out = [ShapeDtypeStruct(o.shape, o.dtype)
for o in specialized.jaxpr.out_avals]
return specialized.jaxpr, tree_unflatten(specialized.out_tree, out)
return specialized.jaxpr
Expand Down Expand Up @@ -2668,12 +2668,11 @@ class ShapeDtypeStruct:
Args:
shape: a sequence of integers representing an array shape
dtype: a dtype-like object
named_shape: (optional) a dictionary representing a named shape
sharding: (optional) a :class:`jax.Sharding` object
"""
__slots__ = ["shape", "dtype", "named_shape", "sharding", "_dll"]
__slots__ = ["shape", "dtype", "sharding", "_dll"]

def __init__(self, shape, dtype, named_shape=None, sharding=None):
def __init__(self, shape, dtype, sharding=None):
self.shape = tuple(shape)
if dtype is None:
raise ValueError("ShapeDtypeStruct: dtype must be specified.")
Expand All @@ -2690,7 +2689,6 @@ def __init__(self, shape, dtype, named_shape=None, sharding=None):
f" layout in a `ShapeDtypeStruct`. Got {sharding}")
self.sharding = sharding.sharding if isinstance(sharding, Layout) else sharding
self._dll = sharding.device_local_layout if isinstance(sharding, Layout) else None
self.named_shape = {} if named_shape is None else dict(named_shape)

size = property(lambda self: math.prod(self.shape))
ndim = property(lambda self: len(self.shape))
Expand All @@ -2706,31 +2704,28 @@ def __len__(self):
raise TypeError("len() of unsized object") from e # same as numpy error

def __repr__(self):
ns = f", named_shape={self.named_shape}" if self.named_shape else ""
sh = f", sharding={self.sharding}" if self.sharding is not None else ""
l = f", layout={self.layout}" if self._dll is not None else ""
return (f"{type(self).__name__}(shape={self.shape}, "
f"dtype={self.dtype.name}{ns}{sh}{l})")
f"dtype={self.dtype.name}{sh}{l})")

__str__ = __repr__

def __eq__(self, other):
if not isinstance(other, ShapeDtypeStruct):
return False
else:
return ((other.shape, other.dtype, other.named_shape, other.sharding, other.layout) ==
(self.shape, self.dtype, self.named_shape, self.sharding, self.layout))
return ((other.shape, other.dtype, other.sharding, other.layout) ==
(self.shape, self.dtype, self.sharding, self.layout))

def __hash__(self):
# TODO(frostig): avoid the conversion from dict by addressing
# https://github.com/google/jax/issues/8182
named = frozenset(self.named_shape.items())
return hash((self.shape, self.dtype, named, self.sharding, self.layout))

return hash((self.shape, self.dtype, self.sharding, self.layout))

core.pytype_aval_mappings[ShapeDtypeStruct] = (
lambda x: ShapedArray(x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
weak_type=False, named_shape=x.named_shape))
weak_type=False))


@api_boundary
Expand Down
4 changes: 1 addition & 3 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,15 +570,13 @@ def _shaped_abstractify_slow(x):
pass

weak_type = getattr(x, 'weak_type', False)
named_shape = getattr(x, 'named_shape', {})
if hasattr(x, 'dtype'):
dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
else:
raise TypeError(
f"Cannot interpret value of type {type(x)} as an abstract array; it "
"does not have a dtype attribute")
return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
named_shape=named_shape)
return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type)

# TODO(mattjj,yashkatariya): replace xla.abstractify with this, same behavior
def shaped_abstractify(x):
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,7 @@ def __dlpack_device__(self) -> tuple[enum.Enum, int]:

def __reduce__(self):
fun, args, arr_state = self._value.__reduce__() # type: ignore
aval_state = {'weak_type': self.aval.weak_type,
'named_shape': self.aval.named_shape}
aval_state = {'weak_type': self.aval.weak_type}
return (_reconstruct_array, (fun, args, arr_state, aval_state))

@use_cpp_method()
Expand Down
Loading

0 comments on commit 4d744f6

Please sign in to comment.