Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce jitted function overhead #1101

Merged
merged 4 commits into from
Nov 29, 2024

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Nov 22, 2024

This reduces the overhead on the benchmarked numba function from ~10us to 2.5us on my machine, with trust_input=True

It will hopefully go further down when we remove deprecated function stuff like output_subset and dict returns


📚 Documentation preview 📚: https://pytensor--1101.org.readthedocs.build/en/1101/

@ricardoV94 ricardoV94 changed the title Reduce jit fn overhead Reduce jitted function overhead Nov 22, 2024
Copy link
Contributor

@Ch0ronomato Ch0ronomato left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looked at the torch linker and some q's around the tests.

x = pt.vector("x")
out = pt.exp(x)

fn = function([x], out, mode="NUMBA")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also test this on torch? Do you want me to do that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's much more going on with Torch that we may want to think about more carefully which sort of graph to benchmark and also cpu/gpu. Let's open an issue for when the torch backend is a bit more established?

test_x = np.zeros(1000)
assert np.sum(fn(test_x)) == 1000

benchmark(fn, test_x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does benchmark fail if some amount of time is elapsed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a job that should show a warning but I don't think it's working. It at least allows us to do git bisect or something like that over commits to see if performance was dropped

# Torch does not accept numpy inputs and may return GPU objects
def fn(*inputs, inner_fn=inner_fn):
outs = inner_fn(*(pytorch_typify(inp) for inp in inputs))
return tuple(out.cpu() for out in outs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a few PR's open that might cobble this fyi.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whichever gets merged first, we can solve the conflicts after

inner_fn = torch.compile(fn)

# Torch does not accept numpy inputs and may return GPU objects
def fn(*inputs, inner_fn=inner_fn):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just use the closure scoped function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wdym?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I usually see us just reference inner_fn without declaring an optional param. I think I've seen both but whatever works.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I think this is a bit faster, but I am not positive

@ricardoV94 ricardoV94 force-pushed the reduce_jit_fn_overhead branch from 862f158 to b11db68 Compare November 23, 2024 09:21
@ricardoV94 ricardoV94 requested a review from aseyboldt November 23, 2024 09:28
Comment on lines -16 to -23
def output_filter(self, var: "Variable", out: Any) -> Any:
if not isinstance(var, np.ndarray) and isinstance(
var.type, pytensor.tensor.TensorType
):
return var.type.filter(out, allow_downcast=True)

return out

Copy link
Member Author

@ricardoV94 ricardoV94 Nov 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was actually wrong, It probably was meant to be if not isinstance(out, np.ndarray). This way it always triggered var.type.fliter which was quite slow.

Anyway, if numba is returning a non-array where we expected an array it means something is wrong in our dispatch, and we should fix it there.

@ricardoV94 ricardoV94 force-pushed the reduce_jit_fn_overhead branch 4 times, most recently from 311f83e to a97734e Compare November 25, 2024 16:17
@ricardoV94 ricardoV94 force-pushed the reduce_jit_fn_overhead branch from a97734e to d546c36 Compare November 25, 2024 16:38
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, I left a dumb comment that you can ignore (the gc stuff was re-implemented in the last commit).

@@ -701,34 +696,7 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None):
compute_map, nodes, input_storage, output_storage, storage_map
)

computed, last_user = gc_helper(nodes)

if self.allow_gc:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this gc happening somewhere else now? Why was it here if it could just be removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't do anything for jitted functions where you don't control intermediate allocations

@ricardoV94 ricardoV94 merged commit 1a3af4b into pymc-devs:main Nov 29, 2024
59 of 60 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants