Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid duplicating arrays #515

Draft
wants to merge 29 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
3b6efd7
add CacheInputs to simplify cache key handling logic
majosm Feb 7, 2025
4037cdd
change a few more Hashable to CacheKeyT
majosm Jan 28, 2025
b260d7c
use expr instead of tuple for cache key in CachedMapper when no extra…
majosm Feb 5, 2025
b73ee5d
add map_dict_of_named_arrays to DirectPredecessorsGetter
majosm Sep 20, 2024
e3bd62c
support functions as inputs and outputs in DirectPredecessorsGetter
majosm Sep 24, 2024
dd047ab
add collision/duplication checks to CachedMapper/TransformMapper/Tran…
majosm Aug 29, 2024
d989c10
add result deduplication to transform mappers
majosm Sep 24, 2024
8d2acd7
add FIXME
majosm Sep 5, 2024
ea04a0a
avoid unnecessary duplication in CopyMapper/CopyMapperWithExtraArgs
majosm Jun 10, 2024
6e164c6
add Deduplicator
majosm Sep 20, 2024
1ad8438
avoid unnecessary duplication in InlineMarker
majosm Jun 11, 2024
f7bfc99
avoid duplication in tagged() for Axis/ReductionDescriptor/_SuppliedA…
majosm Aug 27, 2024
386e87d
avoid duplication in Array.with_tagged_axis
majosm Jun 11, 2024
6f4e5c7
avoid duplication in with_tagged_reduction for IndexLambda/Einsum
majosm Jun 11, 2024
ba4ff37
attempt to avoid duplication in CodeGenPreprocessor
majosm Jun 10, 2024
a097e85
limit PlaceholderSubstitutor to one call stack frame
majosm Jul 3, 2024
b3a7457
tweak Inliner/PlaceholderSubstitutor implementations
majosm Jul 12, 2024
397fcff
use context manager to avoid leaking traceback tag setting in test
majosm Jul 16, 2024
6f0c329
refactor FFTRealizationMapper to avoid resetting cache in __init__
majosm Jul 16, 2024
bc7bb73
add allow_duplicate_nodes option to RandomDAGContext in tests
majosm Aug 27, 2024
6d8c6bd
fix some more tests
majosm Aug 27, 2024
41c035a
use Mapper.rec instead of super().rec in CachedMapAndCopyMapper and T…
majosm Sep 24, 2024
fc2ffca
add assertions to check for double-caching
majosm Sep 24, 2024
835e3de
don't check for collisions in ArrayToDotNodeInfoMapper
majosm Sep 24, 2024
86e9b79
avoid duplication in MPMSMaterializer
majosm Sep 18, 2024
82261c6
move key func definitions out of MPMSMaterializerCache
majosm Feb 5, 2025
e50bcae
avoid duplicates in EinsumWithNoBroadcastsRewriter
majosm Feb 7, 2025
714ec28
forbid DepenencyMapper from being called on functions
majosm Feb 7, 2025
8ac0b8d
enable err_on_* checks if __debug__
majosm Sep 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,11 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool:

# {{{ DirectPredecessorsGetter

class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]):
class DirectPredecessorsGetter(
Mapper[
FrozenOrderedSet[ArrayOrNames | FunctionDefinition],
FrozenOrderedSet[ArrayOrNames],
[]]):
"""
Mapper to get the
`direct predecessors
Expand All @@ -334,9 +338,17 @@ class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]):

We only consider the predecessors of a nodes in a data-flow sense.
"""
def __init__(self, *, include_functions: bool = False) -> None:
super().__init__()
self.include_functions = include_functions

def _get_preds_from_shape(self, shape: ShapeType) -> FrozenOrderedSet[ArrayOrNames]:
return FrozenOrderedSet(dim for dim in shape if isinstance(dim, Array))

def map_dict_of_named_arrays(
self, expr: DictOfNamedArrays) -> FrozenOrderedSet[ArrayOrNames]:
return FrozenOrderedSet(expr._data.values())

