diff --git a/pyproject.toml b/pyproject.toml index 665a83d81..b91a3568e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "bidict", "immutabledict", "loopy>=2020.2", - "pytools>=2024.1.14", + "pytools>=2024.1.21", "pymbolic>=2024.2", "typing_extensions>=4", ] diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 130db8b7a..8cd520d94 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -716,11 +716,19 @@ def unify_axes_tags( equations_collector.equations ) - for tag, var in equations_collector.known_tag_to_var.items(): - if isinstance(tag, AxisIgnoredForPropagationTag): - continue + ignored_vars = set({ + tag_var for tag, tag_var in equations_collector.known_tag_to_var.items() + if isinstance(tag, AxisIgnoredForPropagationTag) + }) + + ignored_vars.update({ + ax_var for (ary, ax), ax_var in equations_collector.axis_to_var.items() + if ary.axes[ax].tags_of_type(AxisIgnoredForPropagationTag) + }) - reachable_nodes = get_reachable_nodes(propagation_graph, var) + for tag, var in equations_collector.known_tag_to_var.items(): + reachable_nodes = get_reachable_nodes(propagation_graph, var, + exclude_nodes=ignored_vars) for reachable_var in (reachable_nodes - known_tag_vars): axis_to_solved_tags.setdefault( equations_collector.axis_to_var.inverse[reachable_var], diff --git a/test/test_pytato.py b/test/test_pytato.py index 38e2fda5e..45d333c32 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1324,6 +1324,42 @@ def test_unify_axes_tags(): # }}} +def test_ignoring_axes_during_propagation(): + from pytools.tag import UniqueTag + + from pytato.transform.metadata import AxisIgnoredForPropagationTag + + class ElementAxisTag(UniqueTag): + pass + + class DOFAxisTagX(UniqueTag): + pass + + class DOFAxisTagY(UniqueTag): + pass + + a = pt.make_placeholder("a", (4, 4)) + a = a.with_tagged_axis(0, AxisIgnoredForPropagationTag()) + a = a.with_tagged_axis(1, AxisIgnoredForPropagationTag()) + + u = pt.make_placeholder("u", (128, 4, 4)) + u = u.with_tagged_axis(0, ElementAxisTag()) + u = u.with_tagged_axis(1, DOFAxisTagX()) + u = u.with_tagged_axis(2, DOFAxisTagY()) + + u_x = pt.einsum("il,elj->eij", a, u) + u_y = pt.einsum("jl,eil->eij", a, u) + + expr = u_x + u_y + + unified = pt.unify_axes_tags(expr) + iax_to_tags = {i: unified.axes[i].tags for i in range(len(unified.axes))} + + assert iax_to_tags[0] == frozenset([ElementAxisTag()]) + assert iax_to_tags[1] == frozenset([DOFAxisTagX()]) + assert iax_to_tags[2] == frozenset([DOFAxisTagY()]) + + def test_rewrite_einsums_with_no_broadcasts(): a = pt.make_placeholder("a", (10, 4, 1)) b = pt.make_placeholder("b", (10, 1, 4))