Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Commit

Permalink
Merge pull request #4 from PhoenixDL/sptatial0
Browse files Browse the repository at this point in the history
Resize, Zoom, Crop Transforms
  • Loading branch information
mibaumgartner authored Dec 2, 2019
2 parents f825221 + d9caee1 commit 79583d5
Show file tree
Hide file tree
Showing 24 changed files with 800 additions and 47 deletions.
2 changes: 2 additions & 0 deletions rising/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from ._version import get_versions
__version__ = get_versions()['version']
del get_versions

from rising.interface import AbstractMixin
25 changes: 25 additions & 0 deletions rising/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
class AbstractMixin(object):
def __init__(self, *args, **kwargs):
"""
This class implements an interface which handles non processed arguments.
Subclass all classes which mixin additional methods and attributes
to existing classes with multiple inheritance from this class as backup
for handling additional arguments.
Parameters
----------
kwargs:
keyword arguments saved to object if it is the last class before object.
Otherwise forwarded to next class.
"""
mro = type(self).mro()
mro_idx = mro.index(AbstractMixin)
# +2 because index starts at 0 and only one more class should be called
if mro_idx + 2 == len(mro):
# only object init is missing
super().__init__()
for key, item in kwargs.items():
setattr(self, key, item)
else:
# class is not last before object -> forward arguments
super().__init__(*args, **kwargs)
6 changes: 4 additions & 2 deletions rising/ops/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np


def torch_one_hot(target: torch.Tensor, num_classes: int) -> torch.Tensor:
def torch_one_hot(target: torch.Tensor, num_classes: int = None) -> torch.Tensor:
"""
Compute one hot encoding of input tensor
Expand All @@ -11,13 +11,15 @@ def torch_one_hot(target: torch.Tensor, num_classes: int) -> torch.Tensor:
target: torch.Tensor
tensor to be converted
num_classes: int
number of classes
number of classes. If :param:`num_classes` is None, the maximum of target is used
Returns
-------
torch.Tensor
one hot encoded tensor
"""
if num_classes is None:
num_classes = target.max() - target.min() + 1
dtype, device = target.dtype, target.device
target_onehot = torch.zeros(*target.shape, num_classes,
dtype=dtype, device=device)
Expand Down
9 changes: 9 additions & 0 deletions rising/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from rising.transforms.abstract import *
from rising.transforms.channel import *
from rising.transforms.compose import *
from rising.transforms.crop import *
from rising.transforms.format import *
from rising.transforms.intensity import *
from rising.transforms.kernel import *
from rising.transforms.spatial import *
from rising.transforms.utility import *
34 changes: 23 additions & 11 deletions rising/transforms/abstract.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import torch
import typing
import random
import importlib
from typing import Callable, Union, Sequence, Any

from rising import AbstractMixin
from rising.utils import check_scalar

__all__ = ["AbstractTransform", "BaseTransform", "PerSampleTransform",
"PerChannelTransform", "RandomDimsTransform", "RandomProcess"]


augment_callable = Callable[[torch.Tensor], Any]
augment_axis_callable = Callable[[torch.Tensor, Union[float, Sequence]], Any]
Expand Down Expand Up @@ -232,9 +235,10 @@ def forward(self, **data) -> dict:
return data


class RandomProcess:
def __init__(self, *args, random_mode: str, random_args: Sequence = (),
random_kwargs: dict = None, random_module: str = "random",
class RandomProcess(AbstractMixin):
def __init__(self, *args, random_mode: str,
random_args: Union[Sequence, Sequence[Sequence]] = (),
random_module: str = "random", rand_seq: bool = True,
**kwargs):
"""
Saves specified function to generate random values to current class.
Expand All @@ -244,18 +248,21 @@ def __init__(self, *args, random_mode: str, random_args: Sequence = (),
----------
random_mode: str
specifies distribution which should be used to sample additive value
random_args: Sequence
positional arguments passed for random function
random_kwargs: dict
keyword arguments for random function
random_args: Union[Sequence, Sequence[Sequence]]
positional arguments passed for random function. If Sequence[Sequence]
is provided, a random value for each item in the outer
Sequence is generated
random_module: str
module from where function random function should be imported
rand_seq: bool
if enabled, multiple random values are generated if :param:`random_args`
is of type Sequence[Sequence]
"""
super().__init__(*args, **kwargs)
self.random_module = random_module
self.random_mode = random_mode
self.ranndom_args = random_args
self.random_kwargs = {} if random_kwargs is None else random_kwargs
self.random_args = random_args
self.rand_seq = rand_seq

def rand(self, **kwargs):
"""
Expand All @@ -266,7 +273,12 @@ def rand(self, **kwargs):
Any
object generated from function
"""
return self.random_fn(*self.ranndom_args, **self.random_kwargs, **kwargs)
if (self.rand_seq and len(self.random_args) > 0 and
isinstance(self.random_args[0], Sequence)):
val = tuple(self.random_fn(*args, **kwargs) for args in self.random_args)
else:
val = self.random_fn(*self.random_args, **kwargs)
return val

@property
def random_mode(self) -> str:
Expand Down
1 change: 1 addition & 0 deletions rising/transforms/channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__all__ = []
10 changes: 5 additions & 5 deletions rising/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from rising.utils import check_scalar
from .abstract import AbstractTransform, RandomProcess

__all__ = ["Compose", "DropoutCompose"]


class Compose(AbstractTransform):
def __init__(self, *transforms):
Expand Down Expand Up @@ -40,7 +42,7 @@ def forward(self, **data) -> dict:
class DropoutCompose(RandomProcess, Compose):
def __init__(self, *transforms, dropout: Union[float, Sequence[float]] = 0.5,
random_mode: str = "random", random_args: Sequence = (),
random_kwargs: dict = None, random_module: str = "random", **kwargs):
random_module: str = "random", **kwargs):
"""
Compose multiple transforms to one
Expand All @@ -56,8 +58,6 @@ def __init__(self, *transforms, dropout: Union[float, Sequence[float]] = 0.5,
specifies distribution which should be used to sample additive value
random_args: Sequence
positional arguments passed for random function
random_kwargs: dict
keyword arguments for random function
random_module: str
module from where function random function should be imported
Expand All @@ -67,8 +67,8 @@ def __init__(self, *transforms, dropout: Union[float, Sequence[float]] = 0.5,
if dropout is a sequence it must have the same length as transforms
"""
super().__init__(*transforms, random_mode=random_mode,
random_kwargs=random_kwargs, random_args=random_args,
random_module=random_module, **kwargs)
random_args=random_args, random_module=random_module,
rand_seq=False, **kwargs)
if check_scalar(dropout):
dropout = [dropout] * len(self.transforms)
if len(dropout) != len(self.transforms):
Expand Down
136 changes: 136 additions & 0 deletions rising/transforms/crop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import Sequence, Union
from rising.transforms.abstract import BaseTransform, RandomProcess
from rising.transforms.functional.crop import random_crop, center_crop

