From 214d273d8cbf05b7c8ccf3b96f778fb581425e52 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sun, 21 Mar 2021 19:41:04 -0700 Subject: [PATCH] undo changes to host_callback (not needed anymore) --- jax/_src/lax/lax.py | 4 +++- jax/experimental/host_callback.py | 2 -- tests/host_callback_test.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index e358ab32a28a..b4a4568f08b9 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -443,7 +443,9 @@ def convert_element_type(operand: Array, new_dtype: DType = None, # Python has big integers, but convert_element_type(2 ** 100, np.float32) need # not be an error since the target dtype fits the value. Handle this case by - # converting to a NumPy array before calling bind. + # converting to a NumPy array before calling bind. Without this step, we'd + # first canonicalize the input to a value of dtype int32 or int64, leading to + # an overflow error. if type(operand) is int: operand = np.asarray(operand, new_dtype) diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 961ad1cadd1b..a84a23b0407f 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -934,7 +934,6 @@ def _outside_call_jvp_rule(primals, tangents, **params): if not params["identity"]: raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.") tangent_instantiated = tuple(map(_instantiate_zeros, primals, tangents)) - tangent_instantiated = tuple(map(ad.replace_float0s, primals, tangent_instantiated)) arg_treedef = params["arg_treedef"] # The argument to the jvp tap is a pair of the tapped primals and tangents @@ -947,7 +946,6 @@ def _outside_call_jvp_rule(primals, tangents, **params): arg_treedef=jvp_arg_treedef, )) out_primals_tapped, out_tangents_tapped = util.split_list(out_all, [len(primals)]) - out_tangents_tapped = map(ad.recast_to_float0, out_primals_tapped, out_tangents_tapped) return tuple(out_primals_tapped), tuple(out_tangents_tapped) diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 87934ccff2f6..a0d74caca76f 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -1028,7 +1028,7 @@ def func(x, yint): 2 ) transforms: ['jvp', 'transpose'] what: pair ( 2.00 - 0 )""", testing_stream.output) + False )""", testing_stream.output) testing_stream.reset() def test_tap_vmap(self): @@ -1590,8 +1590,8 @@ def padded_sum(x): ( 3 ) ) ) ( ( [0. 0.1 0.2 0.3 0.4] [0. 0.2 0.4 0.6 0.8] ) - ( ( 0 ) - ( 0 ) ) ) )""", testing_stream.output) + ( ( False ) + ( False ) ) ) )""", testing_stream.output) testing_stream.reset() # Now with JIT