diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 842cf695aa..79109ad9b7 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -1,7 +1,6 @@ from collections.abc import Callable from functools import singledispatch -from numbers import Number -from textwrap import indent +from textwrap import dedent, indent from typing import Any import numba @@ -15,7 +14,6 @@ from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( create_numba_signature, - create_tuple_creator, numba_funcify, numba_njit, use_optimized_cheap_pass, @@ -26,7 +24,7 @@ encode_literals, store_core_outputs, ) -from pytensor.link.utils import compile_function_src, get_name_for_object +from pytensor.link.utils import compile_function_src from pytensor.scalar.basic import ( AND, OR, @@ -34,7 +32,6 @@ Add, Composite, IntDiv, - Mean, Mul, ScalarMaximum, ScalarMinimum, @@ -77,11 +74,6 @@ def scalar_in_place_fn_Sub(op, idx, res, arr): return f"{res}[{idx}] -= {arr}" -@scalar_in_place_fn.register(Mean) -def scalar_in_place_fn_Mean(op, idx, res, arr): - return f"{res}[{idx}] += ({arr} - {res}[{idx}]) / (i + 1)" - - @scalar_in_place_fn.register(Mul) def scalar_in_place_fn_Mul(op, idx, res, arr): return f"{res}[{idx}] *= {arr}" @@ -169,40 +161,32 @@ def create_vectorize_func( return elemwise_fn -def create_axis_reducer( - scalar_op: Op, - identity: np.ndarray | Number, - axis: int, - ndim: int, - dtype: numba.types.Type, +def create_multiaxis_reducer( + scalar_op, + identity, + axes, + ndim, + dtype, keepdims: bool = False, - return_scalar=False, -) -> numba.core.dispatcher.Dispatcher: - r"""Create Python function that performs a NumPy-like reduction on a given axis. +): + r"""Construct a function that reduces multiple axes. The functions generated by this function take the following form: .. code-block:: python - def careduce_axis(x): - res_shape = tuple( - shape[i] if i < axis else shape[i + 1] for i in range(ndim - 1) - ) - res = np.full(res_shape, identity, dtype=dtype) + def careduce_add(x): + # For x.ndim == 3 and axes == (0, 1) and scalar_op == "Add" + x_shape = x.shape + res_shape = x_shape[2] + res = np.full(res_shape, numba_basic.to_scalar(0.0), dtype=out_dtype) - x_axis_first = x.transpose(reaxis_first) - - for m in range(x.shape[axis]): - reduce_fn(res, x_axis_first[m], res) - - if keepdims: - return np.expand_dims(res, axis) - else: - return res + for i0 in range(x_shape[0]): + for i1 in range(x_shape[1]): + for i2 in range(x_shape[2]): + res[i2] += x[i0, i1, i2] - - This can be removed/replaced when - https://github.com/numba/numba/issues/4504 is implemented. + return res Parameters ========== @@ -210,25 +194,29 @@ def careduce_axis(x): The scalar :class:`Op` that performs the desired reduction. identity: The identity value for the reduction. - axis: - The axis to reduce. + axes: + The axes to reduce. ndim: - The number of dimensions of the result. + The number of dimensions of the input variable. dtype: The data type of the result. - keepdims: - Determines whether or not the reduced dimension is retained. - - + keepdims: boolean, default False + Whether to keep the reduced dimensions. Returns ======= A Python function that can be JITed. """ + # if len(axes) == 1: + # return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype) - axis = normalize_axis_index(axis, ndim) + axes = normalize_axis_tuple(axes, ndim) + if keepdims and len(axes) > 1: + raise NotImplementedError( + "Cannot keep multiple dimensions when reducing multiple axes" + ) - reduce_elemwise_fn_name = "careduce_axis" + careduce_fn_name = f"careduce_{scalar_op}" identity = str(identity) if identity == "inf": @@ -241,162 +229,55 @@ def careduce_axis(x): "numba_basic": numba_basic, "out_dtype": dtype, } + complete_reduction = len(axes) == ndim + kept_axis = tuple(i for i in range(ndim) if i not in axes) + + res_indices = [] + arr_indices = [] + for i in range(ndim): + index_label = f"i{i}" + arr_indices.append(index_label) + if i not in axes: + res_indices.append(index_label) + res_indices = ", ".join(res_indices) if res_indices else () + arr_indices = ", ".join(arr_indices) if arr_indices else () + + inplace_update_stmt = scalar_in_place_fn( + scalar_op, res_indices, "res", f"x[{arr_indices}]" + ) - if ndim > 1: - res_shape_tuple_ctor = create_tuple_creator( - lambda i, shape: shape[i] if i < axis else shape[i + 1], ndim - 1 - ) - global_env["res_shape_tuple_ctor"] = res_shape_tuple_ctor - - res_indices = [] - arr_indices = [] - count = 0 - - for i in range(ndim): - if i == axis: - arr_indices.append("i") - else: - res_indices.append(f"idx_arr[{count}]") - arr_indices.append(f"idx_arr[{count}]") - count = count + 1 - - res_indices = ", ".join(res_indices) - arr_indices = ", ".join(arr_indices) - - inplace_update_statement = scalar_in_place_fn( - scalar_op, res_indices, "res", f"x[{arr_indices}]" - ) - inplace_update_statement = indent(inplace_update_statement, " " * 4 * 3) - - return_expr = f"np.expand_dims(res, {axis})" if keepdims else "res" - reduce_elemwise_def_src = f""" -def {reduce_elemwise_fn_name}(x): - - x_shape = np.shape(x) - res_shape = res_shape_tuple_ctor(x_shape) - res = np.full(res_shape, numba_basic.to_scalar({identity}), dtype=out_dtype) - - axis_shape = x.shape[{axis}] - - for idx_arr in np.ndindex(res_shape): - for i in range(axis_shape): -{inplace_update_statement} - - return {return_expr} - """ + res_shape = f"({', '.join(f'x_shape[{i}]' for i in kept_axis)})" + if complete_reduction and ndim > 0: + # We accumulate on a scalar, not an array + res_creator = f"np.asarray({identity}).astype(out_dtype).item()" + inplace_update_stmt = inplace_update_stmt.replace("res[()]", "res") + return_obj = "np.asarray(res)" else: - inplace_update_statement = scalar_in_place_fn(scalar_op, "0", "res", "x[i]") - inplace_update_statement = indent(inplace_update_statement, " " * 4 * 2) - - return_expr = "res" if keepdims else "res.item()" - if not return_scalar: - return_expr = f"np.asarray({return_expr})" - reduce_elemwise_def_src = f""" -def {reduce_elemwise_fn_name}(x): - - res = np.full(1, numba_basic.to_scalar({identity}), dtype=out_dtype) - - axis_shape = x.shape[{axis}] - - for i in range(axis_shape): -{inplace_update_statement} - - return {return_expr} + res_creator = ( + f"np.full({res_shape}, np.asarray({identity}).item(), dtype=out_dtype)" + ) + return_obj = "res" + + if keepdims: + [axis] = axes + return_obj = f"np.expand_dims({return_obj}, {axis})" + + careduce_def_src = dedent( + f""" + def {careduce_fn_name}(x): + x_shape = x.shape + res_shape = {res_shape} + res = {res_creator} """ - - reduce_elemwise_fn_py = compile_function_src( - reduce_elemwise_def_src, reduce_elemwise_fn_name, {**globals(), **global_env} ) - - return reduce_elemwise_fn_py - - -def create_multiaxis_reducer( - scalar_op, - identity, - axes, - ndim, - dtype, - input_name="input", - return_scalar=False, -): - r"""Construct a function that reduces multiple axes. - - The functions generated by this function take the following form: - - .. code-block:: python - - def careduce_maximum(input): - axis_0_res = careduce_axes_fn_0(input) - axis_1_res = careduce_axes_fn_1(axis_0_res) - ... - axis_N_res = careduce_axes_fn_N(axis_N_minus_1_res) - return axis_N_res - - The range 0-N is determined by the `axes` argument (i.e. the - axes to be reduced). - - - Parameters - ========== - scalar_op: - The scalar :class:`Op` that performs the desired reduction. - identity: - The identity value for the reduction. - axes: - The axes to reduce. - ndim: - The number of dimensions of the result. - dtype: - The data type of the result. - return_scalar: - If True, return a scalar, otherwise an array. - - Returns - ======= - A Python function that can be JITed. - - """ - if len(axes) == 1: - return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype) - - axes = normalize_axis_tuple(axes, ndim) - - careduce_fn_name = f"careduce_{scalar_op}" - global_env = {} - to_reduce = sorted(axes, reverse=True) - careduce_lines_src = [] - var_name = input_name - - for i, axis in enumerate(to_reduce): - careducer_axes_fn_name = f"careduce_axes_fn_{i}" - reducer_py_fn = create_axis_reducer(scalar_op, identity, axis, ndim, dtype) - reducer_fn = numba_basic.numba_njit( - boundscheck=False, fastmath=config.numba__fastmath - )(reducer_py_fn) - - global_env[careducer_axes_fn_name] = reducer_fn - - ndim -= 1 - last_var_name = var_name - var_name = f"axis_{i}_res" - careduce_lines_src.append( - f"{var_name} = {careducer_axes_fn_name}({last_var_name})" + for axis in range(ndim): + careduce_def_src += indent( + f"for i{axis} in range(x_shape[{axis}]):\n", + " " * (4 + 4 * axis), ) - - careduce_assign_lines = indent("\n".join(careduce_lines_src), " " * 4) - if not return_scalar: - pre_result = "np.asarray" - post_result = "" - else: - pre_result = "np.asarray" - post_result = ".item()" - - careduce_def_src = f""" -def {careduce_fn_name}({input_name}): -{careduce_assign_lines} - return {pre_result}({var_name}){post_result} - """ + careduce_def_src += indent(inplace_update_stmt, " " * (4 + 4 * ndim)) + careduce_def_src += "\n\n" + careduce_def_src += indent(f"return {return_obj}", " " * 4) careduce_fn = compile_function_src( careduce_def_src, careduce_fn_name, {**globals(), **global_env} @@ -551,32 +432,29 @@ def ov_elemwise(*inputs): @numba_funcify.register(Sum) def numba_funcify_Sum(op, node, **kwargs): + ndim_input = node.inputs[0].ndim axes = op.axis if axes is None: axes = list(range(node.inputs[0].ndim)) - - axes = tuple(axes) - - ndim_input = node.inputs[0].ndim + else: + axes = normalize_axis_tuple(axes, ndim_input) if hasattr(op, "acc_dtype") and op.acc_dtype is not None: acc_dtype = op.acc_dtype else: acc_dtype = node.outputs[0].type.dtype - np_acc_dtype = np.dtype(acc_dtype) - out_dtype = np.dtype(node.outputs[0].dtype) if ndim_input == len(axes): - - @numba_njit(fastmath=True) + # Slightly faster than `numba_funcify_CAReduce` for this case + @numba_njit(fastmath=config.numba__fastmath) def impl_sum(array): return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype) elif len(axes) == 0: - - @numba_njit(fastmath=True) + # These cases should be removed by rewrites! + @numba_njit(fastmath=config.numba__fastmath) def impl_sum(array): return np.asarray(array, dtype=out_dtype) @@ -609,7 +487,6 @@ def numba_funcify_CAReduce(op, node, **kwargs): # Make sure it has the correct dtype scalar_op_identity = np.array(scalar_op_identity, dtype=np_acc_dtype) - input_name = get_name_for_object(node.inputs[0]) ndim = node.inputs[0].ndim careduce_py_fn = create_multiaxis_reducer( op.scalar_op, @@ -617,7 +494,6 @@ def numba_funcify_CAReduce(op, node, **kwargs): axes, ndim, np.dtype(node.outputs[0].type.dtype), - input_name=input_name, ) careduce_fn = jit_compile_reducer(node, careduce_py_fn, reduce_to_scalar=False) @@ -730,11 +606,11 @@ def numba_funcify_Softmax(op, node, **kwargs): if axis is not None: axis = normalize_axis_index(axis, x_at.ndim) - reduce_max_py = create_axis_reducer( + reduce_max_py = create_multiaxis_reducer( scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True ) - reduce_sum_py = create_axis_reducer( - add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True + reduce_sum_py = create_multiaxis_reducer( + add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True ) jit_fn = numba_basic.numba_njit( @@ -767,8 +643,8 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs): axis = op.axis if axis is not None: axis = normalize_axis_index(axis, sm_at.ndim) - reduce_sum_py = create_axis_reducer( - add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True + reduce_sum_py = create_multiaxis_reducer( + add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True ) jit_fn = numba_basic.numba_njit( @@ -799,16 +675,16 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): if axis is not None: axis = normalize_axis_index(axis, x_at.ndim) - reduce_max_py = create_axis_reducer( + reduce_max_py = create_multiaxis_reducer( scalar_maximum, -np.inf, - axis, + (axis,), x_at.ndim, x_dtype, keepdims=True, ) - reduce_sum_py = create_axis_reducer( - add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True + reduce_sum_py = create_multiaxis_reducer( + add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True ) jit_fn = numba_basic.numba_njit( diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index bb2baf0636..3c33434e56 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1871,32 +1871,6 @@ def L_op(self, inputs, outputs, gout): add = Add(upcast_out, name="add") -class Mean(ScalarOp): - identity = 0 - commutative = True - associative = False - nfunc_spec = ("mean", 2, 1) - nfunc_variadic = "mean" - - def impl(self, *inputs): - return sum(inputs) / len(inputs) - - def c_code(self, node, name, inputs, outputs, sub): - (z,) = outputs - if not inputs: - return f"{z} = 0;" - else: - return f"{z} = ({' + '.join(inputs)}) / ((double) {len(inputs)});" - - def L_op(self, inputs, outputs, gout): - (gz,) = gout - retval = [gz / len(inputs)] * len(inputs) - return retval - - -mean = Mean(float_out, name="mean") - - class Mul(ScalarOp): identity = 1 commutative = True diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 8c86a834ea..efcc2500a7 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -1316,63 +1316,7 @@ def complex_from_polar(abs, angle): """Return complex-valued tensor from polar coordinate specification.""" -class Mean(FixedOpCAReduce): - __props__ = ("axis",) - nfunc_spec = ("mean", 1, 1) - - def __init__(self, axis=None): - super().__init__(ps.mean, axis) - assert self.axis is None or len(self.axis) == 1 - - def __str__(self): - if self.axis is not None: - args = ", ".join(str(x) for x in self.axis) - return f"Mean{{{args}}}" - else: - return "Mean" - - def _output_dtype(self, idtype): - # we want to protect against overflow - return "float64" - - def perform(self, node, inp, out): - (input,) = inp - (output,) = out - if self.axis is None: - axis = None - else: - axis = self.axis[0] - # numpy.asarray is needed as otherwise we can end up with a - # numpy scalar. - output[0] = np.asarray(np.mean(input, dtype="float64", axis=axis)) - - def c_code(self, node, name, inames, onames, sub): - ret = super().c_code(node, name, inames, onames, sub) - - if self.axis is not None: - return ret - - # TODO: c_code perform support only axis is None - return ( - ret - + f""" - *((double *)PyArray_DATA({onames[0]})) /= PyArray_SIZE({inames[0]}); - """ - ) - - def clone(self, **kwargs): - axis = kwargs.get("axis", self.axis) - return type(self)(axis=axis) - - -# TODO: implement the grad. When done and tested, you can make this the default -# version. -# def grad(self, (x,), (gout,)): -# import pdb;pdb.set_trace() -# return grad(mean(x, self.axis, op=False),[x]) - - -def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None): +def mean(input, axis=None, dtype=None, keepdims=False, acc_dtype=None): """ Computes the mean value along the given axis(es) of a tensor `input`. @@ -1397,25 +1341,6 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None) be in a float type). If None, then we use the same rules as `sum()`. """ input = as_tensor_variable(input) - if op: - if dtype not in (None, "float64"): - raise NotImplementedError( - "The Mean op does not support the dtype argument, " - "and will always use float64. If you want to specify " - "the dtype, call tensor.mean(..., op=False).", - dtype, - ) - if acc_dtype not in (None, "float64"): - raise NotImplementedError( - "The Mean op does not support the acc_dtype argument, " - "and will always use float64. If you want to specify " - "acc_dtype, call tensor.mean(..., op=False).", - dtype, - ) - out = Mean(axis)(input) - if keepdims: - out = makeKeepDims(input, out, axis) - return out if dtype is not None: # The summation will be done with the specified dtype. diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 4c13004409..72150b01ae 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -15,15 +15,15 @@ from pytensor.gradient import grad from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph -from pytensor.tensor.elemwise import DimShuffle -from pytensor.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum +from pytensor.tensor.elemwise import CAReduce, DimShuffle +from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from tests.link.numba.test_basic import ( compare_numba_and_py, scalar_my_multi_out, set_test_value, ) -from tests.tensor.test_elemwise import TestElemwise +from tests.tensor.test_elemwise import TestElemwise, careduce_benchmark_tester rng = np.random.default_rng(42849) @@ -249,24 +249,12 @@ def test_Dimshuffle_non_contiguous(): ( lambda x, axis=None, dtype=None, acc_dtype=None: All(axis)(x), 0, - set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), + set_test_value(pt.vector(dtype="bool"), np.array([False, True, False])), ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Any(axis)(x), 0, - set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x), - 0, - set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x), - 0, - set_test_value( - pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), + set_test_value(pt.vector(dtype="bool"), np.array([False, True, False])), ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Sum( @@ -313,6 +301,24 @@ def test_Dimshuffle_non_contiguous(): pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) ), ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: Prod( + axis=axis, dtype=dtype, acc_dtype=acc_dtype + )(x), + (), # Empty axes would normally be rewritten away, but we want to test it still works + set_test_value( + pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) + ), + ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: Prod( + axis=axis, dtype=dtype, acc_dtype=acc_dtype + )(x), + None, + set_test_value( + pt.scalar(), np.array(99.0, dtype=config.floatX) + ), # Scalar input would normally be rewritten away, but we want to test it still works + ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Prod( axis=axis, dtype=dtype, acc_dtype=acc_dtype @@ -379,7 +385,7 @@ def test_CAReduce(careduce_fn, axis, v): g = careduce_fn(v, axis=axis) g_fg = FunctionGraph(outputs=[g]) - compare_numba_and_py( + fn, _ = compare_numba_and_py( g_fg, [ i.tag.test_value @@ -387,6 +393,10 @@ def test_CAReduce(careduce_fn, axis, v): if not isinstance(i, SharedVariable | Constant) ], ) + # Confirm CAReduce is in the compiled function + fn.dprint() + [node] = fn.maker.fgraph.apply_nodes + assert isinstance(node.op, CAReduce) def test_scalar_Elemwise_Clip(): @@ -631,10 +641,10 @@ def test_logsumexp_benchmark(size, axis, benchmark): X_lse_fn = pytensor.function([X], X_lse, mode="NUMBA") # JIT compile first - _ = X_lse_fn(X_val) - res = benchmark(X_lse_fn, X_val) + res = X_lse_fn(X_val) exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True) np.testing.assert_array_almost_equal(res, exp_res) + benchmark(X_lse_fn, X_val) def test_fused_elemwise_benchmark(benchmark): @@ -665,3 +675,19 @@ def test_elemwise_out_type(): x_val = np.broadcast_to(np.zeros((3,)), (6, 3)) assert func(x_val).shape == (18,) + + +@pytest.mark.parametrize( + "axis", + (0, 1, 2, (0, 1), (0, 2), (1, 2), None), + ids=lambda x: f"axis={x}", +) +@pytest.mark.parametrize( + "c_contiguous", + (True, False), + ids=lambda x: f"c_contiguous={x}", +) +def test_numba_careduce_benchmark(axis, c_contiguous, benchmark): + return careduce_benchmark_tester( + axis, c_contiguous, mode="NUMBA", benchmark=benchmark + ) diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index e648869d4c..5aab9a95cc 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -43,7 +43,6 @@ log1p, log2, log10, - mean, mul, neg, neq, @@ -58,7 +57,7 @@ true_div, uint8, ) -from pytensor.tensor.type import fscalar, imatrix, iscalar, matrix +from pytensor.tensor.type import fscalar, imatrix, matrix from tests.link.test_link import make_function @@ -521,34 +520,6 @@ def test_constant(): assert c.dtype == "float32" -@pytest.mark.parametrize("mode", [Mode("py"), Mode("cvm")]) -def test_mean(mode): - a = iscalar("a") - b = iscalar("b") - z = mean(a, b) - z_fn = pytensor.function([a, b], z, mode=mode) - res = z_fn(1, 1) - assert np.allclose(res, 1.0) - - a = fscalar("a") - b = fscalar("b") - c = fscalar("c") - - z = mean(a, b, c) - - z_fn = pytensor.function([a, b, c], pytensor.grad(z, [a]), mode=mode) - res = z_fn(3, 4, 5) - assert np.allclose(res, 1 / 3) - - z_fn = pytensor.function([a, b, c], pytensor.grad(z, [b]), mode=mode) - res = z_fn(3, 4, 5) - assert np.allclose(res, 1 / 3) - - z = mean() - z_fn = pytensor.function([], z, mode=mode) - assert z_fn() == 0 - - def test_shape(): a = float32("a") assert isinstance(a.type, ScalarType) diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 7ccc2fd95c..c1644e41e1 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -983,27 +983,33 @@ def test_CAReduce(self): assert vect_node.inputs[0] is bool_tns -@pytest.mark.parametrize( - "axis", - (0, 1, 2, (0, 1), (0, 2), (1, 2), None), - ids=lambda x: f"axis={x}", -) -@pytest.mark.parametrize( - "c_contiguous", - (True, False), - ids=lambda x: f"c_contiguous={x}", -) -def test_careduce_benchmark(axis, c_contiguous, benchmark): +def careduce_benchmark_tester(axis, c_contiguous, mode, benchmark): N = 256 x_test = np.random.uniform(size=(N, N, N)) transpose_axis = (0, 1, 2) if c_contiguous else (2, 0, 1) x = pytensor.shared(x_test, name="x", shape=x_test.shape) out = x.transpose(transpose_axis).sum(axis=axis) - fn = pytensor.function([], out) + fn = pytensor.function([], out, mode=mode) np.testing.assert_allclose( fn(), x_test.transpose(transpose_axis).sum(axis=axis), ) benchmark(fn) + + +@pytest.mark.parametrize( + "axis", + (0, 1, 2, (0, 1), (0, 2), (1, 2), None), + ids=lambda x: f"axis={x}", +) +@pytest.mark.parametrize( + "c_contiguous", + (True, False), + ids=lambda x: f"c_contiguous={x}", +) +def test_c_careduce_benchmark(axis, c_contiguous, benchmark): + return careduce_benchmark_tester( + axis, c_contiguous, mode="FAST_RUN", benchmark=benchmark + ) diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 14bc2614e3..2d19ef0114 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -40,7 +40,6 @@ Argmax, Dot, Max, - Mean, Prod, ProdWithoutZeros, Sum, @@ -2587,15 +2586,6 @@ def test_mod_compile(): class TestInferShape(utt.InferShapeTester): - def test_Mean(self): - adtens3 = dtensor3() - adtens3_val = random(3, 4, 5) - aiscal_val = 2 - self._compile_and_check([adtens3], [Mean(None)(adtens3)], [adtens3_val], Mean) - self._compile_and_check( - [adtens3], [Mean(aiscal_val)(adtens3)], [adtens3_val], Mean - ) - def test_Max(self): adtens3 = dtensor3() adtens3_val = random(4, 5, 3)