def map_index_lambda(self, expr: IndexLambda) -> FrozenOrderedSet[ArrayOrNames]:
return (FrozenOrderedSet(expr.bindings.values())
| self._get_preds_from_shape(expr.shape))
Expand Down Expand Up @@ -397,8 +409,17 @@ def map_distributed_send_ref_holder(self,
) -> FrozenOrderedSet[ArrayOrNames]:
return FrozenOrderedSet([expr.passthrough_data])

def map_call(self, expr: Call) -> FrozenOrderedSet[ArrayOrNames]:
return FrozenOrderedSet(expr.bindings.values())
def map_call(
self, expr: Call) -> FrozenOrderedSet[ArrayOrNames | FunctionDefinition]:
result: FrozenOrderedSet[ArrayOrNames | FunctionDefinition] = \
FrozenOrderedSet(expr.bindings.values())
if self.include_functions:
result = result | FrozenOrderedSet([expr.function])
return result

def map_function_definition(
self, expr: FunctionDefinition) -> FrozenOrderedSet[ArrayOrNames]:
return FrozenOrderedSet(expr.returns.values())

def map_named_call_result(
self, expr: NamedCallResult) -> FrozenOrderedSet[ArrayOrNames]:
Expand Down Expand Up @@ -622,11 +643,11 @@ def combine(self, *args: int) -> int:
return sum(args)

