From 4f7d7096ea98fa1285b50a9d583373b5963d425d Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 31 Oct 2024 11:20:15 +0100 Subject: [PATCH] Simplify `local_[mul|div]_switch_sink` and fix downcasting bug --- pytensor/tensor/rewriting/math.py | 183 +++++++++++----------------- tests/tensor/rewriting/test_math.py | 25 +++- 2 files changed, 93 insertions(+), 115 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 2e30e1399b..68cc0e5e96 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -621,65 +621,43 @@ def local_mul_switch_sink(fgraph, node): part of the graph. """ - for idx, i in enumerate(node.inputs): - if i.owner and i.owner.op == switch: - switch_node = i.owner - try: - if ( - get_underlying_scalar_constant_value( - switch_node.inputs[1], only_process_constants=True - ) - == 0.0 - ): - listmul = node.inputs[:idx] + node.inputs[idx + 1 :] - fmul = mul(*([*listmul, switch_node.inputs[2]])) - - # Copy over stacktrace for elementwise multiplication op - # from previous elementwise multiplication op. - # An error in the multiplication (e.g. errors due to - # inconsistent shapes), will point to the - # multiplication op. - copy_stack_trace(node.outputs, fmul) - - fct = [switch(switch_node.inputs[0], 0, fmul)] - fct[0].tag.values_eq_approx = values_eq_approx_remove_nan - - # Copy over stacktrace for switch op from both previous - # elementwise multiplication op and previous switch op, - # because an error in this part can be caused by either - # of the two previous ops. - copy_stack_trace(node.outputs + switch_node.outputs, fct) - return fct - except NotScalarConstantError: - pass - try: - if ( - get_underlying_scalar_constant_value( - switch_node.inputs[2], only_process_constants=True - ) - == 0.0 - ): - listmul = node.inputs[:idx] + node.inputs[idx + 1 :] - fmul = mul(*([*listmul, switch_node.inputs[1]])) - # Copy over stacktrace for elementwise multiplication op - # from previous elementwise multiplication op. - # An error in the multiplication (e.g. errors due to - # inconsistent shapes), will point to the - # multiplication op. - copy_stack_trace(node.outputs, fmul) - - fct = [switch(switch_node.inputs[0], fmul, 0)] - fct[0].tag.values_eq_approx = values_eq_approx_remove_nan - - # Copy over stacktrace for switch op from both previous - # elementwise multiplication op and previous switch op, - # because an error in this part can be caused by either - # of the two previous ops. - copy_stack_trace(node.outputs + switch_node.outputs, fct) - return fct - except NotScalarConstantError: - pass - return False + for mul_inp_idx, mul_inp in enumerate(node.inputs): + if mul_inp.owner and mul_inp.owner.op == switch: + switch_node = mul_inp.owner + # Look for a zero as the first or second branch of the switch + for branch in range(2): + zero_switch_input = switch_node.inputs[1 + branch] + if not get_unique_constant_value(zero_switch_input) == 0.0: + continue + + switch_cond = switch_node.inputs[0] + other_switch_input = switch_node.inputs[1 + (1 - branch)] + + listmul = list(node.inputs) + listmul[mul_inp_idx] = other_switch_input + fmul = mul(*listmul) + + # Copy over stacktrace for elementwise multiplication op + # from previous elementwise multiplication op. + # An error in the multiplication (e.g. errors due to + # inconsistent shapes), will point to the + # multiplication op. + copy_stack_trace(node.outputs, fmul) + + if branch == 0: + fct = switch(switch_cond, zero_switch_input, fmul) + else: + fct = switch(switch_cond, fmul, zero_switch_input) + + # Tell debug_mode than the output is correct, even if nan disappear + fct.tag.values_eq_approx = values_eq_approx_remove_nan + + # Copy over stacktrace for switch op from both previous + # elementwise multiplication op and previous switch op, + # because an error in this part can be caused by either + # of the two previous ops. + copy_stack_trace(node.outputs + switch_node.outputs, fct) + return [fct] @register_canonicalize @@ -699,62 +677,39 @@ def local_div_switch_sink(fgraph, node): See `local_mul_switch_sink` for more details. """ - op = node.op - if node.inputs[0].owner and node.inputs[0].owner.op == switch: - switch_node = node.inputs[0].owner - try: - if ( - get_underlying_scalar_constant_value( - switch_node.inputs[1], only_process_constants=True - ) - == 0.0 - ): - fdiv = op(switch_node.inputs[2], node.inputs[1]) - # Copy over stacktrace for elementwise division op - # from previous elementwise multiplication op. - # An error in the division (e.g. errors due to - # inconsistent shapes or division by zero), - # will point to the new division op. - copy_stack_trace(node.outputs, fdiv) - - fct = [switch(switch_node.inputs[0], 0, fdiv)] - fct[0].tag.values_eq_approx = values_eq_approx_remove_nan - - # Copy over stacktrace for switch op from both previous - # elementwise division op and previous switch op, - # because an error in this part can be caused by either - # of the two previous ops. - copy_stack_trace(node.outputs + switch_node.outputs, fct) - return fct - except NotScalarConstantError: - pass - try: - if ( - get_underlying_scalar_constant_value( - switch_node.inputs[2], only_process_constants=True - ) - == 0.0 - ): - fdiv = op(switch_node.inputs[1], node.inputs[1]) - # Copy over stacktrace for elementwise division op - # from previous elementwise multiplication op. - # An error in the division (e.g. errors due to - # inconsistent shapes or division by zero), - # will point to the new division op. - copy_stack_trace(node.outputs, fdiv) - - fct = [switch(switch_node.inputs[0], fdiv, 0)] - fct[0].tag.values_eq_approx = values_eq_approx_remove_nan + num, denom = node.inputs - # Copy over stacktrace for switch op from both previous - # elementwise division op and previous switch op, - # because an error in this part can be caused by either - # of the two previous ops. - copy_stack_trace(node.outputs + switch_node.outputs, fct) - return fct - except NotScalarConstantError: - pass - return False + if num.owner and num.owner.op == switch: + switch_node = num.owner + # Look for a zero as the first or second branch of the switch + for branch in range(2): + zero_switch_input = switch_node.inputs[1 + branch] + if not get_unique_constant_value(zero_switch_input) == 0.0: + continue + + switch_cond = switch_node.inputs[0] + other_switch_input = switch_node.inputs[1 + (1 - branch)] + + fdiv = node.op(other_switch_input, denom) + + # Copy over stacktrace for elementwise division op + # from previous elementwise multiplication op. + # An error in the division (e.g. errors due to + # inconsistent shapes or division by zero), + # will point to the new division op. + copy_stack_trace(node.outputs, fdiv) + + fct = switch(switch_cond, zero_switch_input, fdiv) + + # Tell debug_mode than the output is correct, even if nan disappear + fct.tag.values_eq_approx = values_eq_approx_remove_nan + + # Copy over stacktrace for switch op from both previous + # elementwise division op and previous switch op, + # because an error in this part can be caused by either + # of the two previous ops. + copy_stack_trace(node.outputs + switch_node.outputs, fct) + return [fct] class AlgebraicCanonizer(NodeRewriter): diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 019833a9d5..1212ee4fbd 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -97,9 +97,11 @@ from pytensor.tensor.rewriting.math import ( compute_mul, is_1pexp, + local_div_switch_sink, local_grad_log_erfc_neg, local_greedy_distributor, local_mul_canonizer, + local_mul_switch_sink, local_reduce_chain, local_sum_prod_of_mul_or_div, mul_canonizer, @@ -2115,7 +2117,6 @@ def test_local_mul_switch_sink(self): f = self.function_remove_nan([x], pytensor.gradient.grad(y, x), self.mode) assert f(5) == 1, f(5) - @pytest.mark.slow def test_local_div_switch_sink(self): c = dscalar() idx = 0 @@ -2149,6 +2150,28 @@ def test_local_div_switch_sink(self): ].size idx += 1 + @pytest.mark.parametrize( + "op, rewrite", [(mul, local_mul_switch_sink), (true_div, local_div_switch_sink)] + ) + def test_local_mul_div_switch_sink_cast(self, op, rewrite): + """Check that we don't downcast during the rewrite. + + Regression test for: https://github.com/pymc-devs/pytensor/issues/1037 + """ + cond = scalar("cond", dtype="bool") + # The zero branch upcasts the output, so we can't ignore its dtype + zero_branch = constant(np.array(0, dtype="float64"), name="zero_branch") + other_branch = scalar("other_branch", dtype="float32") + outer_var = scalar("mul_var", dtype="bool") + + out = op(switch(cond, zero_branch, other_branch), outer_var) + fgraph = FunctionGraph(outputs=[out], clone=False) + [new_out] = rewrite.transform(fgraph, out.owner) + assert new_out.type.dtype == out.type.dtype + + expected_out = switch(cond, zero_branch, op(other_branch, outer_var)) + assert equal_computations([new_out], [expected_out]) + @pytest.mark.skipif( config.cxx == "",