Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Refinement to Improve High Resolution Image Inpainting #112

Merged
merged 6 commits into from
Jul 28, 2022

Conversation

ankuPRK
Copy link
Contributor

@ankuPRK ankuPRK commented Apr 27, 2022

We are a team of researchers at Geomagical Labs (geomagical.com), a subsidiary of IKEA. We work on pioneering Mixed Reality apps which allow customers to scan photorealistic models of their indoor spaces and re-imagine them with virtual furniture.

In this PR we propose an additional refinement step for LaMa to improve high-resolution inpainting results. We observed that when inpainting large regions at high resolution, LaMa struggles at structure completion. However, at low resolutions, LaMa can infill the same missing region much better. To address this we added an additional refinement step that uses the structure from low resolution predictions to guide higher resolution predictions.

Our approach can work on any inpainting network, and does not require any additional training or network modification.

How to run refinement

To run refinement, simply pass refine=True in the evaluation step as:

    python3 bin/predict.py refine=True model.path=$(pwd)/big-lama indir=$(pwd)/LaMa_test_images outdir=$(pwd)/output

Evaluation

Here's a few example comparisons, with each triplet showing the masked image, inpainting with LaMa, and inpainting with LaMa using refinement:
image

Comparison of unrefined and refined images on all test images (kindly shared by you) is available here: https://drive.google.com/drive/folders/15LEa9k_7-dUKb2CPUDuw7e6Zk28KCtzz?usp=sharing

We also performed some numerical evaluation on 1024x1024 size images sampled from [1], using the thin, medium, and thick masks. Results indicate that LaMa+refinement outperforms all the recent inpainting baselines on high resultion inpainting:

Method FID (thin) LPIPS (thin) FID (medium) LPIPS (medium) FID (thick) LPIPS (thick)
AOTGAN [3] 17.387 0.133 34.667 0.144 54.015 0.184
LatentDiffusion [4] 18.505 0.141 31.445 0.149 38.743 0.172
MAT [6] 16.284 0.137 27.829 0.135 38.120 0.157
ZITS [5] 15.696 0.125 23.500 0.121 31.777 0.140
LaMa-Fourier [2] 14.780 0.124 22.584 0.120 29.351 0.140
Big-LaMa [2] 13.143 0.114 21.169 0.116 29.022 0.140
Big-LaMa+refinement (ours) 13.193 0.112 19.864 0.115 26.401 0.135

Table 1. Performance comparison of various recent inpainting approaches on 1k 1024x1024 size images

Video

We have also created a video to explain the technical details of our approach:
https://www.youtube.com/watch?v=gEukhOheWgE

References

[1]
Unsplash Dataset. https://unsplash.com/data, 2020

[2]
Suvorov, R., Logacheva, E., Mashikhin, A., Remizova, A., Ashukha, A., Silvestrov, A., Kong, N., Goka, H., Park, K. and Lempitsky, V., 2022. Resolution-robust Large Mask Inpainting with Fourier Convolutions. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (pp. 2149-2159)

[3]
Zeng, Y., Fu, J., Chao, H. and Guo, B., 2022. Aggregated contextual transformations for high-resolution image inpainting. IEEE Transactions on Visualization and Computer Graphics.

[4]
Rombach, R., Blattmann, A., Lorenz, D., Esser, P. and Ommer, B., 2021. High-Resolution Image Synthesis with Latent Diffusion Models. arXiv preprint arXiv:2112.10752.

[5]
Dong, Q., Cao, C. and Fu, Y., 2022. Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding. arXiv preprint arXiv:2203.00867.

[6]
Li, W., Lin, Z., Zhou, K., Qi, L., Wang, Y. and Jia, J., 2022. MAT: Mask-Aware Transformer for Large Hole Image Inpainting. arXiv preprint arXiv:2203.15270.

@windj007
Copy link
Collaborator

Wow! That's extremely cool! Can't wait to try!

@senya-ashukha
Copy link
Collaborator

Wow folks! that is really impressive!

@windj007
Copy link
Collaborator

Just curious, have you tried your approach with other methods, e.g. MAT or ZITS? Is the improvement for them is like that for LaMa? I'm also amazed by the fact that such an optimization technique does not introduce high-frequency artifacts... How do you think, why is it so?

@ankuPRK
Copy link
Contributor Author

ankuPRK commented Apr 27, 2022

Hi, we haven't tried it with other methods due to our limited bandwidth. We started with LaMa, and then during evaluation we found that without refinement, LaMa still generalizes better than the newer methods on high resolution inpainting. So we didn't have a very strong motivation to try this on the other methods. We'd love to see how it works on them though.

