From 8c3125c172b5868c29fe58863457362cbd144f3d Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sun, 21 Mar 2021 15:53:24 -0700 Subject: [PATCH 1/3] fix convert_element_type on large Py int inputs --- jax/_src/lax/lax.py | 8 +++++++- jax/abstract_arrays.py | 5 ++--- jax/experimental/host_callback.py | 2 ++ tests/api_test.py | 14 ++++++++++++++ tests/host_callback_test.py | 6 +++--- tests/random_test.py | 2 ++ 6 files changed, 30 insertions(+), 7 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 4a2887045fb1..4d244f927f30 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2658,6 +2658,12 @@ def _minmax_translation_rule(c, x, y, *, minmax=None, cmp=None): ad.defjvp_zero(lt_p) +def _convert_element_type_impl(operand, *, new_dtype, weak_type): + if dtypes.is_python_scalar(operand): + operand = np.asarray(operand, dtype=new_dtype) + return xla.apply_primitive(convert_element_type_p, operand, + new_dtype=new_dtype, weak_type=weak_type) + def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type): return operand.shape @@ -2693,7 +2699,7 @@ def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type): return convert_element_type_p.bind(tangent, new_dtype=new_dtype, weak_type=weak_type) convert_element_type_p = core.convert_element_type_p -convert_element_type_p.def_impl(partial(xla.apply_primitive, convert_element_type_p)) +convert_element_type_p.def_impl(_convert_element_type_impl) convert_element_type_p.def_abstract_eval( partial(standard_abstract_eval, convert_element_type_p, _convert_element_type_shape_rule, _convert_element_type_dtype_rule, diff --git a/jax/abstract_arrays.py b/jax/abstract_arrays.py index 734957cdd1b0..57c75a0774c8 100644 --- a/jax/abstract_arrays.py +++ b/jax/abstract_arrays.py @@ -68,9 +68,8 @@ def _zeros_like_python_scalar(t, x): return np.array(0, dtypes.python_scalar_dtypes[t]) def _make_concrete_python_scalar(t, x): - return ConcreteArray( - np.array(x, dtype=dtypes.python_scalar_dtypes[t]), - weak_type=True) + return ConcreteArray(np.array(x, dtype=dtypes.python_scalar_dtypes[t]), + weak_type=True) for t in dtypes.python_scalar_dtypes: core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t) diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index a84a23b0407f..961ad1cadd1b 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -934,6 +934,7 @@ def _outside_call_jvp_rule(primals, tangents, **params): if not params["identity"]: raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.") tangent_instantiated = tuple(map(_instantiate_zeros, primals, tangents)) + tangent_instantiated = tuple(map(ad.replace_float0s, primals, tangent_instantiated)) arg_treedef = params["arg_treedef"] # The argument to the jvp tap is a pair of the tapped primals and tangents @@ -946,6 +947,7 @@ def _outside_call_jvp_rule(primals, tangents, **params): arg_treedef=jvp_arg_treedef, )) out_primals_tapped, out_tangents_tapped = util.split_list(out_all, [len(primals)]) + out_tangents_tapped = map(ad.recast_to_float0, out_primals_tapped, out_tangents_tapped) return tuple(out_primals_tapped), tuple(out_tangents_tapped) diff --git a/tests/api_test.py b/tests/api_test.py index 4b38afd46111..f40d4c19e880 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2393,6 +2393,20 @@ def f(_): expected = jnp.arange(1) + 1 self.assertAllClose(ans, expected) + def test_large_python_int_to_float(self): + # https://github.com/google/jax/pull/6165 + # We skip checks because otherwise we end up calling valid_jaxtype(2**100), + # which tries to form a ConcreteArray with that value and thus leads to a + # NumPy OverflowError. It's true that 2**100 does not inhabit a jax type, + # but as an issue of Python embedding we can handle operations like + # lax.convert_element_type(2 ** 100, jnp.float32) as in the tests below. + # That is, lax.convert_element_type(2 ** 100, jnp.int32) is an error while + # lax.convert_element_type(2 ** 100, jnp.float32) is not. + with jax.core.skipping_checks(): + jnp.multiply(2 ** 100, 3.) # doesn't crash + out = lax.convert_element_type(2 ** 100, jnp.float32) # doesn't crash + self.assertArraysEqual(out, np.float32(2 ** 100)) + class RematTest(jtu.JaxTestCase): diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index a0d74caca76f..87934ccff2f6 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -1028,7 +1028,7 @@ def func(x, yint): 2 ) transforms: ['jvp', 'transpose'] what: pair ( 2.00 - False )""", testing_stream.output) + 0 )""", testing_stream.output) testing_stream.reset() def test_tap_vmap(self): @@ -1590,8 +1590,8 @@ def padded_sum(x): ( 3 ) ) ) ( ( [0. 0.1 0.2 0.3 0.4] [0. 0.2 0.4 0.6 0.8] ) - ( ( False ) - ( False ) ) ) )""", testing_stream.output) + ( ( 0 ) + ( 0 ) ) ) )""", testing_stream.output) testing_stream.reset() # Now with JIT diff --git a/tests/random_test.py b/tests/random_test.py index 41fa226bc71a..3eb3e6b54bf5 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -746,6 +746,8 @@ def f(x): grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2)) def testNoOpByOpUnderHash(self): + if not config.omnistaging_enabled: + raise SkipTest("test requires omnistaging") def fail(*args, **kwargs): assert False apply_primitive, xla.apply_primitive = xla.apply_primitive, fail try: From fe4d12c10fb0cb50b811b10ecd0b417db1cb886f Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sun, 21 Mar 2021 19:38:12 -0700 Subject: [PATCH 2/3] move logic to traceable --- jax/_src/lax/lax.py | 18 +++++++++--------- tests/api_test.py | 14 +++----------- 2 files changed, 12 insertions(+), 20 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 4d244f927f30..e358ab32a28a 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -428,8 +428,8 @@ def convert_element_type(operand: Array, new_dtype: DType = None, if hasattr(operand, '__jax_array__'): operand = operand.__jax_array__() - # Note: don't canonicalize old_dtype because x64 context might - # cause un-canonicalized operands to be passed in. + # Don't canonicalize old_dtype because x64 context might cause + # un-canonicalized operands to be passed in. old_dtype = np.result_type(operand) old_weak_type = dtypes.is_weakly_typed(operand) @@ -441,6 +441,12 @@ def convert_element_type(operand: Array, new_dtype: DType = None, msg = "Casting complex values to real discards the imaginary part" warnings.warn(msg, np.ComplexWarning, stacklevel=2) + # Python has big integers, but convert_element_type(2 ** 100, np.float32) need + # not be an error since the target dtype fits the value. Handle this case by + # converting to a NumPy array before calling bind. + if type(operand) is int: + operand = np.asarray(operand, new_dtype) + if ((old_dtype, old_weak_type) == (new_dtype, new_weak_type) and isinstance(operand, (core.Tracer, xla.DeviceArray))): return operand @@ -2658,12 +2664,6 @@ def _minmax_translation_rule(c, x, y, *, minmax=None, cmp=None): ad.defjvp_zero(lt_p) -def _convert_element_type_impl(operand, *, new_dtype, weak_type): - if dtypes.is_python_scalar(operand): - operand = np.asarray(operand, dtype=new_dtype) - return xla.apply_primitive(convert_element_type_p, operand, - new_dtype=new_dtype, weak_type=weak_type) - def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type): return operand.shape @@ -2699,7 +2699,7 @@ def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type): return convert_element_type_p.bind(tangent, new_dtype=new_dtype, weak_type=weak_type) convert_element_type_p = core.convert_element_type_p -convert_element_type_p.def_impl(_convert_element_type_impl) +convert_element_type_p.def_impl(partial(xla.apply_primitive, convert_element_type_p)) convert_element_type_p.def_abstract_eval( partial(standard_abstract_eval, convert_element_type_p, _convert_element_type_shape_rule, _convert_element_type_dtype_rule, diff --git a/tests/api_test.py b/tests/api_test.py index f40d4c19e880..3bf5a9c31852 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2395,17 +2395,9 @@ def f(_): def test_large_python_int_to_float(self): # https://github.com/google/jax/pull/6165 - # We skip checks because otherwise we end up calling valid_jaxtype(2**100), - # which tries to form a ConcreteArray with that value and thus leads to a - # NumPy OverflowError. It's true that 2**100 does not inhabit a jax type, - # but as an issue of Python embedding we can handle operations like - # lax.convert_element_type(2 ** 100, jnp.float32) as in the tests below. - # That is, lax.convert_element_type(2 ** 100, jnp.int32) is an error while - # lax.convert_element_type(2 ** 100, jnp.float32) is not. - with jax.core.skipping_checks(): - jnp.multiply(2 ** 100, 3.) # doesn't crash - out = lax.convert_element_type(2 ** 100, jnp.float32) # doesn't crash - self.assertArraysEqual(out, np.float32(2 ** 100)) + jnp.multiply(2 ** 100, 3.) # doesn't crash + out = lax.convert_element_type(2 ** 100, jnp.float32) # doesn't crash + self.assertArraysEqual(out, np.float32(2 ** 100)) class RematTest(jtu.JaxTestCase): From 214d273d8cbf05b7c8ccf3b96f778fb581425e52 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sun, 21 Mar 2021 19:41:04 -0700 Subject: [PATCH 3/3] undo changes to host_callback (not needed anymore) --- jax/_src/lax/lax.py | 4 +++- jax/experimental/host_callback.py | 2 -- tests/host_callback_test.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index e358ab32a28a..b4a4568f08b9 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -443,7 +443,9 @@ def convert_element_type(operand: Array, new_dtype: DType = None, # Python has big integers, but convert_element_type(2 ** 100, np.float32) need # not be an error since the target dtype fits the value. Handle this case by - # converting to a NumPy array before calling bind. + # converting to a NumPy array before calling bind. Without this step, we'd + # first canonicalize the input to a value of dtype int32 or int64, leading to + # an overflow error. if type(operand) is int: operand = np.asarray(operand, new_dtype) diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 961ad1cadd1b..a84a23b0407f 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -934,7 +934,6 @@ def _outside_call_jvp_rule(primals, tangents, **params): if not params["identity"]: raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.") tangent_instantiated = tuple(map(_instantiate_zeros, primals, tangents)) - tangent_instantiated = tuple(map(ad.replace_float0s, primals, tangent_instantiated)) arg_treedef = params["arg_treedef"] # The argument to the jvp tap is a pair of the tapped primals and tangents @@ -947,7 +946,6 @@ def _outside_call_jvp_rule(primals, tangents, **params): arg_treedef=jvp_arg_treedef, )) out_primals_tapped, out_tangents_tapped = util.split_list(out_all, [len(primals)]) - out_tangents_tapped = map(ad.recast_to_float0, out_primals_tapped, out_tangents_tapped) return tuple(out_primals_tapped), tuple(out_tangents_tapped) diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 87934ccff2f6..a0d74caca76f 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -1028,7 +1028,7 @@ def func(x, yint): 2 ) transforms: ['jvp', 'transpose'] what: pair ( 2.00 - 0 )""", testing_stream.output) + False )""", testing_stream.output) testing_stream.reset() def test_tap_vmap(self): @@ -1590,8 +1590,8 @@ def padded_sum(x): ( 3 ) ) ) ( ( [0. 0.1 0.2 0.3 0.4] [0. 0.2 0.4 0.6 0.8] ) - ( ( 0 ) - ( 0 ) ) ) )""", testing_stream.output) + ( ( False ) + ( False ) ) ) )""", testing_stream.output) testing_stream.reset() # Now with JIT