From 7cdf3d5488ea0520f7d48f7bbd3f006acdd12204 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 7 Jun 2024 17:59:49 -0500 Subject: [PATCH 01/32] add missing *args, **kwargs in WalkMapper.map_call --- pytato/transform/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index b78c24301..49d553a39 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1186,9 +1186,9 @@ def map_call(self, expr: Call, *args: Any, **kwargs: Any) -> None: if not self.visit(expr): return - self.map_function_definition(expr.function) + self.map_function_definition(expr.function, *args, **kwargs) for bnd in expr.bindings.values(): - self.rec(bnd) + self.rec(bnd, *args, **kwargs) self.post_visit(expr) From 3eaf7a1f3af643d28aec53054992a5512ac11f01 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 12 Mar 2024 09:33:07 -0500 Subject: [PATCH 02/32] memoize clone_for_callee --- pytato/distributed/partition.py | 1 + pytato/transform/__init__.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 5865ec491..606d1a203 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -303,6 +303,7 @@ def __init__(self, self.user_input_names: set[str] = set() self.partition_input_name_to_placeholder: dict[str, Placeholder] = {} + @memoize_method def clone_for_callee( self, function: FunctionDefinition) -> _DistributedInputReplacer: # Function definitions aren't allowed to contain receives, diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 49d553a39..527022c0d 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -276,6 +276,7 @@ def rec(self, expr: CopyMapperResultT) -> CopyMapperResultT: def __call__(self, expr: CopyMapperResultT) -> CopyMapperResultT: return self.rec(expr) + @memoize_method def clone_for_callee( self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper: """ @@ -1229,6 +1230,10 @@ def rec(self, expr: ArrayOrNames, *args: Any, **kwargs: Any super().rec(expr, *args, **kwargs) self._visited_nodes.add(cache_key) + @memoize_method + def clone_for_callee( + self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper: + return type(self)() # }}} @@ -1276,6 +1281,7 @@ def __init__(self, map_fn: Callable[[ArrayOrNames], ArrayOrNames]) -> None: super().__init__() self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn + @memoize_method def clone_for_callee( self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper: # type-ignore-reason: self.__init__ has a different function signature From 2de93498bb44761c27df1a5b5dd7922493ade72d Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 11 Mar 2024 19:19:51 -0500 Subject: [PATCH 03/32] remove default CombineMapper map_call implementation --- pytato/transform/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 527022c0d..bf8843268 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -824,9 +824,9 @@ def map_function_definition(self, expr: FunctionDefinition) -> CombineT: " must override map_function_definition.") def map_call(self, expr: Call) -> CombineT: - return self.combine(self.map_function_definition(expr.function), - *[self.rec(bnd) - for name, bnd in sorted(expr.bindings.items())]) + raise NotImplementedError( + "Mapping calls is context-dependent. Derived classes must override " + "map_call.") def map_named_call_result(self, expr: NamedCallResult) -> CombineT: return self.rec(expr._container) From c0e70cdc8821ae7b14c31acf38e996cdf6bd5896 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 12 Mar 2024 09:51:56 -0500 Subject: [PATCH 04/32] don't memoize map_function_definition in cached walk mappers doesn't make sense, since it doesn't return anything --- pytato/analysis/__init__.py | 6 +++--- pytato/transform/__init__.py | 16 +++++++++++++++- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 38ed276fe..5a54807ed 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -29,7 +29,6 @@ from typing import TYPE_CHECKING, Any, Mapping from pymbolic.mapper.optimize import optimize_mapper -from pytools import memoize_method from pytato.array import ( Array, @@ -447,10 +446,9 @@ def __init__(self) -> None: def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) - @memoize_method def map_function_definition(self, /, expr: FunctionDefinition, *args: Any, **kwargs: Any) -> None: - if not self.visit(expr): + if not self.visit(expr) or expr in self._visited_functions: return new_mapper = self.clone_for_callee(expr) @@ -458,6 +456,8 @@ def map_function_definition(self, /, expr: FunctionDefinition, new_mapper(subexpr, *args, **kwargs) self.count += new_mapper.count + self._visited_functions.add(expr) + self.post_visit(expr, *args, **kwargs) def post_visit(self, expr: Any) -> None: diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index bf8843268..5678a0bce 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1217,6 +1217,7 @@ class CachedWalkMapper(WalkMapper): def __init__(self) -> None: super().__init__() self._visited_nodes: set[Any] = set() + self._visited_functions: set[FunctionDefinition] = set() def get_cache_key(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError @@ -1234,6 +1235,20 @@ def rec(self, expr: ArrayOrNames, *args: Any, **kwargs: Any def clone_for_callee( self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper: return type(self)() + + def map_function_definition(self, expr: FunctionDefinition, + *args: Any, **kwargs: Any) -> None: + if not self.visit(expr) or expr in self._visited_functions: + return + + new_mapper = self.clone_for_callee(expr) + for subexpr in expr.returns.values(): + new_mapper(subexpr, *args, **kwargs) + + self._visited_functions.add(expr) + + self.post_visit(expr, *args, **kwargs) + # }}} @@ -1261,7 +1276,6 @@ def get_cache_key(self, expr: ArrayOrNames) -> int: def post_visit(self, expr: Any) -> None: self.topological_order.append(expr) - @memoize_method def map_function_definition(self, expr: FunctionDefinition) -> None: # do nothing as it includes arrays from a different namespace. return From d391d450eb1ba5a6ae19cdbe8a69f5ad08fd1ba5 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 12 Mar 2024 10:11:51 -0500 Subject: [PATCH 05/32] support calls in InputGatherer --- pytato/transform/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 5678a0bce..e54f6fe1a 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -976,6 +976,12 @@ def map_function_definition(self, expr: FunctionDefinition return frozenset(result) + def map_call(self, expr: Call) -> frozenset[InputArgumentBase]: + return self.combine(self.map_function_definition(expr.function), + *[ + self.rec(bnd) + for name, bnd in sorted(expr.bindings.items())]) + # }}} From 244160c2c65f39294c15b62e78c6fb768fab8ec5 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 7 Jun 2024 12:57:17 -0500 Subject: [PATCH 06/32] make NamedCallResult compatible with attrs cache_hash=True --- pytato/function.py | 16 ++++++---------- pytato/transform/__init__.py | 12 ++++++++++-- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/pytato/function.py b/pytato/function.py index 5a4202011..d5d5fb7a0 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -226,6 +226,7 @@ def __call__(self, **kwargs: Array raise NotImplementedError(self.return_type) +@attrs.frozen(eq=False, repr=False, hash=True, cache_hash=True) class NamedCallResult(NamedArray): """ One of the arrays that are returned from a call to :class:`FunctionDefinition`. @@ -243,15 +244,6 @@ class NamedCallResult(NamedArray): name: str _mapper_method: ClassVar[str] = "map_named_call_result" - def __init__(self, - call: Call, - name: str) -> None: - super().__init__(call, name, - axes=call.function.returns[name].axes, - tags=call.function.returns[name].tags, - non_equality_tags=( - call.function.returns[name].non_equality_tags)) - def with_tagged_axis(self, iaxis: int, tags: Sequence[Tag] | Tag) -> Array: raise ValueError("Tagging a NamedCallResult's axis is illegal, use" @@ -318,7 +310,11 @@ def __iter__(self) -> Iterator[str]: return iter(self.function.returns) def __getitem__(self, name: str) -> NamedCallResult: - return NamedCallResult(self, name) + return NamedCallResult( + self, name, + axes=self.function.returns[name].axes, + tags=self.function.returns[name].tags, + non_equality_tags=self.function.returns[name].non_equality_tags) def __len__(self) -> int: return len(self.function.returns) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index e54f6fe1a..94b82839c 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -466,7 +466,11 @@ def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: def map_named_call_result(self, expr: NamedCallResult) -> Array: call = self.rec(expr._container) assert isinstance(call, Call) - return NamedCallResult(call, expr.name) + return NamedCallResult( + call, expr.name, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) class CopyMapperWithExtraArgs(CachedMapper[ArrayOrNames]): @@ -704,7 +708,11 @@ def map_named_call_result(self, expr: NamedCallResult, *args: Any, **kwargs: Any) -> Array: call = self.rec(expr._container, *args, **kwargs) assert isinstance(call, Call) - return NamedCallResult(call, expr.name) + return NamedCallResult( + call, expr.name, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) # }}} From 9c2d7b9198f0eb245b63fe5f265a1d09287796dc Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 12 Mar 2024 16:09:40 -0500 Subject: [PATCH 07/32] enable cache_hash on FunctionDefinition --- pytato/function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/function.py b/pytato/function.py index d5d5fb7a0..caa968d86 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -92,7 +92,7 @@ class ReturnType(enum.Enum): # eq=False to avoid equality comparison without EqualityMaper -@attrs.define(frozen=True, eq=False, hash=True) +@attrs.define(frozen=True, eq=False, hash=True, cache_hash=True) class FunctionDefinition(Taggable): r""" A function definition that represents its outputs as instances of From 7a5131602ac55c3ef7c2166293cbac4622ca57d4 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 12 Mar 2024 09:45:53 -0500 Subject: [PATCH 08/32] enable calls in DirectPredecessorsGetter --- pytato/analysis/__init__.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5a54807ed..e597a0fd0 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -320,26 +320,26 @@ class DirectPredecessorsGetter(Mapper): We only consider the predecessors of a nodes in a data-flow sense. """ - def _get_preds_from_shape(self, shape: ShapeType) -> frozenset[Array]: + def _get_preds_from_shape(self, shape: ShapeType) -> frozenset[ArrayOrNames]: return frozenset({dim for dim in shape if isinstance(dim, Array)}) - def map_index_lambda(self, expr: IndexLambda) -> frozenset[Array]: + def map_index_lambda(self, expr: IndexLambda) -> frozenset[ArrayOrNames]: return (frozenset(expr.bindings.values()) | self._get_preds_from_shape(expr.shape)) - def map_stack(self, expr: Stack) -> frozenset[Array]: + def map_stack(self, expr: Stack) -> frozenset[ArrayOrNames]: return (frozenset(expr.arrays) | self._get_preds_from_shape(expr.shape)) - def map_concatenate(self, expr: Concatenate) -> frozenset[Array]: + def map_concatenate(self, expr: Concatenate) -> frozenset[ArrayOrNames]: return (frozenset(expr.arrays) | self._get_preds_from_shape(expr.shape)) - def map_einsum(self, expr: Einsum) -> frozenset[Array]: + def map_einsum(self, expr: Einsum) -> frozenset[ArrayOrNames]: return (frozenset(expr.args) | self._get_preds_from_shape(expr.shape)) - def map_loopy_call_result(self, expr: NamedArray) -> frozenset[Array]: + def map_loopy_call_result(self, expr: NamedArray) -> frozenset[ArrayOrNames]: from pytato.loopy import LoopyCall, LoopyCallResult assert isinstance(expr, LoopyCallResult) assert isinstance(expr._container, LoopyCall) @@ -348,7 +348,7 @@ def map_loopy_call_result(self, expr: NamedArray) -> frozenset[Array]: if isinstance(ary, Array)) | self._get_preds_from_shape(expr.shape)) - def _map_index_base(self, expr: IndexBase) -> frozenset[Array]: + def _map_index_base(self, expr: IndexBase) -> frozenset[ArrayOrNames]: return (frozenset([expr.array]) | frozenset(idx for idx in expr.indices if isinstance(idx, Array)) @@ -359,32 +359,34 @@ def _map_index_base(self, expr: IndexBase) -> frozenset[Array]: map_non_contiguous_advanced_index = _map_index_base def _map_index_remapping_base(self, expr: IndexRemappingBase - ) -> frozenset[Array]: + ) -> frozenset[ArrayOrNames]: return frozenset([expr.array]) map_roll = _map_index_remapping_base map_axis_permutation = _map_index_remapping_base map_reshape = _map_index_remapping_base - def _map_input_base(self, expr: InputArgumentBase) -> frozenset[Array]: + def _map_input_base(self, expr: InputArgumentBase) -> frozenset[ArrayOrNames]: return self._get_preds_from_shape(expr.shape) map_placeholder = _map_input_base map_data_wrapper = _map_input_base map_size_param = _map_input_base - def map_distributed_recv(self, expr: DistributedRecv) -> frozenset[Array]: + def map_distributed_recv(self, expr: DistributedRecv) -> frozenset[ArrayOrNames]: return self._get_preds_from_shape(expr.shape) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> frozenset[Array]: + ) -> frozenset[ArrayOrNames]: return frozenset([expr.passthrough_data]) - def map_named_call_result(self, expr: NamedCallResult) -> frozenset[Array]: - raise NotImplementedError( - "DirectPredecessorsGetter does not yet support expressions containing " - "functions.") + def map_call(self, expr: Call) -> frozenset[ArrayOrNames]: + return frozenset(expr.bindings.values()) + + def map_named_call_result( + self, expr: NamedCallResult) -> frozenset[ArrayOrNames]: + return frozenset([expr._container]) # }}} From 072b51a2dc0475fd1fae942b5de4c3d90210625b Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 12 Mar 2024 09:59:27 -0500 Subject: [PATCH 09/32] memoize Call creation --- pytato/function.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytato/function.py b/pytato/function.py index caa968d86..4d5e718d5 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -63,6 +63,7 @@ import attrs from immutabledict import immutabledict +from pytools import memoize_method from pytools.tag import Tag, Taggable from pytato.array import ( @@ -183,7 +184,8 @@ def _with_new_tags( self: FunctionDefinition, tags: frozenset[Tag]) -> FunctionDefinition: return attrs.evolve(self, tags=tags) - def __call__(self, **kwargs: Array + @memoize_method + def __call__(self, /, **kwargs: Array ) -> Array | tuple[Array, ...] | dict[str, Array]: from pytato.array import _get_default_tags from pytato.utils import are_shapes_equal From fbd0c4025af1e2b01878ea26f93e9e3e96d9eb66 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 12 Mar 2024 10:01:51 -0500 Subject: [PATCH 10/32] make NamedCallResult.call a property --- pytato/function.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytato/function.py b/pytato/function.py index 4d5e718d5..0862ceb29 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -242,7 +242,6 @@ class NamedCallResult(NamedArray): The name by which the returned array is referred to in :attr:`FunctionDefinition.returns`. """ - call: Call name: str _mapper_method: ClassVar[str] = "map_named_call_result" @@ -263,6 +262,11 @@ def without_tags(self, raise ValueError("Untagging a NamedCallResult is illegal, use" " Call.without_tags instead") + @property + def call(self) -> Call: + assert isinstance(self._container, Call) + return self._container + @property def shape(self) -> ShapeType: assert isinstance(self._container, Call) From 6cf3be28cea46612c127782ac16cd282d43b5d41 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 12 Mar 2024 16:54:36 -0500 Subject: [PATCH 11/32] remove redundant NamedCallResult.name (already defined in NamedArray) --- pytato/function.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytato/function.py b/pytato/function.py index 0862ceb29..082645500 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -242,7 +242,6 @@ class NamedCallResult(NamedArray): The name by which the returned array is referred to in :attr:`FunctionDefinition.returns`. """ - name: str _mapper_method: ClassVar[str] = "map_named_call_result" def with_tagged_axis(self, iaxis: int, From 4e2aae7d79ef4b242d5b41c0d519197fc0da8075 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 12 Mar 2024 10:02:39 -0500 Subject: [PATCH 12/32] memoize NamedCallResult creation --- pytato/function.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytato/function.py b/pytato/function.py index 082645500..cd09d8a40 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -314,6 +314,7 @@ def __contains__(self, name: object) -> bool: def __iter__(self) -> Iterator[str]: return iter(self.function.returns) + @memoize_method def __getitem__(self, name: str) -> NamedCallResult: return NamedCallResult( self, name, From 2b00442ab4047f7d9837eb92edce6f240f01e86a Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 11 Mar 2024 19:36:42 -0500 Subject: [PATCH 13/32] fix docstring --- pytato/tags.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/tags.py b/pytato/tags.py index b97125f63..fbe477d06 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -205,7 +205,7 @@ class ExpandedDimsReshape(UniqueTag): class FunctionIdentifier(UniqueTag): """ A tag that can be attached to a :class:`~pytato.function.FunctionDefinition` - node to to describe the function's identifier. One can use this to refer + node to describe the function's identifier. One can use this to refer all instances of :class:`~pytato.function.FunctionDefinition`, for example in transformations.transform.calls.concatenate_calls`. From c5e88ab9b3375280198afb4ae49bd33ecd7505d1 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 11 Mar 2024 16:27:58 -0500 Subject: [PATCH 14/32] remove non-argument placeholder check now done in arraycontext --- pytato/function.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pytato/function.py b/pytato/function.py index cd09d8a40..6e927fda2 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -163,14 +163,6 @@ def _placeholders(self) -> Mapping[str, Placeholder]: if isinstance(arg, Placeholder)}) all_placeholders |= new_placeholders - # FIXME: Need a way to check for *any* captured arrays, not just placeholders - if __debug__: - pl_names = frozenset(arg.name for arg in all_placeholders) - extra_pl_names = pl_names - self.parameters - assert not extra_pl_names, \ - f"Found non-argument placeholder '{next(iter(extra_pl_names))}' " \ - "in function definition." - return immutabledict({arg.name: arg for arg in all_placeholders}) def get_placeholder(self, name: str) -> Placeholder: From 595bcb707b8ed5fb33859511c1555158138c8c85 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 11 Jun 2024 14:37:10 -0500 Subject: [PATCH 15/32] fix equality for FunctionDefinition --- pytato/equality.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytato/equality.py b/pytato/equality.py index 5750d2b93..79d038d72 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -298,6 +298,7 @@ def map_function_definition(self, expr1: FunctionDefinition, expr2: Any ) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.parameters == expr2.parameters + and expr1.return_type == expr2.return_type and (set(expr1.returns.keys()) == set(expr2.returns.keys())) and all(self.rec(expr1.returns[k], expr2.returns[k]) for k in expr1.returns) @@ -311,6 +312,7 @@ def map_call(self, expr1: Call, expr2: Any) -> bool: and all(self.rec(bnd, expr2.bindings[name]) for name, bnd in expr1.bindings.items()) + and expr1.tags == expr2.tags ) def map_named_call_result(self, expr1: NamedCallResult, expr2: Any) -> bool: From b310a916cbf182f4f03fac2fa2b37806287c5726 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 13 Jun 2024 14:18:22 -0500 Subject: [PATCH 16/32] add FIXME --- pytato/function.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytato/function.py b/pytato/function.py index 6e927fda2..81ddf4e0f 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -215,6 +215,7 @@ def __call__(self, /, **kwargs: Array return tuple(call_site[f"_{iarg}"] for iarg in range(len(self.returns))) elif self.return_type == ReturnType.DICT_OF_ARRAYS: + # FIXME: Should this be immutabledict? return {kw: call_site[kw] for kw in self.returns} else: raise NotImplementedError(self.return_type) From 729d72081b176dc19993d7ff7a91ff5498d9f9b3 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 17 Jun 2024 14:20:56 -0500 Subject: [PATCH 17/32] attempt to fix doc warning --- pytato/function.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/pytato/function.py b/pytato/function.py index 81ddf4e0f..6a3abb754 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -17,6 +17,17 @@ A type variable corresponding to the return type of the function :func:`pytato.trace_call`. + +Internal stuff that is only here because the documentation tool wants it +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. class:: Tag + + See :class:`pytools.tag.Tag`. + +.. class:: AxesT + + A :class:`tuple` of :class:`pytato.array.Axis` objects. """ __copyright__ = """ From cf3d0a9092b55377572cde02d44a9316698bee73 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 6 Jun 2024 16:15:16 -0500 Subject: [PATCH 18/32] don't construct NamedCallResult directly --- pytato/transform/__init__.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 94b82839c..f99a4aca0 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -466,11 +466,7 @@ def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: def map_named_call_result(self, expr: NamedCallResult) -> Array: call = self.rec(expr._container) assert isinstance(call, Call) - return NamedCallResult( - call, expr.name, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + return call[expr.name] class CopyMapperWithExtraArgs(CachedMapper[ArrayOrNames]): @@ -708,11 +704,7 @@ def map_named_call_result(self, expr: NamedCallResult, *args: Any, **kwargs: Any) -> Array: call = self.rec(expr._container, *args, **kwargs) assert isinstance(call, Call) - return NamedCallResult( - call, expr.name, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + return call[expr.name] # }}} From 10e0cefb3effa9ac57b85c351453aac7388d7e2e Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 18 Jun 2024 13:23:30 -0500 Subject: [PATCH 19/32] fix mapper method name in UsersCollector --- pytato/transform/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index f99a4aca0..5ea192d5b 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1775,7 +1775,7 @@ def map_call(self, expr: Call, *args: Any) -> None: for bnd in expr.bindings.values(): self.rec(bnd) - def map_named_call(self, expr: NamedCallResult, *args: Any) -> None: + def map_named_call_result(self, expr: NamedCallResult, *args: Any) -> None: assert isinstance(expr._container, Call) for bnd in expr._container.bindings.values(): self.node_to_users.setdefault(bnd, set()).add(expr) From 81e49b410c789791d960cb0f80243319a7013cd3 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 14 Jun 2024 17:00:54 -0500 Subject: [PATCH 20/32] add FIXME --- pytato/transform/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 5ea192d5b..f84d92585 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -968,6 +968,8 @@ def map_function_definition(self, expr: FunctionDefinition if inp.name in expr.parameters: # drop, reference to argument pass + # FIXME: Checked upon function definition creation in arraycontext + # now, can probably drop this else: raise ValueError("function definition refers to non-argument " f"placeholder named '{inp.name}'") From 769d48e03abc548f7069179623ed291761df9d21 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 10 Jul 2024 11:01:47 -0500 Subject: [PATCH 21/32] Revert "add FIXME" This reverts commit 72bf01daf4757c72560c5eaf937377b57ce6b07d. --- pytato/transform/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index f84d92585..5ea192d5b 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -968,8 +968,6 @@ def map_function_definition(self, expr: FunctionDefinition if inp.name in expr.parameters: # drop, reference to argument pass - # FIXME: Checked upon function definition creation in arraycontext - # now, can probably drop this else: raise ValueError("function definition refers to non-argument " f"placeholder named '{inp.name}'") From 6de23ec8e400c0bca96cc575b19da7bb032332d9 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 10 Jul 2024 11:02:06 -0500 Subject: [PATCH 22/32] Revert "remove non-argument placeholder check" This reverts commit 73478bc4ac16877f946ce868317d9d1d10fb2ca6. --- pytato/function.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pytato/function.py b/pytato/function.py index 6a3abb754..272378ecf 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -174,6 +174,14 @@ def _placeholders(self) -> Mapping[str, Placeholder]: if isinstance(arg, Placeholder)}) all_placeholders |= new_placeholders + # FIXME: Need a way to check for *any* captured arrays, not just placeholders + if __debug__: + pl_names = frozenset(arg.name for arg in all_placeholders) + extra_pl_names = pl_names - self.parameters + assert not extra_pl_names, \ + f"Found non-argument placeholder '{next(iter(extra_pl_names))}' " \ + "in function definition." + return immutabledict({arg.name: arg for arg in all_placeholders}) def get_placeholder(self, name: str) -> Placeholder: From 3d6998556e76a30b28a1c13c164ddfeef94ed52c Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 9 Jul 2024 16:45:16 -0500 Subject: [PATCH 23/32] add some more missing *args, **kwargs to WalkMapper --- pytato/transform/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 5ea192d5b..8fbe45ef2 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1180,7 +1180,7 @@ def map_loopy_call(self, expr: LoopyCall, *args: Any, **kwargs: Any) -> None: def map_function_definition(self, expr: FunctionDefinition, *args: Any, **kwargs: Any) -> None: - if not self.visit(expr): + if not self.visit(expr, *args, **kwargs): return new_mapper = self.clone_for_callee(expr) @@ -1190,14 +1190,14 @@ def map_function_definition(self, expr: FunctionDefinition, self.post_visit(expr, *args, **kwargs) def map_call(self, expr: Call, *args: Any, **kwargs: Any) -> None: - if not self.visit(expr): + if not self.visit(expr, *args, **kwargs): return self.map_function_definition(expr.function, *args, **kwargs) for bnd in expr.bindings.values(): self.rec(bnd, *args, **kwargs) - self.post_visit(expr) + self.post_visit(expr, *args, **kwargs) def map_named_call_result(self, expr: NamedCallResult, *args: Any, **kwargs: Any) -> None: @@ -1244,7 +1244,7 @@ def clone_for_callee( def map_function_definition(self, expr: FunctionDefinition, *args: Any, **kwargs: Any) -> None: - if not self.visit(expr) or expr in self._visited_functions: + if not self.visit(expr, *args, **kwargs) or expr in self._visited_functions: return new_mapper = self.clone_for_callee(expr) From 2ebc1b20cddc48a5356556476bde94cc2154b326 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 10 Jul 2024 15:51:34 -0500 Subject: [PATCH 24/32] remove some unnecessary *args, **kwargs --- pytato/analysis/__init__.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index e597a0fd0..5a6040beb 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -448,19 +448,18 @@ def __init__(self) -> None: def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) - def map_function_definition(self, /, expr: FunctionDefinition, - *args: Any, **kwargs: Any) -> None: + def map_function_definition(self, expr: FunctionDefinition) -> None: if not self.visit(expr) or expr in self._visited_functions: return new_mapper = self.clone_for_callee(expr) for subexpr in expr.returns.values(): - new_mapper(subexpr, *args, **kwargs) + new_mapper(subexpr) self.count += new_mapper.count self._visited_functions.add(expr) - self.post_visit(expr, *args, **kwargs) + self.post_visit(expr) def post_visit(self, expr: Any) -> None: if isinstance(expr, Call): From 45eb68d54d6b0e74776fa9bde8d41c7f798b8c49 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 10 Jul 2024 15:56:53 -0500 Subject: [PATCH 25/32] add get_func_def_cache_key to walk mappers to correctly handle function caching when extra arguments are present --- pytato/analysis/__init__.py | 11 +++++++++-- pytato/codegen.py | 5 ++++- pytato/transform/__init__.py | 19 +++++++++++++------ 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5a6040beb..20e3b813e 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -411,6 +411,9 @@ def __init__(self) -> None: def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) + def get_func_def_cache_key(self, expr: FunctionDefinition) -> int: + return id(expr) + def post_visit(self, expr: Any) -> None: self.count += 1 @@ -448,8 +451,12 @@ def __init__(self) -> None: def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) + def get_func_def_cache_key(self, expr: FunctionDefinition) -> int: + return id(expr) + def map_function_definition(self, expr: FunctionDefinition) -> None: - if not self.visit(expr) or expr in self._visited_functions: + cache_key = self.get_func_def_cache_key(expr) + if not self.visit(expr) or cache_key in self._visited_functions: return new_mapper = self.clone_for_callee(expr) @@ -457,7 +464,7 @@ def map_function_definition(self, expr: FunctionDefinition) -> None: new_mapper(subexpr) self.count += new_mapper.count - self._visited_functions.add(expr) + self._visited_functions.add(cache_key) self.post_visit(expr) diff --git a/pytato/codegen.py b/pytato/codegen.py index 0e1126289..8617c0fd0 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -42,7 +42,7 @@ SizeParam, make_dict_of_named_arrays, ) -from pytato.function import NamedCallResult +from pytato.function import FunctionDefinition, NamedCallResult from pytato.loopy import LoopyCall from pytato.scalar_expr import IntegralScalarExpression from pytato.target import Target @@ -254,6 +254,9 @@ def __init__(self) -> None: def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) + def get_func_def_cache_key(self, expr: FunctionDefinition) -> int: + return id(expr) + def post_visit(self, expr: Any) -> None: if isinstance(expr, (Placeholder, SizeParam, DataWrapper)): if expr.name is not None: diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 8fbe45ef2..eeae459ae 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1222,20 +1222,24 @@ class CachedWalkMapper(WalkMapper): def __init__(self) -> None: super().__init__() - self._visited_nodes: set[Any] = set() - self._visited_functions: set[FunctionDefinition] = set() + self._visited_arrays_or_names: set[Any] = set() + self._visited_functions: set[Any] = set() def get_cache_key(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError + def get_func_def_cache_key( + self, expr: FunctionDefinition, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError + def rec(self, expr: ArrayOrNames, *args: Any, **kwargs: Any ) -> None: cache_key = self.get_cache_key(expr, *args, **kwargs) - if cache_key in self._visited_nodes: + if cache_key in self._visited_arrays_or_names: return super().rec(expr, *args, **kwargs) - self._visited_nodes.add(cache_key) + self._visited_arrays_or_names.add(cache_key) @memoize_method def clone_for_callee( @@ -1244,14 +1248,17 @@ def clone_for_callee( def map_function_definition(self, expr: FunctionDefinition, *args: Any, **kwargs: Any) -> None: - if not self.visit(expr, *args, **kwargs) or expr in self._visited_functions: + cache_key = self.get_func_def_cache_key(expr, *args, **kwargs) + if ( + not self.visit(expr, *args, **kwargs) + or cache_key in self._visited_functions): return new_mapper = self.clone_for_callee(expr) for subexpr in expr.returns.values(): new_mapper(subexpr, *args, **kwargs) - self._visited_functions.add(expr) + self._visited_functions.add(cache_key) self.post_visit(expr, *args, **kwargs) From 950f44f8cfe226e30869b21423106c9949901ae3 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 10 Jul 2024 16:04:37 -0500 Subject: [PATCH 26/32] undo memoizing Call creation --- pytato/function.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytato/function.py b/pytato/function.py index 272378ecf..e4e0d571e 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -195,8 +195,7 @@ def _with_new_tags( self: FunctionDefinition, tags: frozenset[Tag]) -> FunctionDefinition: return attrs.evolve(self, tags=tags) - @memoize_method - def __call__(self, /, **kwargs: Array + def __call__(self, **kwargs: Array ) -> Array | tuple[Array, ...] | dict[str, Array]: from pytato.array import _get_default_tags from pytato.utils import are_shapes_equal From d1cb54e58811930081aceab47919405c186b76bc Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 10 Jul 2024 16:13:23 -0500 Subject: [PATCH 27/32] don't use regular dict for function call results --- pytato/function.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytato/function.py b/pytato/function.py index e4e0d571e..5dd5a35f0 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -233,8 +233,7 @@ def __call__(self, **kwargs: Array return tuple(call_site[f"_{iarg}"] for iarg in range(len(self.returns))) elif self.return_type == ReturnType.DICT_OF_ARRAYS: - # FIXME: Should this be immutabledict? - return {kw: call_site[kw] for kw in self.returns} + return immutabledict({kw: call_site[kw] for kw in self.returns}) else: raise NotImplementedError(self.return_type) From ddf546bf966170d9e696c4b7bec600dec564e96a Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 10 Jul 2024 16:31:04 -0500 Subject: [PATCH 28/32] fix type annotation for function result --- pytato/function.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pytato/function.py b/pytato/function.py index 5dd5a35f0..c79dfcfe3 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -61,7 +61,6 @@ from typing import ( Callable, ClassVar, - Dict, Hashable, Iterable, Iterator, @@ -87,7 +86,7 @@ ) -ReturnT = TypeVar("ReturnT", Array, Tuple[Array, ...], Dict[str, Array]) +ReturnT = TypeVar("ReturnT", Array, Tuple[Array, ...], Mapping[str, Array]) # {{{ Call/NamedCallResult @@ -111,7 +110,7 @@ class FunctionDefinition(Taggable): :class:`~pytato.Array` with the inputs being :class:`~pytato.array.Placeholder`\ s. The outputs of the function can be a single :class:`pytato.Array`, a tuple of :class:`pytato.Array`\ s or an - instance of ``Dict[str, Array]``. + instance of ``Mapping[str, Array]``. .. attribute:: parameters @@ -196,7 +195,7 @@ def _with_new_tags( return attrs.evolve(self, tags=tags) def __call__(self, **kwargs: Array - ) -> Array | tuple[Array, ...] | dict[str, Array]: + ) -> Array | tuple[Array, ...] | Mapping[str, Array]: from pytato.array import _get_default_tags from pytato.utils import are_shapes_equal From 1e7c71d00d481b04d938c1b3ea21f8bd3a2f58dd Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 25 Jul 2024 14:19:42 -0500 Subject: [PATCH 29/32] undo memoizing clone_for_callee doesn't avoid retraversal when the same function is encountered inside the bodies of two different functions --- pytato/distributed/partition.py | 1 - pytato/transform/__init__.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 606d1a203..5865ec491 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -303,7 +303,6 @@ def __init__(self, self.user_input_names: set[str] = set() self.partition_input_name_to_placeholder: dict[str, Placeholder] = {} - @memoize_method def clone_for_callee( self, function: FunctionDefinition) -> _DistributedInputReplacer: # Function definitions aren't allowed to contain receives, diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index eeae459ae..ea15f88da 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -276,7 +276,6 @@ def rec(self, expr: CopyMapperResultT) -> CopyMapperResultT: def __call__(self, expr: CopyMapperResultT) -> CopyMapperResultT: return self.rec(expr) - @memoize_method def clone_for_callee( self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper: """ @@ -1241,7 +1240,6 @@ def rec(self, expr: ArrayOrNames, *args: Any, **kwargs: Any super().rec(expr, *args, **kwargs) self._visited_arrays_or_names.add(cache_key) - @memoize_method def clone_for_callee( self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper: return type(self)() @@ -1308,7 +1306,6 @@ def __init__(self, map_fn: Callable[[ArrayOrNames], ArrayOrNames]) -> None: super().__init__() self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn - @memoize_method def clone_for_callee( self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper: # type-ignore-reason: self.__init__ has a different function signature From b7d412d65ce0a9d2c0ab9a78a0bfc293c1e100ab Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 25 Jul 2024 23:18:03 -0500 Subject: [PATCH 30/32] add SizeParamGatherer.map_call default map_call implementation was removed from CombineMapper, so it needs to be here --- pytato/transform/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index ea15f88da..00ef86a06 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1005,6 +1005,12 @@ def map_function_definition(self, expr: FunctionDefinition return self.combine(*[self.rec(ret) for ret in expr.returns.values()]) + def map_call(self, expr: Call) -> frozenset[SizeParam]: + return self.combine(self.map_function_definition(expr.function), + *[ + self.rec(bnd) + for name, bnd in sorted(expr.bindings.items())]) + # }}} From c9cda75d28c3eb79c212d432c1c14377fe62f182 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 29 Jul 2024 16:55:28 -0500 Subject: [PATCH 31/32] Revert "add get_func_def_cache_key to walk mappers to correctly handle function caching when extra arguments are present" This reverts commit 45eb68d54d6b0e74776fa9bde8d41c7f798b8c49. --- pytato/analysis/__init__.py | 11 ++--------- pytato/codegen.py | 5 +---- pytato/transform/__init__.py | 19 ++++++------------- 3 files changed, 9 insertions(+), 26 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 20e3b813e..5a6040beb 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -411,9 +411,6 @@ def __init__(self) -> None: def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) - def get_func_def_cache_key(self, expr: FunctionDefinition) -> int: - return id(expr) - def post_visit(self, expr: Any) -> None: self.count += 1 @@ -451,12 +448,8 @@ def __init__(self) -> None: def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) - def get_func_def_cache_key(self, expr: FunctionDefinition) -> int: - return id(expr) - def map_function_definition(self, expr: FunctionDefinition) -> None: - cache_key = self.get_func_def_cache_key(expr) - if not self.visit(expr) or cache_key in self._visited_functions: + if not self.visit(expr) or expr in self._visited_functions: return new_mapper = self.clone_for_callee(expr) @@ -464,7 +457,7 @@ def map_function_definition(self, expr: FunctionDefinition) -> None: new_mapper(subexpr) self.count += new_mapper.count - self._visited_functions.add(cache_key) + self._visited_functions.add(expr) self.post_visit(expr) diff --git a/pytato/codegen.py b/pytato/codegen.py index 8617c0fd0..0e1126289 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -42,7 +42,7 @@ SizeParam, make_dict_of_named_arrays, ) -from pytato.function import FunctionDefinition, NamedCallResult +from pytato.function import NamedCallResult from pytato.loopy import LoopyCall from pytato.scalar_expr import IntegralScalarExpression from pytato.target import Target @@ -254,9 +254,6 @@ def __init__(self) -> None: def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) - def get_func_def_cache_key(self, expr: FunctionDefinition) -> int: - return id(expr) - def post_visit(self, expr: Any) -> None: if isinstance(expr, (Placeholder, SizeParam, DataWrapper)): if expr.name is not None: diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 00ef86a06..5e6a9e6ff 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1227,24 +1227,20 @@ class CachedWalkMapper(WalkMapper): def __init__(self) -> None: super().__init__() - self._visited_arrays_or_names: set[Any] = set() - self._visited_functions: set[Any] = set() + self._visited_nodes: set[Any] = set() + self._visited_functions: set[FunctionDefinition] = set() def get_cache_key(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError - def get_func_def_cache_key( - self, expr: FunctionDefinition, *args: Any, **kwargs: Any) -> Any: - raise NotImplementedError - def rec(self, expr: ArrayOrNames, *args: Any, **kwargs: Any ) -> None: cache_key = self.get_cache_key(expr, *args, **kwargs) - if cache_key in self._visited_arrays_or_names: + if cache_key in self._visited_nodes: return super().rec(expr, *args, **kwargs) - self._visited_arrays_or_names.add(cache_key) + self._visited_nodes.add(cache_key) def clone_for_callee( self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper: @@ -1252,17 +1248,14 @@ def clone_for_callee( def map_function_definition(self, expr: FunctionDefinition, *args: Any, **kwargs: Any) -> None: - cache_key = self.get_func_def_cache_key(expr, *args, **kwargs) - if ( - not self.visit(expr, *args, **kwargs) - or cache_key in self._visited_functions): + if not self.visit(expr, *args, **kwargs) or expr in self._visited_functions: return new_mapper = self.clone_for_callee(expr) for subexpr in expr.returns.values(): new_mapper(subexpr, *args, **kwargs) - self._visited_functions.add(cache_key) + self._visited_functions.add(expr) self.post_visit(expr, *args, **kwargs) From 742de0a5fb1a0e8db946e1dea8f3d1ede3160bc3 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 29 Jul 2024 18:08:24 -0500 Subject: [PATCH 32/32] Revert "don't memoize map_function_definition in cached walk mappers" This reverts commit c0e70cdc8821ae7b14c31acf38e996cdf6bd5896. --- pytato/analysis/__init__.py | 6 +++--- pytato/transform/__init__.py | 16 +--------------- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5a6040beb..072ea8e66 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -29,6 +29,7 @@ from typing import TYPE_CHECKING, Any, Mapping from pymbolic.mapper.optimize import optimize_mapper +from pytools import memoize_method from pytato.array import ( Array, @@ -448,8 +449,9 @@ def __init__(self) -> None: def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) + @memoize_method def map_function_definition(self, expr: FunctionDefinition) -> None: - if not self.visit(expr) or expr in self._visited_functions: + if not self.visit(expr): return new_mapper = self.clone_for_callee(expr) @@ -457,8 +459,6 @@ def map_function_definition(self, expr: FunctionDefinition) -> None: new_mapper(subexpr) self.count += new_mapper.count - self._visited_functions.add(expr) - self.post_visit(expr) def post_visit(self, expr: Any) -> None: diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 5e6a9e6ff..4d389c245 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1228,7 +1228,6 @@ class CachedWalkMapper(WalkMapper): def __init__(self) -> None: super().__init__() self._visited_nodes: set[Any] = set() - self._visited_functions: set[FunctionDefinition] = set() def get_cache_key(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError @@ -1245,20 +1244,6 @@ def rec(self, expr: ArrayOrNames, *args: Any, **kwargs: Any def clone_for_callee( self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper: return type(self)() - - def map_function_definition(self, expr: FunctionDefinition, - *args: Any, **kwargs: Any) -> None: - if not self.visit(expr, *args, **kwargs) or expr in self._visited_functions: - return - - new_mapper = self.clone_for_callee(expr) - for subexpr in expr.returns.values(): - new_mapper(subexpr, *args, **kwargs) - - self._visited_functions.add(expr) - - self.post_visit(expr, *args, **kwargs) - # }}} @@ -1286,6 +1271,7 @@ def get_cache_key(self, expr: ArrayOrNames) -> int: def post_visit(self, expr: Any) -> None: self.topological_order.append(expr) + @memoize_method def map_function_definition(self, expr: FunctionDefinition) -> None: # do nothing as it includes arrays from a different namespace. return