Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cross #850

Merged
merged 14 commits into from
Jan 11, 2022
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
## Bug Fixes
- [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension

# Feature additions
## Linear Algebra
- [#850](https://github.com/helmholtz-analytics/heat/pull/850) New Feature `cross`

# v1.1.0

Expand Down
85 changes: 84 additions & 1 deletion heat/core/linalg/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,92 @@
from .. import factories
from .. import manipulations
from .. import sanitation
from .. import stride_tricks
from .. import types

__all__ = ["dot", "matmul", "norm", "outer", "projection", "trace", "transpose", "tril", "triu"]
__all__ = [
"cross",
"dot",
"matmul",
"norm",
"outer",
"projection",
"trace",
"transpose",
"tril",
"triu",
]


def cross(x1: DNDarray, x2: DNDarray, axis: int = -1) -> DNDarray:
"""
Returns the cross product.

Parameters
----------
x1 : DNDarray
First input array.
x2 : DNDarray
Second input array. Must have the same shape as 'x1'.
mtar marked this conversation as resolved.
Show resolved Hide resolved
axis : int
mtar marked this conversation as resolved.
Show resolved Hide resolved
Axis that defines the vectors for which to compute the cross product. Default: -1

Raises
------
ValueError
If the two input arrays don't match in shape, split, device, or comm. If the vectors are along the split axis.
TypeError
If 'axis' is not an integer.

Examples
--------
>>> a = ht.eye(3)
>>> b = ht.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]])
>>> cross = ht.cross(a, b)
DNDarray([[0., 0., 1.],
[1., 0., 0.],
[0., 1., 0.]], dtype=ht.float32, device=cpu:0, split=None)
"""
sanitation.sanitize_in(x1)
sanitation.sanitize_in(x2)

if x1.gshape != x2.gshape:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shapes should be the same except for the axes that define the vectors. Example:

import numpy as np
a = np.arange(2*12).reshape(2,-1,3)
b = np.arange(2*4*2).reshape(2,4,2)
np.cross(a,b)
array([[[  -2,    0,    0],
        [ -15,   10,    1],
        [ -40,   32,    2],
        [ -77,   66,    3]],

       [[-126,  112,    4],
        [-187,  170,    5],
        [-260,  240,    6],
        [-345,  322,    7]]])

# different vector axis for b
b = np.transpose(b, (0, 2, 1))
b.shape
(2, 2, 4)
a.shape
(2, 4, 3)
np.cross(a, b, axisb=1)
array([[[  -2,    0,    0],
        [ -15,   10,    1],
        [ -40,   32,    2],
        [ -77,   66,    3]],

       [[-126,  112,    4],
        [-187,  170,    5],
        [-260,  240,    6],
        [-345,  322,    7]]])

raise ValueError(
"'x1' and 'x2' must have the same shape, {} != {}".format(x1.gshape, x2.gshape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the check here should be whether a.shape and b.shape are broadcastable after purging axisa, axisb

)
if x1.split != x2.split:
raise ValueError(
"'x1' and 'x2' must have the same split, {} != {}".format(x1.split, x2.split)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the splits must match after purging the vector axis

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont know what you mean here. Is this error message not sufficient?

)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there an open issue for this? we should probably have this functionality

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we use halos to do this?

if x1.device != x2.device:
raise ValueError(
"'x1' and 'x2' must have the same device type, {} != {}".format(x1.device, x2.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? one of them will be copied to cpu, but the operation can still be performed, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can remove the check. PyTorch will throw an error about the device mismatch then.

)
if x1.comm != x2.comm: # pragma: no cover
raise ValueError("'x1' and 'x2' must have the same comm, {} != {}".format(x1.comm, x2.comm))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may result in errors. in most places we assume that all DNDarrays are mapped across all processes the same way. I am very uncertain if this will cause errors at scale.


if not isinstance(axis, int):
try:
axis = int(axis)
except Exception:
raise TypeError("'axis' must be an integer.")

axis = stride_tricks.sanitize_axis(x1.shape, axis)

if x1.split == axis:
raise ValueError(
"The computation of the cross product with vectors along the split axis is not supported."
)
else:
x1.balance_()
x2.balance_()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?


promoted = torch.promote_types(x1.larray.dtype, x2.larray.dtype)

ret = torch.cross(x1.larray.type(promoted), x2.larray.type(promoted), dim=axis)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so if im reading the correctly, this is a limited cross product function. do you have plans to implement more cases?

Copy link
Collaborator Author

@mtar mtar Aug 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What other cases? The only restriction I made deliberately was not allowing the split axis of the DNDarray as the axis argument. What do you miss?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it actually doesn't make any sense to split along the vector axis, it can only be 2 or 3 elements. What's missing is being able to set the axisa, axisb parameters, and being able to perform cross products of 2D with 3D vectors.

ret = DNDarray(ret, x1.gshape, types.heat_type_of(ret), x1.split, x1.device, x1.comm, True)

return ret


def dot(a: DNDarray, b: DNDarray, out: Optional[DNDarray] = None) -> Union[DNDarray, float]:
Expand Down
49 changes: 49 additions & 0 deletions heat/core/linalg/tests/test_basics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Type
import torch
import os
import unittest
Expand All @@ -8,6 +9,54 @@


class TestLinalgBasics(TestCase):
def test_cross(self):
a = ht.eye(3)
b = ht.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]])

# different types
cross = ht.cross(a, b)
self.assertEqual(cross.shape, a.shape)
self.assertEqual(cross.dtype, a.dtype)
self.assertEqual(cross.split, a.split)
self.assertEqual(cross.comm, a.comm)
self.assertEqual(cross.device, a.device)
self.assertTrue(ht.equal(cross, ht.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]])))

# axis
a = ht.eye(3, split=0)
b = ht.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=ht.float, split=0)

cross = ht.cross(a, b)
self.assertEqual(cross.shape, a.shape)
self.assertEqual(cross.dtype, a.dtype)
self.assertEqual(cross.split, a.split)
self.assertEqual(cross.comm, a.comm)
self.assertEqual(cross.device, a.device)
self.assertTrue(ht.equal(cross, ht.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]])))

a = ht.eye(3, dtype=ht.int8, split=1)
b = ht.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=ht.int8, split=1)

cross = ht.cross(a, b, axis=0)
self.assertEqual(cross.shape, a.shape)
self.assertEqual(cross.dtype, a.dtype)
self.assertEqual(cross.split, a.split)
self.assertEqual(cross.comm, a.comm)
self.assertEqual(cross.device, a.device)
self.assertTrue(ht.equal(cross, ht.array([[0, 0, -1], [-1, 0, 0], [0, -1, 0]])))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add tests for axisa, axisb, 2D with 3D cross product

with self.assertRaises(ValueError):
ht.cross(ht.eye(3), ht.eye(4))
with self.assertRaises(ValueError):
ht.cross(ht.eye(3, split=0), ht.eye(3, split=1))
if torch.cuda.is_available():
with self.assertRaises(ValueError):
ht.cross(ht.eye(3, device="gpu"), ht.eye(3, device="cpu"))
with self.assertRaises(TypeError):
ht.cross(ht.eye(3), ht.eye(3), axis="wasd")
with self.assertRaises(ValueError):
ht.cross(ht.eye(3, split=0), ht.eye(3, split=0), axis=0)

def test_dot(self):
# ONLY TESTING CORRECTNESS! ALL CALLS IN DOT ARE PREVIOUSLY TESTED
# cases to test:
Expand Down