Skip to content

Commit

Permalink
Add handlers for new JAX pytree key types.
Browse files Browse the repository at this point in the history
Adds manual handlers for jax.tree_util types SequenceKey, DictKey,
GetAttrKey, and FlattenedIndexKey. These new handlers are needed
because these are no longer ordinary dataclasses in JAX.

Also upgrades the version of JAX used by default with uv, and adds
tests to ensure the new handlers work properly.
  • Loading branch information
danieldjohnson committed Dec 16, 2024
1 parent 25f3abd commit 7f6cb24
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 79 deletions.
55 changes: 55 additions & 0 deletions tests/renderer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,61 @@ def hook_that_crashes(node, path, node_renderer):
]"""),
expected_roundtrip_collapsed="[jax.lax.Precision.HIGHEST]",
),
dict(
testcase_name="jax_ShapeDtypeStruct",
target=jax.ShapeDtypeStruct(shape=(1, 2), dtype=jnp.float32),
expected_collapsed=(
"ShapeDtypeStruct(shape=(1, 2), dtype=dtype('float32'))"
),
expected_expanded=textwrap.dedent("""\
ShapeDtypeStruct(
shape=(1, 2),
dtype=dtype('float32'),
)"""),
expected_roundtrip_collapsed=(
"jax.ShapeDtypeStruct(shape=(1, 2), dtype=np.dtype('float32'))"
),
),
dict(
testcase_name="jax_SequenceKey",
target=jax.tree_util.SequenceKey(42),
expected_collapsed="SequenceKey(idx=42)",
expected_expanded=textwrap.dedent("""\
SequenceKey(
idx=42,
)"""),
expected_roundtrip_collapsed="jax.tree_util.SequenceKey(idx=42)",
),
dict(
testcase_name="jax_DictKey",
target=jax.tree_util.DictKey("a"),
expected_collapsed="DictKey(key='a')",
expected_expanded=textwrap.dedent("""\
DictKey(
key='a',
)"""),
expected_roundtrip_collapsed="jax.tree_util.DictKey(key='a')",
),
dict(
testcase_name="jax_GetAttrKey",
target=jax.tree_util.GetAttrKey("a"),
expected_collapsed="GetAttrKey(name='a')",
expected_expanded=textwrap.dedent("""\
GetAttrKey(
name='a',
)"""),
expected_roundtrip_collapsed="jax.tree_util.GetAttrKey(name='a')",
),
dict(
testcase_name="jax_FlattenedIndexKey",
target=jax.tree_util.FlattenedIndexKey(3),
expected_collapsed="FlattenedIndexKey(key=3)",
expected_expanded=textwrap.dedent("""\
FlattenedIndexKey(
key=3,
)"""),
expected_roundtrip_collapsed="jax.tree_util.FlattenedIndexKey(key=3)",
),
dict(
testcase_name="pytorch_module",
target_builder=fixture_lib.SomePyTorchModule.build,
Expand Down
141 changes: 102 additions & 39 deletions treescope/external/jax_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import functools
import typing
from typing import Mapping
from typing import Any, Mapping, Sequence

import numpy as np
from treescope import canonical_aliases
Expand Down Expand Up @@ -238,10 +238,10 @@ def truncate_array_and_mask(
# each sharding in order to figure out what device order it has, and then
# explicitly request a fully-replicated output that is definitely safe to
# retrieve.
sharding_kwargs["out_shardings"] = (
jax.sharding.GSPMDSharding.get_replicated(
array.sharding._device_assignment # pylint: disable=protected-access
)
sharding_kwargs[
"out_shardings"
] = jax.sharding.GSPMDSharding.get_replicated(
array.sharding._device_assignment # pylint: disable=protected-access
)
if array.size < SUMMARIZE_USING_NUMPY_THRESHOLD and safe_to_summarize(array):
fn = functools.partial(_truncate_part_with_slices, xnp=np)
Expand Down Expand Up @@ -300,39 +300,66 @@ def faster_array_repr(array: jax.Array) -> str:
return f"{prefix}{datastring}, {dtype_str}"


def render_shape_dtype_struct(
node: jax.ShapeDtypeStruct,
path: str | None,
subtree_renderer: renderers.TreescopeSubtreeRenderer,
) -> (
rendering_parts.RenderableTreePart
| rendering_parts.RenderableAndLineAnnotations
| type(NotImplemented)
):
"""Renders jax.ShapeDtypeStruct."""
assert jax is not None, "JAX is not available."
if type(node) is not jax.ShapeDtypeStruct: # pylint: disable=unidiomatic-typecheck
return NotImplemented
attributes = {
"shape": node.shape,
"dtype": node.dtype,
}
if node.sharding is not None:
attributes["sharding"] = node.sharding

# Make sure we can correctly round-trip it. We check because ShapeDtypeStruct
# occasionally adds new attributes for new JAX features.
rebuilt = jax.ShapeDtypeStruct(**attributes)
if rebuilt != node:
return NotImplemented
else:
return repr_lib.render_object_constructor(
object_type=jax.ShapeDtypeStruct,
attributes=attributes,
path=path,
subtree_renderer=subtree_renderer,
roundtrippable=True,
)
def make_checked_dataclasslike_renderer(
cls: type[Any],
fields: Sequence[str],
fields_with_none_default: Sequence[str] = (),
) -> renderers.TreescopeNodeHandler:
"""Builds a roundtrippable renderer for a dataclass-like class.
This function can be used to safely render classes that behave like Python
dataclasses (i.e. they can be roundtripped by calling the constructor with
attributes as keyword arguments). It is robust to potential new attributes
being added by checking that it is possible to rebuild the instance correctly.
This can be ued to render JAX builtin classes.
Args:
cls: The class to render.
fields: A sequence of attribute names to render as keyword args.
fields_with_none_default: A sequence of attribute names to render as keyword
args only if they exist and their value is not None.
Returns:
A node handler for nodes of this type, which returns a simple rendering
whenever the object is correctly described by these attributes.
"""

def render_it(
node: Any,
path: str | None,
subtree_renderer: renderers.TreescopeSubtreeRenderer,
) -> (
rendering_parts.RenderableTreePart
| rendering_parts.RenderableAndLineAnnotations
| type(NotImplemented)
):
if type(node) is not cls: # pylint: disable=unidiomatic-typecheck
raise RuntimeError(f"BAD type {node} {cls}")
return NotImplemented
try:
attributes = {k: getattr(node, k) for k in fields}
except AttributeError:
raise RuntimeError(f"BAD attribute {node} {fields}")
return NotImplemented
for k in fields_with_none_default:
if hasattr(node, k) and getattr(node, k) is not None:
attributes[k] = getattr(node, k)

# Make sure we can correctly round-trip it.
rebuilt = cls(**attributes)
if rebuilt != node:
raise RuntimeError(f"BAD rebuild {node} {rebuilt}")
return NotImplemented
else:
return repr_lib.render_object_constructor(
object_type=cls,
attributes=attributes,
path=path,
subtree_renderer=subtree_renderer,
roundtrippable=True,
)

return render_it


def render_precision(
Expand Down Expand Up @@ -621,7 +648,31 @@ def set_up_treescope():
"Cannot set up JAX support in treescope: JAX cannot be imported."
)
type_registries.TREESCOPE_HANDLER_REGISTRY[jax.ShapeDtypeStruct] = (
render_shape_dtype_struct
make_checked_dataclasslike_renderer(
jax.ShapeDtypeStruct,
fields=("shape", "dtype"),
fields_with_none_default=("sharding",),
)
)
type_registries.TREESCOPE_HANDLER_REGISTRY[jax.tree_util.SequenceKey] = (
make_checked_dataclasslike_renderer(
jax.tree_util.SequenceKey, fields=("idx",)
)
)
type_registries.TREESCOPE_HANDLER_REGISTRY[jax.tree_util.DictKey] = (
make_checked_dataclasslike_renderer(
jax.tree_util.DictKey, fields=("key",)
)
)
type_registries.TREESCOPE_HANDLER_REGISTRY[jax.tree_util.GetAttrKey] = (
make_checked_dataclasslike_renderer(
jax.tree_util.GetAttrKey, fields=("name",)
)
)
type_registries.TREESCOPE_HANDLER_REGISTRY[
jax.tree_util.FlattenedIndexKey
] = make_checked_dataclasslike_renderer(
jax.tree_util.FlattenedIndexKey, fields=("key",)
)
type_registries.TREESCOPE_HANDLER_REGISTRY[jax.lax.Precision] = (
render_precision
Expand All @@ -647,3 +698,15 @@ def set_up_treescope():
canonical_aliases.populate_from_public_api(
jax_api_module, canonical_aliases.prefix_filter("jax")
)

for key_cls_name in [
"SequenceKey",
"DictKey",
"GetAttrKey",
"FlattenedIndexKey",
]:
canonical_aliases.add_alias(
getattr(jax.tree_util, key_cls_name),
canonical_aliases.ModuleAttributePath("jax.tree_util", (key_cls_name,)),
on_conflict="ignore",
)
Loading

0 comments on commit 7f6cb24

Please sign in to comment.