Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Commit

Permalink
pr comments, abstractmixin, define __all__
Browse files Browse the repository at this point in the history
  • Loading branch information
mibaumgartner committed Dec 2, 2019
1 parent 989404c commit d9caee1
Show file tree
Hide file tree
Showing 22 changed files with 118 additions and 8 deletions.
2 changes: 2 additions & 0 deletions rising/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from ._version import get_versions
__version__ = get_versions()['version']
del get_versions

from rising.interface import AbstractMixin
25 changes: 25 additions & 0 deletions rising/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
class AbstractMixin(object):
def __init__(self, *args, **kwargs):
"""
This class implements an interface which handles non processed arguments.
Subclass all classes which mixin additional methods and attributes
to existing classes with multiple inheritance from this class as backup
for handling additional arguments.
Parameters
----------
kwargs:
keyword arguments saved to object if it is the last class before object.
Otherwise forwarded to next class.
"""
mro = type(self).mro()
mro_idx = mro.index(AbstractMixin)
# +2 because index starts at 0 and only one more class should be called
if mro_idx + 2 == len(mro):
# only object init is missing
super().__init__()
for key, item in kwargs.items():
setattr(self, key, item)
else:
# class is not last before object -> forward arguments
super().__init__(*args, **kwargs)
2 changes: 1 addition & 1 deletion rising/ops/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def torch_one_hot(target: torch.Tensor, num_classes: int = None) -> torch.Tensor
one hot encoded tensor
"""
if num_classes is None:
num_classes = target.max() + 1
num_classes = target.max() - target.min() + 1
dtype, device = target.dtype, target.device
target_onehot = torch.zeros(*target.shape, num_classes,
dtype=dtype, device=device)
Expand Down
9 changes: 9 additions & 0 deletions rising/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from rising.transforms.abstract import *
from rising.transforms.channel import *
from rising.transforms.compose import *
from rising.transforms.crop import *
from rising.transforms.format import *
from rising.transforms.intensity import *
from rising.transforms.kernel import *
from rising.transforms.spatial import *
from rising.transforms.utility import *
7 changes: 5 additions & 2 deletions rising/transforms/abstract.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import torch
import typing
import random
import importlib
from typing import Callable, Union, Sequence, Any

from rising import AbstractMixin
from rising.utils import check_scalar

__all__ = ["AbstractTransform", "BaseTransform", "PerSampleTransform",
"PerChannelTransform", "RandomDimsTransform", "RandomProcess"]


augment_callable = Callable[[torch.Tensor], Any]
augment_axis_callable = Callable[[torch.Tensor, Union[float, Sequence]], Any]
Expand Down Expand Up @@ -232,7 +235,7 @@ def forward(self, **data) -> dict:
return data


class RandomProcess:
class RandomProcess(AbstractMixin):
def __init__(self, *args, random_mode: str,
random_args: Union[Sequence, Sequence[Sequence]] = (),
random_module: str = "random", rand_seq: bool = True,
Expand Down
1 change: 1 addition & 0 deletions rising/transforms/channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__all__ = []
2 changes: 2 additions & 0 deletions rising/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from rising.utils import check_scalar
from .abstract import AbstractTransform, RandomProcess

__all__ = ["Compose", "DropoutCompose"]


class Compose(AbstractTransform):
def __init__(self, *transforms):
Expand Down
6 changes: 4 additions & 2 deletions rising/transforms/crop.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Sequence, Union
from .abstract import BaseTransform, RandomProcess
from .functional.crop import *
from rising.transforms.abstract import BaseTransform, RandomProcess
from rising.transforms.functional.crop import random_crop, center_crop

__all__ = ["CenterCrop", "RandomCrop", "CenterCropRandomSize", "RandomCropRandomSize"]


class CenterCrop(BaseTransform):
Expand Down
2 changes: 2 additions & 0 deletions rising/transforms/format.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .abstract import AbstractTransform

__all__ = ["MapToSeq", "SeqToMap"]


class MapToSeq(AbstractTransform):
def __init__(self, *keys, grad: bool = False, **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion rising/transforms/functional/crop.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# TODO: progressive resizing
import torch
import random
from typing import Union, Sequence

from rising.utils import check_scalar

__all__ = ["crop", "center_crop", "random_crop"]


def crop(data: torch.Tensor, corner: Sequence[int], size: Sequence[int]):
"""
Expand Down
3 changes: 3 additions & 0 deletions rising/transforms/functional/intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

