Skip to content

Commit

Permalink
Allow 2d vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
mtar committed Dec 13, 2021
1 parent 1d1434d commit 8ddef62
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions heat/core/linalg/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

def cross(x1: DNDarray, x2: DNDarray, axis: int = -1) -> DNDarray:
"""
Returns the cross product.
Returns the cross product. 2D vectors will we converted to 3D.
Parameters
----------
Expand Down Expand Up @@ -74,6 +74,19 @@ def cross(x1: DNDarray, x2: DNDarray, axis: int = -1) -> DNDarray:
sanitation.sanitize_in(x1)
sanitation.sanitize_in(x2)

# 2d -> 3d vector
if x1.shape[axis] == 2:
shape = tuple(1 if i == axis else j for i, j in enumerate(x1.shape))
x1 = manipulations.concatenate(
[x1, factories.zeros(shape, dtype=x1.dtype, device=x1.device)]
)

if x2.shape[axis] == 2:
shape = tuple(1 if i == axis else j for i, j in enumerate(x2.shape))
x2 = manipulations.concatenate(
[x2, factories.zeros(shape, dtype=x2.dtype, device=x2.device)]
)

if x1.gshape != x2.gshape:
raise ValueError(
"'x1' and 'x2' must have the same shape, {} != {}".format(x1.gshape, x2.gshape)
Expand Down Expand Up @@ -101,9 +114,9 @@ def cross(x1: DNDarray, x2: DNDarray, axis: int = -1) -> DNDarray:
raise ValueError(
"The computation of the cross product with vectors along the split axis is not supported."
)
else:
x1.balance_()
x2.balance_()

if not (x1.is_balanced and x2.is_balanced):
x2 = manipulations.redistribute(x2, x2.lshape_map, x1.lshape_map)

promoted = torch.promote_types(x1.larray.dtype, x2.larray.dtype)

Expand Down

0 comments on commit 8ddef62

Please sign in to comment.