From 60029e7d098b86b6cee82b5c52a33f30fa7cd4fc Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 5 Oct 2023 10:53:08 -0700 Subject: [PATCH] lax.abs: better error for unsigned inputs --- jax/_src/lax/lax.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index e45de759eab0..a0e213b3a2fb 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1710,6 +1710,7 @@ def _nary_lower_hlo(op: Callable, ctx, _complex_elem_types = {np.float32, np.float64} _int = {np.integer} _bool = {np.bool_} +_signedint = {np.signedinteger} _num = _int | _float | _complex _any = _int | _float | _complex | _bool @@ -1944,7 +1945,7 @@ def _conj_transpose_rule(t, x, *, input_dtype): ad.primitive_jvps[conj_p] = partial(ad.linear_jvp, conj_p) ad.primitive_transposes[conj_p] = _conj_transpose_rule -abs_p = unop(_complex_basetype, _num, 'abs') +abs_p = unop(_complex_basetype, _signedint | _float | _complex, 'abs') mlir.register_lowering(abs_p, partial(_nary_lower_hlo, hlo.AbsOp)) def _abs_jvp_rule(g, ans, x):