From 1b8eca655c4ec99c9a66951cac7152629da928f9 Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Tue, 27 Aug 2024 13:54:21 -0700 Subject: [PATCH] Add additional error fallbacks during rendering. - Avoid calling __repr__ in the renderer last-resort fallback, since this is user-defined and can still throw an exception for invalid objects. - Catch errors while rendering deferred parts, to avoid breaking the IPython display integration if a deferred part raises an exception. (This can happen due to bugs in array rendering adapters, for instance.) Before, errors in deferred parts could bubble out to IPython, which would fall back to text rendering, even though a partial HTML output was already shown. - Add tests for custom handlers and error checking logic. PiperOrigin-RevId: 668126136 --- tests/fixtures/treescope_examples_fixture.py | 43 ++++++ tests/renderer_test.py | 122 ++++++++++++++++-- .../_internal/api/ipython_integration.py | 1 + treescope/lowering.py | 36 +++++- treescope/renderers.py | 14 +- 5 files changed, 198 insertions(+), 18 deletions(-) diff --git a/tests/fixtures/treescope_examples_fixture.py b/tests/fixtures/treescope_examples_fixture.py index 57d9df5..70627a7 100644 --- a/tests/fixtures/treescope_examples_fixture.py +++ b/tests/fixtures/treescope_examples_fixture.py @@ -27,6 +27,7 @@ import jax import torch +import treescope class MyTestEnum(enum.Enum): @@ -122,6 +123,7 @@ def wrapped_function(): class SomethingCallable: + def __call__(self, value: int) -> int: return value + 1 @@ -175,6 +177,47 @@ def __repr__(self): return "Non-idiomatic\nmultiline\nobject" +class ObjectWithCustomHandler: + + def __treescope_repr__(self, path, subtree_renderer): + del subtree_renderer + return treescope.rendering_parts.text( + f"" + ) + + +class ObjectWithCustomHandlerThatThrows: + + def __treescope_repr__(self, path, subtree_renderer): + del path, subtree_renderer + raise RuntimeError("Simulated treescope_repr failure!") + + def __repr__(self): + return "" + + +class ObjectWithReprThatThrows: + + def __repr__(self): + raise RuntimeError("Simulated repr failure!") + + +class ObjectWithCustomHandlerThatThrowsDeferred: + + def __treescope_repr__(self, path, subtree_renderer): + del path, subtree_renderer + def _internal_main_thunk(layout_decision): + del layout_decision + raise RuntimeError("Simulated deferred treescope_repr failure!") + + return treescope.lowering.maybe_defer_rendering( + main_thunk=_internal_main_thunk, + placeholder_thunk=lambda: treescope.rendering_parts.text( + "" + ), + ) + + class SomePyTorchModule(torch.nn.Module): """A basic PyTorch module to test rendering.""" diff --git a/tests/renderer_test.py b/tests/renderer_test.py index 3c6bdca..5b3b1be 100644 --- a/tests/renderer_test.py +++ b/tests/renderer_test.py @@ -395,17 +395,13 @@ def hook_that_crashes(node, path, node_renderer): testcase_name="well_known_function", target=treescope.render_to_text, expected_collapsed="render_to_text", - expected_roundtrip_collapsed=( - "treescope.render_to_text" - ), + expected_roundtrip_collapsed="treescope.render_to_text", ), dict( testcase_name="well_known_type", target=treescope.IPythonVisualization, expected_collapsed="IPythonVisualization", - expected_roundtrip_collapsed=( - "treescope.IPythonVisualization" - ), + expected_roundtrip_collapsed="treescope.IPythonVisualization", ), dict( testcase_name="ast_nodes", @@ -426,6 +422,21 @@ def hook_that_crashes(node, path, node_renderer): ), )"""), ), + dict( + testcase_name="custom_handler", + target=[fixture_lib.ObjectWithCustomHandler()], + expected_collapsed=( + "[]" + ), + ), + dict( + testcase_name="custom_handler_that_throws", + target=[fixture_lib.ObjectWithCustomHandlerThatThrows()], + ignore_exceptions=True, + expected_collapsed=( + "[]" + ), + ), dict( testcase_name="dtype_standard", target=np.dtype(np.float32), @@ -533,6 +544,7 @@ def test_object_rendering( expected_roundtrip: str | None = None, expected_roundtrip_collapsed: str | None = None, expand_depth: int = 1, + ignore_exceptions: bool = False, ): if target_builder is not None: assert target is None @@ -541,7 +553,9 @@ def test_object_rendering( renderer = treescope.active_renderer.get() # Render it to IR. rendering = rendering_parts.build_full_line_with_annotations( - renderer.to_foldable_representation(target) + renderer.to_foldable_representation( + target, ignore_exceptions=ignore_exceptions + ) ) # Collapse all foldables. @@ -671,6 +685,69 @@ def test_fallback_repr_one_line(self): ]"""), ) + def test_fallback_repr_after_error(self): + target = [fixture_lib.ObjectWithCustomHandlerThatThrows()] + renderer = treescope.active_renderer.get() + + with self.assertRaisesWithLiteralMatch( + RuntimeError, "Simulated treescope_repr failure!" + ): + renderer.to_foldable_representation(target) + + rendering = rendering_parts.build_full_line_with_annotations( + renderer.to_foldable_representation(target, ignore_exceptions=True) + ) + + layout_algorithms.expand_to_depth(rendering, 0) + self.assertEqual( + lowering.render_to_text_as_root(rendering), + "[]", + ) + layout_algorithms.expand_to_depth(rendering, 2) + self.assertEqual( + lowering.render_to_text_as_root(rendering), + textwrap.dedent(f"""\ + [ + , # {object.__repr__(target[0])} + ]"""), + ) + + def test_ignore_exceptions_in_deferred(self): + target = [fixture_lib.ObjectWithCustomHandlerThatThrowsDeferred()] + renderer = treescope.active_renderer.get() + + with self.assertRaisesWithLiteralMatch( + RuntimeError, "Simulated deferred treescope_repr failure!" + ): + renderer.to_foldable_representation(target) + + with lowering.collecting_deferred_renderings() as deferreds: + foldable_ir = rendering_parts.build_full_line_with_annotations( + renderer.to_foldable_representation(target) + ) + + # It's difficult to test the IPython wrapper so we instead test the internal + # helper function that produces the streaming HTML output. + html_parts = lowering._render_to_html_as_root_streaming( + root_node=foldable_ir, + roundtrip=False, + deferreds=deferreds, + ignore_exceptions=True, + ) + self.assertContainsInOrder( + [ + "[", + "<RuntimeError during deferred rendering", + "Traceback", + "in _internal_main_thunk", + "raise RuntimeError", + "RuntimeError: Simulated deferred treescope_repr failure!", + ">", + "]", + ], + "".join(html_parts), + ) + def test_fallback_repr_multiline_idiomatic(self): target = [fixture_lib.UnknownObjectWithMultiLineRepr()] renderer = treescope.active_renderer.get() @@ -738,6 +815,33 @@ def test_fallback_repr_basic(self): ]"""), ) + def test_failsafe_for_throw_in_repr(self): + target = [fixture_lib.ObjectWithReprThatThrows()] + renderer = treescope.active_renderer.get() + + with self.assertRaisesWithLiteralMatch( + RuntimeError, "Simulated repr failure!" + ): + renderer.to_foldable_representation(target) + + rendering = rendering_parts.build_full_line_with_annotations( + renderer.to_foldable_representation(target, ignore_exceptions=True) + ) + + layout_algorithms.expand_to_depth(rendering, 0) + self.assertEqual( + lowering.render_to_text_as_root(rendering), + f"[{object.__repr__(target[0])}]", + ) + layout_algorithms.expand_to_depth(rendering, 2) + self.assertEqual( + lowering.render_to_text_as_root(rendering), + textwrap.dedent(f"""\ + [ + {object.__repr__(target[0])}, # Error occured while formatting this object. + ]"""), + ) + def test_shared_values(self): shared = ["bar"] target = [shared, shared, {"foo": shared}] @@ -803,9 +907,7 @@ def inner_autovisualizer(node, path): ), ) - with treescope.active_autovisualizer.set_scoped( - autovisualizer_for_test - ): + with treescope.active_autovisualizer.set_scoped(autovisualizer_for_test): renderer = treescope.active_renderer.get() rendering = rendering_parts.build_full_line_with_annotations( renderer.to_foldable_representation(target) diff --git a/treescope/_internal/api/ipython_integration.py b/treescope/_internal/api/ipython_integration.py index 05f9f42..b52a1af 100644 --- a/treescope/_internal/api/ipython_integration.py +++ b/treescope/_internal/api/ipython_integration.py @@ -94,6 +94,7 @@ def _display_and_maybe_steal( roundtrip=roundtrip_mode, compressed=compress_html, stealable=stealable, + ignore_exceptions=ignore_exceptions, ) else: rendering = lowering.render_to_html_as_root( diff --git a/treescope/lowering.py b/treescope/lowering.py index 0bdd9d5..1924826 100644 --- a/treescope/lowering.py +++ b/treescope/lowering.py @@ -23,6 +23,7 @@ import contextlib import io import json +import traceback from typing import Any, Callable, Iterator, Sequence import uuid @@ -162,8 +163,10 @@ def render_to_text_as_root( Text for the rendered node. """ if strip_whitespace_lines and not strip_trailing_whitespace: - raise ValueError("strip_whitespace_lines must be False if " - "strip_trailing_whitespace is False.") + raise ValueError( + "strip_whitespace_lines must be False if " + "strip_trailing_whitespace is False." + ) stream = io.StringIO() root_node.render_to_text( @@ -223,6 +226,7 @@ def _render_to_html_as_root_streaming( root_node: rendering_parts.RenderableTreePart, roundtrip: bool, deferreds: Sequence[foldable_impl.DeferredWithThunk], + ignore_exceptions: bool = False, ) -> Iterator[str]: """Helper function: renders a root node to HTML one step at a time. @@ -230,6 +234,8 @@ def _render_to_html_as_root_streaming( root_node: The root node to render. roundtrip: Whether to render in roundtrip mode. deferreds: Sequence of deferred objects to render and splice in. + ignore_exceptions: Whether to ignore exceptions during deferred rendering, + replacing them with error markers. Yields: HTML source for the rendered node, followed by logic to substitute each @@ -347,7 +353,26 @@ def _render_one( layout_decision = deferred.placeholder.child.get_expand_state() else: layout_decision = None - replacement_part = deferred.thunk(layout_decision) + try: + replacement_part = deferred.thunk(layout_decision) + except Exception as e: # pylint: disable=broad-except + if not ignore_exceptions: + raise + exc_child = rendering_parts.fold_condition( + expanded=rendering_parts.indented_children( + [rendering_parts.text(traceback.format_exc())] + ), + ) + replacement_part = rendering_parts.error_color( + rendering_parts.build_custom_foldable_tree_node( + label=rendering_parts.text( + f"<{type(e).__name__} during deferred rendering" + ), + contents=rendering_parts.siblings( + exc_child, rendering_parts.text(">") + ), + ).renderable + ) _render_one( replacement_part, deferred.placeholder.saved_at_beginning_of_line, @@ -409,6 +434,7 @@ def display_streaming_as_root( roundtrip: bool = False, compressed: bool = True, stealable: bool = False, + ignore_exceptions: bool = False, ) -> str | None: """Displays a root node in an IPython notebook in a streaming fashion. @@ -419,6 +445,8 @@ def display_streaming_as_root( compressed: Whether to compress the HTML for display. stealable: Whether to return an extra HTML snippet that allows the streaming rendering to be relocated after it is shown. + ignore_exceptions: Whether to ignore exceptions during deferred rendering, + replacing them with error markers. Returns: If ``stealable`` is True, a final HTML snippet which, if inserted into a @@ -431,7 +459,7 @@ def display_streaming_as_root( import IPython.display # pylint: disable=g-import-not-at-top render_iterator = _render_to_html_as_root_streaming( - root_node, roundtrip, deferreds + root_node, roundtrip, deferreds, ignore_exceptions=ignore_exceptions ) encapsulated_iterator = html_encapsulation.encapsulate_streaming_html( render_iterator, compress=compressed, stealable=stealable diff --git a/treescope/renderers.py b/treescope/renderers.py index ee6b46b..78bc440 100644 --- a/treescope/renderers.py +++ b/treescope/renderers.py @@ -328,12 +328,18 @@ def _render_subtree( f"No handler registered for a node of type {type(node)}." ) # Fall back to a basic `repr` so that we still render something even - # without a handler for it. + # without a handler for it. We use the object repr because a custom + # repr may still raise an exception if the object is in an invalid + # state. return rendering_parts.RenderableAndLineAnnotations( - renderable=rendering_parts.abbreviation_color( - rendering_parts.text(repr(node)) + renderable=rendering_parts.error_color( + rendering_parts.text(object.__repr__(node)) + ), + annotations=rendering_parts.comment_color( + rendering_parts.text( + " # Error occured while formatting this object." + ) ), - annotations=rendering_parts.empty_part(), ) else: raise ValueError(