Skip to content

Commit

Permalink
add unique
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Jan 22, 2025
1 parent 81bbb78 commit 993c6a7
Show file tree
Hide file tree
Showing 3 changed files with 258 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/haliax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
)
from .hof import fold, map, scan, vmap
from .jax_utils import filter_checkpoint
from .ops import clip, isclose, pad_left, trace, tril, triu, where
from .ops import clip, isclose, pad_left, trace, tril, triu, unique, where
from .partitioning import auto_sharded, axis_mapping, fsdp, named_jit, shard, shard_with_axis_mapping
from .specialized_fns import top_k
from .types import Scalar
Expand Down Expand Up @@ -1011,6 +1011,7 @@ def true_divide(x1: NamedOrNumeric, x2: NamedOrNumeric, /) -> NamedOrNumeric:
"vmap",
"trace",
"where",
"unique",
"clip",
"tril",
"triu",
Expand Down
148 changes: 147 additions & 1 deletion src/haliax/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

import jax
import jax.numpy as jnp
from jaxtyping import ArrayLike

import haliax

from .axis import Axis, AxisSelector
from .core import NamedArray, NamedOrNumeric, broadcast_arrays, broadcast_arrays_and_return_axes, named
Expand Down Expand Up @@ -156,4 +159,147 @@ def raw_array_or_scalar(x: NamedOrNumeric):
return x


__all__ = ["trace", "where", "tril", "triu", "isclose", "pad_left", "clip"]
@typing.overload
def unique(
array: NamedArray, Unique: Axis, *, axis: AxisSelector | None = None, fill_value: ArrayLike | None = None
) -> NamedArray:
...


@typing.overload
def unique(
array: NamedArray,
Unique: Axis,
*,
return_index: typing.Literal[True],
axis: AxisSelector | None = None,
fill_value: ArrayLike | None = None,
) -> tuple[NamedArray, NamedArray]:
...


@typing.overload
def unique(
array: NamedArray,
Unique: Axis,
*,
return_inverse: typing.Literal[True],
axis: AxisSelector | None = None,
fill_value: ArrayLike | None = None,
) -> tuple[NamedArray, NamedArray]:
...


@typing.overload
def unique(
array: NamedArray,
Unique: Axis,
*,
return_counts: typing.Literal[True],
axis: AxisSelector | None = None,
fill_value: ArrayLike | None = None,
) -> tuple[NamedArray, NamedArray]:
...


@typing.overload
def unique(
array: NamedArray,
Unique: Axis,
*,
return_index: bool = False,
return_inverse: bool = False,
return_counts: bool = False,
axis: AxisSelector | None = None,
fill_value: ArrayLike | None = None,
) -> NamedArray | tuple[NamedArray, ...]:
...


def unique(
array: NamedArray,
Unique: Axis,
*,
return_index: bool = False,
return_inverse: bool = False,
return_counts: bool = False,
axis: AxisSelector | None = None,
fill_value: ArrayLike | None = None,
) -> NamedArray | tuple[NamedArray, ...]:
"""
Like jnp.unique, but with named axes.
Args:
array: The input array.
Unique: The name of the axis that will be created to hold the unique values.
fill_value: The value to use for the fill_value argument of jnp.unique
axis: The axis along which to find unique values.
return_index: If True, return the indices of the unique values.
return_inverse: If True, return the indices of the input array that would reconstruct the unique values.
"""
size = Unique.size

is_multireturn = return_index or return_inverse or return_counts

kwargs = dict(
size=size,
fill_value=fill_value,
return_index=return_index,
return_inverse=return_inverse,
return_counts=return_counts,
)

if axis is not None:
axis_index = array._lookup_indices(axis)
if axis_index is None:
raise ValueError(f"Axis {axis} not found in array. Available axes: {array.axes}")
out = jnp.unique(array.array, axis=axis_index, **kwargs)
else:
out = jnp.unique(array.array, **kwargs)

if is_multireturn:
unique = out[0]
next_index = 1
if return_index:
index = out[next_index]
next_index += 1
if return_inverse:
inverse = out[next_index]
next_index += 1
if return_counts:
counts = out[next_index]
next_index += 1
else:
unique = out

ret = []

if axis is not None:
out_axes = haliax.axis.replace_axis(array.axes, axis, Unique)
else:
out_axes = (Unique,)

unique_values = haliax.named(unique, out_axes)
if not is_multireturn:
return unique_values

ret.append(unique_values)

if return_index:
ret.append(haliax.named(index, Unique))

if return_inverse:
if axis is not None:
assert axis_index is not None
inverse = haliax.named(inverse, array.axes[axis_index])
else:
inverse = haliax.named(inverse, array.axes)
ret.append(inverse)

