From fe4d12c10fb0cb50b811b10ecd0b417db1cb886f Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sun, 21 Mar 2021 19:38:12 -0700 Subject: [PATCH] move logic to traceable --- jax/_src/lax/lax.py | 18 +++++++++--------- tests/api_test.py | 14 +++----------- 2 files changed, 12 insertions(+), 20 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 4d244f927f30..e358ab32a28a 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -428,8 +428,8 @@ def convert_element_type(operand: Array, new_dtype: DType = None, if hasattr(operand, '__jax_array__'): operand = operand.__jax_array__() - # Note: don't canonicalize old_dtype because x64 context might - # cause un-canonicalized operands to be passed in. + # Don't canonicalize old_dtype because x64 context might cause + # un-canonicalized operands to be passed in. old_dtype = np.result_type(operand) old_weak_type = dtypes.is_weakly_typed(operand) @@ -441,6 +441,12 @@ def convert_element_type(operand: Array, new_dtype: DType = None, msg = "Casting complex values to real discards the imaginary part" warnings.warn(msg, np.ComplexWarning, stacklevel=2) + # Python has big integers, but convert_element_type(2 ** 100, np.float32) need + # not be an error since the target dtype fits the value. Handle this case by + # converting to a NumPy array before calling bind. + if type(operand) is int: + operand = np.asarray(operand, new_dtype) + if ((old_dtype, old_weak_type) == (new_dtype, new_weak_type) and isinstance(operand, (core.Tracer, xla.DeviceArray))): return operand @@ -2658,12 +2664,6 @@ def _minmax_translation_rule(c, x, y, *, minmax=None, cmp=None): ad.defjvp_zero(lt_p) -def _convert_element_type_impl(operand, *, new_dtype, weak_type): - if dtypes.is_python_scalar(operand): - operand = np.asarray(operand, dtype=new_dtype) - return xla.apply_primitive(convert_element_type_p, operand, - new_dtype=new_dtype, weak_type=weak_type) - def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type): return operand.shape @@ -2699,7 +2699,7 @@ def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type): return convert_element_type_p.bind(tangent, new_dtype=new_dtype, weak_type=weak_type) convert_element_type_p = core.convert_element_type_p -convert_element_type_p.def_impl(_convert_element_type_impl) +convert_element_type_p.def_impl(partial(xla.apply_primitive, convert_element_type_p)) convert_element_type_p.def_abstract_eval( partial(standard_abstract_eval, convert_element_type_p, _convert_element_type_shape_rule, _convert_element_type_dtype_rule, diff --git a/tests/api_test.py b/tests/api_test.py index f40d4c19e880..3bf5a9c31852 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2395,17 +2395,9 @@ def f(_): def test_large_python_int_to_float(self): # https://github.com/google/jax/pull/6165 - # We skip checks because otherwise we end up calling valid_jaxtype(2**100), - # which tries to form a ConcreteArray with that value and thus leads to a - # NumPy OverflowError. It's true that 2**100 does not inhabit a jax type, - # but as an issue of Python embedding we can handle operations like - # lax.convert_element_type(2 ** 100, jnp.float32) as in the tests below. - # That is, lax.convert_element_type(2 ** 100, jnp.int32) is an error while - # lax.convert_element_type(2 ** 100, jnp.float32) is not. - with jax.core.skipping_checks(): - jnp.multiply(2 ** 100, 3.) # doesn't crash - out = lax.convert_element_type(2 ** 100, jnp.float32) # doesn't crash - self.assertArraysEqual(out, np.float32(2 ** 100)) + jnp.multiply(2 ** 100, 3.) # doesn't crash + out = lax.convert_element_type(2 ** 100, jnp.float32) # doesn't crash + self.assertArraysEqual(out, np.float32(2 ** 100)) class RematTest(jtu.JaxTestCase):