Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More fixes for functions #503

Merged
merged 32 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
7cdf3d5
add missing *args, **kwargs in WalkMapper.map_call
majosm Jun 7, 2024
3eaf7a1
memoize clone_for_callee
majosm Mar 12, 2024
2de9349
remove default CombineMapper map_call implementation
majosm Mar 12, 2024
c0e70cd
don't memoize map_function_definition in cached walk mappers
majosm Mar 12, 2024
d391d45
support calls in InputGatherer
majosm Mar 12, 2024
244160c
make NamedCallResult compatible with attrs cache_hash=True
majosm Jun 7, 2024
9c2d7b9
enable cache_hash on FunctionDefinition
majosm Mar 12, 2024
7a51316
enable calls in DirectPredecessorsGetter
majosm Mar 12, 2024
072b51a
memoize Call creation
majosm Mar 12, 2024
fbd0c40
make NamedCallResult.call a property
majosm Mar 12, 2024
6cf3be2
remove redundant NamedCallResult.name (already defined in NamedArray)
majosm Mar 12, 2024
4e2aae7
memoize NamedCallResult creation
majosm Mar 12, 2024
2b00442
fix docstring
majosm Mar 12, 2024
c5e88ab
remove non-argument placeholder check
majosm Mar 11, 2024
595bcb7
fix equality for FunctionDefinition
majosm Jun 11, 2024
b310a91
add FIXME
majosm Jun 13, 2024
729d720
attempt to fix doc warning
majosm Jun 17, 2024
cf3d0a9
don't construct NamedCallResult directly
majosm Jun 6, 2024
10e0cef
fix mapper method name in UsersCollector
majosm Jun 18, 2024
81e49b4
add FIXME
majosm Jun 14, 2024
769d48e
Revert "add FIXME"
majosm Jul 10, 2024
6de23ec
Revert "remove non-argument placeholder check"
majosm Jul 10, 2024
3d69985
add some more missing *args, **kwargs to WalkMapper
majosm Jul 9, 2024
2ebc1b2
remove some unnecessary *args, **kwargs
majosm Jul 10, 2024
45eb68d
add get_func_def_cache_key to walk mappers to correctly handle functi…
majosm Jul 10, 2024
950f44f
undo memoizing Call creation
majosm Jul 10, 2024
d1cb54e
don't use regular dict for function call results
majosm Jul 10, 2024
ddf546b
fix type annotation for function result
majosm Jul 10, 2024
1e7c71d
undo memoizing clone_for_callee
majosm Jul 25, 2024
b7d412d
add SizeParamGatherer.map_call
majosm Jul 26, 2024
c9cda75
Revert "add get_func_def_cache_key to walk mappers to correctly handl…
majosm Jul 29, 2024
742de0a
Revert "don't memoize map_function_definition in cached walk mappers"
majosm Jul 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 30 additions & 22 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from typing import TYPE_CHECKING, Any, Mapping

from pymbolic.mapper.optimize import optimize_mapper
from pytools import memoize_method

