forked from DawyD/bm3d-gpu
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
34 changed files
with
2,434 additions
and
51,991 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
BasedOnStyle: Google | ||
IndentWidth: 2 | ||
ColumnLimit: 80 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
*.so | ||
*.egg-info | ||
*.o | ||
*.ninja* | ||
*.egg | ||
.vscode | ||
*.pyc |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .bm3d import BM3D |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch | ||
from torch.nn import Module, Parameter | ||
from torch.autograd import Function | ||
|
||
import bm3d_cuda | ||
|
||
|
||
class bm3d_function(Function): | ||
|
||
@staticmethod | ||
def forward(ctx, im, variance, two_step=True): | ||
''' | ||
im: 4D tensor with shape (batch_size, channel, height, width) | ||
variance: noise variance | ||
two_step: whether to use two step method | ||
''' | ||
assert im.shape == 4, "BM3D forward input must be 4D tensor with shape (batch_size, channel, height, width)" | ||
assert im.shape[0] == 1, "Only support batch_size=1" | ||
if im.shape[0] == 3: | ||
print("Warning: We do not support RGB image, inference as multiple gray images") | ||
output = torch.zeros_like(im) | ||
for c in range(im.shape[1]): | ||
output[0, c, :, :] = bm3d_cuda.forward(im[0, c, :, :], variance, two_step) | ||
return output | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
''' | ||
grad_output: gradient of output | ||
''' | ||
raise NotImplementedError("BM3D backward is not implemented") | ||
|
||
|
||
class BM3D(Module): | ||
|
||
def __init__(self, two_step=True): | ||
''' | ||
Support interpolation mode with Bilinear and Nearest. | ||
''' | ||
super(forward_warp, self).__init__() | ||
self._two_step = two_step | ||
|
||
def forward(self, input, variance): | ||
return bm3d_function.apply(input, variance, self._two_step) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from setuptools import setup | ||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension | ||
|
||
setup( | ||
name='bm3d_cuda', | ||
ext_modules=[ | ||
CUDAExtension('bm3d_cuda', [ | ||
'cuda/bm3d_cuda.cpp', | ||
'cuda/filtering.cu', | ||
'cuda/blockmatching.cu', | ||
'cuda/dct8x8.cu', | ||
], | ||
libraries=['cufft', 'cudart', 'png'] | ||
) | ||
], | ||
cmdclass={ | ||
'build_ext': BuildExtension | ||
} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .bm3d import BM3D |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch | ||
from torch.nn import Module, Parameter | ||
from torch.autograd import Function | ||
|
||
import bm3d_cuda | ||
|
||
|
||
class bm3d_function(Function): | ||
|
||
@staticmethod | ||
def forward(ctx, im, variance, two_step=True): | ||
''' | ||
im: 4D tensor with shape (batch_size, channel, height, width) | ||
variance: noise variance | ||
two_step: whether to use two step method | ||
''' | ||
assert im.shape == 4, "BM3D forward input must be 4D tensor with shape (batch_size, channel, height, width)" | ||
assert im.shape[0] == 1, "Only support batch_size=1" | ||
if im.shape[0] == 3: | ||
print("Warning: We do not support RGB image, inference as multiple gray images") | ||
output = torch.zeros_like(im) | ||
for c in range(im.shape[1]): | ||
output[0, c, :, :] = bm3d_cuda.forward(im[0, c, :, :], variance, two_step) | ||
return output | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
''' | ||
grad_output: gradient of output | ||
''' | ||
raise NotImplementedError("BM3D backward is not implemented") | ||
|
||
|
||
class BM3D(Module): | ||
|
||
def __init__(self, two_step=True): | ||
''' | ||
Support interpolation mode with Bilinear and Nearest. | ||
''' | ||
super(forward_warp, self).__init__() | ||
self._two_step = two_step | ||
|
||
def forward(self, input, variance): | ||
return bm3d_function.apply(input, variance, self._two_step) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from setuptools import setup | ||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension | ||
|
||
setup( | ||
name='bm3d_cuda', | ||
ext_modules=[ | ||
CUDAExtension('bm3d_cuda', [ | ||
'cuda/bm3d_cuda.cpp', | ||
'cuda/filtering.cu', | ||
'cuda/blockmatching.cu', | ||
'cuda/dct8x8.cu', | ||
], | ||
libraries=['cufft', 'cudart', 'png'] | ||
) | ||
], | ||
cmdclass={ | ||
'build_ext': BuildExtension | ||
} | ||
) |
Oops, something went wrong.