Skip to content

Commit

Permalink
add rotation and translation classmethods in se3 and so3 (kornia#2001)
Browse files Browse the repository at this point in the history
* add rot_x

* implement roty, rotz, trans, transx, transy, transz

* small improvement

* improve quaternion exposure

* Apply suggestions from code review

Co-authored-by: Christie Jacob <[email protected]>

Co-authored-by: Christie Jacob <[email protected]>
  • Loading branch information
edgarriba and cjpurackal authored Nov 15, 2022
1 parent 9b38e07 commit 15a4a32
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 8 deletions.
80 changes: 78 additions & 2 deletions kornia/geometry/liegroup/se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# https://github.com/strasdat/Sophus/blob/master/sympy/sophus/se3.py
from typing import Optional

from kornia.core import Module, Parameter, Tensor, concatenate, eye, pad, tensor, where
from kornia.core import Module, Parameter, Tensor, concatenate, eye, pad, stack, tensor, where, zeros_like
from kornia.geometry.liegroup.so3 import So3
from kornia.geometry.linalg import batched_dot_product
from kornia.testing import KORNIA_CHECK_SHAPE, KORNIA_CHECK_TYPE
from kornia.testing import KORNIA_CHECK, KORNIA_CHECK_SAME_DEVICES, KORNIA_CHECK_SHAPE, KORNIA_CHECK_TYPE


class Se3(Module):
Expand Down Expand Up @@ -247,3 +247,79 @@ def inverse(self) -> 'Se3':
"""
r_inv = self.r.inverse()
return Se3(r_inv, r_inv * (-1 * self.t))

@classmethod
def rot_x(cls, x: Tensor) -> "Se3":
"""Construct a x-axis rotation.
Args:
x: the x-axis rotation angle.
"""
zs = zeros_like(x)
return cls(So3.rot_x(x), stack((zs, zs, zs), -1))

@classmethod
def rot_y(cls, y: Tensor) -> "Se3":
"""Construct a y-axis rotation.
Args:
y: the y-axis rotation angle.
"""
zs = zeros_like(y)
return cls(So3.rot_y(y), stack((zs, zs, zs), -1))

@classmethod
def rot_z(cls, z: Tensor) -> "Se3":
"""Construct a z-axis rotation.
Args:
z: the z-axis rotation angle.
"""
zs = zeros_like(z)
return cls(So3.rot_z(z), stack((zs, zs, zs), -1))

@classmethod
def trans(cls, x: Tensor, y: Tensor, z: Tensor) -> "Se3":
"""Construct a translation only Se3 instance.
Args:
x: the x-axis translation.
y: the y-axis translation.
z: the z-axis translation.
"""
KORNIA_CHECK(x.shape == y.shape)
KORNIA_CHECK(y.shape == z.shape)
KORNIA_CHECK_SAME_DEVICES([x, y, z])
batch_size = x.shape[0] if len(x.shape) > 0 else None
rotation = So3.identity(batch_size, x.device, x.dtype)
return cls(rotation, stack((x, y, z), -1))

@classmethod
def trans_x(cls, x: Tensor) -> "Se3":
"""Construct a x-axis translation.
Args:
x: the x-axis translation.
"""
zs = zeros_like(x)
return cls.trans(x, zs, zs)

@classmethod
def trans_y(cls, y: Tensor) -> "Se3":
"""Construct a y-axis translation.
Args:
y: the y-axis translation.
"""
zs = zeros_like(y)
return cls.trans(zs, y, zs)

@classmethod
def trans_z(cls, z: Tensor) -> "Se3":
"""Construct a z-axis translation.
Args:
z: the z-axis translation.
"""
zs = zeros_like(z)
return cls.trans(zs, zs, z)
30 changes: 30 additions & 0 deletions kornia/geometry/liegroup/so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,33 @@ def inverse(self) -> 'So3':
tensor([1., -0., -0., -0.], requires_grad=True)
"""
return So3(self.q.conj())

@classmethod
def rot_x(cls, x: Tensor) -> "So3":
"""Construct a x-axis rotation.
Args:
x: the x-axis rotation angle.
"""
zs = zeros_like(x)
return cls.exp(stack((x, zs, zs), -1))

@classmethod
def rot_y(cls, y: Tensor) -> "So3":
"""Construct a z-axis rotation.
Args:
y: the y-axis rotation angle.
"""
zs = zeros_like(y)
return cls.exp(stack((zs, y, zs), -1))

@classmethod
def rot_z(cls, z: Tensor) -> "So3":
"""Construct a z-axis rotation.
Args:
z: the z-axis rotation angle.
"""
zs = zeros_like(z)
return cls.exp(stack((zs, zs, z), -1))
9 changes: 3 additions & 6 deletions kornia/geometry/quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,9 @@ def data(self) -> Tensor:
return self._data

@property
def coeffs(self) -> Tensor:
"""Return the underlying data with shape :math:`(B, 4)`.
Alias for :func:`~kornia.geometry.quaternion.Quaternion.data`
"""
return self._data
def coeffs(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Return a tuple with the underlying coefficients in WXYZ order."""
return self.w, self.x, self.y, self.z

@property
def real(self) -> Tensor:
Expand Down
63 changes: 63 additions & 0 deletions test/geometry/liegroup/test_se3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch

from kornia.geometry.conversions import QuaternionCoeffOrder, euler_from_quaternion, rotation_matrix_to_quaternion
from kornia.geometry.liegroup import Se3, So3
from kornia.geometry.quaternion import Quaternion
from kornia.testing import BaseTester
Expand Down Expand Up @@ -149,3 +150,65 @@ def test_inverse(self, device, dtype, batch_size):
sinv = Se3(rot, t).inverse()
self.assert_close(sinv.r.inverse().q.data, q.data)
self.assert_close(sinv.t, sinv.r * (-1 * t))

@pytest.mark.parametrize("batch_size", (None, 1, 2, 5))
def test_rot_x(self, device, dtype, batch_size):
x = self._make_rand_data(device, dtype, batch_size, dims=1).squeeze(-1)
se3 = Se3.rot_x(x)
quat = rotation_matrix_to_quaternion(se3.so3.matrix(), order=QuaternionCoeffOrder.WXYZ)
quat = Quaternion(quat)
roll, _, _ = euler_from_quaternion(*quat.coeffs)
self.assert_close(x, roll)
self.assert_close(se3.t, torch.zeros_like(se3.t))

@pytest.mark.parametrize("batch_size", (None, 1, 2, 5))
def test_rot_y(self, device, dtype, batch_size):
y = self._make_rand_data(device, dtype, batch_size, dims=1).squeeze(-1)
se3 = Se3.rot_y(y)
quat = rotation_matrix_to_quaternion(se3.so3.matrix(), order=QuaternionCoeffOrder.WXYZ)
quat = Quaternion(quat)
_, pitch, _ = euler_from_quaternion(*quat.coeffs)
self.assert_close(y, pitch)
self.assert_close(se3.t, torch.zeros_like(se3.t))

@pytest.mark.parametrize("batch_size", (None, 1, 2, 5))
def test_rot_z(self, device, dtype, batch_size):
z = self._make_rand_data(device, dtype, batch_size, dims=1).squeeze(-1)
se3 = Se3.rot_z(z)
quat = rotation_matrix_to_quaternion(se3.so3.matrix(), order=QuaternionCoeffOrder.WXYZ)
quat = Quaternion(quat)
_, _, yaw = euler_from_quaternion(*quat.coeffs)
self.assert_close(z, yaw)
self.assert_close(se3.t, torch.zeros_like(se3.t))

@pytest.mark.parametrize("batch_size", (None, 1, 2, 5))
def test_trans(self, device, dtype, batch_size):
trans = self._make_rand_data(device, dtype, batch_size, dims=3)
x, y, z = trans[..., 0], trans[..., 1], trans[..., 2]
se3 = Se3.trans(x, y, z)
self.assert_close(se3.t, trans)
self.assert_close(se3.so3.matrix(), So3.identity(batch_size, device, dtype).matrix())

@pytest.mark.parametrize("batch_size", (None, 1, 2, 5))
def test_trans_x(self, device, dtype, batch_size):
x = self._make_rand_data(device, dtype, batch_size, dims=1).squeeze(-1)
zs = torch.zeros_like(x)
se3 = Se3.trans_x(x)
self.assert_close(se3.t, torch.stack((x, zs, zs), -1))
self.assert_close(se3.so3.matrix(), So3.identity(batch_size, device, dtype).matrix())

@pytest.mark.parametrize("batch_size", (None, 1, 2, 5))
def test_trans_y(self, device, dtype, batch_size):
y = self._make_rand_data(device, dtype, batch_size, dims=1).squeeze(-1)
zs = torch.zeros_like(y)
se3 = Se3.trans_y(y)
self.assert_close(se3.t, torch.stack((zs, y, zs), -1))
self.assert_close(se3.so3.matrix(), So3.identity(batch_size, device, dtype).matrix())

@pytest.mark.parametrize("batch_size", (None, 1, 2, 5))
def test_trans_z(self, device, dtype, batch_size):
z = self._make_rand_data(device, dtype, batch_size, dims=1).squeeze(-1)
zs = torch.zeros_like(z)
se3 = Se3.trans_z(z)
self.assert_close(se3.t, torch.stack((zs, zs, z), -1))
self.assert_close(se3.so3.matrix(), So3.identity(batch_size, device, dtype).matrix())
22 changes: 22 additions & 0 deletions test/geometry/liegroup/test_so3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch

from kornia.geometry.conversions import euler_from_quaternion
from kornia.geometry.liegroup import So3
from kornia.geometry.quaternion import Quaternion
from kornia.testing import BaseTester
Expand Down Expand Up @@ -186,3 +187,24 @@ def test_inverse(self, device, dtype, batch_size):
q = Quaternion.random(batch_size, device, dtype)
self.assert_close(So3(q).inverse().inverse().q.data, q.data)
self.assert_close(So3(q).inverse().inverse().matrix(), So3(q).matrix())

@pytest.mark.parametrize("batch_size", (None, 1, 2, 5))
def test_rot_x(self, device, dtype, batch_size):
x = self._make_rand_data(device, dtype, batch_size, dims=1).squeeze(-1)
so3 = So3.rot_x(x)
roll, _, _ = euler_from_quaternion(*so3.q.coeffs)
self.assert_close(x, roll)

@pytest.mark.parametrize("batch_size", (None, 1, 2, 5))
def test_rot_y(self, device, dtype, batch_size):
y = self._make_rand_data(device, dtype, batch_size, dims=1).squeeze(-1)
so3 = So3.rot_y(y)
_, pitch, _ = euler_from_quaternion(*so3.q.coeffs)
self.assert_close(y, pitch)

@pytest.mark.parametrize("batch_size", (None, 1, 2, 5))
def test_rot_z(self, device, dtype, batch_size):
z = self._make_rand_data(device, dtype, batch_size, dims=1).squeeze(-1)
so3 = So3.rot_z(z)
_, _, yaw = euler_from_quaternion(*so3.q.coeffs)
self.assert_close(z, yaw)

0 comments on commit 15a4a32

Please sign in to comment.