Skip to content

Commit

Permalink
change variance from uint into float
Browse files Browse the repository at this point in the history
  • Loading branch information
lizhihao6 committed Dec 11, 2023
1 parent e53eeeb commit 88fa137
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 13 deletions.
10 changes: 5 additions & 5 deletions pytorch_bm3d/cuda/bm3d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ extern "C" void run_DCT2D8x8(float* d_transformed_stacks,
extern "C" void run_hard_treshold_block(
const uint2 start_point, float* patch_stack, float* w_P,
const uint* __restrict num_patches_in_stack, const uint2 stacks_dim,
const Params params, const uint sigma, const dim3 num_threads,
const Params params, const float sigma, const dim3 num_threads,
const dim3 num_blocks, const uint shared_memory_size);

extern "C" void run_IDCT2D8x8(float* d_gathered_stacks,
Expand All @@ -77,7 +77,7 @@ extern "C" void run_wiener_filtering(
const uint2 start_point, float* patch_stack,
const float* __restrict patch_stack_basic, float* w_P,
const uint* __restrict num_patches_in_stack, uint2 stacks_dim,
const Params params, const uint sigma, const dim3 num_threads,
const Params params, const float sigma, const dim3 num_threads,
const dim3 num_blocks, const uint shared_memory_size);

// Cuda error handling
Expand Down Expand Up @@ -283,7 +283,7 @@ class BM3D {
arrays.
*/
void first_step(std::vector<raw_int*>& denoised_image, int width, int height,
int channels, uint* sigma) {
int channels, float* sigma) {
// image dimensions
const uint2 image_dim = make_uint2(width, height);

Expand Down Expand Up @@ -511,7 +511,7 @@ class BM3D {
}

void second_step(std::vector<raw_int*>& denoised_image, int width, int height,
int channels, uint* sigma) {
int channels, float* sigma) {
// Image dimensions
const uint2 image_dim = make_uint2(width, height);

Expand Down Expand Up @@ -880,7 +880,7 @@ class BM3D {
component and each following width*height pixels represent color components
*/
void denoise_host_image(raw_int* src_image, raw_int* dst_image, int width,
int height, int channels, uint* sigma,
int height, int channels, float* sigma,
bool two_step) {
Stopwatch total;
total.start();
Expand Down
6 changes: 3 additions & 3 deletions pytorch_bm3d/cuda/bm3d_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

// CUDA forward declarations

torch::Tensor bm3d_cuda_forward(torch::Tensor input, BM3D& bm3d, uint& variance,
bool twostep) {
torch::Tensor bm3d_cuda_forward(torch::Tensor input, BM3D& bm3d,
float& variance, bool twostep) {
// Allocate images
auto output = torch::zeros_like(input);

Expand Down Expand Up @@ -39,7 +39,7 @@ torch::Tensor bm3d_cuda_forward(torch::Tensor input, BM3D& bm3d, uint& variance,
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)

torch::Tensor bm3d_forward(torch::Tensor input, uint& variance, bool twostep) {
torch::Tensor bm3d_forward(torch::Tensor input, float& variance, bool twostep) {
CHECK_INPUT(input);

BM3D bm3d;
Expand Down
8 changes: 4 additions & 4 deletions pytorch_bm3d/cuda/filtering.cu
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ __global__ void hard_treshold_block(
// 3D groups
uint2 stacks_dim, // IN: dimensions limiting addresses of reference patches
const Params params, // IN: denoising parameters
const uint sigma // IN: noise variance
const float sigma // IN: noise variance
) {
extern __shared__ float data[];

Expand Down Expand Up @@ -321,7 +321,7 @@ __global__ void wiener_filtering(
// 3D groups
uint2 stacks_dim, // IN: dimensions limiting addresses of reference patches
const Params params, // IN: denoising parameters
const uint sigma // IN: Noise variance
const float sigma // IN: Noise variance
) {
extern __shared__ float data[];

Expand Down Expand Up @@ -410,7 +410,7 @@ extern "C" void run_get_block(const uint2 start_point,
extern "C" void run_hard_treshold_block(
const uint2 start_point, float* patch_stack, float* w_P,
const uint* __restrict num_patches_in_stack, const uint2 stacks_dim,
const Params params, const uint sigma, const dim3 num_threads,
const Params params, const float sigma, const dim3 num_threads,
const dim3 num_blocks, const uint shared_memory_size) {
hard_treshold_block<<<num_blocks, num_threads, shared_memory_size>>>(
start_point, patch_stack, w_P, num_patches_in_stack, stacks_dim, params,
Expand Down Expand Up @@ -443,7 +443,7 @@ extern "C" void run_wiener_filtering(
const uint2 start_point, float* patch_stack,
const float* __restrict patch_stack_basic, float* w_P,
const uint* __restrict num_patches_in_stack, uint2 stacks_dim,
const Params params, const uint sigma, const dim3 num_threads,
const Params params, const float sigma, const dim3 num_threads,
const dim3 num_blocks, const uint shared_memory_size) {
wiener_filtering<<<num_blocks, num_threads, shared_memory_size>>>(
start_point, patch_stack, patch_stack_basic, w_P, num_patches_in_stack,
Expand Down
2 changes: 1 addition & 1 deletion test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
variance = 20 * 20

lq, gt = lq * scale, gt * scale
variance = variance * (scale ** 2)
variance = variance * (scale ** 2) + 0.0001

bm3d = BM3D(two_step=True)

Expand Down

0 comments on commit 88fa137

Please sign in to comment.