-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reduce jitted function overhead #1101
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly looked at the torch linker and some q's around the tests.
x = pt.vector("x") | ||
out = pt.exp(x) | ||
|
||
fn = function([x], out, mode="NUMBA") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also test this on torch? Do you want me to do that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there's much more going on with Torch that we may want to think about more carefully which sort of graph to benchmark and also cpu/gpu. Let's open an issue for when the torch backend is a bit more established?
test_x = np.zeros(1000) | ||
assert np.sum(fn(test_x)) == 1000 | ||
|
||
benchmark(fn, test_x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does benchmark fail if some amount of time is elapsed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a job that should show a warning but I don't think it's working. It at least allows us to do git bisect or something like that over commits to see if performance was dropped
pytensor/link/pytorch/linker.py
Outdated
# Torch does not accept numpy inputs and may return GPU objects | ||
def fn(*inputs, inner_fn=inner_fn): | ||
outs = inner_fn(*(pytorch_typify(inp) for inp in inputs)) | ||
return tuple(out.cpu() for out in outs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a few PR's open that might cobble this fyi.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whichever gets merged first, we can solve the conflicts after
inner_fn = torch.compile(fn) | ||
|
||
# Torch does not accept numpy inputs and may return GPU objects | ||
def fn(*inputs, inner_fn=inner_fn): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe just use the closure scoped function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wdym?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I usually see us just reference inner_fn
without declaring an optional param. I think I've seen both but whatever works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I think this is a bit faster, but I am not positive
862f158
to
b11db68
Compare
def output_filter(self, var: "Variable", out: Any) -> Any: | ||
if not isinstance(var, np.ndarray) and isinstance( | ||
var.type, pytensor.tensor.TensorType | ||
): | ||
return var.type.filter(out, allow_downcast=True) | ||
|
||
return out | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was actually wrong, It probably was meant to be if not isinstance(out, np.ndarray)
. This way it always triggered var.type.fliter
which was quite slow.
Anyway, if numba is returning a non-array where we expected an array it means something is wrong in our dispatch, and we should fix it there.
311f83e
to
a97734e
Compare
a97734e
to
d546c36
Compare
d546c36
to
4a96d91
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, I left a dumb comment that you can ignore (the gc stuff was re-implemented in the last commit).
@@ -701,34 +696,7 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None): | |||
compute_map, nodes, input_storage, output_storage, storage_map | |||
) | |||
|
|||
computed, last_user = gc_helper(nodes) | |||
|
|||
if self.allow_gc: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this gc happening somewhere else now? Why was it here if it could just be removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't do anything for jitted functions where you don't control intermediate allocations
This reduces the overhead on the benchmarked numba function from ~10us to 2.5us on my machine, with trust_input=True
It will hopefully go further down when we remove deprecated function stuff like output_subset and dict returns
📚 Documentation preview 📚: https://pytensor--1101.org.readthedocs.build/en/1101/