-
Notifications
You must be signed in to change notification settings - Fork 119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding conditionals for torch #939
Comments
Hey @ricardoV94 , could I get some clarity on scalar loop? I was under the impression that it might just work (I don't see any explicit tests for numba or jax) - what is the work needed for scalar loop? Here is an example test I wrote, that also maybe invalid def test_ScalarOp():
n_steps = int64("n_steps")
x0 = float64("x0")
const = float64("const")
x = x0 + const
op = ScalarLoop(init=[x0], constant=[const], update=[x])
x = op(n_steps, x0, const)
fn = function([n_steps, x0, const], x, mode=pytorch_mode)
np.testing.assert_allclose(fn(5, 0, 1), 5)
np.testing.assert_allclose(fn(5, 0, 2), 10)
np.testing.assert_allclose(fn(4, 3, -1), -1) op = ScalarLoop(), node = ScalarLoop(n_steps, x0, const)
kwargs = {'input_storage': [[None], [None], [None]], 'output_storage': [[None]], 'storage_map': {ScalarLoop.0: [None], const: [None], x0: [None], n_steps: [None]}}
nfunc_spec = None
@pytorch_funcify.register(ScalarOp)
def pytorch_funcify_ScalarOp(op, node, **kwargs):
"""Return pytorch function that implements the same computation as the Scalar Op.
This dispatch is expected to return a pytorch function that works on Array inputs as Elemwise does,
even though it's dispatched on the Scalar Op.
"""
nfunc_spec = getattr(op, "nfunc_spec", None)
if nfunc_spec is None:
> raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}")
E NotImplementedError: Dispatch not implemented for Scalar Op ScalarLoop
pytensor/link/pytorch/dispatch/scalar.py:19: NotImplementedError |
You haven't seen JAX/Numba code because scalar loop isn't yet supported in those backends either. I suggest checking the perform method to have an idea of how the Operator works |
For Blockwise you should be able to use |
I'm gonna close this out for now. We have the larger lists of ops and I'm not actively working on the scan op |
Description
Add the branching ops
The text was updated successfully, but these errors were encountered: