Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove named shapes from avals #21069

Merged
merged 1 commit into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 17 additions & 21 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,9 @@ def xla_computation(fun: Callable,
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
wrapped function returns a pair where the first element is the XLA
computation and the second element is a pytree with the same structure as
the output of ``fun`` and where the leaves are objects with ``shape``,
``dtype``, and ``named_shape`` attributes representing the corresponding
types of the output leaves.
the output of ``fun`` and where the leaves are objects with ``shape`` and
``dtype`` attributes representing the corresponding types of the output
leaves.
donate_argnums: Specify which arguments are "donated" to the computation.
It is safe to donate arguments if you no longer need them once the
computation has finished. In some cases XLA can make use of donated
Expand Down Expand Up @@ -557,8 +557,8 @@ def computation_maker(*args, **kwargs):
m = mlir.module_to_bytecode(lowering_result.module)
built = xc._xla.mlir.mlir_module_to_xla_computation(
m, use_tuple_args=tuple_args, return_tuple=True)
out_shapes_flat = [
ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
out_shape = tree_unflatten(out_tree(), out_shapes_flat)
for out_aval in out_avals:
if not isinstance(out_aval, ShapedArray):
Expand Down Expand Up @@ -2337,8 +2337,8 @@ def make_jaxpr(fun: Callable,
wrapped function returns a pair where the first element is the
``ClosedJaxpr`` representation of ``fun`` and the second element is a
pytree with the same structure as the output of ``fun`` and where the
leaves are objects with ``shape``, ``dtype``, and ``named_shape``
attributes representing the corresponding types of the output leaves.
leaves are objects with ``shape`` and ``dtype`` attributes representing
the corresponding types of the output leaves.

Returns:
A wrapped version of ``fun`` that when applied to example arguments returns
Expand Down Expand Up @@ -2400,8 +2400,7 @@ def make_jaxpr_f(*args, **kwargs):
else:
jaxpr = traced.jaxpr
if return_shape:
out = [ShapeDtypeStruct(o.shape, o.dtype, getattr(o, 'named_shape', None))
for o in jaxpr.out_avals]
out = [ShapeDtypeStruct(o.shape, o.dtype) for o in jaxpr.out_avals]
return jaxpr, tree_unflatten(tree_structure(traced.out_info), out)
return jaxpr

Expand Down Expand Up @@ -2691,12 +2690,13 @@ class ShapeDtypeStruct:
Args:
shape: a sequence of integers representing an array shape
dtype: a dtype-like object
named_shape: (optional) a dictionary representing a named shape
sharding: (optional) a :class:`jax.Sharding` object
"""
__slots__ = ["shape", "dtype", "named_shape", "sharding", "_dll"]
__slots__ = ["shape", "dtype", "sharding", "_dll"]
named_shape = {}

def __init__(self, shape, dtype, named_shape=None, sharding=None):
def __init__(self, shape, dtype, sharding=None, named_shape=None):
del named_shape # ignored, vestigial
self.shape = tuple(shape)
if dtype is None:
raise ValueError("ShapeDtypeStruct: dtype must be specified.")
Expand All @@ -2713,7 +2713,6 @@ def __init__(self, shape, dtype, named_shape=None, sharding=None):
f" layout in a `ShapeDtypeStruct`. Got {sharding}")
self.sharding = sharding.sharding if isinstance(sharding, Layout) else sharding
self._dll = sharding.device_local_layout if isinstance(sharding, Layout) else None
self.named_shape = {} if named_shape is None else dict(named_shape)

size = property(lambda self: math.prod(self.shape))
ndim = property(lambda self: len(self.shape))
Expand All @@ -2729,31 +2728,28 @@ def __len__(self):
raise TypeError("len() of unsized object") from e # same as numpy error

def __repr__(self):
ns = f", named_shape={self.named_shape}" if self.named_shape else ""
sh = f", sharding={self.sharding}" if self.sharding is not None else ""
l = f", layout={self.layout}" if self._dll is not None else ""
return (f"{type(self).__name__}(shape={self.shape}, "
f"dtype={self.dtype.name}{ns}{sh}{l})")
f"dtype={self.dtype.name}{sh}{l})")

__str__ = __repr__

def __eq__(self, other):
if not isinstance(other, ShapeDtypeStruct):
return False
else:
return ((other.shape, other.dtype, other.named_shape, other.sharding, other.layout) ==
(self.shape, self.dtype, self.named_shape, self.sharding, self.layout))
return ((other.shape, other.dtype, other.sharding, other.layout) ==
(self.shape, self.dtype, self.sharding, self.layout))

def __hash__(self):
# TODO(frostig): avoid the conversion from dict by addressing
# https://github.com/google/jax/issues/8182
named = frozenset(self.named_shape.items())
return hash((self.shape, self.dtype, named, self.sharding, self.layout))

return hash((self.shape, self.dtype, self.sharding, self.layout))

core.pytype_aval_mappings[ShapeDtypeStruct] = (
lambda x: ShapedArray(x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
weak_type=False, named_shape=x.named_shape))
weak_type=False))


@api_boundary
Expand Down
4 changes: 1 addition & 3 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,15 +570,13 @@ def _shaped_abstractify_slow(x):
pass

weak_type = getattr(x, 'weak_type', False)
named_shape = getattr(x, 'named_shape', {})
if hasattr(x, 'dtype'):
dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
else:
raise TypeError(
f"Cannot interpret value of type {type(x)} as an abstract array; it "
"does not have a dtype attribute")
return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
named_shape=named_shape)
return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type)

# TODO(mattjj,yashkatariya): replace xla.abstractify with this, same behavior
def shaped_abstractify(x):
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,7 @@ def __dlpack_device__(self) -> tuple[enum.Enum, int]:

def __reduce__(self):
fun, args, arr_state = self._value.__reduce__()
aval_state = {'weak_type': self.aval.weak_type,
'named_shape': self.aval.named_shape}
aval_state = {'weak_type': self.aval.weak_type}
return (_reconstruct_array, (fun, args, arr_state, aval_state))

@use_cpp_method()
Expand Down
Loading
Loading