Skip to content

Commit

Permalink
move logic to traceable
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Mar 22, 2021
1 parent 8c3125c commit fe4d12c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 20 deletions.
18 changes: 9 additions & 9 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,8 @@ def convert_element_type(operand: Array, new_dtype: DType = None,
if hasattr(operand, '__jax_array__'):
operand = operand.__jax_array__()

# Note: don't canonicalize old_dtype because x64 context might
# cause un-canonicalized operands to be passed in.
# Don't canonicalize old_dtype because x64 context might cause
# un-canonicalized operands to be passed in.
old_dtype = np.result_type(operand)
old_weak_type = dtypes.is_weakly_typed(operand)

Expand All @@ -441,6 +441,12 @@ def convert_element_type(operand: Array, new_dtype: DType = None,
msg = "Casting complex values to real discards the imaginary part"
warnings.warn(msg, np.ComplexWarning, stacklevel=2)

# Python has big integers, but convert_element_type(2 ** 100, np.float32) need
# not be an error since the target dtype fits the value. Handle this case by
# converting to a NumPy array before calling bind.
if type(operand) is int:
operand = np.asarray(operand, new_dtype)

if ((old_dtype, old_weak_type) == (new_dtype, new_weak_type)
and isinstance(operand, (core.Tracer, xla.DeviceArray))):
return operand
Expand Down Expand Up @@ -2658,12 +2664,6 @@ def _minmax_translation_rule(c, x, y, *, minmax=None, cmp=None):
ad.defjvp_zero(lt_p)


def _convert_element_type_impl(operand, *, new_dtype, weak_type):
if dtypes.is_python_scalar(operand):
operand = np.asarray(operand, dtype=new_dtype)
return xla.apply_primitive(convert_element_type_p, operand,
new_dtype=new_dtype, weak_type=weak_type)

def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type):
return operand.shape

Expand Down Expand Up @@ -2699,7 +2699,7 @@ def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type):
return convert_element_type_p.bind(tangent, new_dtype=new_dtype, weak_type=weak_type)

convert_element_type_p = core.convert_element_type_p
convert_element_type_p.def_impl(_convert_element_type_impl)
convert_element_type_p.def_impl(partial(xla.apply_primitive, convert_element_type_p))
convert_element_type_p.def_abstract_eval(
partial(standard_abstract_eval, convert_element_type_p,
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
Expand Down
14 changes: 3 additions & 11 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2395,17 +2395,9 @@ def f(_):

def test_large_python_int_to_float(self):
# https://github.com/google/jax/pull/6165
# We skip checks because otherwise we end up calling valid_jaxtype(2**100),
# which tries to form a ConcreteArray with that value and thus leads to a
# NumPy OverflowError. It's true that 2**100 does not inhabit a jax type,
# but as an issue of Python embedding we can handle operations like
# lax.convert_element_type(2 ** 100, jnp.float32) as in the tests below.
# That is, lax.convert_element_type(2 ** 100, jnp.int32) is an error while
# lax.convert_element_type(2 ** 100, jnp.float32) is not.
with jax.core.skipping_checks():
jnp.multiply(2 ** 100, 3.) # doesn't crash
out = lax.convert_element_type(2 ** 100, jnp.float32) # doesn't crash
self.assertArraysEqual(out, np.float32(2 ** 100))
jnp.multiply(2 ** 100, 3.) # doesn't crash
out = lax.convert_element_type(2 ** 100, jnp.float32) # doesn't crash
self.assertArraysEqual(out, np.float32(2 ** 100))


class RematTest(jtu.JaxTestCase):
Expand Down

0 comments on commit fe4d12c

Please sign in to comment.