from rising.utils import check_scalar

__all__ = ["norm_range", "norm_min_max", "norm_zero_mean_unit_std", "norm_mean_std",
"add_noise", "add_value", "gamma_correction", "scale_by_value"]


def norm_range(data: torch.Tensor, min: float, max: float,
per_channel: bool = True, out: torch.Tensor = None) -> torch.Tensor:
Expand Down
3 changes: 2 additions & 1 deletion rising/transforms/functional/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Sequence, Union

from rising.utils import check_scalar
from rising.ops import torch_one_hot

__all__ = ["mirror", "rot90", "resize"]


def mirror(data: torch.Tensor, dims: Union[int, Sequence[int]]) -> torch.Tensor:
Expand Down
9 changes: 8 additions & 1 deletion rising/transforms/intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@

from .abstract import BaseTransform, PerSampleTransform, AbstractTransform, \
PerChannelTransform, RandomProcess
from .functional.intensity import *
from rising.utils import check_scalar
from rising.transforms.functional.intensity import norm_range, norm_min_max, norm_mean_std, \
norm_zero_mean_unit_std, add_noise, gamma_correction, add_value, scale_by_value

__all__ = ["ClampTransform", "NormRangeTransform", "NormMinMaxTransform",
"NormZeroMeanUnitStdTransform", "NormMeanStdTransform", "NoiseTransform",
"GaussianNoiseTransform", "ExponentialNoiseTransform", "GammaCorrectionTransform",
"RandomValuePerChannelTransform", "RandomAddValue", "RandomScaleValue"]


class ClampTransform(BaseTransform):
Expand Down
2 changes: 2 additions & 0 deletions rising/transforms/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from .abstract import AbstractTransform
from rising.utils import check_scalar

__all__ = ["KernelTransform", "GaussianSmoothingTransform"]


class KernelTransform(AbstractTransform):
def __init__(self, in_channels: int, kernel_size: Union[int, Sequence], dim: int = 2,
Expand Down
3 changes: 3 additions & 0 deletions rising/transforms/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

from .functional.spatial import *

__all__ = ["MirrorTransform", "Rot90Transform", "ResizeTransform",
"ZoomTransform", "ProgressiveResize", "SizeStepScheduler"]

schduler_type = Callable[[int], Union[int, Sequence[int]]]


Expand Down
2 changes: 2 additions & 0 deletions rising/transforms/utility.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .abstract import AbstractTransform

__all__ = ["DoNothingTransform"]


class DoNothingTransform(AbstractTransform):
def __init__(self, grad: bool = False, **kwargs):
Expand Down
40 changes: 40 additions & 0 deletions tests/test_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import unittest
from rising import AbstractMixin


class Abstract(object):
def __init__(self, **kwargs):
super().__init__()
self.abstract = True


class AbstractForward(object):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.abstract = True


class PreMix(AbstractMixin, Abstract):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


class PostMix(AbstractForward, AbstractMixin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


class MyTestCase(unittest.TestCase):
def test_pre_mix(self):
obj = PreMix(a=True)
self.assertFalse(hasattr(obj, "a"))
self.assertTrue(obj.abstract)

def test_post_mix(self):
obj = PostMix(a=True)
self.assertTrue(obj.a)
self.assertTrue(obj.abstract)


if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions tests/test_transforms/test_abstract_transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest
from unittest.mock import Mock, call
import torch
import random

from rising.transforms.abstract import *

Expand Down
1 change: 1 addition & 0 deletions tests/test_transforms/test_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import random

from rising.transforms.crop import *
from rising.transforms.functional.crop import random_crop, center_crop


class TestCrop(unittest.TestCase):
Expand Down
1 change: 1 addition & 0 deletions tests/test_transforms/test_functional/test_crop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
import torch
import random

from rising.transforms.functional.crop import *

Expand Down
1 change: 1 addition & 0 deletions tests/test_transforms/test_functional/test_intensity.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
import torch
from math import isclose

from rising.transforms.functional.intensity import *
Expand Down
1 change: 1 addition & 0 deletions tests/test_transforms/test_spatial_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from tests.test_transforms import chech_data_preservation
from rising.transforms.spatial import *
from rising.transforms.functional.spatial import resize


class TestSpatialTransforms(unittest.TestCase):
Expand Down

0 comments on commit d9caee1

Please sign in to comment.