Fall back to CPU in GPU kernels #4380
-
We implement some custom XLA calls that are right now only available on CPU. Until we have a full GPU implementation, we would like to support calling our custom calls from jitted GPU functions by copying to the host and calling the CPU custom call targets. Is this possible somehow? def hybrid_func(x):
# do math on GPU
x = x + 1
# use our custom call that is being executed on CPU
# this copies x to the host, does the computation, and copies back to device
x = my_custom_xla_op(x)
# continue with GPU stuff
# ...
return x
hybrid_func = jax.jit(hybrid_func, platform=jax.devices('gpu')[0])
hybrid_func(jnp.zeros(10)) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
I think the custom call targets you register on the GPU backend are really host functions (that get device pointers, and can call device kernels). So it would be fine to pull a value back to the host, perform computation there (e.g. by calling a CPU custom call target function), and push it back to the device all within a function registered as a GPU custom call. |
Beta Was this translation helpful? Give feedback.
I think the custom call targets you register on the GPU backend are really host functions (that get device pointers, and can call device kernels). So it would be fine to pull a value back to the host, perform computation there (e.g. by calling a CPU custom call target function), and push it back to the device all within a function registered as a GPU custom call.