Skip to content

Commit

Permalink
Simplify local_[mul|div]_switch_sink and fix downcasting bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 31, 2024
1 parent 4d0aa3f commit 4f7d709
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 115 deletions.
183 changes: 69 additions & 114 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,65 +621,43 @@ def local_mul_switch_sink(fgraph, node):
part of the graph.
"""
for idx, i in enumerate(node.inputs):
if i.owner and i.owner.op == switch:
switch_node = i.owner
try:
if (
get_underlying_scalar_constant_value(
switch_node.inputs[1], only_process_constants=True
)
== 0.0
):
listmul = node.inputs[:idx] + node.inputs[idx + 1 :]
fmul = mul(*([*listmul, switch_node.inputs[2]]))

# Copy over stacktrace for elementwise multiplication op
# from previous elementwise multiplication op.
# An error in the multiplication (e.g. errors due to
# inconsistent shapes), will point to the
# multiplication op.
copy_stack_trace(node.outputs, fmul)

fct = [switch(switch_node.inputs[0], 0, fmul)]
fct[0].tag.values_eq_approx = values_eq_approx_remove_nan

# Copy over stacktrace for switch op from both previous
# elementwise multiplication op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs + switch_node.outputs, fct)
return fct
except NotScalarConstantError:
pass
try:
if (
get_underlying_scalar_constant_value(
switch_node.inputs[2], only_process_constants=True
)
== 0.0
):
listmul = node.inputs[:idx] + node.inputs[idx + 1 :]
fmul = mul(*([*listmul, switch_node.inputs[1]]))
# Copy over stacktrace for elementwise multiplication op
# from previous elementwise multiplication op.
# An error in the multiplication (e.g. errors due to
# inconsistent shapes), will point to the
# multiplication op.
copy_stack_trace(node.outputs, fmul)

fct = [switch(switch_node.inputs[0], fmul, 0)]
fct[0].tag.values_eq_approx = values_eq_approx_remove_nan

# Copy over stacktrace for switch op from both previous
# elementwise multiplication op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs + switch_node.outputs, fct)
return fct
except NotScalarConstantError:
pass
return False
for mul_inp_idx, mul_inp in enumerate(node.inputs):
if mul_inp.owner and mul_inp.owner.op == switch:
switch_node = mul_inp.owner
# Look for a zero as the first or second branch of the switch
for branch in range(2):
zero_switch_input = switch_node.inputs[1 + branch]
if not get_unique_constant_value(zero_switch_input) == 0.0:
continue

switch_cond = switch_node.inputs[0]
other_switch_input = switch_node.inputs[1 + (1 - branch)]

listmul = list(node.inputs)
listmul[mul_inp_idx] = other_switch_input
fmul = mul(*listmul)

# Copy over stacktrace for elementwise multiplication op
# from previous elementwise multiplication op.
# An error in the multiplication (e.g. errors due to
# inconsistent shapes), will point to the
# multiplication op.
copy_stack_trace(node.outputs, fmul)

if branch == 0:
fct = switch(switch_cond, zero_switch_input, fmul)
else:
fct = switch(switch_cond, fmul, zero_switch_input)

# Tell debug_mode than the output is correct, even if nan disappear
fct.tag.values_eq_approx = values_eq_approx_remove_nan

# Copy over stacktrace for switch op from both previous
# elementwise multiplication op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs + switch_node.outputs, fct)
return [fct]


@register_canonicalize
Expand All @@ -699,62 +677,39 @@ def local_div_switch_sink(fgraph, node):
See `local_mul_switch_sink` for more details.
"""
op = node.op
if node.inputs[0].owner and node.inputs[0].owner.op == switch:
switch_node = node.inputs[0].owner
try:
if (
get_underlying_scalar_constant_value(
switch_node.inputs[1], only_process_constants=True
)
== 0.0
):
fdiv = op(switch_node.inputs[2], node.inputs[1])
# Copy over stacktrace for elementwise division op
# from previous elementwise multiplication op.
# An error in the division (e.g. errors due to
# inconsistent shapes or division by zero),
# will point to the new division op.
copy_stack_trace(node.outputs, fdiv)

