Skip to content

Commit

Permalink
Remove the raising to high level operator within Unify Axis (#565)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkoskelo authored Feb 5, 2025
1 parent c187f16 commit a34a9e5
Show file tree
Hide file tree
Showing 5 changed files with 380 additions and 336 deletions.
5 changes: 3 additions & 2 deletions pytato/scalar_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ def map_reduce(self, expr: Reduce) -> ScalarExpression:
for name, bound in expr.bounds.items()}))


IDX_LAMBDA_RE = re.compile(r"_r?(0|([1-9][0-9]*))")
IDX_LAMBDA_REDUCTION_AXIS_INDEX = re.compile(r"^(_r?(?P<index>0|[1-9][0-9]*))$")
IDX_LAMBDA_AXIS_INDEX = re.compile(r"^(_(?P<index>0|[1-9][0-9]*))$")


class DependencyMapper(DependencyMapperBase[P]):
Expand All @@ -185,7 +186,7 @@ def map_variable(self,
expr: prim.Variable, *args: P.args, **kwargs: P.kwargs
) -> DependenciesT:
if ((not self.include_idx_lambda_indices)
and IDX_LAMBDA_RE.fullmatch(str(expr))):
and IDX_LAMBDA_REDUCTION_AXIS_INDEX.fullmatch(str(expr))):
return set()
else:
return super().map_variable(expr, *args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion pytato/transform/lower_to_index_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def map_einsum(self, expr: Einsum) -> IndexLambda:
def map_roll(self, expr: Roll) -> IndexLambda:
from pytato.utils import dim_to_index_lambda_components

index_expr: prim.Expression = prim.Variable("_in0")
index_expr: prim.ExpressionNode = prim.Variable("_in0")
indices: list[ArithmeticExpression] = [
prim.Variable(f"_{d}") for d in range(expr.ndim)]
axis = expr.axis
Expand Down
Loading

0 comments on commit a34a9e5

Please sign in to comment.