Skip to content

Commit

Permalink
Add additional error fallbacks during rendering.
Browse files Browse the repository at this point in the history
- Avoid calling __repr__ in the renderer last-resort fallback, since
  this is user-defined and can still throw an exception for invalid
  objects.
- Catch errors while rendering deferred parts, to avoid breaking the
  IPython display integration if a deferred part raises an exception.
  (This can happen due to bugs in array rendering adapters, for instance.)
  Before, errors in deferred parts could bubble out to IPython, which would
  fall back to text rendering, even though a partial HTML output was already
  shown.
- Add tests for custom handlers and error checking logic.

PiperOrigin-RevId: 668126136
  • Loading branch information
danieldjohnson authored and Treescope Developers committed Aug 27, 2024
1 parent 5e429cb commit 1b8eca6
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 18 deletions.
43 changes: 43 additions & 0 deletions tests/fixtures/treescope_examples_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import jax
import torch
import treescope


class MyTestEnum(enum.Enum):
Expand Down Expand Up @@ -122,6 +123,7 @@ def wrapped_function():


class SomethingCallable:

def __call__(self, value: int) -> int:
return value + 1

Expand Down Expand Up @@ -175,6 +177,47 @@ def __repr__(self):
return "Non-idiomatic\nmultiline\nobject"


class ObjectWithCustomHandler:

def __treescope_repr__(self, path, subtree_renderer):
del subtree_renderer
return treescope.rendering_parts.text(
f"<ObjectWithCustomHandler custom rendering! Path: {repr(path)}>"
)


class ObjectWithCustomHandlerThatThrows:

def __treescope_repr__(self, path, subtree_renderer):
del path, subtree_renderer
raise RuntimeError("Simulated treescope_repr failure!")

def __repr__(self):
return "<Fallback repr for ObjectWithCustomHandlerThatThrows>"


class ObjectWithReprThatThrows:

def __repr__(self):
raise RuntimeError("Simulated repr failure!")


class ObjectWithCustomHandlerThatThrowsDeferred:

def __treescope_repr__(self, path, subtree_renderer):
del path, subtree_renderer
def _internal_main_thunk(layout_decision):
del layout_decision
raise RuntimeError("Simulated deferred treescope_repr failure!")

return treescope.lowering.maybe_defer_rendering(
main_thunk=_internal_main_thunk,
placeholder_thunk=lambda: treescope.rendering_parts.text(
"<deferred placeholder>"
),
)


class SomePyTorchModule(torch.nn.Module):
"""A basic PyTorch module to test rendering."""

Expand Down
122 changes: 112 additions & 10 deletions tests/renderer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,17 +395,13 @@ def hook_that_crashes(node, path, node_renderer):
testcase_name="well_known_function",
target=treescope.render_to_text,
expected_collapsed="render_to_text",
expected_roundtrip_collapsed=(
"treescope.render_to_text"
),
expected_roundtrip_collapsed="treescope.render_to_text",
),
dict(
testcase_name="well_known_type",
target=treescope.IPythonVisualization,
expected_collapsed="IPythonVisualization",
expected_roundtrip_collapsed=(
"treescope.IPythonVisualization"
),
expected_roundtrip_collapsed="treescope.IPythonVisualization",
),
dict(
testcase_name="ast_nodes",
Expand All @@ -426,6 +422,21 @@ def hook_that_crashes(node, path, node_renderer):
),
)"""),
),
dict(
testcase_name="custom_handler",
target=[fixture_lib.ObjectWithCustomHandler()],
expected_collapsed=(
"[<ObjectWithCustomHandler custom rendering! Path: '[0]'>]"
),
),
dict(
testcase_name="custom_handler_that_throws",
target=[fixture_lib.ObjectWithCustomHandlerThatThrows()],
ignore_exceptions=True,
expected_collapsed=(
"[<Fallback repr for ObjectWithCustomHandlerThatThrows>]"
),
),
dict(
testcase_name="dtype_standard",
target=np.dtype(np.float32),
Expand Down Expand Up @@ -533,6 +544,7 @@ def test_object_rendering(
expected_roundtrip: str | None = None,
expected_roundtrip_collapsed: str | None = None,
expand_depth: int = 1,
ignore_exceptions: bool = False,
):
if target_builder is not None:
assert target is None
Expand All @@ -541,7 +553,9 @@ def test_object_rendering(
renderer = treescope.active_renderer.get()
# Render it to IR.
rendering = rendering_parts.build_full_line_with_annotations(
renderer.to_foldable_representation(target)
renderer.to_foldable_representation(
target, ignore_exceptions=ignore_exceptions
)
)

# Collapse all foldables.
Expand Down Expand Up @@ -671,6 +685,69 @@ def test_fallback_repr_one_line(self):
]"""),
)

