diff --git a/jax/_src/api.py b/jax/_src/api.py index 2256a44b7ed9..dad1f020842d 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -381,9 +381,9 @@ def xla_computation(fun: Callable, return_shape: Optional boolean, defaults to ``False``. If ``True``, the wrapped function returns a pair where the first element is the XLA computation and the second element is a pytree with the same structure as - the output of ``fun`` and where the leaves are objects with ``shape``, - ``dtype``, and ``named_shape`` attributes representing the corresponding - types of the output leaves. + the output of ``fun`` and where the leaves are objects with ``shape`` and + ``dtype`` attributes representing the corresponding types of the output + leaves. donate_argnums: Specify which arguments are "donated" to the computation. It is safe to donate arguments if you no longer need them once the computation has finished. In some cases XLA can make use of donated @@ -557,8 +557,8 @@ def computation_maker(*args, **kwargs): m = mlir.module_to_bytecode(lowering_result.module) built = xc._xla.mlir.mlir_module_to_xla_computation( m, use_tuple_args=tuple_args, return_tuple=True) - out_shapes_flat = [ - ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals] + out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals] + out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals] out_shape = tree_unflatten(out_tree(), out_shapes_flat) for out_aval in out_avals: if not isinstance(out_aval, ShapedArray): @@ -2337,8 +2337,8 @@ def make_jaxpr(fun: Callable, wrapped function returns a pair where the first element is the ``ClosedJaxpr`` representation of ``fun`` and the second element is a pytree with the same structure as the output of ``fun`` and where the - leaves are objects with ``shape``, ``dtype``, and ``named_shape`` - attributes representing the corresponding types of the output leaves. + leaves are objects with ``shape`` and ``dtype`` attributes representing + the corresponding types of the output leaves. Returns: A wrapped version of ``fun`` that when applied to example arguments returns @@ -2400,8 +2400,7 @@ def make_jaxpr_f(*args, **kwargs): else: jaxpr = traced.jaxpr if return_shape: - out = [ShapeDtypeStruct(o.shape, o.dtype, getattr(o, 'named_shape', None)) - for o in jaxpr.out_avals] + out = [ShapeDtypeStruct(o.shape, o.dtype) for o in jaxpr.out_avals] return jaxpr, tree_unflatten(tree_structure(traced.out_info), out) return jaxpr @@ -2691,12 +2690,11 @@ class ShapeDtypeStruct: Args: shape: a sequence of integers representing an array shape dtype: a dtype-like object - named_shape: (optional) a dictionary representing a named shape sharding: (optional) a :class:`jax.Sharding` object """ - __slots__ = ["shape", "dtype", "named_shape", "sharding", "_dll"] + __slots__ = ["shape", "dtype", "sharding", "_dll"] - def __init__(self, shape, dtype, named_shape=None, sharding=None): + def __init__(self, shape, dtype, sharding=None): self.shape = tuple(shape) if dtype is None: raise ValueError("ShapeDtypeStruct: dtype must be specified.") @@ -2713,7 +2711,6 @@ def __init__(self, shape, dtype, named_shape=None, sharding=None): f" layout in a `ShapeDtypeStruct`. Got {sharding}") self.sharding = sharding.sharding if isinstance(sharding, Layout) else sharding self._dll = sharding.device_local_layout if isinstance(sharding, Layout) else None - self.named_shape = {} if named_shape is None else dict(named_shape) size = property(lambda self: math.prod(self.shape)) ndim = property(lambda self: len(self.shape)) @@ -2729,11 +2726,10 @@ def __len__(self): raise TypeError("len() of unsized object") from e # same as numpy error def __repr__(self): - ns = f", named_shape={self.named_shape}" if self.named_shape else "" sh = f", sharding={self.sharding}" if self.sharding is not None else "" l = f", layout={self.layout}" if self._dll is not None else "" return (f"{type(self).__name__}(shape={self.shape}, " - f"dtype={self.dtype.name}{ns}{sh}{l})") + f"dtype={self.dtype.name}{sh}{l})") __str__ = __repr__ @@ -2741,19 +2737,17 @@ def __eq__(self, other): if not isinstance(other, ShapeDtypeStruct): return False else: - return ((other.shape, other.dtype, other.named_shape, other.sharding, other.layout) == - (self.shape, self.dtype, self.named_shape, self.sharding, self.layout)) + return ((other.shape, other.dtype, other.sharding, other.layout) == + (self.shape, self.dtype, self.sharding, self.layout)) def __hash__(self): # TODO(frostig): avoid the conversion from dict by addressing # https://github.com/google/jax/issues/8182 - named = frozenset(self.named_shape.items()) - return hash((self.shape, self.dtype, named, self.sharding, self.layout)) - + return hash((self.shape, self.dtype, self.sharding, self.layout)) core.pytype_aval_mappings[ShapeDtypeStruct] = ( lambda x: ShapedArray(x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True), - weak_type=False, named_shape=x.named_shape)) + weak_type=False)) @api_boundary diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 98700306e425..dd1cdcbe6bb8 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -570,15 +570,13 @@ def _shaped_abstractify_slow(x): pass weak_type = getattr(x, 'weak_type', False) - named_shape = getattr(x, 'named_shape', {}) if hasattr(x, 'dtype'): dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True) else: raise TypeError( f"Cannot interpret value of type {type(x)} as an abstract array; it " "does not have a dtype attribute") - return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type, - named_shape=named_shape) + return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type) # TODO(mattjj,yashkatariya): replace xla.abstractify with this, same behavior def shaped_abstractify(x): diff --git a/jax/_src/array.py b/jax/_src/array.py index c76546d94eae..ec1971d372b5 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -466,8 +466,7 @@ def __dlpack_device__(self) -> tuple[enum.Enum, int]: def __reduce__(self): fun, args, arr_state = self._value.__reduce__() - aval_state = {'weak_type': self.aval.weak_type, - 'named_shape': self.aval.named_shape} + aval_state = {'weak_type': self.aval.weak_type} return (_reconstruct_array, (fun, args, arr_state, aval_state)) @use_cpp_method() diff --git a/jax/_src/core.py b/jax/_src/core.py index 5ed1e1871cb2..a65e454dc0d6 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1418,9 +1418,6 @@ def __repr__(self): def strip_weak_type(self) -> AbstractValue: return self - def strip_named_shape(self) -> AbstractValue: - return self - def join(self, other): raise NotImplementedError("must override") @@ -1695,6 +1692,8 @@ def canonicalize_shape(shape: Shape, context: str="") -> tuple[Any, ...]: Returns: A tuple of canonical dimension values. """ + if isinstance(shape, int): + shape = shape, try: return tuple(unsafe_map(_canonicalize_dimension, shape)) except TypeError: @@ -1733,25 +1732,22 @@ def _invalid_shape_error(shape: Shape, context: str=""): return TypeError(msg) class ShapedArray(UnshapedArray): - __slots__ = ['shape', 'named_shape'] + __slots__ = ['shape'] array_abstraction_level = 2 - def __init__(self, shape, dtype, weak_type=False, named_shape=None): + def __init__(self, shape, dtype, weak_type=False): self.shape = canonicalize_shape(shape) self.dtype = _dtype_object(dtype) self.weak_type = weak_type - self.named_shape = {} if named_shape is None else dict(named_shape) - def update(self, shape=None, dtype=None, weak_type=None, named_shape=None): + def update(self, shape=None, dtype=None, weak_type=None): if shape is None: shape = self.shape if dtype is None: dtype = self.dtype if weak_type is None: weak_type = self.weak_type - if named_shape is None: - named_shape = self.named_shape - return ShapedArray(shape, dtype, weak_type, named_shape) + return ShapedArray(shape, dtype, weak_type) ndim = property(lambda self: len(self.shape)) size = property(lambda self: @@ -1766,25 +1762,22 @@ def update(self, shape=None, dtype=None, weak_type=None, named_shape=None): def __eq__(self, other): return (type(self) is type(other) and self.dtype == other.dtype and self.shape == other.shape - and self.weak_type == other.weak_type - and self.named_shape == other.named_shape) + and self.weak_type == other.weak_type) def __hash__(self): # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use # the unique character code via hash(self.dtype.char) - return hash((self.shape, self.dtype, self.weak_type, - tuple(self.named_shape.items()))) + return hash((self.shape, self.dtype, self.weak_type)) def at_least_vspace(self): return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), - self.weak_type, self.named_shape) + self.weak_type) def join(self, other): if definitely_equal_shape(self.shape, other.shape) and self.dtype == other.dtype: weak_type = self.weak_type and other.weak_type - named_shape = join_named_shapes(self.named_shape, other.named_shape) - return self.update(weak_type=weak_type, named_shape=named_shape) + return self.update(weak_type=weak_type) elif self.dtype == other.dtype: return UnshapedArray(self.dtype) else: @@ -1794,14 +1787,7 @@ def str_short(self, short_dtypes=False): dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name dt_str = dt_str.replace('void', 'float0') shapestr = ','.join(map(str, self.shape)) - if self.named_shape: - named_shapestr = ','.join(f'{k}:{v}' for k, v in self.named_shape.items()) - return f'{dt_str}[{shapestr};{named_shapestr}]' - else: - return f'{dt_str}[{shapestr}]' - - def strip_named_shape(self): - return self.update(named_shape={}) + return f'{dt_str}[{shapestr}]' def _len(self, ignored_tracer): try: @@ -1849,12 +1835,9 @@ def join(self, other) -> AbstractValue: return self elif self.shape == other.shape and self.dtype == other.dtype: weak_type = self.weak_type and other.weak_type - named_shape = join_named_shapes(self.named_shape, other.named_shape) - return ShapedArray( - self.shape, self.dtype, weak_type=weak_type, named_shape=named_shape) + return ShapedArray(self.shape, self.dtype, weak_type=weak_type) elif self.dtype == other.dtype: - return UnshapedArray(self.dtype, - weak_type=self.weak_type and other.weak_type) + return UnshapedArray(self.dtype, weak_type=self.weak_type and other.weak_type) else: raise TypeError(self, other) @@ -2090,8 +2073,7 @@ def raise_to_shaped(aval: AbstractValue, weak_type=None): Bot: lambda aval, _: aval, UnshapedArray: lambda aval, _: aval, ShapedArray: lambda aval, weak_type: ShapedArray( - aval.shape, aval.dtype, weak_type, aval.named_shape - ), + aval.shape, aval.dtype, weak_type), DConcreteArray: lambda aval, weak_type: DShapedArray( aval.shape, aval.dtype, weak_type ), @@ -2282,94 +2264,6 @@ def dim_constant(ct: int): def dim_value_aval() -> AbstractValue: return ShapedArray((), dim_value_dtype(), weak_type=True) -# ------------------- Named shapes ------------------- - - -class NamedShape: - def __init__(self, *args, **kwargs): - self.__positional = canonicalize_shape(args) - # TODO: Assert that kwargs match axis env? - self.__named = dict(kwargs) - - @property - def rank(self): - return len(self.__positional) + len(self.__named) - - @property - def positional_rank(self): - return len(self.__positional) - - @property - def named_rank(self): - return len(self.__named) - - @property - def positional(self): - return self.__positional - - @property - def names(self): - return self.__named.keys() - - @property - def named_sizes(self): - return self.__named.values() - - @property - def named_items(self): - return self.__named.items() - - def __getitem__(self, idx): - try: - idx = operator.index(idx) - return self.__positional[idx] - except TypeError: - pass - return self.__named[idx] - - @property - def total(self): - total = 1 - for s in self.__positional: total *= s - for s in self.__named.values(): total *= s - return total - - def __str__(self): - # TODO(mattjj,frostig): revise not to miss commas - if not self.__named: - return str(self.__positional) - return (f"({', '.join(map(str, self.__positional))}{', ' if self.__named else ''}" - f"{', '.join(f'{k}={v}' for k, v in self.__named.items())})") - - def __eq__(self, other): - if isinstance(other, NamedShape): - return (self.__positional, self.__named) == (other.__positional, other.__named) - if isinstance(other, tuple): - return not self.__named and self.__positional == other - return False - - def __hash__(self): - named = frozenset(self.__named.items()) - return hash((self.__positional, named)) - -def join_named_shapes(*named_shapes): - result = {} - for named_shape in named_shapes: - for name, size in named_shape.items(): - if result.setdefault(name, size) != size: - raise TypeError( - f"Axis name {name} used with inconsistent sizes: {result[name]} != {size}") - return result - -# TODO: Make canonicalize_shape return named shapes? -def as_named_shape(shape) -> NamedShape: - if isinstance(shape, int): - shape = (shape,) - if isinstance(shape, NamedShape): - return shape - return NamedShape(*shape) - - # ------------------- Call ------------------- class CallPrimitive(Primitive): @@ -2574,17 +2468,15 @@ def _map_shaped_array( # TODO: Extend the named shape if axis is None: return aval return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype, - named_shape=aval.named_shape, weak_type=aval.weak_type) + weak_type=aval.weak_type) def _unmap_shaped_array( size: int, axis_name: AxisName, axis: int | None, aval: ShapedArray ) -> ShapedArray: - named_shape = dict(aval.named_shape) - named_shape.pop(axis_name, None) # TODO: make this mandatory - if axis is None: return aval.update(named_shape=named_shape) + if axis is None: return aval elif type(axis) is int: return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype, - named_shape=named_shape, weak_type=aval.weak_type) + weak_type=aval.weak_type) else: raise TypeError(axis) def _map_dshaped_array( @@ -2780,16 +2672,8 @@ def subst_axis_names_var(v: Var, subst: AxisSubst, var_map: dict[Var, Var]) -> V # Var identity is load-bearing, so we can't have duplicates! if isinstance(v, DropVar): return v assert v not in var_map - if not hasattr(v.aval, 'named_shape'): - var_map[v] = v - return v - names = tuple(it.chain.from_iterable(subst(name) for name in v.aval.named_shape)) - named_shape = {name: axis_frame(name).size for name in names} - if len(named_shape) != len(names): - raise DuplicateAxisNameError(v) - new_v = Var(v.suffix, v.aval.update(named_shape=named_shape)) - var_map[v] = new_v - return new_v + var_map[v] = v + return v def subst_axis_names_eqn(eqn: JaxprEqn, subst: AxisSubst, var_map: dict[Var, Var]) -> JaxprEqn: invars: list[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in eqn.invars] @@ -2857,31 +2741,20 @@ def typecheck(aval: AbstractValue, x) -> bool: return typecompat(aval, get_aval(x)) def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool: - """Determine whether `aval` conforms to `aval_ref`. - - Ignores weak_type and named_shape, other than to check that an axis name isn't - used with different sizes. - """ + """Determine whether `aval` conforms to `aval_ref`. Ignores weak_type.""" try: return typematch(aval_ref, lattice_join(aval_ref, aval)) except TypeError: return False def typematch(aval1: AbstractValue, aval2: AbstractValue) -> bool: - """Determine whether `aval1` and `aval2` are equivalent. - - Ignores weak_type and named_shape, other than to check that an axis name isn't - used with different sizes. - """ + """Determine whether `aval1` and `aval2` are equivalent. Ignores weak_type.""" if aval1 == aval2: return True # unequal avals may still represent the same type, because type is represented - # by avals at the shaped level, and because weak type tags and (for now) named - # shape components aren't considered part of the type - if isinstance(aval1, ShapedArray) and isinstance(aval2, ShapedArray): - # a bonus check for whether any named axes have inconsistent sizes - join_named_shapes(aval1.named_shape, aval2.named_shape) - return (raise_to_shaped(aval1, weak_type=False).strip_named_shape() == - raise_to_shaped(aval2, weak_type=False).strip_named_shape()) + # by avals at the shaped level, and because weak type tags aren't considered + # part of the type + return (raise_to_shaped(aval1, weak_type=False) == + raise_to_shaped(aval2, weak_type=False)) class JaxprTypeError(TypeError): pass diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 9683a8b1deb1..13b9caf3d749 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -327,13 +327,11 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args): f""" {str(ty_tree_).replace("'", "")}""") raise TypeError(m) # TODO(mattjj): compare primals' tangent types to tangent objects' types - primal_avals_out = [ - raise_to_shaped(core.get_aval(x), weak_type=False).strip_named_shape() - for x in primals_out] - tangent_avals_out = [ - raise_to_shaped(core.get_aval(t), weak_type=False).strip_named_shape() - if type(t) is not SymbolicZero else t.aval.strip_weak_type() - for t in tangents_out] + primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) + for x in primals_out] + tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False) + if type(t) is not SymbolicZero else t.aval.strip_weak_type() + for t in tangents_out] if primal_avals_out != tangent_avals_out: if len(primal_avals_out) == 1: (av1,), (av2,) = primal_avals_out, tangent_avals_out diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 5c83d40089df..48cfe5e71a89 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -1237,8 +1237,7 @@ def pp_arg_dim(dim_idx: int | None) -> str: out_avals = tuple( core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars, *exported_dim_values), - dtype=out_aval.dtype, weak_type=out_aval.weak_type, - named_shape=out_aval.named_shape) + dtype=out_aval.dtype, weak_type=out_aval.weak_type) for out_aval in exported.out_avals) return out_avals, set(exported.ordered_effects + exported.unordered_effects) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 596d1979b936..ea9da4574e3d 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -200,8 +200,8 @@ def write_cotangent(prim, v, ct): # TODO(mattjj): add back these checks for dynamic shapes # if config.enable_checks.value: # ct_aval = core.get_aval(ct_env[v]) - # joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type().strip_named_shape() - # assert v.aval.strip_weak_type().strip_named_shape() == joined_aval, (prim, v.aval, ct_aval) + # joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type() + # assert v.aval.strip_weak_type() == joined_aval, (prim, v.aval, ct_aval) def read_cotangent(v): return ct_env.pop(v, Zero(v.aval.at_least_vspace())) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index a095837446a6..a6f4fa14b410 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1091,7 +1091,7 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out, donated_args, input_output_aliases = list(input_output_aliases) # To match-up in-avals to out-avals we only care about the number of # bytes, so we strip off unrelated aval metadata (eg. the named shape) - strip_metadata = lambda a: a.strip_named_shape().strip_weak_type() + strip_metadata = lambda a: a.strip_weak_type() avals_in = map(strip_metadata, avals_in) avals_out = map(strip_metadata, avals_out) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 886b7eec7adc..ed341edde7c4 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -854,12 +854,10 @@ def lower_parallel_callable( def _pmap_unmap_shaped_array( size: int, axis_name: core.AxisName, axis: int | None, aval: ShapedArray ) -> ShapedArray: - named_shape = dict(aval.named_shape) - named_shape.pop(axis_name, None) # TODO: make this mandatory - if axis is None: return aval.update(named_shape=named_shape) + if axis is None: return aval elif type(axis) is int: return ShapedArray(tuple_update(aval.shape, axis, size), aval.dtype, - named_shape=named_shape, weak_type=aval.weak_type) + weak_type=aval.weak_type) else: raise TypeError(axis) @@ -1507,22 +1505,17 @@ def _pmap_lowering(ctx, *in_nodes, axis_name, def tile_aval_nd(axis_sizes, in_axes: ArrayMapping, aval): assert isinstance(aval, ShapedArray) shape = list(aval.shape) - named_shape = dict(aval.named_shape) for name, axis in in_axes.items(): assert shape[axis] % axis_sizes[name] == 0 - assert name not in named_shape - named_shape[name] = axis_sizes[name] shape[axis] //= axis_sizes[name] - return aval.update(shape=tuple(shape), named_shape=named_shape) + return aval.update(shape=tuple(shape)) def untile_aval_nd(axis_sizes, out_axes: ArrayMapping, aval): assert isinstance(aval, ShapedArray) shape = list(aval.shape) - named_shape = dict(aval.named_shape) for name, axis in out_axes.items(): shape[axis] *= axis_sizes[name] - named_shape.pop(name, None) # The name might be missing --- it's a broadcast. - return aval.update(shape=tuple(shape), named_shape=named_shape) + return aval.update(shape=tuple(shape)) def mesh_local_to_global(mesh, axes: ArrayMapping, aval): diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index aa707386c5db..d311e233ac83 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1321,7 +1321,7 @@ def _create_jaxpr(init_val): raise TypeError(msg.format(cond_tree)) pred_aval = cond_jaxpr.out_avals[0] if (not isinstance(pred_aval, ShapedArray) - or pred_aval.strip_weak_type().strip_named_shape() != ShapedArray((), np.bool_)): + or pred_aval.strip_weak_type() != ShapedArray((), np.bool_)): msg = "cond_fun must return a boolean scalar, but got output type(s) {}." raise TypeError(msg.format(cond_jaxpr.out_avals)) return init_vals, init_avals, body_jaxpr, in_tree, cond_jaxpr, cond_consts, body_consts, body_tree diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index f9fc12ca00b3..c30dc8373ebb 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -59,8 +59,7 @@ from jax._src.lax import slicing from jax._src.lax.utils import ( _input_dtype, dtype_to_string, standard_abstract_eval, - standard_multi_result_abstract_eval, standard_named_shape_rule, - standard_primitive) + standard_multi_result_abstract_eval, standard_primitive) from jax._src import xla_bridge from jax._src.lib import xla_client from jax._src.lib.mlir import ir @@ -2563,7 +2562,7 @@ def _convert_element_type_bind(operand, *, new_dtype, weak_type, sharding): convert_element_type_p.def_abstract_eval( partial(standard_abstract_eval, convert_element_type_p, _convert_element_type_shape_rule, _convert_element_type_dtype_rule, - _convert_element_type_weak_type_rule, standard_named_shape_rule)) + _convert_element_type_weak_type_rule)) ad.defjvp(convert_element_type_p, _convert_element_type_jvp_rule) ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule batching.defvectorized(convert_element_type_p) @@ -3360,7 +3359,7 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions): type(core.get_aval(d).dtype) is core.bint for d in shape)): shape = _broadcast_in_dim_shape_rule( # error checking x, shape=shape, broadcast_dimensions=broadcast_dimensions) - return core.ShapedArray(shape, x.dtype, x.weak_type, x.named_shape) + return core.ShapedArray(shape, x.dtype, x.weak_type) # If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray # (even if x is a ShapedArray) # TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code @@ -4057,25 +4056,12 @@ def _reduce_jvp_rule(primals, tangents, *, computation, jaxpr, dimensions): reducer = core.jaxpr_as_fun(jaxpr) return _reduce_jvp(reducer, init_values, primal_xs, tangent_xs, dimensions) -def _reduce_named_shape_rule(*avals, computation, jaxpr, dimensions): - # TODO(mattjj,frostig): see the TODOs noting limitations/assumptions in - # _reduce_batching_rule. We're making the same assumptions here for now. - num_operands = len(avals) // 2 - operand_avals, init_avals = split_list(avals, [num_operands]) - if any(a.named_shape for a in init_avals): - raise NotImplementedError - named_shapes = [a.named_shape for a in operand_avals] - join = core.join_named_shapes(*(a.named_shape for a in operand_avals)) - return [join] * len(named_shapes) - - reduce_p = core.Primitive('reduce') reduce_p.multiple_results = True reduce_p.def_impl(partial(dispatch.apply_primitive, reduce_p)) reduce_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule, - _reduce_dtype_rule, _reduce_weak_type_rule, - _reduce_named_shape_rule)) + _reduce_dtype_rule, _reduce_weak_type_rule)) batching.primitive_batchers[reduce_p] = _reduce_batch_rule ad.primitive_jvps[reduce_p] = _reduce_jvp_rule @@ -4839,9 +4825,6 @@ def _rng_bit_generator_lowering( return [out_key, out_vals] -def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm): - return [key.named_shape, key.named_shape] - rng_bit_generator_p = Primitive("rng_bit_generator") rng_bit_generator_p.multiple_results = True rng_bit_generator_p.def_impl( @@ -4849,8 +4832,7 @@ def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm): rng_bit_generator_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, rng_bit_generator_p, _rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule, - _rng_bit_generator_weak_type_rule, - _rng_bit_generator_named_shape_rule)) + _rng_bit_generator_weak_type_rule)) mlir.register_lowering(rng_bit_generator_p, _rng_bit_generator_lowering) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 6e2e6139bde8..0cb4a940da23 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -738,21 +738,15 @@ def _allreduce_impl(pos_reducer, *args, axes, axis_index_groups): return [pos_reducer(arg, axes) for arg in args] def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups): - # TODO(frostig,mattjj,jekbradbury): maybe check aval names here + named_axes = tuple(axis for axis in axes if not isinstance(axis, int)) pos_axes = tuple(axis for axis in axes if isinstance(axis, int)) - named_shapes = [arg.named_shape for arg in args] - named_axes = {axis for axis in axes if not isinstance(axis, int)} - if axis_index_groups is None: - named_shapes = [{name: size for name, size in arg.named_shape.items() - if name not in named_axes} for arg in args] - else: + if axis_index_groups is not None: if len(pos_axes) != 0: raise ValueError(f"axis_index_groups can only be used with reductions over " f"named axes, but got: {axes}") out_avals = [ ShapedArray(lax._reduce_op_shape_rule(raise_to_shaped(arg), axes=pos_axes), - arg.dtype, named_shape=named_shape) - for arg, named_shape in zip(args, named_shapes)] + arg.dtype) for arg in args] return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups): @@ -1301,12 +1295,7 @@ def _all_gather_effectful_abstract_eval( new_shape[all_gather_dimension] *= axis_size else: new_shape.insert(all_gather_dimension, axis_size) - new_named_shape = {name: size for name, size in x_aval.named_shape.items() - if name not in axis_name} - out_aval = x_aval.update(shape=new_shape, named_shape=new_named_shape) - effects = {*map(core.NamedAxisEffect, axis_name)} - return out_aval, effects - + return x_aval.update(shape=new_shape), {*map(core.NamedAxisEffect, axis_name)} def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): return (psum_scatter(cts, axis_name=axis_name, @@ -1437,15 +1426,7 @@ def _reduce_scatter_effectful_abstract_eval( f"{scatter_dim_input_size} must match shard count " f"{axis_size}") del new_shape[scatter_dimension] - - new_named_shape = { - name: size - for name, size in x_aval.named_shape.items() - if name not in axis_name - } - out_aval = x_aval.update(shape=new_shape, named_shape=new_named_shape) - effects = {*map(core.NamedAxisEffect, axis_name)} - return out_aval, effects + return x_aval.update(shape=new_shape), {*map(core.NamedAxisEffect, axis_name)} def _reduce_scatter_transpose_rule(cts, x, *, axis_name, scatter_dimension, @@ -1633,9 +1614,7 @@ def _axis_index_lowering(ctx, *, axis_name): def _axis_index_effectful_abstract_eval(*, axis_name): frame = core.axis_frame(axis_name) - out_aval = ShapedArray((), np.int32, named_shape={axis_name: frame.size}) - return out_aval, {core.NamedAxisEffect(axis_name)} - + return ShapedArray((), np.int32), {core.NamedAxisEffect(axis_name)} axis_index_p = core.Primitive('axis_index') mlir.register_lowering(axis_index_p, _axis_index_lowering) @@ -1691,14 +1670,7 @@ def _pdot_effectful_abstract_eval( pos_aval = lax.dot_general_p.abstract_eval( x, y, dimension_numbers=[pos_contract, pos_batch], precision=precision, preferred_element_type=None)[0] - common_named_shape = core.join_named_shapes(x.named_shape, y.named_shape) - named_shape = {name: size - for name, size in common_named_shape.items() - if name not in axis_name} - out_aval = pos_aval.update(named_shape=named_shape) - effects = {*map(core.NamedAxisEffect, axis_name)} - return out_aval, effects - + return pos_aval, {*map(core.NamedAxisEffect, axis_name)} def _pdot_vmap_collective_rule(axis_size, frame_name, _, vals_in, dims_in, *, axis_name, pos_contract, pos_batch, precision): diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 7cc15f30c459..01301db1a9a0 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -24,6 +24,8 @@ from jax._src.util import safe_zip from jax._src.lib import xla_client +zip, unsafe_zip = safe_zip, zip + import numpy as np xops = xla_client.ops @@ -35,20 +37,19 @@ def _argnum_weak_type(*argnums): return lambda *args, **_: all(args[i].weak_type for i in argnums) def standard_primitive(shape_rule, dtype_rule, name, - weak_type_rule=None, named_shape_rule=None): + weak_type_rule=None): weak_type_rule = weak_type_rule or _standard_weak_type_rule - named_shape_rule = named_shape_rule or standard_named_shape_rule prim = core.Primitive(name) prim.def_impl(partial(dispatch.apply_primitive, prim)) prim.def_abstract_eval( partial(standard_abstract_eval, prim, shape_rule, dtype_rule, - weak_type_rule, named_shape_rule)) + weak_type_rule)) return prim def _get_array_abstraction_level(a): return a.array_abstraction_level def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, - named_shape_rule, *avals, **kwargs): + *avals, **kwargs): assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals assert not prim.multiple_results weak_type = weak_type_rule(*avals, **kwargs) @@ -58,8 +59,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, return core.ConcreteArray(out.dtype, out, weak_type=weak_type) elif least_specialized is core.ShapedArray: return core.ShapedArray(shape_rule(*avals, **kwargs), - dtype_rule(*avals, **kwargs), weak_type=weak_type, - named_shape=named_shape_rule(*avals, **kwargs)) + dtype_rule(*avals, **kwargs), weak_type=weak_type) elif least_specialized is core.DShapedArray: shape = shape_rule(*avals, **kwargs) ty = (core.ShapedArray if all(type(d) is int for d in shape) @@ -71,8 +71,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, raise TypeError(avals, least_specialized) def standard_multi_result_abstract_eval( - prim, shape_rule, dtype_rule, weak_type_rule, - named_shape_rule, *avals, **kwargs): + prim, shape_rule, dtype_rule, weak_type_rule, *avals, **kwargs): assert prim.multiple_results assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals least_specialized = max(map(type, avals), key=_get_array_abstraction_level) @@ -80,18 +79,16 @@ def standard_multi_result_abstract_eval( if least_specialized is core.ConcreteArray: out_vals = prim.impl(*[x.val for x in avals], **kwargs) return [core.ConcreteArray(val.dtype, val, weak_type=weak_type) - for val, weak_type in safe_zip(out_vals, weak_types)] + for val, weak_type in zip(out_vals, weak_types)] elif least_specialized is core.ShapedArray: out_shapes = shape_rule(*avals, **kwargs) out_dtypes = dtype_rule(*avals, **kwargs) - out_named_shapes = named_shape_rule(*avals, **kwargs) - return [core.ShapedArray(s, d, weak_type=weak_type, named_shape=named_shape) - for s, d, weak_type, named_shape - in safe_zip(out_shapes, out_dtypes, weak_types, out_named_shapes)] + return [core.ShapedArray(s, d, weak_type=weak_type) + for s, d, weak_type in zip(out_shapes, out_dtypes, weak_types)] elif least_specialized is core.UnshapedArray: out_dtypes = dtype_rule(*avals, **kwargs) return [core.UnshapedArray(dtype, weak_type=weak_type) - for dtype, weak_type in safe_zip(out_dtypes, weak_types)] + for dtype, weak_type in zip(out_dtypes, weak_types)] else: raise TypeError(avals, least_specialized) @@ -103,9 +100,6 @@ def translation_rule(ctx, avals_in, avals_out, *args, **kwargs): return [op(*args, **kwargs)] return translation_rule -def standard_named_shape_rule(*avals, **kwargs): - return core.join_named_shapes(*(a.named_shape for a in avals)) - def _standard_weak_type_rule(*avals, **kwargs): return all(aval.weak_type for aval in avals) diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index cf245f7927be..7d228e4beef4 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -198,19 +198,19 @@ def init(key: KeyArray, return init @export -def _compute_fans(shape: core.NamedShape, +def _compute_fans(shape: Sequence[int], in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = () - ) -> tuple[Array, Array]: + ) -> tuple[float, float]: """ Compute effective input and output sizes for a linear or convolutional layer. Axes not in in_axis, out_axis, or batch_axis are assumed to constitute the "receptive field" of a convolution (kernel spatial dimensions). """ - if shape.rank <= 1: - raise ValueError(f"Can't compute input and output sizes of a {shape.rank}" + if len(shape) <= 1: + raise ValueError(f"Can't compute input and output sizes of a {len(shape)}" "-dimensional weights tensor. Must be at least 2D.") if isinstance(in_axis, int): @@ -225,13 +225,13 @@ def _compute_fans(shape: core.NamedShape, batch_size = shape[batch_axis] else: batch_size = math.prod([shape[i] for i in batch_axis]) - receptive_field_size = shape.total / in_size / out_size / batch_size + receptive_field_size = math.prod(shape) / in_size / out_size / batch_size fan_in = in_size * receptive_field_size fan_out = out_size * receptive_field_size return fan_in, fan_out def _complex_uniform(key: KeyArray, - shape: Sequence[int] | core.NamedShape, + shape: Sequence[int], dtype: DTypeLikeInexact) -> Array: """ Sample uniform random values within a disk on the complex plane, @@ -245,7 +245,7 @@ def _complex_uniform(key: KeyArray, return r * jnp.exp(1j * theta) def _complex_truncated_normal(key: KeyArray, upper: ArrayLike, - shape: Sequence[int] | core.NamedShape, + shape: Sequence[int], dtype: DTypeLikeInexact) -> Array: """ Sample random values from a centered normal distribution on the complex plane, @@ -317,9 +317,9 @@ def variance_scaling( def init(key: KeyArray, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: + shape = core.canonicalize_shape(shape) dtype = dtypes.canonicalize_dtype(dtype) - named_shape = core.as_named_shape(shape) - fan_in, fan_out = _compute_fans(named_shape, in_axis, out_axis, batch_axis) + fan_in, fan_out = _compute_fans(shape, in_axis, out_axis, batch_axis) if mode == "fan_in": denominator = fan_in elif mode == "fan_out": denominator = fan_out elif mode == "fan_avg": denominator = (fan_in + fan_out) / 2 @@ -332,18 +332,18 @@ def init(key: KeyArray, if jnp.issubdtype(dtype, jnp.floating): # constant is stddev of standard normal truncated to (-2, 2) stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype) - return random.truncated_normal(key, -2, 2, named_shape, dtype) * stddev + return random.truncated_normal(key, -2, 2, shape, dtype) * stddev else: # constant is stddev of complex standard normal truncated to 2 stddev = jnp.sqrt(variance) / jnp.array(.95311164380491208, dtype) - return _complex_truncated_normal(key, 2, named_shape, dtype) * stddev + return _complex_truncated_normal(key, 2, shape, dtype) * stddev elif distribution == "normal": - return random.normal(key, named_shape, dtype) * jnp.sqrt(variance) + return random.normal(key, shape, dtype) * jnp.sqrt(variance) elif distribution == "uniform": if jnp.issubdtype(dtype, jnp.floating): - return random.uniform(key, named_shape, dtype, -1) * jnp.sqrt(3 * variance) + return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance) else: - return _complex_uniform(key, named_shape, dtype) * jnp.sqrt(variance) + return _complex_uniform(key, shape, dtype) * jnp.sqrt(variance) else: raise ValueError(f"invalid distribution for variance scaling initializer: {distribution}") diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 97770ff1aa38..3eea7e22feed 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -148,9 +148,7 @@ def _logical_aval_to_interpret_mode_aval(aval): return aval.update(inner_aval=inner_aval) if isinstance(aval, jax_core.ShapedArray): inner_dtype = _logical_to_interpret_mode_dtype(aval.dtype) - return jax_core.ShapedArray(aval.shape, - inner_dtype, - weak_type=aval.weak_type, named_shape=aval.named_shape) + return jax_core.ShapedArray(aval.shape, inner_dtype, weak_type=aval.weak_type) return aval def _get_next_indices(grid, indices): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index a2b5993eabe1..32f856be5d69 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -490,7 +490,7 @@ def eval_shape(*args, **kwargs): p, _ = _infer_params(fun, jit_info, args, kwargs) out_s = [None if is_unspecified(s) else s for s in p.params['out_shardings']] # TODO(yashkatariya): Add `Layout` to SDS. - out = [api.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape, sharding=s) + out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s) for x, s in zip(p.params['jaxpr'].out_avals, out_s)] return tree_unflatten(p.out_tree, out) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 18ae281629b6..3b179c144b4c 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -44,7 +44,6 @@ from jax._src.interpreters import pxla from jax._src.interpreters import xla from jax._src.lax import lax as lax_internal -from jax._src.lax import utils as lax_utils from jax._src.lib import gpu_prng from jax._src.lib import xla_client as xc from jax._src.lib import version as jaxlib_version @@ -611,8 +610,7 @@ def random_fold_in(keys, msgs): def random_fold_in_abstract_eval(keys_aval, msgs_aval): shape = lax_internal.broadcasting_shape_rule( 'random_fold_in', keys_aval, msgs_aval) - named_shape = lax_utils.standard_named_shape_rule(keys_aval, msgs_aval) - return core.ShapedArray(shape, keys_aval.dtype, named_shape=named_shape) + return core.ShapedArray(shape, keys_aval.dtype) @random_fold_in_p.def_impl def random_fold_in_impl(keys, msgs): @@ -640,19 +638,7 @@ def random_fold_in_lowering(ctx, keys, msgs): def random_bits(keys, bit_width, shape): - shape = core.as_named_shape(shape) - for name, size in shape.named_items: - # TODO(frostig,mattjj,apaszke): Is this real_size check necessary, - # and is it meant to raise a user-facing ValueError? Should it be - # an `assert` (or RuntimeError) instead? Why do we check it in - # calls to `random_bits` instead of a more common paralleism path? - real_size = lax.psum(1, name) - if real_size != size: - raise ValueError(f"The shape of axis {name} was specified as {size}, " - f"but it really is {real_size}") - axis_index = lax.axis_index(name) - keys = random_fold_in(keys, axis_index) - return random_bits_p.bind(keys, bit_width=bit_width, shape=shape.positional) + return random_bits_p.bind(keys, bit_width=bit_width, shape=shape) random_bits_p = core.Primitive('random_bits') ad.defjvp_zero(random_bits_p) @@ -822,8 +808,7 @@ def _threefry2x32_abstract_eval(*args): .format(args)) if all(isinstance(arg, core.ShapedArray) for arg in args): shape = lax_internal.broadcasting_shape_rule(*args) - named_shape = core.join_named_shapes(*(a.named_shape for a in args)) - aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32), named_shape=named_shape) + aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32)) else: aval = core.UnshapedArray(jnp.dtype(jnp.uint32)) return (aval,) * 2 diff --git a/jax/_src/random.py b/jax/_src/random.py index 6a0a3c0f9932..a908f9b3a3b7 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -35,7 +35,6 @@ from jax._src import prng from jax._src import xla_bridge from jax._src.api import jit, vmap -from jax._src.core import NamedShape from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -319,12 +318,10 @@ def wrap_key_data(key_bits_array: Array, *, ### random samplers -def _check_shape(name: str, shape: Shape | NamedShape, *param_shapes) -> None: - shape = core.as_named_shape(shape) - +def _check_shape(name: str, shape: Shape, *param_shapes) -> None: if param_shapes: - shape_ = lax.broadcast_shapes(shape.positional, *param_shapes) - if shape.positional != shape_: + shape_ = lax.broadcast_shapes(shape, *param_shapes) # type: ignore + if shape != shape_: msg = ("{} parameter shapes must be broadcast-compatible with shape " "argument, and the result of broadcasting the shapes must equal " "the shape argument, but got result {} for shape argument {}.") @@ -361,7 +358,7 @@ def bits(key: KeyArrayLike, def uniform(key: KeyArrayLike, - shape: Shape | NamedShape = (), + shape: Shape = (), dtype: DTypeLikeFloat = float, minval: RealArray = 0., maxval: RealArray = 1.) -> Array: @@ -381,12 +378,12 @@ def uniform(key: KeyArrayLike, """ key, _ = _check_prng_key("uniform", key) dtypes.check_user_dtype_supported(dtype) + shape = core.canonicalize_shape(shape) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `uniform` must be a float dtype, " f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - shape = core.as_named_shape(shape) return _uniform(key, shape, dtype, minval, maxval) @partial(jit, static_argnums=(1, 2)) @@ -397,8 +394,8 @@ def _uniform(key, shape, dtype, minval, maxval) -> Array: minval = lax.convert_element_type(minval, dtype) maxval = lax.convert_element_type(maxval, dtype) - minval = lax.broadcast_to_rank(minval, shape.positional_rank) - maxval = lax.broadcast_to_rank(maxval, shape.positional_rank) + minval = lax.broadcast_to_rank(minval, len(shape)) + maxval = lax.broadcast_to_rank(maxval, len(shape)) finfo = jnp.finfo(dtype) nbits, nmant = finfo.bits, finfo.nmant @@ -427,7 +424,7 @@ def _uniform(key, shape, dtype, minval, maxval) -> Array: floats = lax.bitcast_convert_type(float_bits, dtype) - np.array(1., dtype) return lax.max( minval, - lax.reshape(floats * (maxval - minval) + minval, shape.positional)) + lax.reshape(floats * (maxval - minval) + minval, shape)) def randint(key: KeyArrayLike, @@ -674,7 +671,7 @@ def choice(key: KeyArrayLike, def normal(key: KeyArrayLike, - shape: Shape | NamedShape = (), + shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: r"""Sample standard normal random values with given shape and float dtype. @@ -696,12 +693,12 @@ def normal(key: KeyArrayLike, A random array with the specified shape and dtype. """ key, _ = _check_prng_key("normal", key) + shape = core.canonicalize_shape(shape) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.inexact): raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, " f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - shape = core.as_named_shape(shape) return _normal(key, shape, dtype) @partial(jit, static_argnums=(1, 2)) @@ -812,7 +809,7 @@ def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array: def truncated_normal(key: KeyArrayLike, lower: RealArray, upper: RealArray, - shape: Shape | NamedShape | None = None, + shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: r"""Sample truncated standard normal random values with given shape and dtype. @@ -847,8 +844,6 @@ def truncated_normal(key: KeyArrayLike, raise ValueError(f"dtype argument to `truncated_normal` must be a float " f"dtype, got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - if shape is not None: - shape = core.as_named_shape(shape) return _truncated_normal(key, lower, upper, shape, dtype) @partial(jit, static_argnums=(3, 4)) @@ -877,7 +872,7 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> Array: def bernoulli(key: KeyArrayLike, p: RealArray = np.float32(0.5), - shape: Shape | NamedShape | None = None) -> Array: + shape: Shape | None = None) -> Array: r"""Sample Bernoulli random values with given shape and mean. The values are distributed according to the probability mass function: @@ -901,8 +896,6 @@ def bernoulli(key: KeyArrayLike, """ key, _ = _check_prng_key("bernoulli", key) dtype = dtypes.canonicalize_dtype(lax.dtype(p)) - if shape is not None: - shape = core.as_named_shape(shape) if not jnp.issubdtype(dtype, np.floating): msg = "bernoulli probability `p` must have a floating dtype, got {}." raise TypeError(msg.format(dtype)) diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index c97888de3560..a71d671c5345 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -218,11 +218,9 @@ def get_ref_state_effects( if isinstance(eff, (ReadEffect, WriteEffect, AccumEffect)) and eff.input_index == i} for i, _ in enumerate(avals)] -def shaped_array_ref(shape: tuple[int, ...], dtype, - weak_type: bool = False, - named_shape = None) -> AbstractRef: - return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type, - named_shape=named_shape)) +def shaped_array_ref( + shape: tuple[int, ...], dtype, weak_type: bool = False) -> AbstractRef: + return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type)) def _shard_ref(mesh, names, ref_aval: AbstractRef): del mesh diff --git a/jax/core.py b/jax/core.py index b023d2daf163..80025e8619f3 100644 --- a/jax/core.py +++ b/jax/core.py @@ -42,7 +42,6 @@ MainTrace as MainTrace, MapPrimitive as MapPrimitive, NameGatheringSubst as NameGatheringSubst, - NamedShape as NamedShape, OutDBIdx as OutDBIdx, OutputType as OutputType, ParamDict as ParamDict, @@ -61,7 +60,6 @@ Var as Var, abstract_token as abstract_token, apply_todos as apply_todos, - as_named_shape as as_named_shape, aval_mapping_handlers as aval_mapping_handlers, axis_frame as axis_frame, call as call, @@ -97,7 +95,6 @@ jaxpr_uses_outfeed as jaxpr_uses_outfeed, jaxprs_in_params as jaxprs_in_params, join_effects as join_effects, - join_named_shapes as join_named_shapes, lattice_join as lattice_join, leaked_tracer_error as leaked_tracer_error, literalable_types as literalable_types, diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 893dae286587..4e40a83bfdfa 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -572,9 +572,7 @@ def _unshard_shaped_array(mesh: Mesh, names: AxisNames, aval: core.AbstractValue,) -> core.AbstractValue: assert isinstance(aval, core.ShapedArray) return aval.update(tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) - for i, sz in enumerate(aval.shape)), - named_shape={k: v for k, v in aval.named_shape.items() - if k not in mesh.shape}) + for i, sz in enumerate(aval.shape))) core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array # Type-checking diff --git a/tests/api_test.py b/tests/api_test.py index b5f4ba48e8c8..8f4ebcfe52c0 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2706,6 +2706,8 @@ def __init__(self, *args, **kwargs): self.assertEqual(out_shape.shape, (3,)) def test_eval_shape_names(self): + raise unittest.SkipTest("named shape are deprecated") + def fun(x, y): return lax.psum(x, 'i') + y @@ -6571,6 +6573,7 @@ def f(x): self.assertIn('psum', str(jaxpr)) def test_make_jaxpr_named(self): + raise unittest.SkipTest("named shape are deprecated") def f(x): return x - lax.psum(x, 'i') diff --git a/tests/core_test.py b/tests/core_test.py index cb080692ceb4..0838702c4be6 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -550,32 +550,6 @@ def test_raise_to_shaped_weak_type(self, value, weak_type): aval = core.raise_to_shaped(core.get_aval(value)) self.assertEqual(aval.weak_type, weak_type) - def test_lattice_join_named_shape(self): - aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10}) - self.assertEqual(core.lattice_join(aval1, aval1), aval1) - - aval2 = core.ShapedArray((2, 3), np.float32, False, {'j': 5}) - expected = core.ShapedArray((2, 3), np.float32, False, {'i': 10, 'j': 5}) - self.assertEqual(core.lattice_join(aval1, aval2), expected) - - aval3 = core.ShapedArray((2, 3), np.float32, False, {'i': 5}) - self.assertRaises(TypeError, lambda: core.lattice_join(aval1, aval3)) - - def test_typecompat_named_shape(self): - aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10}) - aval2 = core.ShapedArray((2, 3), np.float32, False, {'j': 5}) - self.assertTrue(core.typecompat(aval1, aval2)) - - aval3 = core.ShapedArray((2, 3), np.float32, False, {'i': 5}) - self.assertFalse(core.typecompat(aval1, aval3)) - - def test_named_shape_comparision(self): - self.assertTrue(core.NamedShape(2, 3) == (2, 3)) - self.assertFalse(core.NamedShape(2, i=3) == (2,)) - self.assertFalse(core.NamedShape(2, i=3) == (2, 3)) - self.assertFalse(core.NamedShape(2, i=3) == None) - self.assertFalse(core.NamedShape() == []) - @jtu.with_config(jax_dynamic_shapes=True) class DynamicShapesTest(jtu.JaxTestCase): diff --git a/tests/lax_test.py b/tests/lax_test.py index a567af500176..18c501cba36c 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3287,26 +3287,6 @@ def testLog1pNearOne(self): expected.astype(np.complex64), lax.log1p(np.complex64(1e-5))) -class LaxNamedShapeTest(jtu.JaxTestCase): - - def test_abstract_eval(self): - aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10}) - out, _ = lax.sin_p.abstract_eval(aval1) - self.assertEqual(out, aval1) - - aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10}) - aval2 = core.ShapedArray((2, 3), np.float32, False, {'j': 5}) - expected = core.ShapedArray((2, 3), np.float32, False, {'i': 10, 'j': 5}) - out, _ = lax.add_p.abstract_eval(aval1, aval2) - self.assertEqual(out, expected) - - def test_abstract_eval_collective(self): - with core.extend_axis_env('i', 10, None): - aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10, 'j': 5}) - expected = core.ShapedArray((2, 3), np.float32, False, {'j': 5}) - (out,), _ = lax.psum_p.abstract_eval(aval1, axes=('i',), axis_index_groups=None) - self.assertEqual(out, expected) - class FooTyRules: # handlers