Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[JAX] Use a hand-written Python wrapper for CompiledFunction.
A hand-written Python wrapper gives us a number of benefits: * it makes it cheaper to cast `self` to a `CompiledFunction` in the `__call__` method, since we no longer need to use pybind11's dynamic casting logic. * we can avoid building an empty kwarg dict even when no kwargs are passed, which is something pybind11 does behind the scenes. * the change allows us to prepare for using faster calling conventions, e.g. METH_FASTCALL in Python 3.7 or newer, which would allow the interpreter to avoid forming a tuple for the positional arguments. In passing, remove unused arguments from PopulateCacheEntry. Benchmarks (N.B. incorporating the PyBuffer hand-written PyObject change as well): ``` name old cpu/op new cpu/op delta jit_simple_dispatch 11.3µs ± 6% 10.0µs ± 8% -11.17% (p=0.000 n=20+20) name old time/op new time/op delta jit_simple_dispatch 11.3µs ± 6% 10.0µs ±10% -11.25% (p=0.000 n=20+19) ``` PiperOrigin-RevId: 366498581
- Loading branch information