Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cross #850

Merged
merged 14 commits into from
Jan 11, 2022
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
- [#856](https://github.com/helmholtz-analytics/heat/pull/856) New `DNDarray` method `__torch_proxy__`
- [#885](https://github.com/helmholtz-analytics/heat/pull/885) New `DNDarray` method `conj`

# Feature additions
### Linear Algebra
- [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()`
- [#846](https://github.com/helmholtz-analytics/heat/pull/846) New features `norm`, `vector_norm`, `matrix_norm`
- [#850](https://github.com/helmholtz-analytics/heat/pull/850) New Feature `cross`
### Logical
- [#862](https://github.com/helmholtz-analytics/heat/pull/862) New feature `signbit`
### Manipulations
Expand Down
115 changes: 115 additions & 0 deletions heat/core/linalg/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
from .. import rounding
from .. import sanitation
from .. import statistics
from .. import stride_tricks
from .. import types

__all__ = [
"cross",
"dot",
"matmul",
"matrix_norm",
Expand All @@ -40,6 +42,119 @@
]


def cross(
a: DNDarray, b: DNDarray, axisa: int = -1, axisb: int = -1, axisc: int = -1, axis: int = -1
) -> DNDarray:
"""
Returns the cross product. 2D vectors will we converted to 3D.

Parameters
----------
a : DNDarray
First input array.
b : DNDarray
Second input array. Must have the same shape as 'a'.
axisa: int
Axis of `a` that defines the vector(s). By default, the last axis.
axisb: int
Axis of `b` that defines the vector(s). By default, the last axis.
axisc: int
Axis of the output containing the cross product vector(s). By default, the last axis.
axis : int
mtar marked this conversation as resolved.
Show resolved Hide resolved
Axis that defines the vectors for which to compute the cross product. Overrides `axisa`, `axisb` and `axisc`. Default: -1

Raises
------
ValueError
If the two input arrays don't match in shape, split, device, or comm. If the vectors are along the split axis.
TypeError
If 'axis' is not an integer.

Examples
--------
>>> a = ht.eye(3)
>>> b = ht.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]])
>>> cross = ht.cross(a, b)
DNDarray([[0., 0., 1.],
[1., 0., 0.],
[0., 1., 0.]], dtype=ht.float32, device=cpu:0, split=None)
"""
sanitation.sanitize_in(a)
sanitation.sanitize_in(b)

if a.device != b.device:
raise ValueError(
"'a' and 'b' must have the same device type, {} != {}".format(a.device, b.device)
)
if a.comm != b.comm: # pragma: no cover
raise ValueError("'a' and 'b' must have the same comm, {} != {}".format(a.comm, b.comm))

a_2d, b_2d = False, False
a_shape, b_shape = list(a.shape), list(b.shape)

if not axis == -1 or torch.unique(torch.tensor([axisa, axisb, axisc, axis])).numel() == 1:
axis = stride_tricks.sanitize_axis(a.shape, axis)
axisa, axisb, axisc = (axis,) * 3
else:
axisa = stride_tricks.sanitize_axis(a.shape, axisa)
axisb = stride_tricks.sanitize_axis(b.shape, axisb)
axisc = stride_tricks.sanitize_axis(a.shape, axisc)

if a.split == axisa or b.split == axisb:
raise ValueError(
"The computation of the cross product with vectors along the split axis is not supported."
)

# all dimensions except axisa, axisb must be broadcastable
del a_shape[axisa], b_shape[axisb]
output_shape = stride_tricks.broadcast_shape(a_shape, b_shape)

# 2d -> 3d vector
if a.shape[axisa] == 2:
a_2d = True
shape = tuple(1 if i == axisa else j for i, j in enumerate(a.shape))
a = manipulations.concatenate(
[a, factories.zeros(shape, dtype=a.dtype, device=a.device)], axis=axisa
)
if b.shape[axisb] == 2:
b_2d = True
shape = tuple(1 if i == axisb else j for i, j in enumerate(b.shape))
b = manipulations.concatenate(
[b, factories.zeros(shape, dtype=b.dtype, device=b.device)], axis=axisb
)

if axisc != axisa:
a = manipulations.moveaxis(a, axisa, axisc)

if axisc != axisb:
b = manipulations.moveaxis(b, axisb, axisc)

axis = axisc

# by now split axes must be aligned
if a.split != b.split:
raise ValueError("'a' and 'b' must have the same split, {} != {}".format(a.split, b.split))

if not (a.is_balanced and b.is_balanced):
# TODO: replace with sanitize_redistribute after #888 is merged
b = manipulations.redistribute(b, b.lshape_map, a.lshape_map)

promoted = torch.promote_types(a.larray.dtype, b.larray.dtype)

ret = torch.cross(a.larray.type(promoted), b.larray.type(promoted), dim=axis)

# if both vector axes have dimension 2, return the z-component of the cross product
if a_2d and b_2d:
z_slice = [slice(None, None, None)] * ret.ndim
z_slice[axisc] = -1
ret = ret[z_slice]
else:
output_shape = output_shape[:axis] + (3,) + output_shape[axis:]

ret = DNDarray(ret, output_shape, types.heat_type_of(ret), a.split, a.device, a.comm, True)
return ret


def dot(a: DNDarray, b: DNDarray, out: Optional[DNDarray] = None) -> Union[DNDarray, float]:
"""
Returns the dot product of two ``DNDarrays``.
Expand Down
82 changes: 82 additions & 0 deletions heat/core/linalg/tests/test_basics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Type
import torch
import os
import unittest
Expand All @@ -8,6 +9,87 @@


class TestLinalgBasics(TestCase):
def test_cross(self):
a = ht.eye(3)
b = ht.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]])

# different types
cross = ht.cross(a, b)
self.assertEqual(cross.shape, a.shape)
self.assertEqual(cross.dtype, a.dtype)
self.assertEqual(cross.split, a.split)
self.assertEqual(cross.comm, a.comm)
self.assertEqual(cross.device, a.device)
self.assertTrue(ht.equal(cross, ht.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]])))

# axis
a = ht.eye(3, split=0)
b = ht.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=ht.float, split=0)

cross = ht.cross(a, b)
self.assertEqual(cross.shape, a.shape)
self.assertEqual(cross.dtype, a.dtype)
self.assertEqual(cross.split, a.split)
self.assertEqual(cross.comm, a.comm)
self.assertEqual(cross.device, a.device)
self.assertTrue(ht.equal(cross, ht.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]])))

a = ht.eye(3, dtype=ht.int8, split=1)
b = ht.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=ht.int8, split=1)

cross = ht.cross(a, b, axis=0)
self.assertEqual(cross.shape, a.shape)
self.assertEqual(cross.dtype, a.dtype)
self.assertEqual(cross.split, a.split)
self.assertEqual(cross.comm, a.comm)
self.assertEqual(cross.device, a.device)
self.assertTrue(ht.equal(cross, ht.array([[0, 0, -1], [-1, 0, 0], [0, -1, 0]])))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add tests for axisa, axisb, 2D with 3D cross product

# test axisa, axisb, axisc
np.random.seed(42)
np_a = np.random.randn(40, 3, 50)
np_b = np.random.randn(3, 40, 50)
np_cross = np.cross(np_a, np_b, axisa=1, axisb=0)

a = ht.array(np_a, split=0)
b = ht.array(np_b, split=1)
cross = ht.cross(a, b, axisa=1, axisb=0)
self.assert_array_equal(cross, np_cross)

cross_axisc = ht.cross(a, b, axisa=1, axisb=0, axisc=1)
np_cross_axisc = np.cross(np_a, np_b, axisa=1, axisb=0, axisc=1)
self.assert_array_equal(cross_axisc, np_cross_axisc)

# test vector axes with 2 elements
b_2d = ht.array(np_b[:-1, :, :], split=1)
cross_3d_2d = ht.cross(a, b_2d, axisa=1, axisb=0)
np_cross_3d_2d = np.cross(np_a, np_b[:-1, :, :], axisa=1, axisb=0)
self.assert_array_equal(cross_3d_2d, np_cross_3d_2d)

a_2d = ht.array(np_a[:, :-1, :], split=0)
cross_2d_3d = ht.cross(a_2d, b, axisa=1, axisb=0)
np_cross_2d_3d = np.cross(np_a[:, :-1, :], np_b, axisa=1, axisb=0)
self.assert_array_equal(cross_2d_3d, np_cross_2d_3d)

cross_z_comp = ht.cross(a_2d, b_2d, axisa=1, axisb=0)
np_cross_z_comp = np.cross(np_a[:, :-1, :], np_b[:-1, :, :], axisa=1, axisb=0)
self.assert_array_equal(cross_z_comp, np_cross_z_comp)

a_wrong_split = ht.array(np_a[:, :-1, :], split=2)
with self.assertRaises(ValueError):
ht.cross(a_wrong_split, b, axisa=1, axisb=0)
with self.assertRaises(ValueError):
ht.cross(ht.eye(3), ht.eye(4))
with self.assertRaises(ValueError):
ht.cross(ht.eye(3, split=0), ht.eye(3, split=1))
if torch.cuda.is_available():
with self.assertRaises(ValueError):
ht.cross(ht.eye(3, device="gpu"), ht.eye(3, device="cpu"))
with self.assertRaises(TypeError):
ht.cross(ht.eye(3), ht.eye(3), axis="wasd")
with self.assertRaises(ValueError):
ht.cross(ht.eye(3, split=0), ht.eye(3, split=0), axis=0)

def test_dot(self):
# ONLY TESTING CORRECTNESS! ALL CALLS IN DOT ARE PREVIOUSLY TESTED
# cases to test:
Expand Down