fct = [switch(switch_node.inputs[0], 0, fdiv)]
fct[0].tag.values_eq_approx = values_eq_approx_remove_nan

# Copy over stacktrace for switch op from both previous
# elementwise division op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs + switch_node.outputs, fct)
return fct
except NotScalarConstantError:
pass
try:
if (
get_underlying_scalar_constant_value(
switch_node.inputs[2], only_process_constants=True
)
== 0.0
):
fdiv = op(switch_node.inputs[1], node.inputs[1])
# Copy over stacktrace for elementwise division op
# from previous elementwise multiplication op.
# An error in the division (e.g. errors due to
# inconsistent shapes or division by zero),
# will point to the new division op.
copy_stack_trace(node.outputs, fdiv)

fct = [switch(switch_node.inputs[0], fdiv, 0)]
fct[0].tag.values_eq_approx = values_eq_approx_remove_nan
num, denom = node.inputs

# Copy over stacktrace for switch op from both previous
# elementwise division op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs + switch_node.outputs, fct)
return fct
except NotScalarConstantError:
pass
return False
if num.owner and num.owner.op == switch:
switch_node = num.owner
# Look for a zero as the first or second branch of the switch
for branch in range(2):
zero_switch_input = switch_node.inputs[1 + branch]
if not get_unique_constant_value(zero_switch_input) == 0.0:
continue

switch_cond = switch_node.inputs[0]
other_switch_input = switch_node.inputs[1 + (1 - branch)]

fdiv = node.op(other_switch_input, denom)

# Copy over stacktrace for elementwise division op
# from previous elementwise multiplication op.
# An error in the division (e.g. errors due to
# inconsistent shapes or division by zero),
# will point to the new division op.
copy_stack_trace(node.outputs, fdiv)

fct = switch(switch_cond, zero_switch_input, fdiv)

# Tell debug_mode than the output is correct, even if nan disappear
fct.tag.values_eq_approx = values_eq_approx_remove_nan

# Copy over stacktrace for switch op from both previous
# elementwise division op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs + switch_node.outputs, fct)
return [fct]


class AlgebraicCanonizer(NodeRewriter):
Expand Down
25 changes: 24 additions & 1 deletion tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,11 @@
from pytensor.tensor.rewriting.math import (
compute_mul,
is_1pexp,
local_div_switch_sink,
local_grad_log_erfc_neg,
local_greedy_distributor,
local_mul_canonizer,
local_mul_switch_sink,
local_reduce_chain,
local_sum_prod_of_mul_or_div,
mul_canonizer,
Expand Down Expand Up @@ -2115,7 +2117,6 @@ def test_local_mul_switch_sink(self):
f = self.function_remove_nan([x], pytensor.gradient.grad(y, x), self.mode)
assert f(5) == 1, f(5)

@pytest.mark.slow
def test_local_div_switch_sink(self):
c = dscalar()
idx = 0
Expand Down Expand Up @@ -2149,6 +2150,28 @@ def test_local_div_switch_sink(self):
].size
idx += 1

@pytest.mark.parametrize(
"op, rewrite", [(mul, local_mul_switch_sink), (true_div, local_div_switch_sink)]
)
def test_local_mul_div_switch_sink_cast(self, op, rewrite):
"""Check that we don't downcast during the rewrite.
Regression test for: https://github.com/pymc-devs/pytensor/issues/1037
"""
cond = scalar("cond", dtype="bool")
# The zero branch upcasts the output, so we can't ignore its dtype
zero_branch = constant(np.array(0, dtype="float64"), name="zero_branch")
other_branch = scalar("other_branch", dtype="float32")
outer_var = scalar("mul_var", dtype="bool")

out = op(switch(cond, zero_branch, other_branch), outer_var)
fgraph = FunctionGraph(outputs=[out], clone=False)
[new_out] = rewrite.transform(fgraph, out.owner)
assert new_out.type.dtype == out.type.dtype

expected_out = switch(cond, zero_branch, op(other_branch, outer_var))
assert equal_computations([new_out], [expected_out])


@pytest.mark.skipif(
config.cxx == "",
Expand Down

0 comments on commit 4f7d709

Please sign in to comment.