from pytato.array import (
Array,
Expand Down Expand Up @@ -321,26 +320,26 @@ class DirectPredecessorsGetter(Mapper):

We only consider the predecessors of a nodes in a data-flow sense.
"""
def _get_preds_from_shape(self, shape: ShapeType) -> frozenset[Array]:
def _get_preds_from_shape(self, shape: ShapeType) -> frozenset[ArrayOrNames]:
return frozenset({dim for dim in shape if isinstance(dim, Array)})

def map_index_lambda(self, expr: IndexLambda) -> frozenset[Array]:
def map_index_lambda(self, expr: IndexLambda) -> frozenset[ArrayOrNames]:
return (frozenset(expr.bindings.values())
| self._get_preds_from_shape(expr.shape))

def map_stack(self, expr: Stack) -> frozenset[Array]:
def map_stack(self, expr: Stack) -> frozenset[ArrayOrNames]:
return (frozenset(expr.arrays)
| self._get_preds_from_shape(expr.shape))

def map_concatenate(self, expr: Concatenate) -> frozenset[Array]:
def map_concatenate(self, expr: Concatenate) -> frozenset[ArrayOrNames]:
return (frozenset(expr.arrays)
| self._get_preds_from_shape(expr.shape))

def map_einsum(self, expr: Einsum) -> frozenset[Array]:
def map_einsum(self, expr: Einsum) -> frozenset[ArrayOrNames]:
return (frozenset(expr.args)
| self._get_preds_from_shape(expr.shape))

def map_loopy_call_result(self, expr: NamedArray) -> frozenset[Array]:
def map_loopy_call_result(self, expr: NamedArray) -> frozenset[ArrayOrNames]:
from pytato.loopy import LoopyCall, LoopyCallResult
assert isinstance(expr, LoopyCallResult)
assert isinstance(expr._container, LoopyCall)
Expand All @@ -349,7 +348,7 @@ def map_loopy_call_result(self, expr: NamedArray) -> frozenset[Array]:
if isinstance(ary, Array))
| self._get_preds_from_shape(expr.shape))

def _map_index_base(self, expr: IndexBase) -> frozenset[Array]:
def _map_index_base(self, expr: IndexBase) -> frozenset[ArrayOrNames]:
return (frozenset([expr.array])
| frozenset(idx for idx in expr.indices
if isinstance(idx, Array))
Expand All @@ -360,32 +359,34 @@ def _map_index_base(self, expr: IndexBase) -> frozenset[Array]:
map_non_contiguous_advanced_index = _map_index_base

def _map_index_remapping_base(self, expr: IndexRemappingBase
) -> frozenset[Array]:
) -> frozenset[ArrayOrNames]:
return frozenset([expr.array])

map_roll = _map_index_remapping_base
map_axis_permutation = _map_index_remapping_base
map_reshape = _map_index_remapping_base

def _map_input_base(self, expr: InputArgumentBase) -> frozenset[Array]:
def _map_input_base(self, expr: InputArgumentBase) -> frozenset[ArrayOrNames]:
return self._get_preds_from_shape(expr.shape)

map_placeholder = _map_input_base
map_data_wrapper = _map_input_base
map_size_param = _map_input_base

def map_distributed_recv(self, expr: DistributedRecv) -> frozenset[Array]:
def map_distributed_recv(self, expr: DistributedRecv) -> frozenset[ArrayOrNames]:
return self._get_preds_from_shape(expr.shape)

def map_distributed_send_ref_holder(self,
expr: DistributedSendRefHolder
) -> frozenset[Array]:
) -> frozenset[ArrayOrNames]:
return frozenset([expr.passthrough_data])

def map_named_call_result(self, expr: NamedCallResult) -> frozenset[Array]:
raise NotImplementedError(
"DirectPredecessorsGetter does not yet support expressions containing "
"functions.")
def map_call(self, expr: Call) -> frozenset[ArrayOrNames]:
return frozenset(expr.bindings.values())

def map_named_call_result(
self, expr: NamedCallResult) -> frozenset[ArrayOrNames]:
return frozenset([expr._container])


# }}}
Expand All @@ -410,6 +411,9 @@ def __init__(self) -> None:
def get_cache_key(self, expr: ArrayOrNames) -> int:
return id(expr)

def get_func_def_cache_key(self, expr: FunctionDefinition) -> int:
return id(expr)

def post_visit(self, expr: Any) -> None:
self.count += 1

Expand Down Expand Up @@ -447,18 +451,22 @@ def __init__(self) -> None:
def get_cache_key(self, expr: ArrayOrNames) -> int:
return id(expr)

@memoize_method
def map_function_definition(self, /, expr: FunctionDefinition,
*args: Any, **kwargs: Any) -> None:
if not self.visit(expr):
def get_func_def_cache_key(self, expr: FunctionDefinition) -> int:
return id(expr)

def map_function_definition(self, expr: FunctionDefinition) -> None:
cache_key = self.get_func_def_cache_key(expr)
if not self.visit(expr) or cache_key in self._visited_functions:
return

new_mapper = self.clone_for_callee(expr)
for subexpr in expr.returns.values():
new_mapper(subexpr, *args, **kwargs)
new_mapper(subexpr)
self.count += new_mapper.count

self.post_visit(expr, *args, **kwargs)
self._visited_functions.add(cache_key)

self.post_visit(expr)

def post_visit(self, expr: Any) -> None:
if isinstance(expr, Call):
Expand Down
5 changes: 4 additions & 1 deletion pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
SizeParam,
make_dict_of_named_arrays,
)
from pytato.function import NamedCallResult
from pytato.function import FunctionDefinition, NamedCallResult
from pytato.loopy import LoopyCall
from pytato.scalar_expr import IntegralScalarExpression
from pytato.target import Target
Expand Down Expand Up @@ -254,6 +254,9 @@ def __init__(self) -> None:
def get_cache_key(self, expr: ArrayOrNames) -> int:
return id(expr)

def get_func_def_cache_key(self, expr: FunctionDefinition) -> int:
return id(expr)

def post_visit(self, expr: Any) -> None:
if isinstance(expr, (Placeholder, SizeParam, DataWrapper)):
if expr.name is not None:
Expand Down
2 changes: 2 additions & 0 deletions pytato/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def map_function_definition(self, expr1: FunctionDefinition, expr2: Any
) -> bool:
return (expr1.__class__ is expr2.__class__
and expr1.parameters == expr2.parameters
and expr1.return_type == expr2.return_type
and (set(expr1.returns.keys()) == set(expr2.returns.keys()))
and all(self.rec(expr1.returns[k], expr2.returns[k])
for k in expr1.returns)
Expand All @@ -311,6 +312,7 @@ def map_call(self, expr1: Call, expr2: Any) -> bool:
and all(self.rec(bnd,
expr2.bindings[name])
for name, bnd in expr1.bindings.items())
and expr1.tags == expr2.tags
)

def map_named_call_result(self, expr1: NamedCallResult, expr2: Any) -> bool:
Expand Down
47 changes: 29 additions & 18 deletions pytato/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@

A type variable corresponding to the return type of the function
:func:`pytato.trace_call`.

Internal stuff that is only here because the documentation tool wants it
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. class:: Tag

See :class:`pytools.tag.Tag`.

.. class:: AxesT

A :class:`tuple` of :class:`pytato.array.Axis` objects.
"""

__copyright__ = """
Expand Down Expand Up @@ -50,7 +61,6 @@
from typing import (
Callable,
ClassVar,
Dict,
Hashable,
Iterable,
Iterator,
Expand All @@ -63,6 +73,7 @@
import attrs
from immutabledict import immutabledict

from pytools import memoize_method
from pytools.tag import Tag, Taggable

from pytato.array import (
Expand All @@ -75,7 +86,7 @@
)


ReturnT = TypeVar("ReturnT", Array, Tuple[Array, ...], Dict[str, Array])
ReturnT = TypeVar("ReturnT", Array, Tuple[Array, ...], Mapping[str, Array])


# {{{ Call/NamedCallResult
Expand All @@ -92,14 +103,14 @@ class ReturnType(enum.Enum):


# eq=False to avoid equality comparison without EqualityMaper
@attrs.define(frozen=True, eq=False, hash=True)
@attrs.define(frozen=True, eq=False, hash=True, cache_hash=True)
class FunctionDefinition(Taggable):
r"""
A function definition that represents its outputs as instances of
:class:`~pytato.Array` with the inputs being
:class:`~pytato.array.Placeholder`\ s. The outputs of the function
can be a single :class:`pytato.Array`, a tuple of :class:`pytato.Array`\ s or an
instance of ``Dict[str, Array]``.
instance of ``Mapping[str, Array]``.

.. attribute:: parameters

Expand Down Expand Up @@ -184,7 +195,7 @@ def _with_new_tags(
return attrs.evolve(self, tags=tags)

def __call__(self, **kwargs: Array
) -> Array | tuple[Array, ...] | dict[str, Array]:
) -> Array | tuple[Array, ...] | Mapping[str, Array]:
from pytato.array import _get_default_tags
from pytato.utils import are_shapes_equal

Expand Down Expand Up @@ -221,11 +232,12 @@ def __call__(self, **kwargs: Array
return tuple(call_site[f"_{iarg}"]
for iarg in range(len(self.returns)))
elif self.return_type == ReturnType.DICT_OF_ARRAYS:
return {kw: call_site[kw] for kw in self.returns}
return immutabledict({kw: call_site[kw] for kw in self.returns})
else:
raise NotImplementedError(self.return_type)


@attrs.frozen(eq=False, repr=False, hash=True, cache_hash=True)
class NamedCallResult(NamedArray):
"""
One of the arrays that are returned from a call to :class:`FunctionDefinition`.
Expand All @@ -239,19 +251,8 @@ class NamedCallResult(NamedArray):
The name by which the returned array is referred to in
:attr:`FunctionDefinition.returns`.
"""
call: Call
name: str
_mapper_method: ClassVar[str] = "map_named_call_result"

def __init__(self,
call: Call,
name: str) -> None:
super().__init__(call, name,
axes=call.function.returns[name].axes,
tags=call.function.returns[name].tags,
non_equality_tags=(
call.function.returns[name].non_equality_tags))

def with_tagged_axis(self, iaxis: int,
tags: Sequence[Tag] | Tag) -> Array:
raise ValueError("Tagging a NamedCallResult's axis is illegal, use"
Expand All @@ -269,6 +270,11 @@ def without_tags(self,
raise ValueError("Untagging a NamedCallResult is illegal, use"
" Call.without_tags instead")

@property
def call(self) -> Call:
assert isinstance(self._container, Call)
return self._container

@property
def shape(self) -> ShapeType:
assert isinstance(self._container, Call)
Expand Down Expand Up @@ -317,8 +323,13 @@ def __contains__(self, name: object) -> bool:
def __iter__(self) -> Iterator[str]:
return iter(self.function.returns)

@memoize_method
def __getitem__(self, name: str) -> NamedCallResult:
return NamedCallResult(self, name)
return NamedCallResult(
self, name,
axes=self.function.returns[name].axes,
tags=self.function.returns[name].tags,
non_equality_tags=self.function.returns[name].non_equality_tags)

def __len__(self) -> int:
return len(self.function.returns)
Expand Down
2 changes: 1 addition & 1 deletion pytato/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ class ExpandedDimsReshape(UniqueTag):
class FunctionIdentifier(UniqueTag):
"""
A tag that can be attached to a :class:`~pytato.function.FunctionDefinition`
node to to describe the function's identifier. One can use this to refer
node to describe the function's identifier. One can use this to refer
all instances of :class:`~pytato.function.FunctionDefinition`, for example in
transformations.transform.calls.concatenate_calls`.

Expand Down
Loading
Loading