When applying the refinement, we're basically asking the network to find a high-resolution featuremap that produces an output that, when downscaled, looks like the low-resolution output. We hypothesize that when this featuremap is found, it contains high-level encoded information learned from training about the contents in that region. For example, optimization may adjust a feature that was "kind-of-brick-like" to become "very-red-brick-like." The optimized high level features then gets decoded into low and high frequency brick-like textures.

So, it's possible that these latent features of the downscaler control low and high frequency details jointly. We also tried to optimize the latent features of the upscaler, but it didn't work well, and produced cloud-like artifacts. So it probably has something to do with large and overlapping receptive fields of pixels of the featuremap.

@windj007
Copy link
Collaborator

Thank you for the clarification!

that without refinement, LaMa still generalizes better

Yeah, but in lower resolution these methods are stronger than LaMa - so they might probably benefit more from your method.
Maybe the difference between effects of your method on different architectures would highlight the inherent inductive biases in them - and help build a better new archtecture.

these latent features of the downscaler control low and high frequency details jointly.

Yes, lo-freq and high-freq details do not seem to be disentangled in the features.

Copy link
Collaborator

@windj007 windj007 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd kindly suggest simplifying the usage and making default config values more sensible for broader set of environments (e.g. single-gpu). Anyway, this is a great contribution!