def test_fallback_repr_after_error(self):
target = [fixture_lib.ObjectWithCustomHandlerThatThrows()]
renderer = treescope.active_renderer.get()

with self.assertRaisesWithLiteralMatch(
RuntimeError, "Simulated treescope_repr failure!"
):
renderer.to_foldable_representation(target)

rendering = rendering_parts.build_full_line_with_annotations(
renderer.to_foldable_representation(target, ignore_exceptions=True)
)

layout_algorithms.expand_to_depth(rendering, 0)
self.assertEqual(
lowering.render_to_text_as_root(rendering),
"[<Fallback repr for ObjectWithCustomHandlerThatThrows>]",
)
layout_algorithms.expand_to_depth(rendering, 2)
self.assertEqual(
lowering.render_to_text_as_root(rendering),
textwrap.dedent(f"""\
[
<Fallback repr for ObjectWithCustomHandlerThatThrows>, # {object.__repr__(target[0])}
]"""),
)

def test_ignore_exceptions_in_deferred(self):
target = [fixture_lib.ObjectWithCustomHandlerThatThrowsDeferred()]
renderer = treescope.active_renderer.get()

with self.assertRaisesWithLiteralMatch(
RuntimeError, "Simulated deferred treescope_repr failure!"
):
renderer.to_foldable_representation(target)

with lowering.collecting_deferred_renderings() as deferreds:
foldable_ir = rendering_parts.build_full_line_with_annotations(
renderer.to_foldable_representation(target)
)

# It's difficult to test the IPython wrapper so we instead test the internal
# helper function that produces the streaming HTML output.
html_parts = lowering._render_to_html_as_root_streaming(
root_node=foldable_ir,
roundtrip=False,
deferreds=deferreds,
ignore_exceptions=True,
)
self.assertContainsInOrder(
[
"[",
"&lt;RuntimeError during deferred rendering",
"Traceback",
"in _internal_main_thunk",
"raise RuntimeError",
"RuntimeError: Simulated deferred treescope_repr failure!",
"&gt;",
"]",
],
"".join(html_parts),
)

def test_fallback_repr_multiline_idiomatic(self):
target = [fixture_lib.UnknownObjectWithMultiLineRepr()]
renderer = treescope.active_renderer.get()
Expand Down Expand Up @@ -738,6 +815,33 @@ def test_fallback_repr_basic(self):
]"""),
)

def test_failsafe_for_throw_in_repr(self):
target = [fixture_lib.ObjectWithReprThatThrows()]
renderer = treescope.active_renderer.get()

with self.assertRaisesWithLiteralMatch(
RuntimeError, "Simulated repr failure!"
):
renderer.to_foldable_representation(target)

rendering = rendering_parts.build_full_line_with_annotations(
renderer.to_foldable_representation(target, ignore_exceptions=True)
)

layout_algorithms.expand_to_depth(rendering, 0)
self.assertEqual(
lowering.render_to_text_as_root(rendering),
f"[{object.__repr__(target[0])}]",
)
layout_algorithms.expand_to_depth(rendering, 2)
self.assertEqual(
lowering.render_to_text_as_root(rendering),
textwrap.dedent(f"""\
[
{object.__repr__(target[0])}, # Error occured while formatting this object.
]"""),
)

def test_shared_values(self):
shared = ["bar"]
target = [shared, shared, {"foo": shared}]
Expand Down Expand Up @@ -803,9 +907,7 @@ def inner_autovisualizer(node, path):
),
)

with treescope.active_autovisualizer.set_scoped(
autovisualizer_for_test
):
with treescope.active_autovisualizer.set_scoped(autovisualizer_for_test):
renderer = treescope.active_renderer.get()
rendering = rendering_parts.build_full_line_with_annotations(
renderer.to_foldable_representation(target)
Expand Down
1 change: 1 addition & 0 deletions treescope/_internal/api/ipython_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _display_and_maybe_steal(
roundtrip=roundtrip_mode,
compressed=compress_html,
stealable=stealable,
ignore_exceptions=ignore_exceptions,
)
else:
rendering = lowering.render_to_html_as_root(
Expand Down
36 changes: 32 additions & 4 deletions treescope/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import contextlib
import io
import json
import traceback
from typing import Any, Callable, Iterator, Sequence
import uuid

Expand Down Expand Up @@ -162,8 +163,10 @@ def render_to_text_as_root(
Text for the rendered node.
"""
if strip_whitespace_lines and not strip_trailing_whitespace:
raise ValueError("strip_whitespace_lines must be False if "
"strip_trailing_whitespace is False.")
raise ValueError(
"strip_whitespace_lines must be False if "
"strip_trailing_whitespace is False."
)

