Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce jitted function overhead #1101

Merged
merged 4 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 10 additions & 42 deletions pytensor/link/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,41 +653,36 @@ def create_jitable_thunk(
)

thunk_inputs = self.create_thunk_inputs(storage_map)

thunks = []

thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]

fgraph_jit = self.jit_compile(converted_fgraph)

def thunk(
fgraph=self.fgraph,
fgraph_jit=fgraph_jit,
thunk_inputs=thunk_inputs,
thunk_outputs=thunk_outputs,
):
outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs])
try:
outputs = fgraph_jit(*(x[0] for x in thunk_inputs))
except Exception:
# TODO: Should we add a fake node that combines all outputs,
# since the error may come from any of them?
raise_with_op(self.fgraph, output_nodes[0], thunk)

# strict=False because we are in a hot loop
for o_var, o_storage, o_val in zip(
fgraph.outputs, thunk_outputs, outputs, strict=False
):
compute_map[o_var][0] = True
o_storage[0] = self.output_filter(o_var, o_val)
return outputs
for o_storage, o_val in zip(thunk_outputs, outputs, strict=False):
o_storage[0] = o_val

thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
thunk.lazy = False

thunks.append(thunk)
thunks = [thunk]

return thunks, output_nodes, fgraph_jit

def make_all(self, input_storage=None, output_storage=None, storage_map=None):
fgraph = self.fgraph
nodes = self.schedule(fgraph)
no_recycling = self.no_recycling

input_storage, output_storage, storage_map = map_storage(
fgraph, nodes, input_storage, output_storage, storage_map
Expand All @@ -701,34 +696,7 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None):
compute_map, nodes, input_storage, output_storage, storage_map
)

computed, last_user = gc_helper(nodes)

if self.allow_gc:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this gc happening somewhere else now? Why was it here if it could just be removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't do anything for jitted functions where you don't control intermediate allocations

post_thunk_old_storage = [
[
storage_map[input]
for input in node.inputs
if (input in computed)
and (input not in fgraph.outputs)
and (node == last_user[input])
]
for node in nodes
]
else:
post_thunk_old_storage = None

if no_recycling is True:
no_recycling = list(storage_map.values())
no_recycling = difference(no_recycling, input_storage)
else:
no_recycling = [
storage_map[r] for r in no_recycling if r not in fgraph.inputs
]

fn = streamline(
fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling
)

[fn] = thunks
fn.jit_fn = jit_fn
fn.allow_gc = self.allow_gc
fn.storage_map = storage_map
Expand Down
17 changes: 0 additions & 17 deletions pytensor/link/numba/linker.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,9 @@
from typing import TYPE_CHECKING, Any

import numpy as np

import pytensor
from pytensor.link.basic import JITLinker


if TYPE_CHECKING:
from pytensor.graph.basic import Variable


class NumbaLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using Numba."""

def output_filter(self, var: "Variable", out: Any) -> Any:
if not isinstance(var, np.ndarray) and isinstance(
var.type, pytensor.tensor.TensorType
):
return var.type.filter(out, allow_downcast=True)

return out

Comment on lines -16 to -23
Copy link
Member Author

@ricardoV94 ricardoV94 Nov 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was actually wrong, It probably was meant to be if not isinstance(out, np.ndarray). This way it always triggered var.type.fliter which was quite slow.

Anyway, if numba is returning a non-array where we expected an array it means something is wrong in our dispatch, and we should fix it there.

def fgraph_convert(self, fgraph, **kwargs):
from pytensor.link.numba.dispatch import numba_funcify

Expand Down
26 changes: 11 additions & 15 deletions pytensor/link/pytorch/linker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import copy
from typing import Any

from pytensor.graph.basic import Variable
from pytensor.link.basic import JITLinker
from pytensor.link.utils import unique_name_generator

Expand All @@ -13,14 +9,6 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gen_functors = []

def input_filter(self, inp: Any) -> Any:
from pytensor.link.pytorch.dispatch import pytorch_typify

return pytorch_typify(inp)

def output_filter(self, var: Variable, out: Any) -> Any:
return out.cpu()

def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from pytensor.link.pytorch.dispatch import pytorch_funcify

Expand Down Expand Up @@ -49,6 +37,8 @@ def conversion_func_register(*args, **kwargs):
def jit_compile(self, fn):
import torch

from pytensor.link.pytorch.dispatch import pytorch_typify

class wrapper:
"""
Pytorch would fail compiling our method when trying
Expand All @@ -62,7 +52,7 @@ class wrapper:

def __init__(self, fn, gen_functors):
self.fn = torch.compile(fn)
self.gen_functors = copy.copy(gen_functors)
self.gen_functors = gen_functors.copy()

def __call__(self, *args, **kwargs):
import pytensor.link.utils
Expand All @@ -83,9 +73,15 @@ def __call__(self, *args, **kwargs):
def __del__(self):
del self.gen_functors

res = wrapper(fn, self.gen_functors)
inner_fn = wrapper(fn, self.gen_functors)
self.gen_functors = []
return res

