Skip to content

Commit

Permalink
undo changes to host_callback (not needed anymore)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Mar 22, 2021
1 parent fe4d12c commit 214d273
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 3 additions & 1 deletion jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,9 @@ def convert_element_type(operand: Array, new_dtype: DType = None,

# Python has big integers, but convert_element_type(2 ** 100, np.float32) need
# not be an error since the target dtype fits the value. Handle this case by
# converting to a NumPy array before calling bind.
# converting to a NumPy array before calling bind. Without this step, we'd
# first canonicalize the input to a value of dtype int32 or int64, leading to
# an overflow error.
if type(operand) is int:
operand = np.asarray(operand, new_dtype)

Expand Down
2 changes: 0 additions & 2 deletions jax/experimental/host_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,6 @@ def _outside_call_jvp_rule(primals, tangents, **params):
if not params["identity"]:
raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
tangent_instantiated = tuple(map(_instantiate_zeros, primals, tangents))
tangent_instantiated = tuple(map(ad.replace_float0s, primals, tangent_instantiated))

arg_treedef = params["arg_treedef"]
# The argument to the jvp tap is a pair of the tapped primals and tangents
Expand All @@ -947,7 +946,6 @@ def _outside_call_jvp_rule(primals, tangents, **params):
arg_treedef=jvp_arg_treedef,
))
out_primals_tapped, out_tangents_tapped = util.split_list(out_all, [len(primals)])
out_tangents_tapped = map(ad.recast_to_float0, out_primals_tapped, out_tangents_tapped)
return tuple(out_primals_tapped), tuple(out_tangents_tapped)


Expand Down
6 changes: 3 additions & 3 deletions tests/host_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,7 @@ def func(x, yint):
2 )
transforms: ['jvp', 'transpose'] what: pair
( 2.00
0 )""", testing_stream.output)
False )""", testing_stream.output)
testing_stream.reset()

def test_tap_vmap(self):
Expand Down Expand Up @@ -1590,8 +1590,8 @@ def padded_sum(x):
( 3 ) ) )
( ( [0. 0.1 0.2 0.3 0.4]
[0. 0.2 0.4 0.6 0.8] )
( ( 0 )
( 0 ) ) ) )""", testing_stream.output)
( ( False )
( False ) ) ) )""", testing_stream.output)
testing_stream.reset()

# Now with JIT
Expand Down

0 comments on commit 214d273

Please sign in to comment.