diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index 53306d52dc..e2e612ac93 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -393,6 +393,8 @@ def __init__( assert len(self.input_storage) == len(self.maker.fgraph.inputs) assert len(self.output_storage) == len(self.maker.fgraph.outputs) + self.has_defaults = any(refeed for _, refeed, _ in self.defaults) + # Group indexes of inputs that are potentially aliased to each other # Note: Historically, we only worried about aliasing inputs if they belonged to the same type, # even though there could be two distinct types that use the same kinds of underlying objects. @@ -540,14 +542,40 @@ def __contains__(self, item): self._value = ValueAttribute() self._container = ContainerAttribute() - # TODO: Get rid of all this `expanded_inputs` nonsense - assert len(self.maker.expanded_inputs) == len(self.input_storage) + update_storage = [ + container + for inp, container in zip( + self.maker.expanded_inputs, input_storage, strict=True + ) + if inp.update is not None + ] + # Updates are the last inner outputs that are not returned by Function.__call__ + self.n_returned_outputs = len(self.output_storage) - len(update_storage) + + # Function.__call__ is responsible for updating the inputs, unless the vm promises to do it itself + self.update_input_storage: tuple[int, Container] = () + if getattr(vm, "need_update_inputs", True): + self.update_input_storage = tuple( + zip( + range(self.n_returned_outputs, len(output_storage)), + update_storage, + strict=True, + ) + ) - # This is used only when `vm.need_update_inputs` is `False`, because - # we're using one of the VM objects and it is putting updates back into - # the input containers all by itself. - self.n_returned_outputs = len(self.output_storage) - sum( - inp.update is not None for inp in self.maker.expanded_inputs + # In every function call we place inputs in the input_storage, and the vm places outputs in the output_storage + # After the call, we want to erase (some of) these references, to allow Python to GC them if unused + # Required input containers are the non-default inputs, must always be provided again, so we GC them + self.clear_input_storage_data = tuple( + container.storage for container in input_storage if container.required + ) + # This is only done when `vm.allow_gc` is True, which can change at runtime. + self.clear_output_storage_data = tuple( + container.storage + for container, variable in zip( + self.output_storage, self.maker.fgraph.outputs, strict=True + ) + if variable.owner is not None # Not a constant output ) for node in self.maker.fgraph.apply_nodes: @@ -747,7 +775,7 @@ def checkSV(sv_ori, sv_rpl): elif isinstance(profile, str): profile = pytensor.compile.profiling.ProfileStats(message=profile) - f_cpy = maker.__class__( + f_cpy = type(maker)( inputs=ins, outputs=outs, fgraph=fg_cpy, @@ -765,6 +793,8 @@ def checkSV(sv_ori, sv_rpl): # check that. accept_inplace=True, no_fgraph_prep=True, + output_keys=maker.output_keys, + name=name, ).create(input_storage, storage_map=new_storage_map) for in_ori, in_cpy, ori, cpy in zip( @@ -797,8 +827,6 @@ def checkSV(sv_ori, sv_rpl): f_cpy.trust_input = self.trust_input f_cpy.unpack_single = self.unpack_single - f_cpy.name = name - f_cpy.maker.fgraph.name = name return f_cpy def _restore_defaults(self): @@ -808,7 +836,7 @@ def _restore_defaults(self): value = value.storage[0] self[i] = value - def __call__(self, *args, **kwargs): + def __call__(self, *args, output_subset=None, **kwargs): """ Evaluates value of a function on given arguments. @@ -836,20 +864,21 @@ def __call__(self, *args, **kwargs): List of outputs on indices/keys from ``output_subset`` or all of them, if ``output_subset`` is not passed. """ + trust_input = self.trust_input input_storage = self.input_storage + vm = self.vm profile = self.profile if profile: t0 = time.perf_counter() - output_subset = kwargs.pop("output_subset", None) if output_subset is not None: warnings.warn("output_subset is deprecated.", FutureWarning) if self.output_keys is not None: output_subset = [self.output_keys.index(key) for key in output_subset] # Reinitialize each container's 'provided' counter - if self.trust_input: + if trust_input: for arg_container, arg in zip(input_storage, args, strict=False): arg_container.storage[0] = arg else: @@ -908,7 +937,7 @@ def __call__(self, *args, **kwargs): for k, arg in kwargs.items(): self[k] = arg - if not self.trust_input: + if not trust_input: # Collect aliased inputs among the storage space for potential_group in self._potential_aliased_input_groups: args_share_memory: list[list[int]] = [] @@ -960,11 +989,7 @@ def __call__(self, *args, **kwargs): if profile: t0_fn = time.perf_counter() try: - outputs = ( - self.vm() - if output_subset is None - else self.vm(output_subset=output_subset) - ) + outputs = vm() if output_subset is None else vm(output_subset=output_subset) except Exception: self._restore_defaults() if hasattr(self.vm, "position_of_error"): @@ -991,39 +1016,23 @@ def __call__(self, *args, **kwargs): # Retrieve the values that were computed if outputs is None: - outputs = [x.data for x in self.output_storage] - - # Remove internal references to required inputs. - # These cannot be re-used anyway. - for arg_container in input_storage: - if arg_container.required: - arg_container.storage[0] = None - - # if we are allowing garbage collection, remove the - # output reference from the internal storage cells - if getattr(self.vm, "allow_gc", False): - # strict=False because we are in a hot loop - for o_container, o_variable in zip( - self.output_storage, self.maker.fgraph.outputs, strict=False - ): - if o_variable.owner is not None: - # this node is the variable of computation - # WARNING: This circumvents the 'readonly' attribute in x - o_container.storage[0] = None - - if getattr(self.vm, "need_update_inputs", True): - # Update the inputs that have an update function - # strict=False because we are in a hot loop - for input, storage in reversed( - list(zip(self.maker.expanded_inputs, input_storage, strict=False)) - ): - if input.update is not None: - storage.data = outputs.pop() - else: - outputs = outputs[: self.n_returned_outputs] + outputs = [x.storage[0] for x in self.output_storage] + + # Set updates and filter them out from the returned outputs + for i, input_storage in self.update_input_storage: + input_storage.storage[0] = outputs[i] + outputs = outputs[: self.n_returned_outputs] + + # Remove input and output values from storage data + for storage_data in self.clear_input_storage_data: + storage_data[0] = None + if getattr(vm, "allow_gc", False): + for storage_data in self.clear_output_storage_data: + storage_data[0] = None # Put default values back in the storage - self._restore_defaults() + if self.has_defaults: + self._restore_defaults() if profile: dt_call = time.perf_counter() - t0 @@ -1031,33 +1040,29 @@ def __call__(self, *args, **kwargs): self.maker.mode.call_time += dt_call profile.fct_callcount += 1 profile.fct_call_time += dt_call - if hasattr(self.vm, "update_profile"): - self.vm.update_profile(profile) + if hasattr(vm, "update_profile"): + vm.update_profile(profile) if profile.ignore_first_call: profile.reset() profile.ignore_first_call = False if self.return_none: return None - elif self.unpack_single and len(outputs) == 1 and output_subset is None: - return outputs[0] - else: - if self.output_keys is not None: - assert len(self.output_keys) == len(outputs) - if output_subset is None: - # strict=False because we are in a hot loop - return dict(zip(self.output_keys, outputs, strict=False)) - else: - return { - self.output_keys[index]: outputs[index] - for index in output_subset - } + if output_subset is not None: + outputs = [outputs[i] for i in output_subset] - if output_subset is None: - return outputs + if self.output_keys is None: + if self.unpack_single: + [out] = outputs + return out else: - return [outputs[i] for i in output_subset] + return outputs + else: + output_keys = self.output_keys + if output_subset is not None: + output_keys = [output_keys[i] for i in output_subset] + return dict(zip(output_keys, outputs, strict=True)) value = property( lambda self: self._value, @@ -1077,9 +1082,10 @@ def free(self): # 1.no allow_gc return False # 2.has allow_gc, if allow_gc is False, return True if not getattr(self.vm, "allow_gc", True): - for key in self.vm.storage_map: - if not isinstance(key, Constant): - self.vm.storage_map[key][0] = None + storage_map = self.vm.storage_map + for key, value in storage_map.items(): + if key.owner is not None: # Not a constant + value[0] = None for node in self.nodes_with_inner_function: if hasattr(node.fn, "free"): @@ -1091,10 +1097,6 @@ def get_shared(self): """ return [i.variable for i in self.maker.inputs if i.implicit] - def sync_shared(self): - # NOTE: sync was needed on old gpu backend - pass - def dprint(self, **kwargs): """Debug print itself diff --git a/pytensor/link/basic.py b/pytensor/link/basic.py index daeaa5740f..9cf34983f2 100644 --- a/pytensor/link/basic.py +++ b/pytensor/link/basic.py @@ -653,41 +653,36 @@ def create_jitable_thunk( ) thunk_inputs = self.create_thunk_inputs(storage_map) - - thunks = [] - thunk_outputs = [storage_map[n] for n in self.fgraph.outputs] - fgraph_jit = self.jit_compile(converted_fgraph) def thunk( - fgraph=self.fgraph, fgraph_jit=fgraph_jit, thunk_inputs=thunk_inputs, thunk_outputs=thunk_outputs, ): - outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs]) + try: + outputs = fgraph_jit(*(x[0] for x in thunk_inputs)) + except Exception: + # TODO: Should we add a fake node that combines all outputs, + # since the error may come from any of them? + raise_with_op(self.fgraph, output_nodes[0], thunk) # strict=False because we are in a hot loop - for o_var, o_storage, o_val in zip( - fgraph.outputs, thunk_outputs, outputs, strict=False - ): - compute_map[o_var][0] = True - o_storage[0] = self.output_filter(o_var, o_val) - return outputs + for o_storage, o_val in zip(thunk_outputs, outputs, strict=False): + o_storage[0] = o_val thunk.inputs = thunk_inputs thunk.outputs = thunk_outputs thunk.lazy = False - thunks.append(thunk) + thunks = [thunk] return thunks, output_nodes, fgraph_jit def make_all(self, input_storage=None, output_storage=None, storage_map=None): fgraph = self.fgraph nodes = self.schedule(fgraph) - no_recycling = self.no_recycling input_storage, output_storage, storage_map = map_storage( fgraph, nodes, input_storage, output_storage, storage_map @@ -701,34 +696,7 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None): compute_map, nodes, input_storage, output_storage, storage_map ) - computed, last_user = gc_helper(nodes) - - if self.allow_gc: - post_thunk_old_storage = [ - [ - storage_map[input] - for input in node.inputs - if (input in computed) - and (input not in fgraph.outputs) - and (node == last_user[input]) - ] - for node in nodes - ] - else: - post_thunk_old_storage = None - - if no_recycling is True: - no_recycling = list(storage_map.values()) - no_recycling = difference(no_recycling, input_storage) - else: - no_recycling = [ - storage_map[r] for r in no_recycling if r not in fgraph.inputs - ] - - fn = streamline( - fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling - ) - + [fn] = thunks fn.jit_fn = jit_fn fn.allow_gc = self.allow_gc fn.storage_map = storage_map diff --git a/pytensor/link/jax/linker.py b/pytensor/link/jax/linker.py index 2450b24150..06370b4514 100644 --- a/pytensor/link/jax/linker.py +++ b/pytensor/link/jax/linker.py @@ -3,7 +3,6 @@ from numpy.random import Generator, RandomState from pytensor.compile.sharedvalue import SharedVariable, shared -from pytensor.graph.basic import Constant from pytensor.link.basic import JITLinker @@ -72,12 +71,7 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): def jit_compile(self, fn): import jax - # I suppose we can consider `Constant`s to be "static" according to - # JAX. - static_argnums = [ - n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant) - ] - return jax.jit(fn, static_argnums=static_argnums) + return jax.jit(fn) def create_thunk_inputs(self, storage_map): from pytensor.link.jax.dispatch import jax_typify diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index f120706f3b..553c5ef217 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -1,26 +1,9 @@ -from typing import TYPE_CHECKING, Any - -import numpy as np - -import pytensor from pytensor.link.basic import JITLinker -if TYPE_CHECKING: - from pytensor.graph.basic import Variable - - class NumbaLinker(JITLinker): """A `Linker` that JIT-compiles NumPy-based operations using Numba.""" - def output_filter(self, var: "Variable", out: Any) -> Any: - if not isinstance(var, np.ndarray) and isinstance( - var.type, pytensor.tensor.TensorType - ): - return var.type.filter(out, allow_downcast=True) - - return out - def fgraph_convert(self, fgraph, **kwargs): from pytensor.link.numba.dispatch import numba_funcify diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index ec26fd252f..ac0b0c8c02 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -1,7 +1,3 @@ -import copy -from typing import Any - -from pytensor.graph.basic import Variable from pytensor.link.basic import JITLinker from pytensor.link.utils import unique_name_generator @@ -13,14 +9,6 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.gen_functors = [] - def input_filter(self, inp: Any) -> Any: - from pytensor.link.pytorch.dispatch import pytorch_typify - - return pytorch_typify(inp) - - def output_filter(self, var: Variable, out: Any) -> Any: - return out.cpu() - def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): from pytensor.link.pytorch.dispatch import pytorch_funcify @@ -49,6 +37,8 @@ def conversion_func_register(*args, **kwargs): def jit_compile(self, fn): import torch + from pytensor.link.pytorch.dispatch import pytorch_typify + class wrapper: """ Pytorch would fail compiling our method when trying @@ -62,7 +52,7 @@ class wrapper: def __init__(self, fn, gen_functors): self.fn = torch.compile(fn) - self.gen_functors = copy.copy(gen_functors) + self.gen_functors = gen_functors.copy() def __call__(self, *args, **kwargs): import pytensor.link.utils @@ -83,9 +73,15 @@ def __call__(self, *args, **kwargs): def __del__(self): del self.gen_functors - res = wrapper(fn, self.gen_functors) + inner_fn = wrapper(fn, self.gen_functors) self.gen_functors = [] - return res + + # Torch does not accept numpy inputs and may return GPU objects + def fn(*inputs, inner_fn=inner_fn): + outs = inner_fn(*(pytorch_typify(inp) for inp in inputs)) + return tuple(out.cpu().numpy() for out in outs) + + return fn def create_thunk_inputs(self, storage_map): thunk_inputs = [] diff --git a/tests/link/jax/test_basic.py b/tests/link/jax/test_basic.py index 68070654d4..d0f748f3e7 100644 --- a/tests/link/jax/test_basic.py +++ b/tests/link/jax/test_basic.py @@ -76,7 +76,7 @@ def compare_jax_and_py( if isinstance(jax_res, list): assert all(isinstance(res, jax.Array) for res in jax_res) else: - assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) + assert isinstance(jax_res, jax.Array) pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode) py_res = pytensor_py_fn(*test_inputs) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index dfadc58a69..a4c585da42 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -882,3 +882,20 @@ def test_cache_warning_suppressed(): x_test = np.random.uniform(size=5) np.testing.assert_allclose(fn(x_test), scipy.special.psi(x_test) * 2) + + +@pytest.mark.parametrize("mode", ("default", "trust_input", "direct")) +def test_function_overhead(mode, benchmark): + x = pt.vector("x") + out = pt.exp(x) + + fn = function([x], out, mode="NUMBA") + if mode == "trust_input": + fn.trust_input = True + elif mode == "direct": + fn = fn.vm.jit_fn + + test_x = np.zeros(1000) + assert np.sum(fn(test_x)) == 1000 + + benchmark(fn, test_x) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 25827d23f9..4757143465 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -53,8 +53,6 @@ def compare_pytorch_and_py( assert_fn: func, opt Assert function used to check for equality between python and pytorch. If not provided uses np.testing.assert_allclose - must_be_device_array: Bool - Checks if torch.device.type is cuda """ @@ -66,20 +64,19 @@ def compare_pytorch_and_py( pytensor_torch_fn = function(fn_inputs, fgraph.outputs, mode=pytorch_mode) pytorch_res = pytensor_torch_fn(*test_inputs) - if must_be_device_array: - if isinstance(pytorch_res, list): - assert all(isinstance(res, torch.Tensor) for res in pytorch_res) - else: - assert pytorch_res.device.type == "cuda" + if isinstance(pytorch_res, list): + assert all(isinstance(res, np.ndarray) for res in pytorch_res) + else: + assert isinstance(pytorch_res, np.ndarray) pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode) py_res = pytensor_py_fn(*test_inputs) if len(fgraph.outputs) > 1: for pytorch_res_i, py_res_i in zip(pytorch_res, py_res, strict=True): - assert_fn(pytorch_res_i.detach().cpu().numpy(), py_res_i) + assert_fn(pytorch_res_i, py_res_i) else: - assert_fn(pytorch_res[0].detach().cpu().numpy(), py_res[0]) + assert_fn(pytorch_res[0], py_res[0]) return pytensor_torch_fn, pytorch_res @@ -162,23 +159,23 @@ def test_shared(device): pytensor_torch_fn = function([], a, mode="PYTORCH") pytorch_res = pytensor_torch_fn() - assert isinstance(pytorch_res, torch.Tensor) + assert isinstance(pytorch_res, np.ndarray) assert isinstance(a.get_value(), np.ndarray) - np.testing.assert_allclose(pytorch_res.cpu(), a.get_value()) + np.testing.assert_allclose(pytorch_res, a.get_value()) pytensor_torch_fn = function([], a * 2, mode="PYTORCH") pytorch_res = pytensor_torch_fn() - assert isinstance(pytorch_res, torch.Tensor) + assert isinstance(pytorch_res, np.ndarray) assert isinstance(a.get_value(), np.ndarray) - np.testing.assert_allclose(pytorch_res.cpu(), a.get_value() * 2) + np.testing.assert_allclose(pytorch_res, a.get_value() * 2) new_a_value = np.array([3, 4, 5], dtype=config.floatX) a.set_value(new_a_value) pytorch_res = pytensor_torch_fn() - assert isinstance(pytorch_res, torch.Tensor) - np.testing.assert_allclose(pytorch_res.cpu(), new_a_value * 2) + assert isinstance(pytorch_res, np.ndarray) + np.testing.assert_allclose(pytorch_res, new_a_value * 2) @pytest.mark.parametrize("device", ["cpu", "cuda"]) @@ -225,7 +222,7 @@ def test_alloc_and_empty(): fn = function([dim1], out, mode=pytorch_mode) res = fn(7) assert res.shape == (5, 7, 3) - assert res.dtype == torch.float32 + assert res.dtype == np.float32 v = vector("v", shape=(3,), dtype="float64") out = alloc(v, dim0, dim1, 3) diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index 20c98094c1..2a9cf39c99 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -152,7 +152,7 @@ def test_cast(): _, [res] = compare_pytorch_and_py( fgraph, [np.arange(6, dtype="float32").reshape(2, 3)] ) - assert res.dtype == torch.int32 + assert res.dtype == np.int32 def test_vmap_elemwise():