-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reduce jitted function overhead #1101
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,9 @@ | ||
from typing import TYPE_CHECKING, Any | ||
|
||
import numpy as np | ||
|
||
import pytensor | ||
from pytensor.link.basic import JITLinker | ||
|
||
|
||
if TYPE_CHECKING: | ||
from pytensor.graph.basic import Variable | ||
|
||
|
||
class NumbaLinker(JITLinker): | ||
"""A `Linker` that JIT-compiles NumPy-based operations using Numba.""" | ||
|
||
def output_filter(self, var: "Variable", out: Any) -> Any: | ||
if not isinstance(var, np.ndarray) and isinstance( | ||
var.type, pytensor.tensor.TensorType | ||
): | ||
return var.type.filter(out, allow_downcast=True) | ||
|
||
return out | ||
|
||
Comment on lines
-16
to
-23
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was actually wrong, It probably was meant to be Anyway, if numba is returning a non-array where we expected an array it means something is wrong in our dispatch, and we should fix it there. |
||
def fgraph_convert(self, fgraph, **kwargs): | ||
from pytensor.link.numba.dispatch import numba_funcify | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,3 @@ | ||
import copy | ||
from typing import Any | ||
|
||
from pytensor.graph.basic import Variable | ||
from pytensor.link.basic import JITLinker | ||
from pytensor.link.utils import unique_name_generator | ||
|
||
|
@@ -13,14 +9,6 @@ def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | ||
self.gen_functors = [] | ||
|
||
def input_filter(self, inp: Any) -> Any: | ||
from pytensor.link.pytorch.dispatch import pytorch_typify | ||
|
||
return pytorch_typify(inp) | ||
|
||
def output_filter(self, var: Variable, out: Any) -> Any: | ||
return out.cpu() | ||
|
||
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): | ||
from pytensor.link.pytorch.dispatch import pytorch_funcify | ||
|
||
|
@@ -49,6 +37,8 @@ def conversion_func_register(*args, **kwargs): | |
def jit_compile(self, fn): | ||
import torch | ||
|
||
from pytensor.link.pytorch.dispatch import pytorch_typify | ||
|
||
class wrapper: | ||
""" | ||
Pytorch would fail compiling our method when trying | ||
|
@@ -62,7 +52,7 @@ class wrapper: | |
|
||
def __init__(self, fn, gen_functors): | ||
self.fn = torch.compile(fn) | ||
self.gen_functors = copy.copy(gen_functors) | ||
self.gen_functors = gen_functors.copy() | ||
|
||
def __call__(self, *args, **kwargs): | ||
import pytensor.link.utils | ||
|
@@ -83,9 +73,15 @@ def __call__(self, *args, **kwargs): | |
def __del__(self): | ||
del self.gen_functors | ||
|
||
res = wrapper(fn, self.gen_functors) | ||
inner_fn = wrapper(fn, self.gen_functors) | ||
self.gen_functors = [] | ||
return res | ||
|
||
# Torch does not accept numpy inputs and may return GPU objects | ||
def fn(*inputs, inner_fn=inner_fn): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe just use the closure scoped function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wdym? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I usually see us just reference There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I think this is a bit faster, but I am not positive |
||
outs = inner_fn(*(pytorch_typify(inp) for inp in inputs)) | ||
return tuple(out.cpu().numpy() for out in outs) | ||
|
||
return fn | ||
|
||
def create_thunk_inputs(self, storage_map): | ||
thunk_inputs = [] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -882,3 +882,20 @@ def test_cache_warning_suppressed(): | |
|
||
x_test = np.random.uniform(size=5) | ||
np.testing.assert_allclose(fn(x_test), scipy.special.psi(x_test) * 2) | ||
|
||
|
||
@pytest.mark.parametrize("mode", ("default", "trust_input", "direct")) | ||
def test_function_overhead(mode, benchmark): | ||
x = pt.vector("x") | ||
out = pt.exp(x) | ||
|
||
fn = function([x], out, mode="NUMBA") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we also test this on torch? Do you want me to do that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there's much more going on with Torch that we may want to think about more carefully which sort of graph to benchmark and also cpu/gpu. Let's open an issue for when the torch backend is a bit more established? |
||
if mode == "trust_input": | ||
fn.trust_input = True | ||
elif mode == "direct": | ||
fn = fn.vm.jit_fn | ||
|
||
test_x = np.zeros(1000) | ||
assert np.sum(fn(test_x)) == 1000 | ||
|
||
benchmark(fn, test_x) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does benchmark fail if some amount of time is elapsed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have a job that should show a warning but I don't think it's working. It at least allows us to do git bisect or something like that over commits to see if performance was dropped |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this gc happening somewhere else now? Why was it here if it could just be removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't do anything for jitted functions where you don't control intermediate allocations