Skip to content

Commit

Permalink
Removed math Ops Arg[Max] and Dot
Browse files Browse the repository at this point in the history
  • Loading branch information
twaclaw committed Jul 10, 2024
1 parent 2005fda commit a334c6a
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 118 deletions.
72 changes: 0 additions & 72 deletions pytensor/link/pytorch/dispatch/nlinalg.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import torch

from pytensor.link.pytorch.dispatch import pytorch_funcify
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Argmax, Dot, Max
from pytensor.tensor.nlinalg import (
SVD,
Det,
Expand Down Expand Up @@ -85,14 +83,6 @@ def qr_full(x):
return qr_full


@pytorch_funcify.register(Dot)
def pytorch_funcify_Dot(op, **kwargs):
def dot(x, y):
return torch.dot(x, y)

return dot


@pytorch_funcify.register(MatrixPinv)
def pytorch_funcify_Pinv(op, **kwargs):
hermitian = op.hermitian
Expand All @@ -103,71 +93,9 @@ def pinv(x):
return pinv


@pytorch_funcify.register(BatchedDot)
def pytorch_funcify_BatchedDot(op, **kwargs):
def batched_dot(a, b):
if a.shape[0] != b.shape[0]:
raise TypeError("Shapes must match in the 0-th dimension")
return torch.matmul(a, b)

return batched_dot


@pytorch_funcify.register(KroneckerProduct)
def pytorch_funcify_KroneckerProduct(op, **kwargs):
def _kron(x, y):
return torch.kron(x, y)

Check warning on line 99 in pytensor/link/pytorch/dispatch/nlinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/nlinalg.py#L99

Added line #L99 was not covered by tests

return _kron


@pytorch_funcify.register(Max)
def pytorch_funcify_Max(op, **kwargs):
axis = op.axis

def max(x):
if axis is None:
max_res = torch.max(x.flatten())
return max_res

# PyTorch doesn't support multiple axes for max;
# this is a work-around
axes = [int(ax) for ax in axis]

new_dim = torch.prod(torch.tensor([x.size(ax) for ax in axes])).item()
keep_axes = [i for i in range(x.ndim) if i not in axes]
permute_order = keep_axes + axes
permuted_x = x.permute(*permute_order)
kept_shape = permuted_x.shape[: len(keep_axes)]

new_shape = (*kept_shape, new_dim)
reshaped_x = permuted_x.reshape(new_shape)
max_res, _ = torch.max(reshaped_x, dim=-1)
return max_res

return max


@pytorch_funcify.register(Argmax)
def pytorch_funcify_Argmax(op, **kwargs):
axis = op.axis

def argmax(x):
if axis is None:
return torch.argmax(x.view(-1))

# PyTorch doesn't support multiple axes for argmax;
# this is a work-around
axes = [int(ax) for ax in axis]

new_dim = torch.prod(torch.tensor([x.size(ax) for ax in axes])).item()
keep_axes = [i for i in range(x.ndim) if i not in axes]
permute_order = keep_axes + axes
permuted_x = x.permute(*permute_order)
kept_shape = permuted_x.shape[: len(keep_axes)]

new_shape = (*kept_shape, new_dim)
reshaped_x = permuted_x.reshape(new_shape)
return torch.argmax(reshaped_x, dim=-1)

return argmax
47 changes: 1 addition & 46 deletions tests/link/pytorch/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@
from pytensor.compile.function import function
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor import blas as pt_blas
from pytensor.tensor import nlinalg as pt_nla
from pytensor.tensor.math import argmax, dot, max
from pytensor.tensor.type import matrix, tensor3, vector
from pytensor.tensor.type import matrix
from tests.link.pytorch.test_basic import compare_pytorch_and_py


Expand All @@ -23,27 +20,6 @@ def matrix_test():
return (x, test_value)


def test_BatchedDot():
# tensor3 . tensor3
a = tensor3("a")
a.tag.test_value = (
np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
)
b = tensor3("b")
b.tag.test_value = (
np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
)
out = pt_blas.BatchedDot()(a, b)
fgraph = FunctionGraph([a, b], [out])
compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

# A dimension mismatch should raise a TypeError for compatibility
inputs = [get_test_value(a)[:-1], get_test_value(b)]
pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode="PYTORCH")
with pytest.raises(TypeError):
pytensor_jax_fn(*inputs)


@pytest.mark.parametrize(
"func",
(
Expand Down Expand Up @@ -147,24 +123,3 @@ def test_kron():
y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)

compare_pytorch_and_py(fgraph, [x_np, y_np])


@pytest.mark.parametrize("func", (max, argmax))
@pytest.mark.parametrize("axis", [None, [0], [0, 1], [0, 2], [0, 1, 2]])
def test_max_and_argmax(func, axis):
x = tensor3("x")
np.random.seed(42)
test_value = np.random.randint(0, 20, (4, 3, 2))

out = func(x, axis=axis)
out_fg = FunctionGraph([x], [out])
compare_pytorch_and_py(out_fg, [test_value])


def test_dot():
x = vector("x")
test_value = np.array([1, 2, 3])

out = dot(x, x)
out_fg = FunctionGraph([x], [out])
compare_pytorch_and_py(out_fg, [test_value])

0 comments on commit a334c6a

Please sign in to comment.