This repository has been archived by the owner on Jan 12, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from PhoenixDL/sptatial0
Resize, Zoom, Crop Transforms
- Loading branch information
Showing
24 changed files
with
800 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
from ._version import get_versions | ||
__version__ = get_versions()['version'] | ||
del get_versions | ||
|
||
from rising.interface import AbstractMixin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
class AbstractMixin(object): | ||
def __init__(self, *args, **kwargs): | ||
""" | ||
This class implements an interface which handles non processed arguments. | ||
Subclass all classes which mixin additional methods and attributes | ||
to existing classes with multiple inheritance from this class as backup | ||
for handling additional arguments. | ||
Parameters | ||
---------- | ||
kwargs: | ||
keyword arguments saved to object if it is the last class before object. | ||
Otherwise forwarded to next class. | ||
""" | ||
mro = type(self).mro() | ||
mro_idx = mro.index(AbstractMixin) | ||
# +2 because index starts at 0 and only one more class should be called | ||
if mro_idx + 2 == len(mro): | ||
# only object init is missing | ||
super().__init__() | ||
for key, item in kwargs.items(): | ||
setattr(self, key, item) | ||
else: | ||
# class is not last before object -> forward arguments | ||
super().__init__(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from rising.transforms.abstract import * | ||
from rising.transforms.channel import * | ||
from rising.transforms.compose import * | ||
from rising.transforms.crop import * | ||
from rising.transforms.format import * | ||
from rising.transforms.intensity import * | ||
from rising.transforms.kernel import * | ||
from rising.transforms.spatial import * | ||
from rising.transforms.utility import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__all__ = [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
from typing import Sequence, Union | ||
from rising.transforms.abstract import BaseTransform, RandomProcess | ||
from rising.transforms.functional.crop import random_crop, center_crop | ||
|
||
__all__ = ["CenterCrop", "RandomCrop", "CenterCropRandomSize", "RandomCropRandomSize"] | ||
|
||
|
||
class CenterCrop(BaseTransform): | ||
def __init__(self, size: Union[int, Sequence[int]], keys: Sequence = ('data',), | ||
grad: bool = False, **kwargs): | ||
""" | ||
Apply augment_fn to keys | ||
Parameters | ||
---------- | ||
size: Union[int, Sequence[int]] | ||
size of crop | ||
keys: Sequence | ||
keys which should be augmented | ||
grad: bool | ||
enable gradient computation inside transformation | ||
kwargs: | ||
keyword arguments passed to augment_fn | ||
""" | ||
super().__init__(augment_fn=center_crop, size=size, keys=keys, | ||
grad=grad, **kwargs) | ||
|
||
|
||
class RandomCrop(BaseTransform): | ||
def __init__(self, size: Union[int, Sequence[int]], dist: Union[int, Sequence[int]] = 0, | ||
keys: Sequence = ('data',), grad: bool = False, **kwargs): | ||
""" | ||
Apply augment_fn to keys | ||
Parameters | ||
---------- | ||
size: Union[int, Sequence[int]] | ||
size of crop | ||
dist: Union[int, Sequence[int]] | ||
minimum distance to border. By default zero | ||
keys: Sequence | ||
keys which should be augmented | ||
grad: bool | ||
enable gradient computation inside transformation | ||
kwargs: | ||
keyword arguments passed to augment_fn | ||
""" | ||
super().__init__(augment_fn=random_crop, size=size, dist=dist, | ||
keys=keys, grad=grad, **kwargs) | ||
|
||
|
||
class CenterCropRandomSize(RandomProcess, BaseTransform): | ||
def __init__(self, random_args: Union[Sequence, Sequence[Sequence]], | ||
random_mode: str = "randrange", keys: Sequence = ('data',), | ||
grad: bool = False, **kwargs): | ||
""" | ||
Apply augment_fn to keys | ||
Parameters | ||
---------- | ||
random_args: Union[Sequence, Sequence[Sequence]] | ||
positional arguments passed for random function. If Sequence[Sequence] | ||
is provided, a random value for each item in the outer. This can be | ||
used to set different ranges for different axis. | ||
random_mode: str | ||
specifies distribution which should be used to sample additive value | ||
keys: Sequence | ||
keys which should be augmented | ||
grad: bool | ||
enable gradient computation inside transformation | ||
kwargs: | ||
keyword arguments passed to augment_fn | ||
""" | ||
super().__init__(augment_fn=center_crop, random_mode=random_mode, | ||
random_args=random_args, keys=keys, grad=grad, **kwargs) | ||
|
||
def forward(self, **data) -> dict: | ||
""" | ||
Augment data | ||
Parameters | ||
---------- | ||
data: dict | ||
input batch | ||
Returns | ||
------- | ||
dict | ||
augmented data | ||
""" | ||
self.kwargs["size"] = self.rand() | ||
return super().forward(**data) | ||
|
||
|
||
class RandomCropRandomSize(RandomProcess, BaseTransform): | ||
def __init__(self, random_args: Union[Sequence, Sequence[Sequence]], | ||
random_mode: str = "randrange", dist: Union[int, Sequence[int]] = 0, | ||
keys: Sequence = ('data',), grad: bool = False, **kwargs): | ||
""" | ||
Apply augment_fn to keys | ||
Parameters | ||
---------- | ||
random_mode: str | ||
specifies distribution which should be used to sample additive value | ||
random_args: Union[Sequence, Sequence[Sequence]] | ||
positional arguments passed for random function. If Sequence[Sequence] | ||
is provided, a random value for each item in the outer. This can be | ||
used to set different ranges for different axis. | ||
keys: Sequence | ||
keys which should be augmented | ||
grad: bool | ||
enable gradient computation inside transformation | ||
kwargs: | ||
keyword arguments passed to augment_fn | ||
""" | ||
super().__init__(augment_fn=random_crop, random_mode=random_mode, | ||
random_args=random_args, dist=dist, | ||
keys=keys, grad=grad, **kwargs) | ||
|
||
def forward(self, **data) -> dict: | ||
""" | ||
Augment data | ||
Parameters | ||
---------- | ||
data: dict | ||
input batch | ||
Returns | ||
------- | ||
dict | ||
augmented data | ||
""" | ||
self.kwargs["size"] = self.rand() | ||
return super().forward(**data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.