Skip to content

Commit

Permalink
finish bm3d pytorch support
Browse files Browse the repository at this point in the history
  • Loading branch information
lizhihao6 committed Dec 11, 2023
1 parent 96d57fd commit 4a14bc6
Show file tree
Hide file tree
Showing 34 changed files with 2,434 additions and 51,991 deletions.
3 changes: 3 additions & 0 deletions .clang-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
BasedOnStyle: Google
IndentWidth: 2
ColumnLimit: 80
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
*.so
*.egg-info
*.o
*.ninja*
*.egg
.vscode
*.pyc
25 changes: 0 additions & 25 deletions CMakeLists.txt

This file was deleted.

1 change: 1 addition & 0 deletions build/lib/BM3D/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .bm3d import BM3D
44 changes: 44 additions & 0 deletions build/lib/BM3D/bm3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
from torch.nn import Module, Parameter
from torch.autograd import Function

import bm3d_cuda


class bm3d_function(Function):

@staticmethod
def forward(ctx, im, variance, two_step=True):
'''
im: 4D tensor with shape (batch_size, channel, height, width)
variance: noise variance
two_step: whether to use two step method
'''
assert im.shape == 4, "BM3D forward input must be 4D tensor with shape (batch_size, channel, height, width)"
assert im.shape[0] == 1, "Only support batch_size=1"
if im.shape[0] == 3:
print("Warning: We do not support RGB image, inference as multiple gray images")
output = torch.zeros_like(im)
for c in range(im.shape[1]):
output[0, c, :, :] = bm3d_cuda.forward(im[0, c, :, :], variance, two_step)
return output

@staticmethod
def backward(ctx, grad_output):
'''
grad_output: gradient of output
'''
raise NotImplementedError("BM3D backward is not implemented")


class BM3D(Module):

def __init__(self, two_step=True):
'''
Support interpolation mode with Bilinear and Nearest.
'''
super(forward_warp, self).__init__()
self._two_step = two_step

def forward(self, input, variance):
return bm3d_function.apply(input, variance, self._two_step)
19 changes: 19 additions & 0 deletions build/lib/BM3D/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
name='bm3d_cuda',
ext_modules=[
CUDAExtension('bm3d_cuda', [
'cuda/bm3d_cuda.cpp',
'cuda/filtering.cu',
'cuda/blockmatching.cu',
'cuda/dct8x8.cu',
],
libraries=['cufft', 'cudart', 'png']
)
],
cmdclass={
'build_ext': BuildExtension
}
)
1 change: 1 addition & 0 deletions build/lib/pytorch_bm3d/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .bm3d import BM3D
44 changes: 44 additions & 0 deletions build/lib/pytorch_bm3d/bm3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
from torch.nn import Module, Parameter
from torch.autograd import Function

import bm3d_cuda


class bm3d_function(Function):

@staticmethod
def forward(ctx, im, variance, two_step=True):
'''
im: 4D tensor with shape (batch_size, channel, height, width)
variance: noise variance
two_step: whether to use two step method
'''
assert im.shape == 4, "BM3D forward input must be 4D tensor with shape (batch_size, channel, height, width)"
assert im.shape[0] == 1, "Only support batch_size=1"
if im.shape[0] == 3:
print("Warning: We do not support RGB image, inference as multiple gray images")
output = torch.zeros_like(im)
for c in range(im.shape[1]):
output[0, c, :, :] = bm3d_cuda.forward(im[0, c, :, :], variance, two_step)
return output

@staticmethod
def backward(ctx, grad_output):
'''
grad_output: gradient of output
'''
raise NotImplementedError("BM3D backward is not implemented")


class BM3D(Module):

def __init__(self, two_step=True):
'''
Support interpolation mode with Bilinear and Nearest.
'''
super(forward_warp, self).__init__()
self._two_step = two_step

def forward(self, input, variance):
return bm3d_function.apply(input, variance, self._two_step)
19 changes: 19 additions & 0 deletions build/lib/pytorch_bm3d/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
name='bm3d_cuda',
ext_modules=[
CUDAExtension('bm3d_cuda', [
'cuda/bm3d_cuda.cpp',
'cuda/filtering.cu',
'cuda/blockmatching.cu',
'cuda/dct8x8.cu',
],
libraries=['cufft', 'cudart', 'png']
)
],
cmdclass={
'build_ext': BuildExtension
}
)
Loading

0 comments on commit 4a14bc6

Please sign in to comment.