__all__ = ["CenterCrop", "RandomCrop", "CenterCropRandomSize", "RandomCropRandomSize"]


class CenterCrop(BaseTransform):
def __init__(self, size: Union[int, Sequence[int]], keys: Sequence = ('data',),
grad: bool = False, **kwargs):
"""
Apply augment_fn to keys
Parameters
----------
size: Union[int, Sequence[int]]
size of crop
keys: Sequence
keys which should be augmented
grad: bool
enable gradient computation inside transformation
kwargs:
keyword arguments passed to augment_fn
"""
super().__init__(augment_fn=center_crop, size=size, keys=keys,
grad=grad, **kwargs)


class RandomCrop(BaseTransform):
def __init__(self, size: Union[int, Sequence[int]], dist: Union[int, Sequence[int]] = 0,
keys: Sequence = ('data',), grad: bool = False, **kwargs):
"""
Apply augment_fn to keys
Parameters
----------
size: Union[int, Sequence[int]]
size of crop
dist: Union[int, Sequence[int]]
minimum distance to border. By default zero
keys: Sequence
keys which should be augmented
grad: bool
enable gradient computation inside transformation
kwargs:
keyword arguments passed to augment_fn
"""
super().__init__(augment_fn=random_crop, size=size, dist=dist,
keys=keys, grad=grad, **kwargs)


class CenterCropRandomSize(RandomProcess, BaseTransform):
def __init__(self, random_args: Union[Sequence, Sequence[Sequence]],
random_mode: str = "randrange", keys: Sequence = ('data',),
grad: bool = False, **kwargs):
"""
Apply augment_fn to keys
Parameters
----------
random_args: Union[Sequence, Sequence[Sequence]]
positional arguments passed for random function. If Sequence[Sequence]
is provided, a random value for each item in the outer. This can be
used to set different ranges for different axis.
random_mode: str
specifies distribution which should be used to sample additive value
keys: Sequence
keys which should be augmented
grad: bool
enable gradient computation inside transformation
kwargs:
keyword arguments passed to augment_fn
"""
super().__init__(augment_fn=center_crop, random_mode=random_mode,
random_args=random_args, keys=keys, grad=grad, **kwargs)

def forward(self, **data) -> dict:
"""
Augment data
Parameters
----------
data: dict
input batch
Returns
-------
dict
augmented data
"""
self.kwargs["size"] = self.rand()
return super().forward(**data)


class RandomCropRandomSize(RandomProcess, BaseTransform):
def __init__(self, random_args: Union[Sequence, Sequence[Sequence]],
random_mode: str = "randrange", dist: Union[int, Sequence[int]] = 0,
keys: Sequence = ('data',), grad: bool = False, **kwargs):
"""
Apply augment_fn to keys
Parameters
----------
random_mode: str
specifies distribution which should be used to sample additive value
random_args: Union[Sequence, Sequence[Sequence]]
positional arguments passed for random function. If Sequence[Sequence]
is provided, a random value for each item in the outer. This can be
used to set different ranges for different axis.
keys: Sequence
keys which should be augmented
grad: bool
enable gradient computation inside transformation
kwargs:
keyword arguments passed to augment_fn
"""
super().__init__(augment_fn=random_crop, random_mode=random_mode,
random_args=random_args, dist=dist,
keys=keys, grad=grad, **kwargs)

def forward(self, **data) -> dict:
"""
Augment data
Parameters
----------
data: dict
input batch
Returns
-------
dict
augmented data
"""
self.kwargs["size"] = self.rand()
return super().forward(**data)
2 changes: 2 additions & 0 deletions rising/transforms/format.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .abstract import AbstractTransform

__all__ = ["MapToSeq", "SeqToMap"]


class MapToSeq(AbstractTransform):
def __init__(self, *keys, grad: bool = False, **kwargs):
Expand Down
Loading

0 comments on commit 79583d5

Please sign in to comment.