stream = io.StringIO()
root_node.render_to_text(
Expand Down Expand Up @@ -223,13 +226,16 @@ def _render_to_html_as_root_streaming(
root_node: rendering_parts.RenderableTreePart,
roundtrip: bool,
deferreds: Sequence[foldable_impl.DeferredWithThunk],
ignore_exceptions: bool = False,
) -> Iterator[str]:
"""Helper function: renders a root node to HTML one step at a time.
Args:
root_node: The root node to render.
roundtrip: Whether to render in roundtrip mode.
deferreds: Sequence of deferred objects to render and splice in.
ignore_exceptions: Whether to ignore exceptions during deferred rendering,
replacing them with error markers.
Yields:
HTML source for the rendered node, followed by logic to substitute each
Expand Down Expand Up @@ -347,7 +353,26 @@ def _render_one(
layout_decision = deferred.placeholder.child.get_expand_state()
else:
layout_decision = None
replacement_part = deferred.thunk(layout_decision)
try:
replacement_part = deferred.thunk(layout_decision)
except Exception as e: # pylint: disable=broad-except
if not ignore_exceptions:
raise
exc_child = rendering_parts.fold_condition(
expanded=rendering_parts.indented_children(
[rendering_parts.text(traceback.format_exc())]
),
)
replacement_part = rendering_parts.error_color(
rendering_parts.build_custom_foldable_tree_node(
label=rendering_parts.text(
f"<{type(e).__name__} during deferred rendering"
),
contents=rendering_parts.siblings(
exc_child, rendering_parts.text(">")
),
).renderable
)
_render_one(
replacement_part,
deferred.placeholder.saved_at_beginning_of_line,
Expand Down Expand Up @@ -409,6 +434,7 @@ def display_streaming_as_root(
roundtrip: bool = False,
compressed: bool = True,
stealable: bool = False,
ignore_exceptions: bool = False,
) -> str | None:
"""Displays a root node in an IPython notebook in a streaming fashion.
Expand All @@ -419,6 +445,8 @@ def display_streaming_as_root(
compressed: Whether to compress the HTML for display.
stealable: Whether to return an extra HTML snippet that allows the streaming
rendering to be relocated after it is shown.
ignore_exceptions: Whether to ignore exceptions during deferred rendering,
replacing them with error markers.
Returns:
If ``stealable`` is True, a final HTML snippet which, if inserted into a
Expand All @@ -431,7 +459,7 @@ def display_streaming_as_root(
import IPython.display # pylint: disable=g-import-not-at-top

render_iterator = _render_to_html_as_root_streaming(
root_node, roundtrip, deferreds
root_node, roundtrip, deferreds, ignore_exceptions=ignore_exceptions
)
encapsulated_iterator = html_encapsulation.encapsulate_streaming_html(
render_iterator, compress=compressed, stealable=stealable
Expand Down
14 changes: 10 additions & 4 deletions treescope/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,12 +328,18 @@ def _render_subtree(
f"No handler registered for a node of type {type(node)}."
)
# Fall back to a basic `repr` so that we still render something even
# without a handler for it.
# without a handler for it. We use the object repr because a custom
# repr may still raise an exception if the object is in an invalid
# state.
return rendering_parts.RenderableAndLineAnnotations(
renderable=rendering_parts.abbreviation_color(
rendering_parts.text(repr(node))
renderable=rendering_parts.error_color(
rendering_parts.text(object.__repr__(node))
),
annotations=rendering_parts.comment_color(
rendering_parts.text(
" # Error occured while formatting this object."
)
),
annotations=rendering_parts.empty_part(),
)
else:
raise ValueError(
Expand Down

0 comments on commit 1b8eca6

Please sign in to comment.