Skip to content

Commit

Permalink
Don't include hidden arrays when counting max_total_source_arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed May 16, 2024
1 parent 2e3a935 commit e2f8cc1
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions cubed/core/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,19 @@ def is_fusable(node_dict):
return is_primitive_op(node_dict) and node_dict["primitive_op"].fusable


def num_source_arrays(dag, name):
"""Return the number of (non-hidden) arrays that are inputs to an op.
Hidden arrays are used for internal bookkeeping, are very small virtual arrays
(empty, or offsets for example), and are not shown on the plan visualization.
For these reasons they shouldn't count towards ``max_total_source_arrays``.
"""
nodes = dict(dag.nodes(data=True))
return sum(
not nodes[array]["hidden"] for array in predecessors_unordered(dag, name)
)


def can_fuse_predecessors(
dag,
name,
Expand Down Expand Up @@ -145,9 +158,7 @@ def can_fuse_predecessors(
# the fused function would be more than an allowed maximum, then don't fuse
if len(list(predecessor_ops(dag, name))) > 1:
total_source_arrays = sum(
len(list(predecessors_unordered(dag, pre)))
if is_primitive_op(nodes[pre])
else 1
num_source_arrays(dag, pre) if is_primitive_op(nodes[pre]) else 1
for pre in predecessor_ops(dag, name)
)
if total_source_arrays > max_total_source_arrays:
Expand Down

0 comments on commit e2f8cc1

Please sign in to comment.