def rec(self, expr: ArrayOrNames) -> int:
key = self._cache.get_key(expr)
inputs = self._cache.make_inputs(expr)
try:
return self._cache.retrieve(expr, key=key)
return self._cache_retrieve(inputs)
except KeyError:
s = super().rec(expr)
s = Mapper.rec(self, expr)
if (
isinstance(expr, Array)
and (
Expand All @@ -636,7 +657,7 @@ def rec(self, expr: ArrayOrNames) -> int:
else:
result = 0 + s

self._cache.add(expr, 0, key=key)
self._cache_add(inputs, 0)
return result


Expand Down
89 changes: 53 additions & 36 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,11 @@ class Axis(Taggable):
tags: frozenset[Tag]

def _with_new_tags(self, tags: frozenset[Tag]) -> Axis:
from dataclasses import replace
return replace(self, tags=tags)
if tags != self.tags:
from dataclasses import replace
return replace(self, tags=tags)
else:
return self


@dataclasses.dataclass(frozen=True)
Expand All @@ -495,8 +498,11 @@ class ReductionDescriptor(Taggable):
tags: frozenset[Tag]

def _with_new_tags(self, tags: frozenset[Tag]) -> ReductionDescriptor:
from dataclasses import replace
return replace(self, tags=tags)
if tags != self.tags:
from dataclasses import replace
return replace(self, tags=tags)
else:
return self


@array_dataclass()
Expand Down Expand Up @@ -848,10 +854,14 @@ def with_tagged_axis(self, iaxis: int,
"""
Returns a copy of *self* with *iaxis*-th axis tagged with *tags*.
"""
new_axes = (self.axes[:iaxis]
+ (self.axes[iaxis].tagged(tags),)
+ self.axes[iaxis+1:])
return self.copy(axes=new_axes)
new_axis = self.axes[iaxis].tagged(tags)
if new_axis is not self.axes[iaxis]:
new_axes = (self.axes[:iaxis]
+ (self.axes[iaxis].tagged(tags),)
+ self.axes[iaxis+1:])
return self.copy(axes=new_axes)
else:
return self

@memoize_method
def __repr__(self) -> str:
Expand Down Expand Up @@ -880,7 +890,10 @@ class _SuppliedAxesAndTagsMixin(Taggable):
default=frozenset())

def _with_new_tags(self: Self, tags: frozenset[Tag]) -> Self:
return dataclasses.replace(self, tags=tags)
if tags != self.tags:
return dataclasses.replace(self, tags=tags)
else:
return self


@dataclasses.dataclass(frozen=True, eq=False, repr=False)
Expand Down Expand Up @@ -1129,20 +1142,22 @@ def with_tagged_reduction(self,
f" '{self.var_to_reduction_descr.keys()}',"
f" got '{reduction_variable}'.")

assert isinstance(self.var_to_reduction_descr, immutabledict)
new_var_to_redn_descr = dict(self.var_to_reduction_descr)
new_var_to_redn_descr[reduction_variable] = \
self.var_to_reduction_descr[reduction_variable].tagged(tags)

return type(self)(expr=self.expr,
shape=self.shape,
dtype=self.dtype,
bindings=self.bindings,
axes=self.axes,
var_to_reduction_descr=immutabledict
(new_var_to_redn_descr),
tags=self.tags,
non_equality_tags=self.non_equality_tags)
new_redn_descr = self.var_to_reduction_descr[reduction_variable].tagged(tags)
if new_redn_descr is not self.var_to_reduction_descr[reduction_variable]:
assert isinstance(self.var_to_reduction_descr, immutabledict)
new_var_to_redn_descr = dict(self.var_to_reduction_descr)
new_var_to_redn_descr[reduction_variable] = new_redn_descr
return type(self)(expr=self.expr,
shape=self.shape,
dtype=self.dtype,
bindings=self.bindings,
axes=self.axes,
var_to_reduction_descr=immutabledict
(new_var_to_redn_descr),
tags=self.tags,
non_equality_tags=self.non_equality_tags)
else:
return self

# }}}

Expand Down Expand Up @@ -1293,19 +1308,21 @@ def with_tagged_reduction(self,

# }}}

assert isinstance(self.redn_axis_to_redn_descr, immutabledict)
new_redn_axis_to_redn_descr = dict(self.redn_axis_to_redn_descr)
new_redn_axis_to_redn_descr[redn_axis] = \
self.redn_axis_to_redn_descr[redn_axis].tagged(tags)

return type(self)(access_descriptors=self.access_descriptors,
args=self.args,
axes=self.axes,
redn_axis_to_redn_descr=immutabledict
(new_redn_axis_to_redn_descr),
tags=self.tags,
non_equality_tags=self.non_equality_tags,
)
new_redn_descr = self.redn_axis_to_redn_descr[redn_axis].tagged(tags)
if new_redn_descr is not self.redn_axis_to_redn_descr[redn_axis]:
assert isinstance(self.redn_axis_to_redn_descr, immutabledict)
new_redn_axis_to_redn_descr = dict(self.redn_axis_to_redn_descr)
new_redn_axis_to_redn_descr[redn_axis] = new_redn_descr
return type(self)(access_descriptors=self.access_descriptors,
args=self.args,
axes=self.axes,
redn_axis_to_redn_descr=immutabledict
(new_redn_axis_to_redn_descr),
tags=self.tags,
non_equality_tags=self.non_equality_tags,
)
else:
return self


EINSUM_FIRST_INDEX = re.compile(r"^\s*((?P<alpha>[a-zA-Z])|(?P<ellipsis>\.\.\.))\s*")
Expand Down
94 changes: 60 additions & 34 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,47 +148,55 @@ def __init__(
self.kernels_seen: dict[str, lp.LoopKernel] = kernels_seen or {}

def map_size_param(self, expr: SizeParam) -> Array:
name = expr.name
assert name is not None
return SizeParam( # pylint: disable=missing-kwoa
name=name,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)
assert expr.name is not None
return expr

def map_placeholder(self, expr: Placeholder) -> Array:
name = expr.name
if name is None:
name = self.var_name_gen("_pt_in")
return Placeholder(name=name,
shape=tuple(self.rec(s) if isinstance(s, Array) else s
for s in expr.shape),
dtype=expr.dtype,
axes=expr.axes,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)
new_name = expr.name
if new_name is None:
new_name = self.var_name_gen("_pt_in")
new_shape = self.rec_idx_or_size_tuple(expr.shape)
if (
new_name is expr.name
and new_shape is expr.shape):
return expr
else:
return Placeholder(name=new_name,
shape=new_shape,
dtype=expr.dtype,
axes=expr.axes,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)

def map_loopy_call(self, expr: LoopyCall) -> LoopyCall:
from pytato.target.loopy import LoopyTarget
if not isinstance(self.target, LoopyTarget):
raise ValueError("Got a LoopyCall for a non-loopy target.")
translation_unit = expr.translation_unit.copy(
target=self.target.get_loopy_target())
new_target = self.target.get_loopy_target()

# FIXME: Can't use "is" here because targets aren't unique. Is it OK to
# use the existing target if it's equal to self.target.get_loopy_target()?
# If not, may have to set err_on_no_op_duplication=False
if new_target == expr.translation_unit.target:
new_translation_unit = expr.translation_unit
else:
new_translation_unit = expr.translation_unit.copy(target=new_target)
namegen = UniqueNameGenerator(set(self.kernels_seen))
entrypoint = expr.entrypoint
new_entrypoint = expr.entrypoint

# {{{ eliminate callable name collision

for name, clbl in translation_unit.callables_table.items():
for name, clbl in new_translation_unit.callables_table.items():
if isinstance(clbl, lp.CallableKernel):
assert isinstance(name, str)
if name in self.kernels_seen and (
translation_unit[name] != self.kernels_seen[name]):
new_translation_unit[name] != self.kernels_seen[name]):
# callee name collision => must rename

# {{{ see if it's one of the other kernels

for other_knl in self.kernels_seen.values():
if other_knl.copy(name=name) == translation_unit[name]:
if other_knl.copy(name=name) == new_translation_unit[name]:
new_name = other_knl.name
break
else:
Expand All @@ -198,37 +206,55 @@ def map_loopy_call(self, expr: LoopyCall) -> LoopyCall:

# }}}

if name == entrypoint:
if name == new_entrypoint:
# if the colliding name is the entrypoint, then rename the
# entrypoint as well.
entrypoint = new_name
new_entrypoint = new_name

translation_unit = lp.rename_callable(
translation_unit, name, new_name)
new_translation_unit = lp.rename_callable(
new_translation_unit, name, new_name)
name = new_name

self.kernels_seen[name] = clbl.subkernel

# }}}

bindings: Mapping[str, Any] = immutabledict(
new_bindings: Mapping[str, Any] = immutabledict(
{name: (self.rec(subexpr) if isinstance(subexpr, Array)
else subexpr)
for name, subexpr in sorted(expr.bindings.items())})

return LoopyCall(translation_unit=translation_unit,
bindings=bindings,
entrypoint=entrypoint,
tags=expr.tags
)
assert (
new_entrypoint is expr.entrypoint
or new_entrypoint != expr.entrypoint)
for bnd, new_bnd in zip(
expr.bindings.values(), new_bindings.values(), strict=True):
assert new_bnd is bnd or new_bnd != bnd

if (
new_translation_unit == expr.translation_unit
and all(
new_bnd is bnd
for bnd, new_bnd in zip(
expr.bindings.values(),
new_bindings.values(),
strict=True))
and new_entrypoint is expr.entrypoint):
return expr
else:
return LoopyCall(translation_unit=new_translation_unit,
bindings=new_bindings,
entrypoint=new_entrypoint,
tags=expr.tags
)

def map_data_wrapper(self, expr: DataWrapper) -> Array:
name = _generate_name_for_temp(expr, self.var_name_gen, "_pt_data")
shape = self.rec_idx_or_size_tuple(expr.shape)

self.bound_arguments[name] = expr.data
return Placeholder(name=name,
shape=tuple(self.rec(s) if isinstance(s, Array) else s
for s in expr.shape),
shape=shape,
dtype=expr.dtype,
axes=expr.axes,
tags=expr.tags,
Expand Down
4 changes: 2 additions & 2 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,9 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend:
return new_send

def rec(self, expr: ArrayOrNames) -> ArrayOrNames:
key = self._cache.get_key(expr)
inputs = self._cache.make_inputs(expr)
try:
return self._cache.retrieve(expr, key=key)
return self._cache_retrieve(inputs)
except KeyError:
pass

Expand Down
Loading
Loading