From f58e94733427a10d5b769fb75f5df2ba3c9363f7 Mon Sep 17 00:00:00 2001 From: Nando Metzger <42088121+nandometzger@users.noreply.github.com> Date: Mon, 13 May 2024 09:47:44 +0200 Subject: [PATCH 1/2] Fix upsampling misalignments after resizing In the current implementation, we have the following problem: If the user input an image of odd shape (either not divible by 8 or aspect ratio which don't allow to resize to a multiple of 8) the code does something unexpected: It (or the stable diffusion backbone) cut's of the boarders to make it divisible by 8. After that we just resample it to original resolution. This can cause a misalignment (depht vs RGB) of up to 7 pixels in the extrem case at the bottom right corner. I propose to modify the resize_max_resolution function to slightly change the aspect ratio (which will be at max 768/7~=1% off from the true ratio), but this will in turn fix the misalignment when upsampling to match the resolution. --- marigold/util/image_util.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/marigold/util/image_util.py b/marigold/util/image_util.py index 90f0623..d988ec4 100644 --- a/marigold/util/image_util.py +++ b/marigold/util/image_util.py @@ -80,6 +80,7 @@ def resize_max_res( img: torch.Tensor, max_edge_resolution: int, resample_method: InterpolationMode = InterpolationMode.BILINEAR, + div8=False, ) -> torch.Tensor: """ Resize image to limit maximum edge length while keeping aspect ratio. @@ -104,6 +105,11 @@ def resize_max_res( new_width = int(original_width * downscale_factor) new_height = int(original_height * downscale_factor) + # round it up or down to the next multiple of 8 to avoid upsampling misalignments due to smaller latent dimension + if div8: + new_width = round(new_width / 8) * 8 + new_height = round(new_height / 8) * 8 + resized_img = resize(img, (new_height, new_width), resample_method, antialias=True) return resized_img From c79fd1b7860a1b20944acf90da63135ce315ee11 Mon Sep 17 00:00:00 2001 From: Nando Metzger <42088121+nandometzger@users.noreply.github.com> Date: Mon, 13 May 2024 09:54:13 +0200 Subject: [PATCH 2/2] Update image_util.py --- marigold/util/image_util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/marigold/util/image_util.py b/marigold/util/image_util.py index d988ec4..c413835 100644 --- a/marigold/util/image_util.py +++ b/marigold/util/image_util.py @@ -105,7 +105,8 @@ def resize_max_res( new_width = int(original_width * downscale_factor) new_height = int(original_height * downscale_factor) - # round it up or down to the next multiple of 8 to avoid upsampling misalignments due to smaller latent dimension + # round it up or down to the next multiple of 8 + # to avoid upsampling misalignments due to smaller latent dimension if div8: new_width = round(new_width / 8) * 8 new_height = round(new_height / 8) * 8