Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Karras UNet 1D + 3D #295

Closed
MaximilienLC opened this issue Feb 21, 2024 · 30 comments
Closed

Karras UNet 1D + 3D #295

MaximilienLC opened this issue Feb 21, 2024 · 30 comments

Comments

@MaximilienLC
Copy link

Would the adjustments to transform the 2D Karras Unet to 1D be the same to the adjustments made to the 2D "standard" UNet to create the 1D one?

Thanks!

@lucidrains
Copy link
Owner

lucidrains commented Feb 21, 2024

@MaximilienLC yup, it is quite simple; just replace the mp conv2d with mp conv1d https://github.com/lucidrains/lumiere-pytorch/blob/main/lumiere_pytorch/mp_lumiere.py#L120 + some other adjustments

would you like me to just knock this out for you tomorrow morning?

@lucidrains
Copy link
Owner

@MaximilienLC what type of data are you working with?

@MaximilienLC
Copy link
Author

I trust you more than I trust myself so I'll gladly take the offer hhh. Working with audio-like signals that cannot be easily converted to spectrograms. Thanks a lot!

@lucidrains
Copy link
Owner

oh yes audio duh

ok, I'll knock this out tomorrow morning!

@lucidrains
Copy link
Owner

@MaximilienLC gave this more thought while walking doggo, and i think you need 2d still (to mix adjacent frequency information). i'll make it so the unet2d, you can finely control whether to downsample only the height or the width (and customize the conv2d kernel height/width dimensions too). that should generalize the 1d case

@lucidrains
Copy link
Owner

lucidrains commented Feb 21, 2024

oh sorry I misread, thought you said you were working with spectrogram. nevermind!

@MaximilienLC
Copy link
Author

MaximilienLC commented Feb 21, 2024

btw unrelated to my first question, but while I have you here hhh:

I'm planning to eventually transition to conditional modeling, following the concatenation scheme proposed in Palette & Image Super-Resolution via Iterative Refinement (add extra channels with your conditional info). I'm not super familiar with all the changes that need to be made yet but will eventually (is the concatenation all you need?) Do you think that your lib is flexible enough for me to add this conditioning feature (btw would be happy to propose a PR if I succeed)?

@QuantPrincess
Copy link

While we are on the topic... If I wanted to adjust the original Unet to support a 3d input would the adjustments roughly be similar in terms of changing the conv to 3d + some other minor adjustments?

@lucidrains
Copy link
Owner

lucidrains commented Feb 22, 2024

@QuantPrincess yup, similar complexity. can throw that in there too before month's end. i'm planning on just copy pasting my own code and making a few changes lol

@QuantPrincess
Copy link

Amazing! Thank you.

@lucidrains
Copy link
Owner

lucidrains commented Feb 22, 2024

@QuantPrincess what are you using it for? the latest video unets are all pretrained with images on conv2d with conv1d slipped in for time dimension during video training stage (unless if Sora changed all that, haven't read the technical paper yet)

@lucidrains
Copy link
Owner

@MaximilienLC knocked it out just now - let me know if this runs for you

@lucidrains
Copy link
Owner

@QuantPrincess i can get out a 3d version too, if you aren't doing video stuff. if you are, i'd recommend this

@QuantPrincess
Copy link

QuantPrincess commented Feb 22, 2024

Thanks so much, really appreciate the recommendation. Not doing video stuff though, just some 3d imaging. I will also try to adjust your code base for 3d in mean time!

@lucidrains
Copy link
Owner

lucidrains commented Feb 22, 2024

@QuantPrincess ohh got it, you working with CT / MRI segmentation? yea i can add 3d for you then tomorrow morning

@QuantPrincess
Copy link

Haha yes CTs! Thanks so much for your help. Your code base is really a pleasure to work with.

@lucidrains
Copy link
Owner

lucidrains commented Feb 22, 2024

@QuantPrincess awesome! yea i'll get that done

brings back memories of grad school when i tried to segment kidneys in CT scans with gofai algorithms (our team used watershed segmentation, super sh**ty). oh how far things have come

@MaximilienLC
Copy link
Author

@MaximilienLC knocked it out just now - let me know if this runs for you

Thanks @lucidrains, will do in the coming days and report back!

btw, you might have missed my latest question

I'm planning to eventually transition to conditional modeling, following the concatenation scheme proposed in Palette & Image Super-Resolution via Iterative Refinement (add extra channels with your conditional info). I'm not super familiar with all the changes that need to be made yet but will eventually (is the concatenation all you need?) Do you think that your lib is flexible enough for me to add this conditioning feature (btw would be happy to propose a PR if I succeed)?

@lucidrains
Copy link
Owner

@MaximilienLC unfortunately there is the possibility my open source journey comes to an end soon and i can't get around to that. PR is welcome!

@QuantPrincess
Copy link

Thank you for this! I will check out today and report back. Really appreciate your help. :)

@lucidrains
Copy link
Owner

@QuantPrincess hey! so i realized it still lacks a few features to be usable (factorized attention, and being able to specify downsamples in space vs time separately) CT slices will be a much smaller than the spatial dimensions

@lucidrains
Copy link
Owner

lucidrains commented Feb 23, 2024

@QuantPrincess let me get those in there this weekend, but do feel free to try it out as it is in the meanwhile

@lucidrains lucidrains changed the title Karras UNet 1D question Karras UNet 1D + 3D Feb 23, 2024
@lucidrains
Copy link
Owner

@QuantPrincess let me know if this is intuitive

with downsample_types, you can control at each stage whether to downsample image (spatial) or frame (slices) or all (both). you can also control how many MP resnet blocks are at each stage by passing in a tuple of integer into num_blocks_per_stage

@lucidrains
Copy link
Owner

@QuantPrincess i can get to the factorized attention mid next week and finish off this issue

@QuantPrincess
Copy link

Thank you! I will look at tonight and get back to you. I really can't express enough how awesome your code base is to work with!

@Parskatt
Copy link
Contributor

I think there might be some implementation mistake in the 3DUnet, I'm getting exploding activations. I'll see if I can make a simple repro.

@Parskatt
Copy link
Contributor

@lucidrains see #296

@lucidrains
Copy link
Owner

@Parskatt thank you Johan! you beat me to it

lucidrains added a commit that referenced this issue Feb 28, 2024
lucidrains added a commit that referenced this issue Feb 28, 2024
@lucidrains
Copy link
Owner

@QuantPrincess ok, with this flag, you can do attention across space and time separately (axial attention)

this is my last open source contribution until further notice, good luck!

@QuantPrincess
Copy link

Thanks so much @lucidrains ! Cant wait to test out your work on the axial attention. Best of luck on your next endeavors!

WillyChap pushed a commit to WillyChap/denoising-diffusion-pytorch that referenced this issue Sep 27, 2024
WillyChap pushed a commit to WillyChap/denoising-diffusion-pytorch that referenced this issue Sep 27, 2024
WillyChap pushed a commit to WillyChap/denoising-diffusion-pytorch that referenced this issue Sep 27, 2024
WillyChap pushed a commit to WillyChap/denoising-diffusion-pytorch that referenced this issue Sep 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants