-
-
Notifications
You must be signed in to change notification settings - Fork 985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Added inverse for augmentations. #1013
Conversation
Let's get this merged if the API is okay for everyone so that we can include it in the release for next week /cc @vadimkantorov |
kornia/augmentation/augmentation.py
Outdated
def inverse_transform( | ||
self, input: torch.Tensor, transform: Optional[torch.Tensor] = None, | ||
size: Optional[Tuple[int, int]] = None, **kwargs | ||
) -> torch.Tensor: | ||
return hflip(input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we be based on transformation matrix for hlip and vflip? The direct hlip will be much faster.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how different is this from what we have already ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No no. It is for the inverse. restore back the flip that has been applied.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good overall - add tests
kornia/augmentation/augmentation.py
Outdated
def inverse_transform( | ||
self, input: torch.Tensor, transform: Optional[torch.Tensor] = None, | ||
size: Optional[Tuple[int, int]] = None, **kwargs | ||
) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add some docs here ? what are the kwarg ? could you make it explicit ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking that if it is needed to overwrite some of the parameters during the inverse conversion.
Say, I used bilinear for the forward but I may want to use nearest for inverse? I have not decided do I need to do some overwriting here. I need some real-world feedback.
kornia/augmentation/augmentation.py
Outdated
align_corners = self.align_corners if "align_corners" not in kwargs else kwargs['align_corners'] | ||
padding_mode = 'zeros' if "padding_mode" not in kwargs else kwargs['padding_mode'] | ||
transform = cast(torch.Tensor, transform) | ||
return affine(input, transform.inverse()[..., :2, :3], mode, padding_mode, align_corners) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove boilerplate code and reuse in both places the affine function (or use directly warp_affine)
) -> torch.Tensor: | ||
if self.cropping_mode != 'resample': | ||
raise NotImplementedError(f"`inverse` is only applicable for resample cropping mode. Got {self.cropping_mode}.") | ||
mode = self.resample.name.lower() if "mode" not in kwargs else kwargs['mode'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a lot of boilerplate across augmentations for the flags too
to the batch form (False). Default: False. | ||
""" | ||
|
||
def compute_transformation(self, input: torch.Tensor, params: Dict[str, torch.Tensor]) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess params could Optional
params = self._params | ||
if 'batch_prob' not in params: | ||
params['batch_prob'] = torch.tensor([True] * batch_shape[0]) | ||
warnings.warn("`batch_prob` is not found in params. Will assume applying on all data.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you have any mechanism to make warnings optional ?
Should be okay now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Mind providing examples later to promote
Description
Related to #1000.
@edgarriba, @ducha-aiki I am not sure if we should enable the inverse operation for perspective and cropping transforms?