batch['mask'] = (batch['mask'] > 0) * 1
batch = model(batch)
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy()
unpad_to_size = batch.get('unpad_to_size', None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L84-87 should be outside if-else - they need to be executed regardless refinement

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L84-87 already get addressed inside the refiner. Refiner works on unpadded images (it does the necesssary padding internally and then unpads the output appropriately). We can:

  • add an assertion to check unpad_to_size is not None
  • enable refiner to just run on padded image, if unpad_to_size is None.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I see. I'd move padding-unpadding from the refiner to predict.py - so both parts of the code are simplified and no logic duplication is introduced. What do you think, is it possible and does it make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can let the refiner get padded input. But refiner still needs some padding in place. Because -

Suppose your input image is a square of size 1000. Then the original image isn't padded because 1000%8==0, but in the refiner, once we downscale the image, it's size becomes 500, and 500%8!=0. So we have to pad it to make it 504x504.

So we can't get rid of lines 301 and 302 in refinement.py, but we can:

  • let the padded image to be input to refiner, so that we take L84-87 outside the if-else.
  • refiner then doesn't check unpad_to_size argument at all.
  • Padding would happen in the refiner to ensure downscaled image size is divisible by 8.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refiner still needs some padding in place

I see, thank you for the clarification! Let's just leave that piece of the as is - and add a comment about "padding-unpadding is handled within refiner"

image size is divisible by 8.

Padding size depends on depth of the generator and thus needs to be configurable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, will add the comment. Yeah the padding size of the refiner is not exactly 8, but exactly equal to dataset.pad_out_to_modulo in the predict config. I'll add a comment there in the PR


refine: False # refiner will only run if this is True
refiner:
gpu_ids: 0,1 # the GPU ids of the machine to use. If only single GPU, use: "0,"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest using only 0 by default - or even introduce "None" default (so refiner would rely on the parent device setting). That would make this work by default in more environments without any modifications by default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the refiner needs around 24GB GPU to process 1.8 megapixel images (~1200x1500). Since most people have two 12GB GPUs instead, we decided to split the model onto two GPUs, that's why the default config setting.
Do you suggest to still make it None by default?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, right, I have not thought about memory consumption. It seems that the most of the consumption comes from storing activations for backward... And you're splitting res-blocks between GPUs to distribute that memory - not to speedup inference - because GPUs are called sequentially.

I have a couple ideas how to overcome that without complex logic or requirement to have two GPUs:

  • Set param.requires_grad_(False) for all parameters in the generator - that will lead to storing only activations, not gradients for parameters.
  • Use activation checkpointing - it does something very similar to what you're doing - it splits a nn.Sequential in multiple chunks and runs each chunk with torch.no_grad - so only activations between chunks have to be stored. That will slow the optimization down, but maybe not severely.
  • torch.cuda.amp - optimize in fp16 instead of fp32. In case of refinement there is no adversarial training, so there should not be stability issues due to reduced precision (but I'm not sure)

Copy link
Contributor Author

@ankuPRK ankuPRK May 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ideas! param.requires_grad_ is already set to False, since we freeze the model here: https://github.com/geomagical/lama-with-refiner/blob/24a20f804390c6ab969c28abbe999c940f8d6a56/bin/predict.py#L58
I also manually verified the requires_grad for all the params of the model, they were False.

We were already looking at activation checkpointing, will focus on it more now that you have also mentioned it. Will try the third idea also.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.cuda.amp isn't working because pytorch doesn't seem to support Half dtype for torch.fft.rfftn. PFA link to the relevant issues in Pytorch repo:

pytorch/pytorch#70664
pytorch/pytorch#71680

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also manually verified the requires_grad for all the params of the model

Great, thank you!

torch.cuda.amp isn't working because pytorch doesn't seem to support Half

Sure, I've forgot that I've already tried half and failed because of that... We could wrap rfftn/irfftn with conversion to and from .float(), but I'm not sure there wouldn't be other issues..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @windj007, sorry for coming back after 2 months! We picked up the experiments, our findings:

  1. We were able to perform the optimization in mixed precision. I haven't benchmarked it quantitatively, but qualitative results look good. However, for 1024x1024 images, it only reduces the memory from 21-22GB -> 17-18GB, so it is still not sufficient to fit on a single 12GB GPU
  2. We also tried to play with checkpointing. Performing it naively throws RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn, which we bypass by setting use_reentrant=False. However, this setting has some memory leak problem, which causes the GPU consumption to increase at each training loop, eventually leading to OOM error. We plan to raise this issue on the Pytorch repo.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please share your code for mixed-precision?

Copy link
Contributor Author

@ankuPRK ankuPRK Aug 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, it's in amp_float16 branch of our fork:

https://github.com/geomagical/lama-with-refiner/tree/amp_float16

You can get this code by:

git clone [email protected]:geomagical/lama-with-refiner.git
git checkout amp_float16

Also, I've changed the config file of the refiner to run on a single GPU. But yeah feel free to play around with config parameters or anything :)

Link to the config file in the code: https://github.com/geomagical/lama-with-refiner/blob/amp_float16/configs/prediction/default.yaml

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much 🙏

@ankuPRK
Copy link
Contributor Author

ankuPRK commented May 3, 2022

Thanks for your feedback, appreciate it! We are very much interested in addressing your concerns until you're comfortable enough to merge this PR into your code. The usage complication is primarily because we try to fit the refinement on multiple devices. Refinement requires at least 24GB GPU to run on images of sizes like 1024x1024.

refine: False # refiner will only run if this is True
refiner:
gpu_ids: 0,1 # the GPU ids of the machine to use. If only single GPU, use: "0,"
modulo: ${dataset.pad_out_to_modulo}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@windj007 refiner padding is defined here

@senya-ashukha
Copy link
Collaborator

Screenshot 2022-07-23 at 21 39 24

Hi folks! Your results are amazing, so I've added them to the main page! Fee free to share twitter post and paper!

@ankuPRK
Copy link
Contributor Author

ankuPRK commented Jul 23, 2022

@senya-ashukha that works too! :)

@ankuPRK
Copy link
Contributor Author

ankuPRK commented Jul 25, 2022

Hello,

Thanks for your shout out on the main README. Let us know if you would like us to close this PR.

@senya-ashukha senya-ashukha merged commit bd69ec3 into advimman:main Jul 28, 2022
@senya-ashukha
Copy link
Collaborator

Hi @ankuPRK, @windj007 is on long vacation. I've merged it, but @windj007 may discuss it further as he returns.

@ankuPRK
Copy link
Contributor Author

ankuPRK commented Jul 28, 2022

Sounds good, thanks a lot! Happy to follow up on any further review comments regarding code/formatting/ideas or anything :)

@mhashas
Copy link

mhashas commented Aug 11, 2022

is there anyway i can make this work on a gpu of 11GB? or run it on CPU?

@ankuPRK
Copy link
Contributor Author

ankuPRK commented Aug 15, 2022

@mhashas You can try using our mixed precision branch: #112 (comment)

We can probably have a CPU version, but it will be super slow. If you are okay with a slow CPU version, I think replacing the device with torch.device('cpu') in predict.py and refinement.py should work.

HeunSeungLim pushed a commit to HeunSeungLim/lama_HL that referenced this pull request Nov 14, 2022
Feature Refinement to Improve High Resolution Image Inpainting
@202112213501021
Copy link

Can I separate out the Feature Refinement to Improve High Resolution Image Inpainting technique, just for a single image?

sagor155670 pushed a commit to sagor155670/lama that referenced this pull request Aug 27, 2024
Feature Refinement to Improve High Resolution Image Inpainting
@hjj-lmx
Copy link

hjj-lmx commented Dec 10, 2024

Do you have the complete branch code?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants