Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix tests float16 module losses #1809

Merged
merged 12 commits into from
Aug 19, 2022
41 changes: 33 additions & 8 deletions kornia/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""The testing package contains testing-specific utilities."""
import contextlib
import importlib
import math
from abc import ABC, abstractmethod
from copy import deepcopy
from itertools import product
Expand Down Expand Up @@ -76,30 +77,54 @@ def create_random_fundamental_matrix(batch_size, std_val=1e-3):


class BaseTester(ABC):
DTYPE_PRECISIONS = {torch.float16: (1e-3, 1e-3), torch.float32: (1.3e-6, 1e-5), torch.float64: (1e-6, 1e-5)}

@abstractmethod
def test_smoke(self):
def test_smoke(self, device, dtype):
raise NotImplementedError("Implement a stupid routine.")

@abstractmethod
def test_exception(self):
def test_exception(self, device, dtype):
raise NotImplementedError("Implement a stupid routine.")

@abstractmethod
def test_cardinality(self):
def test_cardinality(self, device, dtype):
raise NotImplementedError("Implement a stupid routine.")

@abstractmethod
def test_jit(self):
def test_jit(self, device, dtype):
raise NotImplementedError("Implement a stupid routine.")

@abstractmethod
def test_gradcheck(self):
def test_gradcheck(self, device):
raise NotImplementedError("Implement a stupid routine.")

@abstractmethod
def test_module(self):
def test_module(self, device, dtype):
raise NotImplementedError("Implement a stupid routine.")

def assert_close(
self,
actual: Tensor,
expected: Tensor,
rtol: Optional[float] = None,
atol: Optional[float] = None,
low_tolerance: bool = False,
) -> None:
if 'xla' in actual.device.type or 'xla' in expected.device.type:
rtol, atol = 1e-2, 1e-2

if rtol is None and atol is None:
actual_rtol, actual_atol = self.DTYPE_PRECISIONS.get(actual.dtype, (0.0, 0.0))
expected_rtol, expected_atol = self.DTYPE_PRECISIONS.get(expected.dtype, (0.0, 0.0))
rtol, atol = max(actual_rtol, expected_rtol), max(actual_atol, expected_atol)

# halve the tolerance if `low_tolerance` is true
rtol = math.sqrt(rtol) if low_tolerance else rtol
atol = math.sqrt(atol) if low_tolerance else atol

return _assert_close(actual, expected, rtol=rtol, atol=atol)


def cartesian_product_of_parameters(**possible_parameters):
"""Create cartesian product of given parameters."""
Expand Down Expand Up @@ -163,7 +188,7 @@ def assert_close(
except ImportError:
# Partial backport of torch.testing.assert_close for torch<1.9
# TODO: remove this branch if kornia relies on torch>=1.9
from torch.testing import assert_allclose as _assert_allclose
from torch.testing import assert_allclose as _assert_close

class UsageError(Exception):
pass
Expand All @@ -177,7 +202,7 @@ def assert_close(
**kwargs: Any,
) -> None:
try:
return _assert_allclose(actual, expected, rtol=rtol, atol=atol, **kwargs)
return _assert_close(actual, expected, rtol=rtol, atol=atol, **kwargs)
except ValueError as error:
raise UsageError(str(error)) from error

Expand Down
Loading