-
Notifications
You must be signed in to change notification settings - Fork 710
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix onnx export by rewriting GaussianBlur (#476)
* Fix onnx export by rewriting GaussianBlur * Address codacy complaints. Reame variable to something other than `input` * Move GaussianBlur2d to anomalib.post_processing * Move blur to `anomlib.models.components.filters`
- Loading branch information
Showing
7 changed files
with
123 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""Implements filters used by models.""" | ||
|
||
from .blur import GaussianBlur2d | ||
|
||
__all__ = ["GaussianBlur2d"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
"""Gaussian blurring via pytorch.""" | ||
from typing import Tuple, Union | ||
|
||
from kornia.filters import get_gaussian_kernel2d | ||
from kornia.filters.filter import _compute_padding | ||
from kornia.filters.kernels import normalize_kernel2d | ||
from torch import Tensor, nn | ||
from torch.nn import functional as F | ||
|
||
|
||
class GaussianBlur2d(nn.Module): | ||
"""Compute GaussianBlur in 2d. | ||
Makes use of kornia functions, but most notably the kernel is not computed | ||
during the forward pass, and does not depend on the input size. As a caveat, | ||
the number of channels that are expected have to be provided during initialization. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
kernel_size: Union[Tuple[int, int], int], | ||
sigma: Union[Tuple[float, float], float], | ||
channels: int, | ||
normalize: bool = True, | ||
border_type: str = "reflect", | ||
padding: str = "same", | ||
) -> None: | ||
"""Initialize model, setup kernel etc.. | ||
Args: | ||
kernel_size (Union[Tuple[int, int], int]): size of the Gaussian kernel to use. | ||
sigma (Union[Tuple[float, float], float]): standard deviation to use for constructing the Gaussian kernel. | ||
channels (int): channels of the input | ||
normalize (bool, optional): Whether to normalize the kernel or not (i.e. all elements sum to 1). | ||
Defaults to True. | ||
border_type (str, optional): Border type to use for padding of the input. Defaults to "reflect". | ||
padding (str, optional): Type of padding to apply. Defaults to "same". | ||
""" | ||
super().__init__() | ||
kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) | ||
sigma = sigma if isinstance(sigma, tuple) else (sigma, sigma) | ||
self.kernel: Tensor | ||
self.register_buffer("kernel", get_gaussian_kernel2d(kernel_size=kernel_size, sigma=sigma)) | ||
if normalize: | ||
self.kernel = normalize_kernel2d(self.kernel) | ||
self.channels = channels | ||
self.kernel.unsqueeze_(0).unsqueeze_(0) | ||
self.kernel = self.kernel.expand(self.channels, -1, -1, -1) | ||
self.border_type = border_type | ||
self.padding = padding | ||
self.height, self.width = self.kernel.shape[-2:] | ||
self.padding_shape = _compute_padding([self.height, self.width]) | ||
|
||
def forward(self, input_tensor: Tensor) -> Tensor: | ||
"""Blur the input with the computed Gaussian. | ||
Args: | ||
input_tensor (Tensor): Input tensor to be blurred. | ||
Returns: | ||
Tensor: Blurred output tensor. | ||
""" | ||
batch, channel, height, width = input_tensor.size() | ||
|
||
if self.padding == "same": | ||
input_tensor = F.pad(input_tensor, self.padding_shape, mode=self.border_type) | ||
|
||
# convolve the tensor with the kernel. | ||
output = F.conv2d(input_tensor, self.kernel, groups=self.channels, padding=0, stride=1) | ||
|
||
if self.padding == "same": | ||
out = output.view(batch, channel, height, width) | ||
else: | ||
out = output.view(batch, channel, height - self.height + 1, width - self.width + 1) | ||
|
||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Test individual components.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import pytest | ||
import torch | ||
from kornia.filters import GaussianBlur2d as korniaGaussianBlur2d | ||
|
||
from anomalib.models.components import GaussianBlur2d | ||
|
||
|
||
@pytest.mark.parametrize("kernel_size", [(33, 33), (9, 9), (11, 5), (3, 3)]) | ||
@pytest.mark.parametrize("sigma", [(4.0, 4.0), (1.9, 3.0), (2.0, 1.5)]) | ||
@pytest.mark.parametrize("channels", list(range(1, 6))) | ||
def test_blur_equivalence(kernel_size, sigma, channels): | ||
for _ in range(10): | ||
input_tensor = torch.randn((3, channels, 128, 128)) | ||
kornia = korniaGaussianBlur2d(kernel_size, sigma, separable=False) | ||
blur_kornia = kornia(input_tensor) | ||
gaussian = GaussianBlur2d(kernel_size, sigma, channels) | ||
blur_gaussian = gaussian(input_tensor) | ||
torch.testing.assert_allclose(blur_kornia, blur_gaussian) |
e19428f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#1798 error is seems with the GaussianBlur2d with channels dimensions in the tensor is note equal and it seems working with the Pacthcore() . Can you explain with this Pacthcore library how it works/?