diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index c100e7d31..b5cddaae8 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -323,7 +323,11 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool: # {{{ DirectPredecessorsGetter -class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]): +class DirectPredecessorsGetter( + Mapper[ + FrozenOrderedSet[ArrayOrNames | FunctionDefinition], + FrozenOrderedSet[ArrayOrNames], + []]): """ Mapper to get the `direct predecessors @@ -334,9 +338,17 @@ class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]): We only consider the predecessors of a nodes in a data-flow sense. """ + def __init__(self, *, include_functions: bool = False) -> None: + super().__init__() + self.include_functions = include_functions + def _get_preds_from_shape(self, shape: ShapeType) -> FrozenOrderedSet[ArrayOrNames]: return FrozenOrderedSet(dim for dim in shape if isinstance(dim, Array)) + def map_dict_of_named_arrays( + self, expr: DictOfNamedArrays) -> FrozenOrderedSet[ArrayOrNames]: + return FrozenOrderedSet(expr._data.values()) + def map_index_lambda(self, expr: IndexLambda) -> FrozenOrderedSet[ArrayOrNames]: return (FrozenOrderedSet(expr.bindings.values()) | self._get_preds_from_shape(expr.shape)) @@ -397,8 +409,17 @@ def map_distributed_send_ref_holder(self, ) -> FrozenOrderedSet[ArrayOrNames]: return FrozenOrderedSet([expr.passthrough_data]) - def map_call(self, expr: Call) -> FrozenOrderedSet[ArrayOrNames]: - return FrozenOrderedSet(expr.bindings.values()) + def map_call( + self, expr: Call) -> FrozenOrderedSet[ArrayOrNames | FunctionDefinition]: + result: FrozenOrderedSet[ArrayOrNames | FunctionDefinition] = \ + FrozenOrderedSet(expr.bindings.values()) + if self.include_functions: + result = result | FrozenOrderedSet([expr.function]) + return result + + def map_function_definition( + self, expr: FunctionDefinition) -> FrozenOrderedSet[ArrayOrNames]: + return FrozenOrderedSet(expr.returns.values()) def map_named_call_result( self, expr: NamedCallResult) -> FrozenOrderedSet[ArrayOrNames]: @@ -622,11 +643,11 @@ def combine(self, *args: int) -> int: return sum(args) def rec(self, expr: ArrayOrNames) -> int: - key = self._cache.get_key(expr) + inputs = self._cache.make_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache_retrieve(inputs) except KeyError: - s = super().rec(expr) + s = Mapper.rec(self, expr) if ( isinstance(expr, Array) and ( @@ -636,7 +657,7 @@ def rec(self, expr: ArrayOrNames) -> int: else: result = 0 + s - self._cache.add(expr, 0, key=key) + self._cache_add(inputs, 0) return result diff --git a/pytato/array.py b/pytato/array.py index ca44c2c2b..7de89c158 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -482,8 +482,11 @@ class Axis(Taggable): tags: frozenset[Tag] def _with_new_tags(self, tags: frozenset[Tag]) -> Axis: - from dataclasses import replace - return replace(self, tags=tags) + if tags != self.tags: + from dataclasses import replace + return replace(self, tags=tags) + else: + return self @dataclasses.dataclass(frozen=True) @@ -495,8 +498,11 @@ class ReductionDescriptor(Taggable): tags: frozenset[Tag] def _with_new_tags(self, tags: frozenset[Tag]) -> ReductionDescriptor: - from dataclasses import replace - return replace(self, tags=tags) + if tags != self.tags: + from dataclasses import replace + return replace(self, tags=tags) + else: + return self @array_dataclass() @@ -848,10 +854,14 @@ def with_tagged_axis(self, iaxis: int, """ Returns a copy of *self* with *iaxis*-th axis tagged with *tags*. """ - new_axes = (self.axes[:iaxis] - + (self.axes[iaxis].tagged(tags),) - + self.axes[iaxis+1:]) - return self.copy(axes=new_axes) + new_axis = self.axes[iaxis].tagged(tags) + if new_axis is not self.axes[iaxis]: + new_axes = (self.axes[:iaxis] + + (self.axes[iaxis].tagged(tags),) + + self.axes[iaxis+1:]) + return self.copy(axes=new_axes) + else: + return self @memoize_method def __repr__(self) -> str: @@ -880,7 +890,10 @@ class _SuppliedAxesAndTagsMixin(Taggable): default=frozenset()) def _with_new_tags(self: Self, tags: frozenset[Tag]) -> Self: - return dataclasses.replace(self, tags=tags) + if tags != self.tags: + return dataclasses.replace(self, tags=tags) + else: + return self @dataclasses.dataclass(frozen=True, eq=False, repr=False) @@ -1129,20 +1142,22 @@ def with_tagged_reduction(self, f" '{self.var_to_reduction_descr.keys()}'," f" got '{reduction_variable}'.") - assert isinstance(self.var_to_reduction_descr, immutabledict) - new_var_to_redn_descr = dict(self.var_to_reduction_descr) - new_var_to_redn_descr[reduction_variable] = \ - self.var_to_reduction_descr[reduction_variable].tagged(tags) - - return type(self)(expr=self.expr, - shape=self.shape, - dtype=self.dtype, - bindings=self.bindings, - axes=self.axes, - var_to_reduction_descr=immutabledict - (new_var_to_redn_descr), - tags=self.tags, - non_equality_tags=self.non_equality_tags) + new_redn_descr = self.var_to_reduction_descr[reduction_variable].tagged(tags) + if new_redn_descr is not self.var_to_reduction_descr[reduction_variable]: + assert isinstance(self.var_to_reduction_descr, immutabledict) + new_var_to_redn_descr = dict(self.var_to_reduction_descr) + new_var_to_redn_descr[reduction_variable] = new_redn_descr + return type(self)(expr=self.expr, + shape=self.shape, + dtype=self.dtype, + bindings=self.bindings, + axes=self.axes, + var_to_reduction_descr=immutabledict + (new_var_to_redn_descr), + tags=self.tags, + non_equality_tags=self.non_equality_tags) + else: + return self # }}} @@ -1293,19 +1308,21 @@ def with_tagged_reduction(self, # }}} - assert isinstance(self.redn_axis_to_redn_descr, immutabledict) - new_redn_axis_to_redn_descr = dict(self.redn_axis_to_redn_descr) - new_redn_axis_to_redn_descr[redn_axis] = \ - self.redn_axis_to_redn_descr[redn_axis].tagged(tags) - - return type(self)(access_descriptors=self.access_descriptors, - args=self.args, - axes=self.axes, - redn_axis_to_redn_descr=immutabledict - (new_redn_axis_to_redn_descr), - tags=self.tags, - non_equality_tags=self.non_equality_tags, - ) + new_redn_descr = self.redn_axis_to_redn_descr[redn_axis].tagged(tags) + if new_redn_descr is not self.redn_axis_to_redn_descr[redn_axis]: + assert isinstance(self.redn_axis_to_redn_descr, immutabledict) + new_redn_axis_to_redn_descr = dict(self.redn_axis_to_redn_descr) + new_redn_axis_to_redn_descr[redn_axis] = new_redn_descr + return type(self)(access_descriptors=self.access_descriptors, + args=self.args, + axes=self.axes, + redn_axis_to_redn_descr=immutabledict + (new_redn_axis_to_redn_descr), + tags=self.tags, + non_equality_tags=self.non_equality_tags, + ) + else: + return self EINSUM_FIRST_INDEX = re.compile(r"^\s*((?P[a-zA-Z])|(?P\.\.\.))\s*") diff --git a/pytato/codegen.py b/pytato/codegen.py index 86a328929..dd0d4063f 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -148,47 +148,55 @@ def __init__( self.kernels_seen: dict[str, lp.LoopKernel] = kernels_seen or {} def map_size_param(self, expr: SizeParam) -> Array: - name = expr.name - assert name is not None - return SizeParam( # pylint: disable=missing-kwoa - name=name, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + assert expr.name is not None + return expr def map_placeholder(self, expr: Placeholder) -> Array: - name = expr.name - if name is None: - name = self.var_name_gen("_pt_in") - return Placeholder(name=name, - shape=tuple(self.rec(s) if isinstance(s, Array) else s - for s in expr.shape), - dtype=expr.dtype, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_name = expr.name + if new_name is None: + new_name = self.var_name_gen("_pt_in") + new_shape = self.rec_idx_or_size_tuple(expr.shape) + if ( + new_name is expr.name + and new_shape is expr.shape): + return expr + else: + return Placeholder(name=new_name, + shape=new_shape, + dtype=expr.dtype, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: from pytato.target.loopy import LoopyTarget if not isinstance(self.target, LoopyTarget): raise ValueError("Got a LoopyCall for a non-loopy target.") - translation_unit = expr.translation_unit.copy( - target=self.target.get_loopy_target()) + new_target = self.target.get_loopy_target() + + # FIXME: Can't use "is" here because targets aren't unique. Is it OK to + # use the existing target if it's equal to self.target.get_loopy_target()? + # If not, may have to set err_on_no_op_duplication=False + if new_target == expr.translation_unit.target: + new_translation_unit = expr.translation_unit + else: + new_translation_unit = expr.translation_unit.copy(target=new_target) namegen = UniqueNameGenerator(set(self.kernels_seen)) - entrypoint = expr.entrypoint + new_entrypoint = expr.entrypoint # {{{ eliminate callable name collision - for name, clbl in translation_unit.callables_table.items(): + for name, clbl in new_translation_unit.callables_table.items(): if isinstance(clbl, lp.CallableKernel): assert isinstance(name, str) if name in self.kernels_seen and ( - translation_unit[name] != self.kernels_seen[name]): + new_translation_unit[name] != self.kernels_seen[name]): # callee name collision => must rename # {{{ see if it's one of the other kernels for other_knl in self.kernels_seen.values(): - if other_knl.copy(name=name) == translation_unit[name]: + if other_knl.copy(name=name) == new_translation_unit[name]: new_name = other_knl.name break else: @@ -198,37 +206,55 @@ def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: # }}} - if name == entrypoint: + if name == new_entrypoint: # if the colliding name is the entrypoint, then rename the # entrypoint as well. - entrypoint = new_name + new_entrypoint = new_name - translation_unit = lp.rename_callable( - translation_unit, name, new_name) + new_translation_unit = lp.rename_callable( + new_translation_unit, name, new_name) name = new_name self.kernels_seen[name] = clbl.subkernel # }}} - bindings: Mapping[str, Any] = immutabledict( + new_bindings: Mapping[str, Any] = immutabledict( {name: (self.rec(subexpr) if isinstance(subexpr, Array) else subexpr) for name, subexpr in sorted(expr.bindings.items())}) - return LoopyCall(translation_unit=translation_unit, - bindings=bindings, - entrypoint=entrypoint, - tags=expr.tags - ) + assert ( + new_entrypoint is expr.entrypoint + or new_entrypoint != expr.entrypoint) + for bnd, new_bnd in zip( + expr.bindings.values(), new_bindings.values(), strict=True): + assert new_bnd is bnd or new_bnd != bnd + + if ( + new_translation_unit == expr.translation_unit + and all( + new_bnd is bnd + for bnd, new_bnd in zip( + expr.bindings.values(), + new_bindings.values(), + strict=True)) + and new_entrypoint is expr.entrypoint): + return expr + else: + return LoopyCall(translation_unit=new_translation_unit, + bindings=new_bindings, + entrypoint=new_entrypoint, + tags=expr.tags + ) def map_data_wrapper(self, expr: DataWrapper) -> Array: name = _generate_name_for_temp(expr, self.var_name_gen, "_pt_data") + shape = self.rec_idx_or_size_tuple(expr.shape) self.bound_arguments[name] = expr.data return Placeholder(name=name, - shape=tuple(self.rec(s) if isinstance(s, Array) else s - for s in expr.shape), + shape=shape, dtype=expr.dtype, axes=expr.axes, tags=expr.tags, diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 27b1e2cee..f3f0ae322 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -294,9 +294,9 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend: return new_send def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - key = self._cache.get_key(expr) + inputs = self._cache.make_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache_retrieve(inputs) except KeyError: pass diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 5b1ba02c4..623644924 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -46,6 +46,7 @@ from typing_extensions import Self from pymbolic.mapper.optimize import optimize_mapper +from pytools import memoize_method from pytato.array import ( AbstractResultWithNamedArrays, @@ -93,6 +94,7 @@ __doc__ = """ .. autoclass:: Mapper +.. autoclass:: CacheInputs .. autoclass:: CachedMapperCache .. autoclass:: CachedMapper .. autoclass:: TransformMapperCache @@ -100,6 +102,7 @@ .. autoclass:: TransformMapperWithExtraArgs .. autoclass:: CopyMapper .. autoclass:: CopyMapperWithExtraArgs +.. autoclass:: Deduplicator .. autoclass:: CombineMapper .. autoclass:: DependencyMapper .. autoclass:: InputGatherer @@ -187,6 +190,14 @@ class ForeignObjectError(ValueError): pass +class CacheCollisionError(ValueError): + pass + + +class CacheNoOpDuplicationError(ValueError): + pass + + # {{{ mapper base class ResultT = TypeVar("ResultT") @@ -299,17 +310,38 @@ def __call__( # {{{ CachedMapper -CacheExprT = TypeVar("CacheExprT") +CacheExprT = TypeVar("CacheExprT", ArrayOrNames, FunctionDefinition) CacheResultT = TypeVar("CacheResultT") CacheKeyT: TypeAlias = Hashable +class CacheInputs(Generic[CacheExprT]): + """Data structure for inputs to :class:`CachedMapperCache`.""" + def __init__( + self, + expr: CacheExprT, + key_func: Callable[..., CacheKeyT], + *args: Any, + **kwargs: Any): + self.expr: CacheExprT = expr + self.args: tuple[Any, ...] = args + self.kwargs: dict[str, Any] = kwargs + self._key_func = key_func + + @memoize_method + def _get_key(self) -> CacheKeyT: + return self._key_func(self.expr, *self.args, **self.kwargs) + + @property + def key(self) -> CacheKeyT: + return self._get_key() + + class CachedMapperCache(Generic[CacheExprT, CacheResultT]): """ Cache for mappers. .. automethod:: __init__ - .. method:: get_key Compute the key for an input expression. @@ -319,55 +351,54 @@ class CachedMapperCache(Generic[CacheExprT, CacheResultT]): """ def __init__( self, - key_func: Callable[..., CacheKeyT]) -> None: + key_func: Callable[..., CacheKeyT], + err_on_collision: bool) -> None: """ Initialize the cache. :arg key_func: Function to compute a hashable cache key from an input expression and any extra arguments. + :arg err_on_collision: Raise an exception if two distinct input expression + instances have the same key. """ - self.get_key = key_func + self.err_on_collision = err_on_collision + self._key_func = key_func self._expr_key_to_result: dict[CacheKeyT, CacheResultT] = {} + if self.err_on_collision: + self._expr_key_to_expr: dict[CacheKeyT, CacheExprT] = {} + + def make_inputs( + self, expr: CacheExprT, *args: Any, **kwargs: Any + ) -> CacheInputs[CacheExprT]: + return CacheInputs(expr, self._key_func, *args, **kwargs) def add( self, - key_inputs: - CacheExprT - # Currently, Python's type system doesn't have a way to annotate - # containers of args/kwargs (ParamSpec won't work here). So we have - # to fall back to using Any. More details here: - # https://github.com/python/typing/issues/1252 - | tuple[CacheExprT, tuple[Any, ...], dict[str, Any]], - result: CacheResultT, - key: CacheKeyT | None = None) -> CacheResultT: + inputs: CacheInputs[CacheExprT], + result: CacheResultT) -> CacheResultT: """Cache a mapping result.""" - if key is None: - if isinstance(key_inputs, tuple): - expr, key_args, key_kwargs = key_inputs - key = self.get_key(expr, *key_args, **key_kwargs) - else: - key = self.get_key(key_inputs) + key = inputs.key + + assert key not in self._expr_key_to_result, \ + f"Cache entry is already present for key '{key}'." self._expr_key_to_result[key] = result + if self.err_on_collision: + self._expr_key_to_expr[key] = inputs.expr return result - def retrieve( - self, - key_inputs: - CacheExprT - | tuple[CacheExprT, tuple[Any, ...], dict[str, Any]], - key: CacheKeyT | None = None) -> CacheResultT: + def retrieve(self, inputs: CacheInputs[CacheExprT]) -> CacheResultT: """Retrieve the cached mapping result.""" - if key is None: - if isinstance(key_inputs, tuple): - expr, key_args, key_kwargs = key_inputs - key = self.get_key(expr, *key_args, **key_kwargs) - else: - key = self.get_key(key_inputs) + key = inputs.key + + result = self._expr_key_to_result[key] + + if self.err_on_collision and inputs.expr is not self._expr_key_to_expr[key]: + raise CacheCollisionError - return self._expr_key_to_result[key] + return result def clear(self) -> None: """Reset the cache.""" @@ -385,6 +416,7 @@ class CachedMapper(Mapper[ResultT, FunctionResultT, P]): """ def __init__( self, + err_on_collision: bool | None = None, _cache: CachedMapperCache[ArrayOrNames, ResultT] | None = None, _function_cache: @@ -392,46 +424,79 @@ def __init__( ) -> None: super().__init__() + if err_on_collision is None: + err_on_collision = __debug__ + self._cache: CachedMapperCache[ArrayOrNames, ResultT] = ( _cache if _cache is not None - else CachedMapperCache(self.get_cache_key)) + else CachedMapperCache( + self.get_cache_key, + err_on_collision=err_on_collision)) self._function_cache: CachedMapperCache[ FunctionDefinition, FunctionResultT] = ( _function_cache if _function_cache is not None - else CachedMapperCache(self.get_function_definition_cache_key)) + else CachedMapperCache( + self.get_function_definition_cache_key, + err_on_collision=err_on_collision)) def get_cache_key( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs - ) -> Hashable: - return (expr, *args, tuple(sorted(kwargs.items()))) + ) -> CacheKeyT: + return \ + (expr, *args, tuple(sorted(kwargs.items()))) if args or kwargs else expr def get_function_definition_cache_key( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs - ) -> Hashable: - return (expr, *args, tuple(sorted(kwargs.items()))) + ) -> CacheKeyT: + return \ + (expr, *args, tuple(sorted(kwargs.items()))) if args or kwargs else expr + + def _cache_add( + self, + inputs: CacheInputs[ArrayOrNames], + result: ResultT) -> ResultT: + return self._cache.add(inputs, result) + + def _function_cache_add( + self, + inputs: CacheInputs[FunctionDefinition], + result: FunctionResultT) -> FunctionResultT: + return self._function_cache.add(inputs, result) + + def _cache_retrieve(self, inputs: CacheInputs[ArrayOrNames]) -> ResultT: + try: + return self._cache.retrieve(inputs) + except CacheCollisionError as e: + raise ValueError( + f"cache collision detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def _function_cache_retrieve( + self, inputs: CacheInputs[FunctionDefinition]) -> FunctionResultT: + try: + return self._function_cache.retrieve(inputs) + except CacheCollisionError as e: + raise ValueError( + f"cache collision detected on {type(inputs.expr)} in " + f"{type(self)}.") from e def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: - key = self._cache.get_key(expr, *args, **kwargs) + inputs = self._cache.make_inputs(expr, *args, **kwargs) try: - return self._cache.retrieve((expr, args, kwargs), key=key) + return self._cache_retrieve(inputs) except KeyError: - return self._cache.add( - (expr, args, kwargs), - super().rec(expr, *args, **kwargs), - key=key) + return self._cache_add(inputs, super().rec(expr, *args, **kwargs)) def rec_function_definition( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs ) -> FunctionResultT: - key = self._function_cache.get_key(expr, *args, **kwargs) + inputs = self._function_cache.make_inputs(expr, *args, **kwargs) try: - return self._function_cache.retrieve((expr, args, kwargs), key=key) + return self._function_cache_retrieve(inputs) except KeyError: - return self._function_cache.add( - (expr, args, kwargs), - super().rec_function_definition(expr, *args, **kwargs), - key=key) + return self._function_cache_add( + inputs, super().rec_function_definition(expr, *args, **kwargs)) def clone_for_callee( self, function: FunctionDefinition) -> Self: @@ -439,8 +504,10 @@ def clone_for_callee( Called to clone *self* before starting traversal of a :class:`pytato.function.FunctionDefinition`. """ - # Functions are cached globally, but arrays aren't - return type(self)(_function_cache=self._function_cache) + return type(self)( + err_on_collision=self._cache.err_on_collision, + # Functions are cached globally, but arrays aren't + _function_cache=self._function_cache) # }}} @@ -448,7 +515,81 @@ def clone_for_callee( # {{{ TransformMapper class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT]): - pass + """ + Cache for :class:`TransformMapper` and :class:`TransformMapperWithExtraArgs`. + + .. automethod:: __init__ + .. automethod:: add + """ + def __init__( + self, + key_func: Callable[..., CacheKeyT], + err_on_collision: bool, + err_on_no_op_duplication: bool) -> None: + """ + Initialize the cache. + + :arg key_func: Function to compute a hashable cache key from an input + expression and any extra arguments. + :arg err_on_collision: Raise an exception if two distinct input expression + instances have the same key. + :arg err_on_no_op_duplication: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + super().__init__(key_func, err_on_collision=err_on_collision) + + self.err_on_no_op_duplication = err_on_no_op_duplication + + self._result_key_to_result: dict[CacheKeyT, CacheExprT] = {} + + def add( + self, + inputs: CacheInputs[CacheExprT], + result: CacheExprT, + result_key: CacheKeyT | None = None) -> CacheExprT: + """ + Cache a mapping result. + + Returns the cached result (which may not be identical to *result* if a + result was already cached with the same result key). + """ + key = inputs.key + + assert key not in self._expr_key_to_result, \ + f"Cache entry is already present for key '{key}'." + + if result_key is None: + result_key = self._key_func(result) + + try: + result = self._result_key_to_result[result_key] + except KeyError: + from pytato.analysis import DirectPredecessorsGetter + pred_getter = DirectPredecessorsGetter(include_functions=True) + if ( + self.err_on_no_op_duplication + and hash(result_key) == hash(key) + and result_key == key + and result is not inputs.expr + # Need this check in order to handle input DAGs that have existing + # duplicates. Deduplication will potentially replace predecessors + # of `expr` with cached versions, producing a new `result` that has + # the same cache key as `expr`. + and all( + result_pred is pred + for pred, result_pred in zip( + pred_getter(inputs.expr), + pred_getter(result), + strict=True))): + raise CacheNoOpDuplicationError from None + + self._result_key_to_result[result_key] = result + + self._expr_key_to_result[key] = result + if self.err_on_collision: + self._expr_key_to_expr[key] = inputs.expr + + return result class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): @@ -458,13 +599,84 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): Enables certain operations that can only be done if the mapping results are also arrays (e.g., computing a cache key from them). Does not implement default mapper methods; for that, see :class:`CopyMapper`. + + .. automethod:: __init__ + .. automethod:: clone_for_callee """ def __init__( self, + err_on_collision: bool | None = None, + err_on_no_op_duplication: bool | None = None, _cache: TransformMapperCache[ArrayOrNames] | None = None, _function_cache: TransformMapperCache[FunctionDefinition] | None = None ) -> None: - super().__init__(_cache=_cache, _function_cache=_function_cache) + """ + :arg err_on_collision: Raise an exception if two distinct input array + instances have the same key. + :arg err_on_no_op_duplication: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + if err_on_collision is None: + err_on_collision = __debug__ + if err_on_no_op_duplication is None: + err_on_no_op_duplication = __debug__ + + if _cache is None: + _cache = TransformMapperCache( + self.get_cache_key, + err_on_collision=err_on_collision, + err_on_no_op_duplication=err_on_no_op_duplication) + + if _function_cache is None: + _function_cache = TransformMapperCache( + self.get_function_definition_cache_key, + err_on_collision=err_on_collision, + err_on_no_op_duplication=err_on_no_op_duplication) + + super().__init__( + err_on_collision=err_on_collision, + _cache=_cache, + _function_cache=_function_cache) + + def get_cache_key(self, expr: ArrayOrNames) -> CacheKeyT: + return expr + + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> CacheKeyT: + return expr + + def _cache_add( + self, + inputs: CacheInputs[ArrayOrNames], + result: ArrayOrNames) -> ArrayOrNames: + try: + return self._cache.add(inputs, result) + except CacheNoOpDuplicationError as e: + raise ValueError( + f"no-op duplication detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def _function_cache_add( + self, + inputs: CacheInputs[FunctionDefinition], + result: FunctionDefinition) -> FunctionDefinition: + try: + return self._function_cache.add(inputs, result) + except CacheNoOpDuplicationError as e: + raise ValueError( + f"no-op duplication detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + function_cache = cast( + "TransformMapperCache[FunctionDefinition]", self._function_cache) + return type(self)( + err_on_collision=function_cache.err_on_collision, + err_on_no_op_duplication=function_cache.err_on_no_op_duplication, + _function_cache=function_cache) # }}} @@ -480,14 +692,79 @@ class TransformMapperWithExtraArgs( The logic in :class:`TransformMapper` purposely does not take the extra arguments to keep the cost of its each call frame low. + + .. automethod:: __init__ + .. automethod:: clone_for_callee """ def __init__( self, + err_on_collision: bool | None = None, + err_on_no_op_duplication: bool | None = None, _cache: TransformMapperCache[ArrayOrNames] | None = None, _function_cache: TransformMapperCache[FunctionDefinition] | None = None ) -> None: - super().__init__(_cache=_cache, _function_cache=_function_cache) + """ + :arg err_on_collision: Raise an exception if two distinct input array + instances have the same key. + :arg err_on_no_op_duplication: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + if err_on_collision is None: + err_on_collision = __debug__ + if err_on_no_op_duplication is None: + err_on_no_op_duplication = __debug__ + + if _cache is None: + _cache = TransformMapperCache( + self.get_cache_key, + err_on_collision=err_on_collision, + err_on_no_op_duplication=err_on_no_op_duplication) + + if _function_cache is None: + _function_cache = TransformMapperCache( + self.get_function_definition_cache_key, + err_on_collision=err_on_collision, + err_on_no_op_duplication=err_on_no_op_duplication) + + super().__init__( + err_on_collision=err_on_collision, + _cache=_cache, + _function_cache=_function_cache) + + def _cache_add( + self, + inputs: CacheInputs[ArrayOrNames], + result: ArrayOrNames) -> ArrayOrNames: + try: + return self._cache.add(inputs, result) + except CacheNoOpDuplicationError as e: + raise ValueError( + f"no-op duplication detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def _function_cache_add( + self, + inputs: CacheInputs[FunctionDefinition], + result: FunctionDefinition) -> FunctionDefinition: + try: + return self._function_cache.add(inputs, result) + except CacheNoOpDuplicationError as e: + raise ValueError( + f"no-op duplication detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + function_cache = cast( + "TransformMapperCache[FunctionDefinition]", self._function_cache) + return type(self)( + err_on_collision=function_cache.err_on_collision, + err_on_no_op_duplication=function_cache.err_on_no_op_duplication, + _function_cache=function_cache) # }}} @@ -507,63 +784,106 @@ def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...] ) -> tuple[IndexOrShapeExpr, ...]: # type-ignore-reason: apparently mypy cannot substitute typevars # here. - return tuple(self.rec(s) if isinstance(s, Array) else s # type: ignore[misc] - for s in situp) + new_situp = tuple( + self.rec(s) if isinstance(s, Array) else s + for s in situp) + if all(new_s is s for s, new_s in zip(situp, new_situp, strict=True)): + return situp + else: + return new_situp # type: ignore[return-value] def map_index_lambda(self, expr: IndexLambda) -> Array: - bindings: Mapping[str, Array] = immutabledict({ + new_shape = self.rec_idx_or_size_tuple(expr.shape) + new_bindings: Mapping[str, Array] = immutabledict({ name: self.rec(subexpr) for name, subexpr in sorted(expr.bindings.items())}) - return IndexLambda(expr=expr.expr, - shape=self.rec_idx_or_size_tuple(expr.shape), - dtype=expr.dtype, - bindings=bindings, - axes=expr.axes, - var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if ( + new_shape is expr.shape + and all( + new_bnd is bnd + for bnd, new_bnd in zip( + expr.bindings.values(), + new_bindings.values(), + strict=True))): + return expr + else: + return IndexLambda(expr=expr.expr, + shape=new_shape, + dtype=expr.dtype, + bindings=new_bindings, + axes=expr.axes, + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_placeholder(self, expr: Placeholder) -> Array: assert expr.name is not None - return Placeholder(name=expr.name, - shape=self.rec_idx_or_size_tuple(expr.shape), - dtype=expr.dtype, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape) + if new_shape is expr.shape: + return expr + else: + return Placeholder(name=expr.name, + shape=new_shape, + dtype=expr.dtype, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_stack(self, expr: Stack) -> Array: - arrays = tuple(_verify_is_array(self.rec(arr)) for arr in expr.arrays) - return Stack(arrays=arrays, axis=expr.axis, axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_arrays = tuple(_verify_is_array(self.rec(arr)) for arr in expr.arrays) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + return expr + else: + return Stack(arrays=new_arrays, axis=expr.axis, axes=expr.axes, + tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_concatenate(self, expr: Concatenate) -> Array: - arrays = tuple(_verify_is_array(self.rec(arr)) for arr in expr.arrays) - return Concatenate(arrays=arrays, axis=expr.axis, - axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_arrays = tuple(_verify_is_array(self.rec(arr)) for arr in expr.arrays) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + return expr + else: + return Concatenate(arrays=new_arrays, axis=expr.axis, + axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_roll(self, expr: Roll) -> Array: - return Roll(array=_verify_is_array(self.rec(expr.array)), - shift=expr.shift, - axis=expr.axis, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array)) + if new_ary is expr.array: + return expr + else: + return Roll(array=new_ary, + shift=expr.shift, + axis=expr.axis, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_axis_permutation(self, expr: AxisPermutation) -> Array: - return AxisPermutation(array=_verify_is_array(self.rec(expr.array)), - axis_permutation=expr.axis_permutation, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array)) + if new_ary is expr.array: + return expr + else: + return AxisPermutation(array=new_ary, + axis_permutation=expr.axis_permutation, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def _map_index_base(self, expr: IndexBase) -> Array: - return type(expr)(_verify_is_array(self.rec(expr.array)), - indices=self.rec_idx_or_size_tuple(expr.indices), - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array)) + new_indices = self.rec_idx_or_size_tuple(expr.indices) + if new_ary is expr.array and new_indices is expr.indices: + return expr + else: + return type(expr)(new_ary, + indices=new_indices, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_basic_index(self, expr: BasicIndex) -> Array: return self._map_index_base(expr) @@ -579,91 +899,132 @@ def map_non_contiguous_advanced_index(self, return self._map_index_base(expr) def map_data_wrapper(self, expr: DataWrapper) -> Array: - return DataWrapper( - data=expr.data, - shape=self.rec_idx_or_size_tuple(expr.shape), - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape) + if new_shape is expr.shape: + return expr + else: + return DataWrapper( + data=expr.data, + shape=new_shape, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_size_param(self, expr: SizeParam) -> Array: assert expr.name is not None - return SizeParam( - name=expr.name, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + return expr def map_einsum(self, expr: Einsum) -> Array: - return Einsum(expr.access_descriptors, - tuple(_verify_is_array(self.rec(arg)) for arg in expr.args), - axes=expr.axes, - redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) - - def map_named_array(self, expr: NamedArray) -> Array: - container = self.rec(expr._container) - assert isinstance(container, AbstractResultWithNamedArrays) - return type(expr)(container, - expr.name, + new_args = tuple(_verify_is_array(self.rec(arg)) for arg in expr.args) + if all( + new_arg is arg + for arg, new_arg in zip(expr.args, new_args, strict=True)): + return expr + else: + return Einsum(expr.access_descriptors, + new_args, axes=expr.axes, + redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, tags=expr.tags, non_equality_tags=expr.non_equality_tags) + def map_named_array(self, expr: NamedArray) -> Array: + new_container = self.rec(expr._container) + assert isinstance(new_container, AbstractResultWithNamedArrays) + if new_container is expr._container: + return expr + else: + return type(expr)(new_container, + expr.name, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> DictOfNamedArrays: - return DictOfNamedArrays({key: _verify_is_array(self.rec(val.expr)) - for key, val in expr.items()}, - tags=expr.tags - ) + new_data = { + key: _verify_is_array(self.rec(val.expr)) + for key, val in expr.items()} + if all( + new_data_val is val.expr + for val, new_data_val in zip( + expr.values(), + new_data.values(), + strict=True)): + return expr + else: + return DictOfNamedArrays(new_data, tags=expr.tags) def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: - bindings: Mapping[Any, Any] = immutabledict( + new_bindings: Mapping[Any, Any] = immutabledict( {name: (self.rec(subexpr) if isinstance(subexpr, Array) else subexpr) for name, subexpr in sorted(expr.bindings.items())}) - - return LoopyCall(translation_unit=expr.translation_unit, - bindings=bindings, - entrypoint=expr.entrypoint, - tags=expr.tags, - ) + if all( + new_bnd is bnd + for bnd, new_bnd in zip( + expr.bindings.values(), + new_bindings.values(), + strict=True)): + return expr + else: + return LoopyCall(translation_unit=expr.translation_unit, + bindings=new_bindings, + entrypoint=expr.entrypoint, + tags=expr.tags, + ) def map_loopy_call_result(self, expr: LoopyCallResult) -> Array: - rec_container = self.rec(expr._container) - assert isinstance(rec_container, LoopyCall) - return LoopyCallResult( - _container=rec_container, - name=expr.name, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_container = self.rec(expr._container) + assert isinstance(new_container, LoopyCall) + if new_container is expr._container: + return expr + else: + return LoopyCallResult( + _container=new_container, + name=expr.name, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_reshape(self, expr: Reshape) -> Array: - return Reshape(_verify_is_array(self.rec(expr.array)), - newshape=self.rec_idx_or_size_tuple(expr.newshape), - order=expr.order, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array)) + new_newshape = self.rec_idx_or_size_tuple(expr.newshape) + if new_ary is expr.array and new_newshape is expr.newshape: + return expr + else: + return Reshape(new_ary, + newshape=new_newshape, + order=expr.order, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_distributed_send_ref_holder( self, expr: DistributedSendRefHolder) -> Array: - return DistributedSendRefHolder( - send=DistributedSend( - data=_verify_is_array(self.rec(expr.send.data)), - dest_rank=expr.send.dest_rank, - comm_tag=expr.send.comm_tag), - passthrough_data=_verify_is_array(self.rec(expr.passthrough_data)), - ) + new_send_data = _verify_is_array(self.rec(expr.send.data)) + if new_send_data is expr.send.data: + new_send = expr.send + else: + new_send = DistributedSend( + data=new_send_data, + dest_rank=expr.send.dest_rank, + comm_tag=expr.send.comm_tag) + new_passthrough = _verify_is_array(self.rec(expr.passthrough_data)) + if new_send is expr.send and new_passthrough is expr.passthrough_data: + return expr + else: + return DistributedSendRefHolder(new_send, new_passthrough) def map_distributed_recv(self, expr: DistributedRecv) -> Array: - return DistributedRecv( - src_rank=expr.src_rank, comm_tag=expr.comm_tag, - shape=self.rec_idx_or_size_tuple(expr.shape), - dtype=expr.dtype, tags=expr.tags, axes=expr.axes, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape) + if new_shape is expr.shape: + return expr + else: + return DistributedRecv( + src_rank=expr.src_rank, comm_tag=expr.comm_tag, + shape=new_shape, dtype=expr.dtype, tags=expr.tags, + axes=expr.axes, non_equality_tags=expr.non_equality_tags) def map_function_definition(self, expr: FunctionDefinition) -> FunctionDefinition: @@ -672,19 +1033,37 @@ def map_function_definition(self, new_mapper = self.clone_for_callee(expr) new_returns = {name: new_mapper(ret) for name, ret in expr.returns.items()} - return dataclasses.replace(expr, returns=immutabledict(new_returns)) + if all( + new_ret is ret + for ret, new_ret in zip( + expr.returns.values(), + new_returns.values(), + strict=True)): + return expr + else: + return dataclasses.replace(expr, returns=immutabledict(new_returns)) def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: - return Call(self.rec_function_definition(expr.function), - immutabledict({name: self.rec(bnd) - for name, bnd in expr.bindings.items()}), - tags=expr.tags, - ) + new_function = self.rec_function_definition(expr.function) + new_bindings = { + name: _verify_is_array(self.rec(bnd)) + for name, bnd in expr.bindings.items()} + if ( + new_function is expr.function + and all( + new_bnd is bnd + for bnd, new_bnd in zip( + expr.bindings.values(), + new_bindings.values(), + strict=True))): + return expr + else: + return Call(new_function, immutabledict(new_bindings), tags=expr.tags) def map_named_call_result(self, expr: NamedCallResult) -> Array: - call = self.rec(expr._container) - assert isinstance(call, Call) - return call[expr.name] + new_call = self.rec(expr._container) + assert isinstance(new_call, Call) + return new_call[expr.name] class CopyMapperWithExtraArgs(TransformMapperWithExtraArgs[P]): @@ -708,70 +1087,104 @@ def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...], def map_index_lambda(self, expr: IndexLambda, *args: P.args, **kwargs: P.kwargs) -> Array: - bindings: Mapping[str, Array] = immutabledict({ + new_shape = self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs) + new_bindings: Mapping[str, Array] = immutabledict({ name: self.rec(subexpr, *args, **kwargs) for name, subexpr in sorted(expr.bindings.items())}) - return IndexLambda(expr=expr.expr, - shape=self.rec_idx_or_size_tuple(expr.shape, - *args, **kwargs), - dtype=expr.dtype, - bindings=bindings, - axes=expr.axes, - var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if ( + new_shape is expr.shape + and all( + new_bnd is bnd + for bnd, new_bnd in zip( + expr.bindings.values(), + new_bindings.values(), + strict=True))): + return expr + else: + return IndexLambda(expr=expr.expr, + shape=new_shape, + dtype=expr.dtype, + bindings=new_bindings, + axes=expr.axes, + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_placeholder(self, expr: Placeholder, *args: P.args, **kwargs: P.kwargs) -> Array: assert expr.name is not None - return Placeholder(name=expr.name, - shape=self.rec_idx_or_size_tuple(expr.shape, - *args, **kwargs), - dtype=expr.dtype, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs) + if new_shape is expr.shape: + return expr + else: + return Placeholder(name=expr.name, + shape=new_shape, + dtype=expr.dtype, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_stack(self, expr: Stack, *args: P.args, **kwargs: P.kwargs) -> Array: - arrays = tuple( + new_arrays: tuple[Array, ...] = tuple( _verify_is_array(self.rec(arr, *args, **kwargs)) for arr in expr.arrays) - return Stack(arrays=arrays, axis=expr.axis, axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + return expr + else: + return Stack(arrays=new_arrays, axis=expr.axis, axes=expr.axes, + tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_concatenate(self, expr: Concatenate, *args: P.args, **kwargs: P.kwargs) -> Array: - arrays = tuple( + new_arrays: tuple[Array, ...] = tuple( _verify_is_array(self.rec(arr, *args, **kwargs)) for arr in expr.arrays) - return Concatenate(arrays=arrays, axis=expr.axis, - axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + return expr + else: + return Concatenate(arrays=new_arrays, axis=expr.axis, + axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_roll(self, expr: Roll, *args: P.args, **kwargs: P.kwargs) -> Array: - return Roll(array=_verify_is_array(self.rec(expr.array, *args, **kwargs)), - shift=expr.shift, - axis=expr.axis, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array, *args, **kwargs)) + if new_ary is expr.array: + return expr + else: + return Roll(array=new_ary, + shift=expr.shift, + axis=expr.axis, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_axis_permutation(self, expr: AxisPermutation, *args: P.args, **kwargs: P.kwargs) -> Array: - return AxisPermutation(array=_verify_is_array( - self.rec(expr.array, *args, **kwargs)), - axis_permutation=expr.axis_permutation, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array, *args, **kwargs)) + if new_ary is expr.array: + return expr + else: + return AxisPermutation(array=new_ary, + axis_permutation=expr.axis_permutation, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def _map_index_base(self, expr: IndexBase, *args: P.args, **kwargs: P.kwargs) -> Array: assert isinstance(expr, _SuppliedAxesAndTagsMixin) - return type(expr)(_verify_is_array(self.rec(expr.array, *args, **kwargs)), - indices=self.rec_idx_or_size_tuple(expr.indices, - *args, **kwargs), - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array, *args, **kwargs)) + new_indices = self.rec_idx_or_size_tuple(expr.indices, *args, **kwargs) + if new_ary is expr.array and new_indices is expr.indices: + return expr + else: + return type(expr)(new_ary, + indices=new_indices, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_basic_index(self, expr: BasicIndex, *args: P.args, **kwargs: P.kwargs) -> Array: @@ -792,98 +1205,142 @@ def map_non_contiguous_advanced_index(self, def map_data_wrapper(self, expr: DataWrapper, *args: P.args, **kwargs: P.kwargs) -> Array: - return DataWrapper( - data=expr.data, - shape=self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs), - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs) + if new_shape is expr.shape: + return expr + else: + return DataWrapper( + data=expr.data, + shape=new_shape, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_size_param(self, expr: SizeParam, *args: P.args, **kwargs: P.kwargs) -> Array: assert expr.name is not None - return SizeParam(name=expr.name, axes=expr.axes, tags=expr.tags) + return expr def map_einsum(self, expr: Einsum, *args: P.args, **kwargs: P.kwargs) -> Array: - return Einsum(expr.access_descriptors, - tuple(_verify_is_array( - self.rec(arg, *args, **kwargs)) for arg in expr.args), - axes=expr.axes, - redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) - - def map_named_array(self, - expr: NamedArray, *args: P.args, **kwargs: P.kwargs) -> Array: - container = self.rec(expr._container, *args, **kwargs) - assert isinstance(container, AbstractResultWithNamedArrays) - return type(expr)(container, - expr.name, + new_args: tuple[Array, ...] = tuple( + _verify_is_array(self.rec(arg, *args, **kwargs)) for arg in expr.args) + if all( + new_arg is arg + for arg, new_arg in zip(expr.args, new_args, strict=True)): + return expr + else: + return Einsum(expr.access_descriptors, + new_args, axes=expr.axes, + redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, tags=expr.tags, non_equality_tags=expr.non_equality_tags) + def map_named_array(self, + expr: NamedArray, *args: P.args, **kwargs: P.kwargs) -> Array: + new_container = self.rec(expr._container, *args, **kwargs) + assert isinstance(new_container, AbstractResultWithNamedArrays) + if new_container is expr._container: + return expr + else: + return type(expr)(new_container, + expr.name, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + def map_dict_of_named_arrays(self, expr: DictOfNamedArrays, *args: P.args, **kwargs: P.kwargs ) -> DictOfNamedArrays: - return DictOfNamedArrays({key: _verify_is_array( - self.rec(val.expr, *args, **kwargs)) - for key, val in expr.items()}, - tags=expr.tags, - ) + new_data = { + key: _verify_is_array(self.rec(val.expr, *args, **kwargs)) + for key, val in expr.items()} + if all( + new_data_val is val.expr + for val, new_data_val in zip( + expr.values(), + new_data.values(), + strict=True)): + return expr + else: + return DictOfNamedArrays(new_data, tags=expr.tags) def map_loopy_call(self, expr: LoopyCall, *args: P.args, **kwargs: P.kwargs) -> LoopyCall: - bindings: Mapping[Any, Any] = immutabledict( + new_bindings: Mapping[Any, Any] = immutabledict( {name: (self.rec(subexpr, *args, **kwargs) if isinstance(subexpr, Array) else subexpr) for name, subexpr in sorted(expr.bindings.items())}) - - return LoopyCall(translation_unit=expr.translation_unit, - bindings=bindings, - entrypoint=expr.entrypoint, - tags=expr.tags, - ) + if all( + new_bnd is bnd + for bnd, new_bnd in zip( + expr.bindings.values(), + new_bindings.values(), + strict=True)): + return expr + else: + return LoopyCall(translation_unit=expr.translation_unit, + bindings=new_bindings, + entrypoint=expr.entrypoint, + tags=expr.tags, + ) def map_loopy_call_result(self, expr: LoopyCallResult, *args: P.args, **kwargs: P.kwargs) -> Array: - rec_loopy_call = self.rec(expr._container, *args, **kwargs) - assert isinstance(rec_loopy_call, LoopyCall) - return LoopyCallResult( - _container=rec_loopy_call, - name=expr.name, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_container = self.rec(expr._container, *args, **kwargs) + assert isinstance(new_container, LoopyCall) + if new_container is expr._container: + return expr + else: + return LoopyCallResult( + _container=new_container, + name=expr.name, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_reshape(self, expr: Reshape, *args: P.args, **kwargs: P.kwargs) -> Array: - return Reshape(_verify_is_array(self.rec(expr.array, *args, **kwargs)), - newshape=self.rec_idx_or_size_tuple(expr.newshape, - *args, **kwargs), - order=expr.order, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array, *args, **kwargs)) + new_newshape = self.rec_idx_or_size_tuple(expr.newshape, *args, **kwargs) + if new_ary is expr.array and new_newshape is expr.newshape: + return expr + else: + return Reshape(new_ary, + newshape=new_newshape, + order=expr.order, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder, *args: P.args, **kwargs: P.kwargs) -> Array: - return DistributedSendRefHolder( - send=DistributedSend( - data=_verify_is_array(self.rec(expr.send.data, *args, **kwargs)), - dest_rank=expr.send.dest_rank, - comm_tag=expr.send.comm_tag), - passthrough_data=_verify_is_array( - self.rec(expr.passthrough_data, *args, **kwargs))) + new_send_data = _verify_is_array(self.rec(expr.send.data, *args, **kwargs)) + if new_send_data is expr.send.data: + new_send = expr.send + else: + new_send = DistributedSend( + data=new_send_data, + dest_rank=expr.send.dest_rank, + comm_tag=expr.send.comm_tag) + new_passthrough = _verify_is_array( + self.rec(expr.passthrough_data, *args, **kwargs)) + if new_send is expr.send and new_passthrough is expr.passthrough_data: + return expr + else: + return DistributedSendRefHolder(new_send, new_passthrough) def map_distributed_recv(self, expr: DistributedRecv, *args: P.args, **kwargs: P.kwargs) -> Array: - return DistributedRecv( - src_rank=expr.src_rank, comm_tag=expr.comm_tag, - shape=self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs), - dtype=expr.dtype, tags=expr.tags, axes=expr.axes, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs) + if new_shape is expr.shape: + return expr + else: + return DistributedRecv( + src_rank=expr.src_rank, comm_tag=expr.comm_tag, + shape=new_shape, dtype=expr.dtype, tags=expr.tags, + axes=expr.axes, non_equality_tags=expr.non_equality_tags) def map_function_definition( self, expr: FunctionDefinition, @@ -895,17 +1352,49 @@ def map_function_definition( def map_call(self, expr: Call, *args: P.args, **kwargs: P.kwargs) -> AbstractResultWithNamedArrays: - return Call(self.rec_function_definition(expr.function, *args, **kwargs), - immutabledict({name: self.rec(bnd, *args, **kwargs) - for name, bnd in expr.bindings.items()}), - tags=expr.tags, - ) + new_function = self.rec_function_definition(expr.function, *args, **kwargs) + new_bindings = { + name: self.rec(bnd, *args, **kwargs) + for name, bnd in expr.bindings.items()} + if ( + new_function is expr.function + and all( + new_bnd is bnd + for bnd, new_bnd in zip( + expr.bindings.values(), + new_bindings.values(), + strict=True))): + return expr + else: + return Call(new_function, immutabledict(new_bindings), tags=expr.tags) def map_named_call_result(self, expr: NamedCallResult, *args: P.args, **kwargs: P.kwargs) -> Array: - call = self.rec(expr._container, *args, **kwargs) - assert isinstance(call, Call) - return call[expr.name] + new_call = self.rec(expr._container, *args, **kwargs) + assert isinstance(new_call, Call) + return new_call[expr.name] + +# }}} + + +# {{{ Deduplicator + +class Deduplicator(CopyMapper): + """Removes duplicate nodes from an expression.""" + def __init__( + self, + _cache: TransformMapperCache[ArrayOrNames] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition] | None = None + ) -> None: + super().__init__( + err_on_collision=False, err_on_no_op_duplication=False, + _cache=_cache, + _function_cache=_function_cache) + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + _function_cache=cast( + "TransformMapperCache[FunctionDefinition]", self._function_cache)) # }}} @@ -1027,7 +1516,7 @@ def map_named_call_result(self, expr: NamedCallResult) -> ResultT: # {{{ DependencyMapper -class DependencyMapper(CombineMapper[R, R]): +class DependencyMapper(CombineMapper[R, Never]): """ Maps a :class:`pytato.array.Array` to a :class:`frozenset` of :class:`pytato.array.Array`'s it depends on. @@ -1089,14 +1578,10 @@ def map_distributed_send_ref_holder( def map_distributed_recv(self, expr: DistributedRecv) -> R: return self.combine(frozenset([expr]), super().map_distributed_recv(expr)) - def map_function_definition(self, expr: FunctionDefinition) -> R: + def map_call(self, expr: Call) -> R: # do not include arrays from the function's body as it would involve # putting arrays from different namespaces into the same collection. - return frozenset() - - def map_call(self, expr: Call) -> R: - return self.combine(self.rec_function_definition(expr.function), - *[self.rec(bnd) for bnd in expr.bindings.values()]) + return self.combine(*[self.rec(bnd) for bnd in expr.bindings.values()]) def map_named_call_result(self, expr: NamedCallResult) -> R: return self.rec(expr._container) @@ -1512,6 +1997,7 @@ class CachedMapAndCopyMapper(CopyMapper): def __init__( self, + # FIXME: Should map_fn be applied to functions too? map_fn: Callable[[ArrayOrNames], ArrayOrNames], _cache: TransformMapperCache[ArrayOrNames] | None = None, _function_cache: TransformMapperCache[FunctionDefinition] | None = None @@ -1527,12 +2013,11 @@ def clone_for_callee( "TransformMapperCache[FunctionDefinition]", self._function_cache)) def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - key = self._cache.get_key(expr) + inputs = self._cache.make_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache_retrieve(inputs) except KeyError: - return self._cache.add( - expr, super().rec(self.map_fn(expr)), key=key) + return self._cache_add(inputs, Mapper.rec(self, self.map_fn(expr))) # }}} @@ -1549,6 +2034,88 @@ class MPMSMaterializerAccumulator: expr: Array +class MPMSMaterializerCache( + CachedMapperCache[ArrayOrNames, MPMSMaterializerAccumulator]): + """ + Cache for :class:`MPMSMaterializer`. + + .. automethod:: __init__ + .. automethod:: add + """ + def __init__( + self, + key_func: Callable[[ArrayOrNames], CacheKeyT], + result_key_func: Callable[[MPMSMaterializerAccumulator], CacheKeyT], + err_on_collision: bool, + err_on_no_op_duplication: bool) -> None: + """ + Initialize the cache. + + :arg err_on_collision: Raise an exception if two distinct input expression + instances have the same key. + :arg err_on_no_op_duplication: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + super().__init__( + key_func, + err_on_collision=err_on_collision) + + self.err_on_no_op_duplication = err_on_no_op_duplication + self.get_result_key = result_key_func + + self._result_key_to_result: dict[ + CacheKeyT, MPMSMaterializerAccumulator] = {} + + def add( + self, + inputs: CacheInputs[ArrayOrNames], + result: MPMSMaterializerAccumulator, + result_key: CacheKeyT | None = None) -> MPMSMaterializerAccumulator: + """ + Cache a mapping result. + + Returns the cached result (which may not be identical to *result* if a + result was already cached with the same result key). + """ + key = inputs.key + + assert key not in self._expr_key_to_result, \ + f"Cache entry is already present for key '{key}'." + + if result_key is None: + result_key = self.get_result_key(result) + + try: + result = self._result_key_to_result[result_key] + except KeyError: + from pytato.analysis import DirectPredecessorsGetter + pred_getter = DirectPredecessorsGetter(include_functions=True) + if ( + self.err_on_no_op_duplication + and hash(result_key) == hash(key) + and result_key == key + and result.expr is not inputs.expr + # Need this check in order to handle input DAGs that have existing + # duplicates. Deduplication will potentially replace predecessors + # of `expr` with cached versions, producing a new `result` that has + # the same cache key as `expr`. + and all( + result_pred is pred + for pred, result_pred in zip( + pred_getter(inputs.expr), + pred_getter(result.expr), + strict=True))): + raise CacheNoOpDuplicationError from None + + self._result_key_to_result[result_key] = result + + self._expr_key_to_result[key] = result + if self.err_on_collision: + self._expr_key_to_expr[key] = inputs.expr + + return result + + def _materialize_if_mpms(expr: Array, nsuccessors: int, predecessors: Iterable[MPMSMaterializerAccumulator] @@ -1566,13 +2133,16 @@ def _materialize_if_mpms(expr: Array, for pred in predecessors), frozenset()) if nsuccessors > 1 and len(materialized_predecessors) > 1: - new_expr = expr.tagged(ImplStored()) + if not expr.tags_of_type(ImplStored): + new_expr = expr.tagged(ImplStored()) + else: + new_expr = expr return MPMSMaterializerAccumulator(frozenset([new_expr]), new_expr) else: return MPMSMaterializerAccumulator(materialized_predecessors, expr) -class MPMSMaterializer(Mapper[MPMSMaterializerAccumulator, Never, []]): +class MPMSMaterializer(CachedMapper[MPMSMaterializerAccumulator, Never, []]): """ See :func:`materialize_with_mpms` for an explanation. @@ -1581,17 +2151,49 @@ class MPMSMaterializer(Mapper[MPMSMaterializerAccumulator, Never, []]): A mapping from a node in the expression graph (i.e. an :class:`~pytato.Array`) to its number of successors. """ - def __init__(self, nsuccessors: Mapping[Array, int]): - super().__init__() + def __init__( + self, + nsuccessors: Mapping[Array, int], + _cache: MPMSMaterializerCache | None = None): + err_on_collision = __debug__ + err_on_no_op_duplication = __debug__ + + if _cache is None: + _cache = MPMSMaterializerCache( + key_func=self.get_cache_key, + result_key_func=self.get_result_cache_key, + err_on_collision=err_on_collision, + err_on_no_op_duplication=err_on_no_op_duplication) + + # Does not support functions, so function_cache is ignored + super().__init__(err_on_collision=err_on_collision, _cache=_cache) + self.nsuccessors = nsuccessors - self.cache: dict[ArrayOrNames, MPMSMaterializerAccumulator] = {} - def rec(self, expr: ArrayOrNames) -> MPMSMaterializerAccumulator: - if expr in self.cache: - return self.cache[expr] - result: MPMSMaterializerAccumulator = super().rec(expr) - self.cache[expr] = result - return result + def get_cache_key(self, expr: ArrayOrNames) -> CacheKeyT: + return expr + + def get_result_cache_key(self, result: MPMSMaterializerAccumulator) -> CacheKeyT: + return result.expr + + def _cache_add( + self, + inputs: CacheInputs[ArrayOrNames], + result: MPMSMaterializerAccumulator) -> MPMSMaterializerAccumulator: + try: + return self._cache.add(inputs, result) + except CacheNoOpDuplicationError as e: + raise ValueError( + f"no-op duplication detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def clone_for_callee( + self, function: FunctionDefinition) -> Self: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + raise AssertionError("Control shouldn't reach this point.") def _map_input_base(self, expr: InputArgumentBase ) -> MPMSMaterializerAccumulator: @@ -1606,26 +2208,46 @@ def map_named_array(self, expr: NamedArray) -> MPMSMaterializerAccumulator: " supported for now.") def map_index_lambda(self, expr: IndexLambda) -> MPMSMaterializerAccumulator: + # FIXME: Why were these being sorted? children_rec = {bnd_name: self.rec(bnd) - for bnd_name, bnd in sorted(expr.bindings.items())} + # for bnd_name, bnd in sorted(expr.bindings.items())} + for bnd_name, bnd in expr.bindings.items()} + new_children: Mapping[str, Array] = immutabledict({ + bnd_name: bnd.expr + # for bnd_name, bnd in sorted(children_rec.items())}) + for bnd_name, bnd in children_rec.items()}) + + if all( + new_bnd is bnd + for bnd, new_bnd in zip( + expr.bindings.values(), + new_children.values(), + strict=True)): + new_expr = expr + else: + new_expr = IndexLambda( + expr=expr.expr, + shape=expr.shape, + dtype=expr.dtype, + bindings=new_children, + axes=expr.axes, + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) - new_expr = IndexLambda(expr=expr.expr, - shape=expr.shape, - dtype=expr.dtype, - bindings=immutabledict({bnd_name: bnd.expr - for bnd_name, bnd in sorted(children_rec.items())}), - axes=expr.axes, - var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], children_rec.values()) def map_stack(self, expr: Stack) -> MPMSMaterializerAccumulator: rec_arrays = [self.rec(ary) for ary in expr.arrays] - new_expr = Stack(tuple(ary.expr for ary in rec_arrays), - expr.axis, axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_arrays = tuple(ary.expr for ary in rec_arrays) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + new_expr = expr + else: + new_expr = Stack(new_arrays, expr.axis, axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], @@ -1633,29 +2255,44 @@ def map_stack(self, expr: Stack) -> MPMSMaterializerAccumulator: def map_concatenate(self, expr: Concatenate) -> MPMSMaterializerAccumulator: rec_arrays = [self.rec(ary) for ary in expr.arrays] - new_expr = Concatenate(tuple(ary.expr for ary in rec_arrays), - expr.axis, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_arrays = tuple(ary.expr for ary in rec_arrays) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + new_expr = expr + else: + new_expr = Concatenate(new_arrays, + expr.axis, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + return _materialize_if_mpms(new_expr, self.nsuccessors[expr], rec_arrays) def map_roll(self, expr: Roll) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) - new_expr = Roll(rec_array.expr, expr.shift, expr.axis, axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if rec_array.expr is expr.array: + new_expr = expr + else: + new_expr = Roll(rec_array.expr, expr.shift, expr.axis, axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + return _materialize_if_mpms(new_expr, self.nsuccessors[expr], (rec_array,)) def map_axis_permutation(self, expr: AxisPermutation ) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) - new_expr = AxisPermutation(rec_array.expr, expr.axis_permutation, - axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if rec_array.expr is expr.array: + new_expr = expr + else: + new_expr = AxisPermutation(rec_array.expr, expr.axis_permutation, + axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + return _materialize_if_mpms(new_expr, self.nsuccessors[expr], (rec_array,)) @@ -1665,16 +2302,23 @@ def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator: rec_indices = {i: self.rec(idx) for i, idx in enumerate(expr.indices) if isinstance(idx, Array)} - - new_expr = type(expr)(rec_array.expr, - tuple(rec_indices[i].expr - if i in rec_indices - else expr.indices[i] - for i in range( - len(expr.indices))), - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_indices = tuple(rec_indices[i].expr + if i in rec_indices + else expr.indices[i] + for i in range( + len(expr.indices))) + if ( + rec_array.expr is expr.array + and all( + new_idx is idx + for idx, new_idx in zip(expr.indices, new_indices, strict=True))): + new_expr = expr + else: + new_expr = type(expr)(rec_array.expr, + new_indices, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], @@ -1687,26 +2331,35 @@ def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator: def map_reshape(self, expr: Reshape) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) - new_expr = Reshape(rec_array.expr, expr.newshape, - expr.order, axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if rec_array.expr is expr.array: + new_expr = expr + else: + new_expr = Reshape(rec_array.expr, expr.newshape, + expr.order, axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], (rec_array,)) def map_einsum(self, expr: Einsum) -> MPMSMaterializerAccumulator: - rec_arrays = [self.rec(ary) for ary in expr.args] - new_expr = Einsum(expr.access_descriptors, - tuple(ary.expr for ary in rec_arrays), - expr.redn_axis_to_redn_descr, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + rec_args = [self.rec(ary) for ary in expr.args] + new_args = tuple(ary.expr for ary in rec_args) + if all( + new_arg is arg + for arg, new_arg in zip(expr.args, new_args, strict=True)): + new_expr = expr + else: + new_expr = Einsum(expr.access_descriptors, + new_args, + expr.redn_axis_to_redn_descr, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], - rec_arrays) + rec_args) def map_dict_of_named_arrays(self, expr: DictOfNamedArrays ) -> MPMSMaterializerAccumulator: @@ -1719,15 +2372,21 @@ def map_loopy_call_result(self, expr: NamedArray) -> MPMSMaterializerAccumulator def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder ) -> MPMSMaterializerAccumulator: - rec_passthrough = self.rec(expr.passthrough_data) rec_send_data = self.rec(expr.send.data) - new_expr = DistributedSendRefHolder( - send=DistributedSend(rec_send_data.expr, - dest_rank=expr.send.dest_rank, - comm_tag=expr.send.comm_tag, - tags=expr.send.tags), - passthrough_data=rec_passthrough.expr, - ) + if rec_send_data.expr is expr.send.data: + new_send = expr.send + else: + new_send = DistributedSend( + rec_send_data.expr, + dest_rank=expr.send.dest_rank, + comm_tag=expr.send.comm_tag, + tags=expr.send.tags) + rec_passthrough = self.rec(expr.passthrough_data) + if new_send is expr.send and rec_passthrough.expr is expr.passthrough_data: + new_expr = expr + else: + new_expr = DistributedSendRefHolder(new_send, rec_passthrough.expr) + return MPMSMaterializerAccumulator( rec_passthrough.materialized_predecessors, new_expr) diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py index 34f89cbc1..a6f7bc134 100644 --- a/pytato/transform/calls.py +++ b/pytato/transform/calls.py @@ -30,7 +30,9 @@ """ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast + +from typing_extensions import Self from pytato.array import ( AbstractResultWithNamedArrays, @@ -38,9 +40,14 @@ DictOfNamedArrays, Placeholder, ) -from pytato.function import Call, NamedCallResult +from pytato.function import Call, FunctionDefinition, NamedCallResult from pytato.tags import InlineCallTag -from pytato.transform import ArrayOrNames, CopyMapper, _verify_is_array +from pytato.transform import ( + ArrayOrNames, + CopyMapper, + TransformMapperCache, + _verify_is_array, +) if TYPE_CHECKING: @@ -55,6 +62,12 @@ class PlaceholderSubstitutor(CopyMapper): A mapping from the placeholder name to the array that it is to be substituted with. + + .. note:: + + This mapper does not deduplicate subexpressions that occur in both the mapped + expression and the substitutions. Must follow up with a + :class:`pytato.transform.Deduplicator` if duplicates need to be removed. """ def __init__(self, substitutions: Mapping[str, Array]) -> None: @@ -63,32 +76,51 @@ def __init__(self, substitutions: Mapping[str, Array]) -> None: self.substitutions = substitutions def map_placeholder(self, expr: Placeholder) -> Array: + # Can't call rec() to remove duplicates here, because the substituted-in + # expression may potentially contain unrelated placeholders whose names + # collide with the ones being replaced return self.substitutions[expr.name] - def map_named_call_result(self, expr: NamedCallResult) -> NamedCallResult: - raise NotImplementedError( - "PlaceholderSubstitutor does not support functions.") + def map_function_definition( + self, expr: FunctionDefinition) -> FunctionDefinition: + # Only operates within the current stack frame + return expr class Inliner(CopyMapper): """ Primary mapper for :func:`inline_calls`. """ - def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: - # inline call sites within the callee. - new_expr = super().map_call(expr) - assert isinstance(new_expr, Call) + def __init__( + self, + _cache: TransformMapperCache[ArrayOrNames] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition] | None = None + ) -> None: + # Must disable collision/duplication checking because we're combining + # expressions that were previously in two different call stack frames + # (and were thus cached separately) + super().__init__( + err_on_collision=False, + err_on_no_op_duplication=False, + _cache=_cache, + _function_cache=_function_cache) + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + _function_cache=cast( + "TransformMapperCache[FunctionDefinition]", self._function_cache)) + def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: if expr.tags_of_type(InlineCallTag): - substitutor = PlaceholderSubstitutor(new_expr.bindings) + substitutor = PlaceholderSubstitutor(expr.bindings) return DictOfNamedArrays( - {name: _verify_is_array(substitutor.rec(ret)) - for name, ret in new_expr.function.returns.items()}, - tags=new_expr.tags + {name: _verify_is_array(self.rec(substitutor(ret))) + for name, ret in expr.function.returns.items()}, + tags=expr.tags ) else: - return new_expr + return super().map_call(expr) def map_named_call_result(self, expr: NamedCallResult) -> Array: new_call_or_inlined_expr = self.rec(expr._container) @@ -104,7 +136,11 @@ class InlineMarker(CopyMapper): Primary mapper for :func:`tag_all_calls_to_be_inlined`. """ def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: - return super().map_call(expr).tagged(InlineCallTag()) + rec_expr = super().map_call(expr) + if rec_expr.tags_of_type(InlineCallTag): + return rec_expr + else: + return rec_expr.tagged(InlineCallTag()) def inline_calls(expr: ArrayOrNames) -> ArrayOrNames: diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index 507a450cd..a082dea02 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -62,6 +62,8 @@ if TYPE_CHECKING: + from collections.abc import Mapping + import numpy as np @@ -257,9 +259,13 @@ def _get_reshaped_indices(expr: Reshape) -> tuple[ScalarExpression, ...]: class ToIndexLambdaMixin: def _rec_shape(self, shape: ShapeType) -> ShapeType: - return tuple(self.rec(s) if isinstance(s, Array) - else s - for s in shape) + new_shape = tuple( + self.rec(s) if isinstance(s, Array) else s + for s in shape) + if all(new_s is s for s, new_s in zip(shape, new_shape, strict=True)): + return shape + else: + return new_shape if TYPE_CHECKING: def rec( @@ -270,17 +276,29 @@ def rec( return super().rec( # type: ignore[no-any-return,misc] expr, *args, **kwargs) - def map_index_lambda(self, expr: IndexLambda) -> IndexLambda: - return IndexLambda(expr=expr.expr, - shape=self._rec_shape(expr.shape), - dtype=expr.dtype, - bindings=immutabledict({name: self.rec(bnd) - for name, bnd - in sorted(expr.bindings.items())}), - axes=expr.axes, - var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + def map_index_lambda(self, expr: IndexLambda) -> Array: + new_shape = self._rec_shape(expr.shape) + new_bindings: Mapping[str, Array] = immutabledict({ + name: self.rec(subexpr) + for name, subexpr in sorted(expr.bindings.items())}) + if ( + new_shape is expr.shape + and all( + new_bnd is bnd + for bnd, new_bnd in zip( + expr.bindings.values(), + new_bindings.values(), + strict=True))): + return expr + else: + return IndexLambda(expr=expr.expr, + shape=new_shape, + dtype=expr.dtype, + bindings=new_bindings, + axes=expr.axes, + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_stack(self, expr: Stack) -> IndexLambda: subscript = tuple(prim.Variable(f"_{i}") diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index e654e8b51..564da41c6 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -465,9 +465,9 @@ def _attach_tags(self, expr: Array, rec_expr: Array) -> Array: return result def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - key = self._cache.get_key(expr) + inputs = self._cache.make_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache_retrieve(inputs) except KeyError: result = Mapper.rec(self, expr) if not isinstance( @@ -475,7 +475,7 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: assert isinstance(expr, Array) # type-ignore reason: passed "ArrayOrNames"; expected "Array" result = self._attach_tags(expr, result) # type: ignore[arg-type] - return self._cache.add(expr, result, key=key) + return self._cache_add(inputs, result) def map_named_call_result(self, expr: NamedCallResult) -> Array: raise NotImplementedError( diff --git a/pytato/transform/remove_broadcasts_einsum.py b/pytato/transform/remove_broadcasts_einsum.py index 2d8f7e0f0..6755631fb 100644 --- a/pytato/transform/remove_broadcasts_einsum.py +++ b/pytato/transform/remove_broadcasts_einsum.py @@ -31,43 +31,95 @@ from typing import cast from pytato.array import Array, Einsum, EinsumAxisDescriptor -from pytato.transform import CopyMapper, MappedT, _verify_is_array +from pytato.transform import ( + ArrayOrNames, + CopyMapperWithExtraArgs, + MappedT, + Mapper, + _verify_is_array, +) from pytato.utils import are_shape_components_equal -class EinsumWithNoBroadcastsRewriter(CopyMapper): - def map_einsum(self, expr: Einsum) -> Array: +class EinsumWithNoBroadcastsRewriter(CopyMapperWithExtraArgs[[tuple[int, ...] | None]]): + def _squeeze_axes( + self, + expr: Array, + axes_to_squeeze: tuple[int, ...] | None = None) -> Array: + result = ( + expr[ + tuple( + slice(None) if idim not in axes_to_squeeze else 0 + for idim in range(expr.ndim))] + if axes_to_squeeze else expr) + return result + + def rec( + self, + expr: ArrayOrNames, + axes_to_squeeze: tuple[int, ...] | None = None) -> ArrayOrNames: + inputs = self._cache.make_inputs(expr, axes_to_squeeze) + try: + return self._cache_retrieve(inputs) + except KeyError: + rec_result: ArrayOrNames = Mapper.rec(self, expr, None) + result: ArrayOrNames + if isinstance(expr, Array): + result = self._squeeze_axes( + _verify_is_array(rec_result), + axes_to_squeeze) + else: + result = rec_result + return self._cache_add(inputs, result) + + def map_einsum( + self, expr: Einsum, axes_to_squeeze: tuple[int, ...] | None) -> Array: new_args: list[Array] = [] new_access_descriptors: list[tuple[EinsumAxisDescriptor, ...]] = [] descr_to_axis_len = expr._access_descr_to_axis_len() - for acc_descrs, arg in zip(expr.access_descriptors, expr.args, strict=True): - arg = _verify_is_array(self.rec(arg)) - axes_to_squeeze: list[int] = [] + for arg, acc_descrs in zip(expr.args, expr.access_descriptors, strict=True): + axes_to_squeeze_list: list[int] = [] for idim, acc_descr in enumerate(acc_descrs): if not are_shape_components_equal(arg.shape[idim], descr_to_axis_len[acc_descr]): assert are_shape_components_equal(arg.shape[idim], 1) - axes_to_squeeze.append(idim) + axes_to_squeeze_list.append(idim) + axes_to_squeeze = tuple(axes_to_squeeze_list) if axes_to_squeeze: - arg = arg[tuple(slice(None) if idim not in axes_to_squeeze else 0 - for idim in range(arg.ndim))] - acc_descrs = tuple(acc_descr + new_arg = _verify_is_array(self.rec(arg, axes_to_squeeze)) + new_acc_descrs = tuple(acc_descr for idim, acc_descr in enumerate(acc_descrs) if idim not in axes_to_squeeze) + else: + new_arg = _verify_is_array(self.rec(arg)) + new_acc_descrs = acc_descrs - new_args.append(arg) - new_access_descriptors.append(acc_descrs) + new_args.append(new_arg) + new_access_descriptors.append(new_acc_descrs) assert len(new_args) == len(expr.args) assert len(new_access_descriptors) == len(expr.access_descriptors) - return Einsum(tuple(new_access_descriptors), - tuple(new_args), - expr.redn_axis_to_redn_descr, - tags=expr.tags, - axes=expr.axes,) + if ( + all( + new_arg is arg + for arg, new_arg in zip(expr.args, new_args, strict=True)) + and all( + new_acc_descr is acc_descr + for acc_descr, new_acc_descr in zip( + expr.access_descriptors, + new_access_descriptors, + strict=True))): + return expr + else: + return Einsum(tuple(new_access_descriptors), + tuple(new_args), + axes=expr.axes, + redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def rewrite_einsums_with_no_broadcasts(expr: MappedT) -> MappedT: @@ -97,6 +149,6 @@ def rewrite_einsums_with_no_broadcasts(expr: MappedT) -> MappedT: alter its value. """ mapper = EinsumWithNoBroadcastsRewriter() - return cast("MappedT", mapper(expr)) + return cast("MappedT", mapper(expr, None)) # vim:fdm=marker diff --git a/pytato/utils.py b/pytato/utils.py index 31247897d..77cecc3bd 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -340,8 +340,10 @@ def are_shape_components_equal( if isinstance(dim1, INT_CLASSES) and isinstance(dim2, INT_CLASSES): return dim1 == dim2 + from pytato.transform import Deduplicator dim1_minus_dim2 = dim1 - dim2 assert isinstance(dim1_minus_dim2, Array) + dim1_minus_dim2 = Deduplicator()(dim1_minus_dim2) from pytato.transform import InputGatherer inputs = InputGatherer()(dim1_minus_dim2) diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index c0c3e7945..7420d1708 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -178,9 +178,10 @@ def stringify_shape(shape: ShapeType) -> str: return "(" + ", ".join(components) + ")" +# FIXME: Make this inherit from CachedWalkMapper instead? class ArrayToDotNodeInfoMapper(CachedMapper[None, None, []]): def __init__(self) -> None: - super().__init__() + super().__init__(err_on_collision=False) self.node_to_dot: dict[ArrayOrNames, _DotNodeInfo] = {} self.functions: set[FunctionDefinition] = set() diff --git a/test/test_apps.py b/test/test_apps.py index f39be848c..0809172a0 100644 --- a/test/test_apps.py +++ b/test/test_apps.py @@ -39,7 +39,7 @@ from pytools.tag import Tag, tag_dataclass import pytato as pt -from pytato.transform import CopyMapper, WalkMapper +from pytato.transform import CopyMapper, Deduplicator, WalkMapper # {{{ Trace an FFT @@ -78,40 +78,21 @@ def map_constant(self, expr): class FFTRealizationMapper(CopyMapper): - def __init__(self, fft_vec_gatherer): - super().__init__() - - self.fft_vec_gatherer = fft_vec_gatherer - - self.old_array_to_new_array = {} - levels = sorted(fft_vec_gatherer.level_to_arrays, reverse=True) - - lev = 0 - arrays = fft_vec_gatherer.level_to_arrays[lev] - self.finalized = False - - for lev in levels: - arrays = fft_vec_gatherer.level_to_arrays[lev] - rec_arrays = [self.rec(ary) for ary in arrays] - # reset cache so that the partial subs are not stored - self._cache.clear() - lev_array = pt.concatenate(rec_arrays, axis=0) - assert lev_array.shape == (fft_vec_gatherer.n,) - - startidx = 0 - for array in arrays: - size = array.shape[0] - sub_array = lev_array[startidx:startidx+size] - startidx += size - self.old_array_to_new_array[array] = sub_array - - assert startidx == fft_vec_gatherer.n - self.finalized = True + def __init__(self, old_array_to_new_array): + # Must use err_on_no_op_duplication=False, because the use of ConstantSizer + # in map_index_lambda creates IndexLambdas that differ only in the type of + # their contained constants, which changes their identity but not their + # equality + super().__init__(err_on_no_op_duplication=False) + self.old_array_to_new_array = old_array_to_new_array def map_index_lambda(self, expr): tags = expr.tags_of_type(FFTIntermediate) - if tags and (self.finalized or expr in self.old_array_to_new_array): - return self.old_array_to_new_array[expr] + if tags: + try: + return self.old_array_to_new_array[expr] + except KeyError: + pass return super().map_index_lambda( expr.copy(expr=ConstantSizer()(expr.expr))) @@ -122,6 +103,29 @@ def map_concatenate(self, expr): (ImplStored(), PrefixNamed("concat"))) +def make_fft_realization_mapper(fft_vec_gatherer): + old_array_to_new_array = {} + levels = sorted(fft_vec_gatherer.level_to_arrays, reverse=True) + + for lev in levels: + lev_mapper = FFTRealizationMapper(old_array_to_new_array) + arrays = fft_vec_gatherer.level_to_arrays[lev] + rec_arrays = [lev_mapper(ary) for ary in arrays] + lev_array = pt.concatenate(rec_arrays, axis=0) + assert lev_array.shape == (fft_vec_gatherer.n,) + + startidx = 0 + for array in arrays: + size = array.shape[0] + sub_array = lev_array[startidx:startidx+size] + startidx += size + old_array_to_new_array[array] = sub_array + + assert startidx == fft_vec_gatherer.n + + return FFTRealizationMapper(old_array_to_new_array) + + def test_trace_fft(ctx_factory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) @@ -134,10 +138,11 @@ def test_trace_fft(ctx_factory): wrap_intermediate_with_level=( lambda level, ary: ary.tagged(FFTIntermediate(level)))) + result = Deduplicator()(result) fft_vec_gatherer = FFTVectorGatherer(n) fft_vec_gatherer(result) - mapper = FFTRealizationMapper(fft_vec_gatherer) + mapper = make_fft_realization_mapper(fft_vec_gatherer) result = mapper(result) diff --git a/test/test_codegen.py b/test/test_codegen.py index 0c6972cf6..a4ca538de 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -926,7 +926,7 @@ def _get_x_shape(_m, n_): x = pt.make_data_wrapper(x_in, shape=_get_x_shape(m, n)) np_out = np.einsum("ij, j -> i", A_in, x_in) - pt_expr = pt.einsum("ij, j -> i", A, x) + pt_expr = pt.transform.Deduplicator()(pt.einsum("ij, j -> i", A, x)) _, (pt_out,) = pt.generate_loopy(pt_expr)(cq, m=m_in, n=n_in) @@ -1582,8 +1582,9 @@ def get_np_input_args(): np_inputs = get_np_input_args() np_result = kernel(np, **np_inputs) - pt_dag = kernel(pt, **{kw: pt.make_data_wrapper(arg) - for kw, arg in np_inputs.items()}) + pt_dag = pt.transform.Deduplicator()( + kernel(pt, **{kw: pt.make_data_wrapper(arg) + for kw, arg in np_inputs.items()})) knl = pt.generate_loopy(pt_dag, options=lp.Options(write_code=True)) @@ -1621,7 +1622,8 @@ def test_zero_size_cl_array_dedup(ctx_factory): dedup_dw_out, count_duplicates=True) # 'x2' would be merged with 'x1' as both of them point to the same data # 'x3' would be merged with 'x4' as both of them point to the same data - assert num_nodes_new == (num_nodes_old - 2) + # '2*x2' would be merged with '2*x1' as they are identical expressions + assert num_nodes_new == (num_nodes_old - 3) # {{{ test_deterministic_codegen @@ -1938,10 +1940,12 @@ def build_expression(tracer): "baz": 65 * twice_x, "quux": 7 * twice_x_2} - result_with_functions = pt.tag_all_calls_to_be_inlined( - pt.make_dict_of_named_arrays(build_expression(pt.trace_call))) - result_without_functions = pt.make_dict_of_named_arrays( - build_expression(lambda fn, *args: fn(*args))) + expr = pt.transform.Deduplicator()( + pt.make_dict_of_named_arrays(build_expression(pt.trace_call))) + + result_with_functions = pt.tag_all_calls_to_be_inlined(expr) + result_without_functions = pt.transform.Deduplicator()( + pt.make_dict_of_named_arrays(build_expression(lambda fn, *args: fn(*args)))) # test that visualizing graphs with functions works dot = pt.get_dot_graph(result_with_functions) diff --git a/test/test_distributed.py b/test/test_distributed.py index d78479e08..65214c4b0 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -555,12 +555,13 @@ def _test_dag_with_multiple_send_nodes_per_sent_array_inner(ctx_factory): x_np = rng.random((10, 4)) x = pt.make_data_wrapper(cla.to_device(queue, x_np)) y = 2 * x + ones = pt.ones(10) send1 = pt.staple_distributed_send( y, dest_rank=1, comm_tag=42, - stapled_to=pt.ones(10)) + stapled_to=ones) send2 = pt.staple_distributed_send( y, dest_rank=2, comm_tag=42, - stapled_to=pt.ones(10)) + stapled_to=ones) z = 4 * y dag = pt.make_dict_of_named_arrays({"z": z, "send1": send1, "send2": send2}) else: diff --git a/test/test_pytato.py b/test/test_pytato.py index da176f124..7fe5a3b49 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -29,6 +29,7 @@ import dataclasses import sys +from contextlib import contextmanager import numpy as np import pytest @@ -723,7 +724,7 @@ def test_small_dag_with_duplicates_count(): # Check that duplicates are correctly calculated assert node_count - num_duplicates == len( - pt.transform.DependencyMapper()(dag)) + pt.transform.DependencyMapper(err_on_collision=False)(dag)) assert node_count - num_duplicates == get_num_nodes( dag, count_duplicates=False) @@ -760,7 +761,7 @@ def test_large_dag_with_duplicates_count(): # Check that duplicates are correctly calculated assert node_count - num_duplicates == len( - pt.transform.DependencyMapper()(dag)) + pt.transform.DependencyMapper(err_on_collision=False)(dag)) assert node_count - num_duplicates == get_num_nodes( dag, count_duplicates=False) @@ -805,6 +806,8 @@ def post_visit(self, expr): assert expr.name == "x" expr, inp = construct_intestine_graph() + expr = pt.transform.Deduplicator()(expr) + result = pt.transform.rec_get_user_nodes(expr, inp) SubexprRecorder()(expr) @@ -932,112 +935,118 @@ def test_einsum_dot_axes_has_correct_dim(): assert len(einsum.axes) == einsum.ndim -def test_created_at(): - pt.set_traceback_tag_enabled() +@contextmanager +def enable_traceback_tag(): + try: + pt.set_traceback_tag_enabled(True) + yield + finally: + pt.set_traceback_tag_enabled(False) - a = pt.make_placeholder("a", (10, 10), "float64") - b = pt.make_placeholder("b", (10, 10), "float64") - # res1 and res2 are defined on different lines and should have different - # CreatedAt tags. - res1 = a+b - res2 = a+b +def test_created_at(): + with enable_traceback_tag(): + a = pt.make_placeholder("a", (10, 10), "float64") + b = pt.make_placeholder("b", (10, 10), "float64") + + # res1 and res2 are defined on different lines and should have different + # CreatedAt tags. + res1 = a+b + res2 = a+b - # res3 and res4 are defined on the same line and should have the same - # CreatedAt tags. - res3 = a+b; res4 = a+b # noqa: E702 + # res3 and res4 are defined on the same line and should have the same + # CreatedAt tags. + res3 = a+b; res4 = a+b # noqa: E702 - # {{{ Check that CreatedAt tags are handled correctly for equality/hashing + # {{{ Check that CreatedAt tags are handled correctly for equality/hashing - assert res1 == res2 == res3 == res4 - assert hash(res1) == hash(res2) == hash(res3) == hash(res4) + assert res1 == res2 == res3 == res4 + assert hash(res1) == hash(res2) == hash(res3) == hash(res4) - assert res1.non_equality_tags != res2.non_equality_tags - assert res3.non_equality_tags == res4.non_equality_tags - assert hash(res1.non_equality_tags) != hash(res2.non_equality_tags) - assert hash(res3.non_equality_tags) == hash(res4.non_equality_tags) + assert res1.non_equality_tags != res2.non_equality_tags + assert res3.non_equality_tags == res4.non_equality_tags + assert hash(res1.non_equality_tags) != hash(res2.non_equality_tags) + assert hash(res3.non_equality_tags) == hash(res4.non_equality_tags) - assert res1.tags == res2.tags == res3.tags == res4.tags - assert hash(res1.tags) == hash(res2.tags) == hash(res3.tags) == hash(res4.tags) + assert res1.tags == res2.tags == res3.tags == res4.tags + assert hash(res1.tags) == hash(res2.tags) == hash(res3.tags) == hash(res4.tags) - # }}} + # }}} - from pytato.tags import CreatedAt + from pytato.tags import CreatedAt - created_tag = frozenset({tag - for tag in res1.non_equality_tags - if isinstance(tag, CreatedAt)}) + created_tag = frozenset({tag + for tag in res1.non_equality_tags + if isinstance(tag, CreatedAt)}) - assert len(created_tag) == 1 + assert len(created_tag) == 1 - # {{{ Make sure the function name appears in the traceback + # {{{ Make sure the function name appears in the traceback - tag, = created_tag + tag, = created_tag - found = False + found = False - stacksummary = tag.traceback.to_stacksummary() - assert len(stacksummary) > 10 + stacksummary = tag.traceback.to_stacksummary() + assert len(stacksummary) > 10 - for frame in tag.traceback.frames: - if frame.name == "test_created_at" and "a+b" in frame.line: - found = True - break + for frame in tag.traceback.frames: + if frame.name == "test_created_at" and "a+b" in frame.line: + found = True + break - assert found + assert found - # }}} + # }}} - # {{{ Make sure that CreatedAt tags are in the visualization + # {{{ Make sure that CreatedAt tags are in the visualization - from pytato.visualization import get_dot_graph - s = get_dot_graph(res1) - assert "test_created_at" in s - assert "a+b" in s + from pytato.visualization import get_dot_graph + s = get_dot_graph(res1) + assert "test_created_at" in s + assert "a+b" in s - # }}} + # }}} - # {{{ Make sure only a single CreatedAt tag is created + # {{{ Make sure only a single CreatedAt tag is created - old_tag = tag + old_tag = tag - res1 = res1 + res2 + res1 = res1 + res2 - created_tag = frozenset({tag - for tag in res1.non_equality_tags - if isinstance(tag, CreatedAt)}) + created_tag = frozenset({tag + for tag in res1.non_equality_tags + if isinstance(tag, CreatedAt)}) - assert len(created_tag) == 1 + assert len(created_tag) == 1 - tag, = created_tag + tag, = created_tag - # Tag should be recreated - assert tag != old_tag + # Tag should be recreated + assert tag != old_tag - # }}} + # }}} - # {{{ Make sure that copying preserves the tag + # {{{ Make sure that copying preserves the tag - old_tag = tag + old_tag = tag - res1_new = pt.transform.map_and_copy(res1, lambda x: x) + res1_new = pt.transform.Deduplicator()(res1) - created_tag = frozenset({tag - for tag in res1_new.non_equality_tags - if isinstance(tag, CreatedAt)}) + created_tag = frozenset({tag + for tag in res1_new.non_equality_tags + if isinstance(tag, CreatedAt)}) - assert len(created_tag) == 1 + assert len(created_tag) == 1 - tag, = created_tag + tag, = created_tag - assert old_tag == tag + assert old_tag == tag - # }}} + # }}} # {{{ Test disabling traceback creation - pt.set_traceback_tag_enabled(False) - a = pt.make_placeholder("a", (10, 10), "float64") created_tag = frozenset({tag @@ -1160,7 +1169,7 @@ class ExistentTag(Tag): out = make_random_dag(rdagc_pt).tagged(ExistentTag()) - dag = pt.make_dict_of_named_arrays({"out": out}) + dag = pt.transform.Deduplicator()(pt.make_dict_of_named_arrays({"out": out})) # get_num_nodes() returns an extra DictOfNamedArrays node assert get_num_tags_of_type(dag, frozenset()) == get_num_nodes(dag) diff --git a/test/testlib.py b/test/testlib.py index a28dec67e..7d58df480 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -101,6 +101,7 @@ def __init__( rng: np.random.Generator, axis_len: int, use_numpy: bool, + allow_duplicate_nodes: bool = False, additional_generators: ( Sequence[tuple[int, Callable[[RandomDAGContext], Array]]] | None) = None @@ -115,6 +116,7 @@ def __init__( self.axis_len = axis_len self.past_results: list[Array] = [] self.use_numpy = use_numpy + self.allow_duplicate_nodes = allow_duplicate_nodes if additional_generators is None: additional_generators = [] @@ -156,6 +158,14 @@ def make_random_reshape( def make_random_dag_inner(rdagc: RandomDAGContext) -> Any: + if not rdagc.use_numpy and not rdagc.allow_duplicate_nodes: + def dedup(expr: Array) -> Array: + return pt.transform._verify_is_array(pt.transform.Deduplicator()(expr)) + + else: + def dedup(expr: Array) -> Array: + return expr + rng = rdagc.rng max_prob_hardcoded = 1500 @@ -166,7 +176,7 @@ def make_random_dag_inner(rdagc: RandomDAGContext) -> Any: v = rng.integers(0, max_prob_hardcoded + additional_prob) if v < 600: - return make_random_constant(rdagc, naxes=rng.integers(1, 3)) + return dedup(make_random_constant(rdagc, naxes=rng.integers(1, 3))) elif v < 1000: op1 = make_random_dag(rdagc) @@ -189,9 +199,9 @@ def make_random_dag_inner(rdagc: RandomDAGContext) -> Any: # just inserted a few new 1-long axes. Those need to go before we # return. if which_op in ["maximum", "minimum"]: - return rdagc.np.squeeze(getattr(rdagc.np, which_op)(op1, op2)) + return dedup(rdagc.np.squeeze(getattr(rdagc.np, which_op)(op1, op2))) else: - return rdagc.np.squeeze(which_op(op1, op2)) + return dedup(rdagc.np.squeeze(which_op(op1, op2))) elif v < 1075: op1 = make_random_dag(rdagc) @@ -199,24 +209,26 @@ def make_random_dag_inner(rdagc: RandomDAGContext) -> Any: if op1.ndim <= 1 and op2.ndim <= 1: continue - return op1 @ op2 + return dedup(op1 @ op2) elif v < 1275: if not rdagc.past_results: continue - return rdagc.past_results[rng.integers(0, len(rdagc.past_results))] + return dedup( + rdagc.past_results[rng.integers(0, len(rdagc.past_results))]) elif v < max_prob_hardcoded: result = make_random_dag(rdagc) - return rdagc.np.transpose( + return dedup( + rdagc.np.transpose( result, - tuple(rng.permuted(list(range(result.ndim))))) + tuple(rng.permuted(list(range(result.ndim)))))) else: base_prob = max_prob_hardcoded for fake_prob, gen_func in rdagc.additional_generators: if base_prob <= v < base_prob + fake_prob: - return gen_func(rdagc) + return dedup(gen_func(rdagc)) base_prob += fake_prob @@ -237,6 +249,14 @@ def make_random_dag(rdagc: RandomDAGContext) -> Any: of the array are of length :attr:`RandomDAGContext.axis_len` (there is at least one axis, but arbitrarily more may be present). """ + if not rdagc.use_numpy and not rdagc.allow_duplicate_nodes: + def dedup(expr: Array) -> Array: + return pt.transform._verify_is_array(pt.transform.Deduplicator()(expr)) + + else: + def dedup(expr: Array) -> Array: + return expr + rng = rdagc.rng result = make_random_dag_inner(rdagc) @@ -248,14 +268,15 @@ def make_random_dag(rdagc: RandomDAGContext) -> Any: subscript[rng.integers(0, result.ndim)] = int( rng.integers(0, rdagc.axis_len)) - return result[tuple(subscript)] + return dedup(result[tuple(subscript)]) elif v == 1: # reduce away an axis # FIXME do reductions other than sum? - return rdagc.np.sum( - result, axis=int(rng.integers(0, result.ndim))) + return dedup( + rdagc.np.sum( + result, axis=int(rng.integers(0, result.ndim)))) else: raise AssertionError() @@ -275,7 +296,8 @@ def get_random_pt_dag(seed: int, Sequence[tuple[int, Callable[[RandomDAGContext], Array]]] | None) = None, axis_len: int = 4, - convert_dws_to_placeholders: bool = False + convert_dws_to_placeholders: bool = False, + allow_duplicate_nodes: bool = False ) -> pt.DictOfNamedArrays: if additional_generators is None: additional_generators = [] @@ -286,6 +308,7 @@ def get_random_pt_dag(seed: int, rdagc_comm = RandomDAGContext(np.random.default_rng(seed=seed), axis_len=axis_len, use_numpy=False, + allow_duplicate_nodes=allow_duplicate_nodes, additional_generators=additional_generators) dag = pt.make_dict_of_named_arrays({"result": make_random_dag(rdagc_comm)})