Skip to content

Commit

Permalink
[JAX] Use a hand-written Python wrapper for CompiledFunction.
Browse files Browse the repository at this point in the history
A hand-written Python wrapper gives us a number of benefits:
* it makes it cheaper to cast `self` to a `CompiledFunction` in the `__call__` method, since we no longer need to use pybind11's dynamic casting logic.
* we can avoid building an empty kwarg dict even when no kwargs are passed, which is something pybind11 does behind the scenes.
* the change allows us to prepare for using faster calling conventions, e.g. METH_FASTCALL in Python 3.7 or newer, which would allow the interpreter to avoid forming a tuple for the positional arguments.

In passing, remove unused arguments from PopulateCacheEntry.

Benchmarks (N.B. incorporating the PyBuffer hand-written PyObject change as well):
```
name                 old cpu/op  new cpu/op  delta
jit_simple_dispatch  11.3µs ± 6%  10.0µs ± 8%  -11.17%  (p=0.000 n=20+20)

name                 old time/op             new time/op             delta
jit_simple_dispatch  11.3µs ± 6%             10.0µs ±10%  -11.25%        (p=0.000 n=20+19)
```

PiperOrigin-RevId: 366498581
  • Loading branch information
hawkinsp authored and ChexDev [email protected] committed Apr 6, 2021
1 parent 67a8c95 commit 39f2ccc
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion chex/_src/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ def _is_traceable(fn):

tokens = (
"_python_jit.", # PyJIT in Python ver. < 3.7
"_cpp_jit.", # CppJIT in Python ver. < 3.7
"_cpp_jit.", # CppJIT in Python ver. < 3.7 (deprecated)
".reraise_with_filtered_traceback", # JIT in Python ver. >= 3.7
"CompiledFunction", # C++ JIT in jaxlib 0.1.66 or newer.
"pmap.", # pmap
"vmap.", # vmap
)
Expand Down

0 comments on commit 39f2ccc

Please sign in to comment.