if return_counts:
ret.append(haliax.named(counts, Unique))

return tuple(ret)


__all__ = ["trace", "where", "tril", "triu", "isclose", "pad_left", "clip", "unique"]
109 changes: 109 additions & 0 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,112 @@ def test_reductions_produce_scalar_named_arrays_when_None_axis():
# But if we specify axes, we always get a NamedArray, even if it's a scalar
assert isinstance(hax.mean(named1, axis=("Height", "Width")), NamedArray)
assert hax.mean(named1, axis=("Height", "Width")).axes == ()


def test_unique():
# named version of this test:
# >>> M = jnp.array([[1, 2],
# ... [2, 3],
# ... [1, 2]])
# >>> jnp.unique(M)
# Array([1, 2, 3], dtype=int32)

Height = Axis("Height", 3)
Width = Axis("Width", 2)

named1 = hax.named([[1, 2], [2, 3], [1, 2]], (Height, Width))

U = Axis("U", 3)

named2 = hax.unique(named1, U)

assert jnp.all(jnp.equal(named2.array, jnp.array([1, 2, 3])))

# If you pass an ``axis`` keyword, you can find unique *slices* of the array along
# that axis:
#
# >>> jnp.unique(M, axis=0)
# Array([[1, 2],
# [2, 3]], dtype=int32)

U2 = Axis("U2", 2)
named3 = hax.unique(named1, U2, axis=Height)
assert jnp.all(jnp.equal(named3.array, jnp.array([[1, 2], [2, 3]])))

# >>> x = jnp.array([3, 4, 1, 3, 1])
# >>> values, indices = jnp.unique(x, return_index=True)
# >>> print(values)
# [1 3 4]
# >>> print(indices)
# [2 0 1]
# >>> jnp.all(values == x[indices])
# Array(True, dtype=bool)

x = hax.named([3, 4, 1, 3, 1], ("Height",))
U3 = Axis("U3", 3)
values, indices = hax.unique(x, U3, return_index=True)

assert jnp.all(jnp.equal(values.array, jnp.array([1, 3, 4])))
assert jnp.all(jnp.equal(indices.array, jnp.array([2, 0, 1])))

assert jnp.all(jnp.equal(values.array, x[{"Height": indices}].array))

# If you set ``return_inverse=True``, then ``unique`` returns the indices within the
# unique values for every entry in the input array:
#
# >>> x = jnp.array([3, 4, 1, 3, 1])
# >>> values, inverse = jnp.unique(x, return_inverse=True)
# >>> print(values)
# [1 3 4]
# >>> print(inverse)
# [1 2 0 1 0]
# >>> jnp.all(values[inverse] == x)
# Array(True, dtype=bool)

values, inverse = hax.unique(x, U3, return_inverse=True)

assert jnp.all(jnp.equal(values.array, jnp.array([1, 3, 4])))
assert jnp.all(jnp.equal(inverse.array, jnp.array([1, 2, 0, 1, 0])))

# In multiple dimensions, the input can be reconstructed using
# :func:`jax.numpy.take`:
#
# >>> values, inverse = jnp.unique(M, axis=0, return_inverse=True)
# >>> jnp.all(jnp.take(values, inverse, axis=0) == M)
# Array(True, dtype=bool)
#

values, inverse = hax.unique(named1, U3, axis=Height, return_inverse=True)

assert jnp.all((values[{"U3": inverse}] == named1).array)

# **Returning counts**
# If you set ``return_counts=True``, then ``unique`` returns the number of occurrences
# within the input for every unique value:
#
# >>> x = jnp.array([3, 4, 1, 3, 1])
# >>> values, counts = jnp.unique(x, return_counts=True)
# >>> print(values)
# [1 3 4]
# >>> print(counts)
# [2 2 1]
#
# For multi-dimensional arrays, this also returns a 1D array of counts
# indicating number of occurrences along the specified axis:
#
# >>> values, counts = jnp.unique(M, axis=0, return_counts=True)
# >>> print(values)
# [[1 2]
# [2 3]]
# >>> print(counts)
# [2 1]

values, counts = hax.unique(x, U3, return_counts=True)

assert jnp.all(jnp.equal(values.array, jnp.array([1, 3, 4])))
assert jnp.all(jnp.equal(counts.array, jnp.array([2, 2, 1])))

values, counts = hax.unique(named1, U2, axis=Height, return_counts=True)

assert jnp.all(jnp.equal(values.array, jnp.array([[1, 2], [2, 3]])))
assert jnp.all(jnp.equal(counts.array, jnp.array([2, 1])))

0 comments on commit 993c6a7

Please sign in to comment.