# Torch does not accept numpy inputs and may return GPU objects
def fn(*inputs, inner_fn=inner_fn):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just use the closure scoped function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wdym?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I usually see us just reference inner_fn without declaring an optional param. I think I've seen both but whatever works.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I think this is a bit faster, but I am not positive

outs = inner_fn(*(pytorch_typify(inp) for inp in inputs))
return tuple(out.cpu().numpy() for out in outs)

return fn

def create_thunk_inputs(self, storage_map):
thunk_inputs = []
Expand Down
17 changes: 17 additions & 0 deletions tests/link/numba/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,3 +882,20 @@ def test_cache_warning_suppressed():

x_test = np.random.uniform(size=5)
np.testing.assert_allclose(fn(x_test), scipy.special.psi(x_test) * 2)


@pytest.mark.parametrize("mode", ("default", "trust_input", "direct"))
def test_function_overhead(mode, benchmark):
x = pt.vector("x")
out = pt.exp(x)

fn = function([x], out, mode="NUMBA")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also test this on torch? Do you want me to do that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's much more going on with Torch that we may want to think about more carefully which sort of graph to benchmark and also cpu/gpu. Let's open an issue for when the torch backend is a bit more established?

if mode == "trust_input":
fn.trust_input = True
elif mode == "direct":
fn = fn.vm.jit_fn

test_x = np.zeros(1000)
assert np.sum(fn(test_x)) == 1000

benchmark(fn, test_x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does benchmark fail if some amount of time is elapsed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a job that should show a warning but I don't think it's working. It at least allows us to do git bisect or something like that over commits to see if performance was dropped

29 changes: 13 additions & 16 deletions tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ def compare_pytorch_and_py(
assert_fn: func, opt
Assert function used to check for equality between python and pytorch. If not
provided uses np.testing.assert_allclose
must_be_device_array: Bool
Checks if torch.device.type is cuda


"""
Expand All @@ -66,20 +64,19 @@ def compare_pytorch_and_py(
pytensor_torch_fn = function(fn_inputs, fgraph.outputs, mode=pytorch_mode)
pytorch_res = pytensor_torch_fn(*test_inputs)

if must_be_device_array:
if isinstance(pytorch_res, list):
assert all(isinstance(res, torch.Tensor) for res in pytorch_res)
else:
assert pytorch_res.device.type == "cuda"
if isinstance(pytorch_res, list):
assert all(isinstance(res, np.ndarray) for res in pytorch_res)
else:
assert isinstance(pytorch_res, np.ndarray)

pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
py_res = pytensor_py_fn(*test_inputs)

if len(fgraph.outputs) > 1:
for pytorch_res_i, py_res_i in zip(pytorch_res, py_res, strict=True):
assert_fn(pytorch_res_i.detach().cpu().numpy(), py_res_i)
assert_fn(pytorch_res_i, py_res_i)
else:
assert_fn(pytorch_res[0].detach().cpu().numpy(), py_res[0])
assert_fn(pytorch_res[0], py_res[0])

return pytensor_torch_fn, pytorch_res

Expand Down Expand Up @@ -162,23 +159,23 @@ def test_shared(device):
pytensor_torch_fn = function([], a, mode="PYTORCH")
pytorch_res = pytensor_torch_fn()

assert isinstance(pytorch_res, torch.Tensor)
assert isinstance(pytorch_res, np.ndarray)
assert isinstance(a.get_value(), np.ndarray)
np.testing.assert_allclose(pytorch_res.cpu(), a.get_value())
np.testing.assert_allclose(pytorch_res, a.get_value())

pytensor_torch_fn = function([], a * 2, mode="PYTORCH")
pytorch_res = pytensor_torch_fn()

assert isinstance(pytorch_res, torch.Tensor)
assert isinstance(pytorch_res, np.ndarray)
assert isinstance(a.get_value(), np.ndarray)
np.testing.assert_allclose(pytorch_res.cpu(), a.get_value() * 2)
np.testing.assert_allclose(pytorch_res, a.get_value() * 2)

new_a_value = np.array([3, 4, 5], dtype=config.floatX)
a.set_value(new_a_value)

pytorch_res = pytensor_torch_fn()
assert isinstance(pytorch_res, torch.Tensor)
np.testing.assert_allclose(pytorch_res.cpu(), new_a_value * 2)
assert isinstance(pytorch_res, np.ndarray)
np.testing.assert_allclose(pytorch_res, new_a_value * 2)


@pytest.mark.parametrize("device", ["cpu", "cuda"])
Expand Down Expand Up @@ -225,7 +222,7 @@ def test_alloc_and_empty():
fn = function([dim1], out, mode=pytorch_mode)
res = fn(7)
assert res.shape == (5, 7, 3)
assert res.dtype == torch.float32
assert res.dtype == np.float32

v = vector("v", shape=(3,), dtype="float64")
out = alloc(v, dim0, dim1, 3)
Expand Down
2 changes: 1 addition & 1 deletion tests/link/pytorch/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_cast():
_, [res] = compare_pytorch_and_py(
fgraph, [np.arange(6, dtype="float32").reshape(2, 3)]
)
assert res.dtype == torch.int32
assert res.dtype == np.int32


def test_vmap_elemwise():
Expand Down