Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can diffusion model be used into image to image translation ? #98

Open
lianggaoquan opened this issue Sep 26, 2022 · 11 comments
Open

Can diffusion model be used into image to image translation ? #98

lianggaoquan opened this issue Sep 26, 2022 · 11 comments

Comments

@lianggaoquan
Copy link

Can diffusion model be used into image to image translation ?

@robert-graf
Copy link

Yes, just concatenate your image to the noised-image input and change the input-channel size.

@lucidrains
Copy link
Owner

@lianggaoquan yea, what Robert said

i can add it later this week

@Mut1nyJD
Copy link

That depends | would say for paired i2i you can do what @robert-graf mentioned however if you for example have segmentation maps as one pair you might be better of adding a SPADE normalization layer into your UNet and don't attach the segmentation map as input.

However for unpaired i2i I think this current framework most likely will not work as I can't see how the current training signal would be enough but maybe I am wrong

@FireWallDragonDarkFluid

Hi, any update for the paired image translation in the repo?
Or can anyone show at least snippet of code in order to modify the repo to do the work?
Anyway, really appreciate all the works, learn a lot!

@huseyin-karaca
Copy link

@robert-graf Where exactly should I perform concatenation operation? Could you please give more details? I tried to do it very beginning of the Unet forward, but did not work.

Yes, just concatenate your image to the noised-image input and change the input-channel size.

@robert-graf
Copy link

@huseyin-karaca
This Google paper introduced this https://iterative-refinement.github.io/palette/.

I did it before the forward call of the U-Net and only updated the input size of the first Con-Block.

# Conditional p(x_0| y) -> p(x_0)*p(y|x_0) --> just added it to the input
if not x_conditional is None and self.opt.conditional:
    x = torch.cat([x, x_conditional], dim=1)
# --------------

Here is the rest for context my Image2Image Code under /img2img2D/diffusion.py. I hope lucidrains is fine with linking my Code here. If you are looking for the paper referenced, the preprint is coming out on Tuesday.

@huseyin-karaca
Copy link

@robert-graf Thank you for your kind reply!

@heitorrapela
Copy link

Hi, so to do i2i using this repo, is it okay to use the Unet self_condition=True, or we have to do the cat manually and change in another place?

@FireWallDragonDarkFluid

@heitorrapela You would have to manually change the code written in this repo to achieve i2i.
The self_condition=True in the Unet from this repo is the implementation of this paper: https://arxiv.org/abs/2208.04202

By the way, diffusion model often achieve better results from pre-trained model when applying to i2i, maybe you could take a look at HuggingFace's diffusers: https://github.com/huggingface/diffusers

@heitorrapela
Copy link

@FireWallDragonDarkFluid, thanks for the response. I was trying with the self_condition, but yes, it was not what I wanted, and in the end, it was still adding artifacts to the translation process.

I will see if I can implement myself with this library or the diffusers. Using diffusers, I just tried simple things, but I still need to train, so I must investigate. Due to my task restrictions, I also cannot use a heavy model, such as SD.

@heitorrapela
Copy link

I did a quick implementation, but I am not 100% sure; I am training some models with it; here are my modifications if anyone wants to try also:

  • I am using ddim (sampling_timesteps < timesteps).
  • I updated the UNet channels to be 2*input_channels. e.g. Unet(dim = 64,dim_mults = (1, 2, 4, 8), flash_attn = False,channels=6).
  • Before line 794:
    model_out = self.model(x, t, x_self_cond)), I added x = torch.cat([x, x_start], dim=1)
  • Here is the workaround to make the code work (for the loss when forwarding the images), before L806, add:
    target = torch.cat([target, x_start], dim=1).
  • Finally, when sampling, I slice the three initial channels corresponding to my sampled image without the initial image.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants