Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Effstabledreamfusion #492

Merged
merged 11 commits into from
Oct 2, 2024

Conversation

jadevaibhav
Copy link
Contributor

@jadevaibhav jadevaibhav commented Jul 25, 2024

Efficient training of DreamFusion-like systems on higher-resolution images

I am working on a feature with Dreamfusion system(which can be extended to others). The basic idea is: to train using a higher-resolution image, we subsample pixels from it for NeRF rendering with a mask. Then we calculate the SDS loss at the original resolution image. The computational benefit is from a subsampling number of rays for NeRF training, while we train using higher resolution images (for a better visual model) in diffusion; resulting in roughly the same compute cost.

On testing using the demo prompt, using 128x128 image resolution and 64x64 subsampling for NeRF training, I get the following result.
Screenshot 2024-07-25 at 4 33 10 PM
I would like any feedback on potential issues with this idea, and how to improve results. I am looking forward to hearing from this community! @DSaurus @voletiv @bennyguo @thuliu-yt16

@jadevaibhav
Copy link
Contributor Author

jadevaibhav commented Jul 29, 2024

For comparison, with the efficient sampling method described above, I get ~30 min for training NeRF with 128x128 resolution (subsampled to 64x64). Without efficient sampling I get ~41 min of training duration (128x128 resolution), keeping all other parameters the same.

@jadevaibhav jadevaibhav marked this pull request as ready for review August 5, 2024 04:46
DSaurus
DSaurus previously approved these changes Aug 9, 2024
Copy link
Collaborator

@DSaurus DSaurus left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jadevaibhav ,

Great job! Thank you for contributing to threestudio. Could you provide examples about how to run efficient dream fusion and some 3D rendering videos of results? Then I'm glad to merge these commits and add this feature in README.

@jadevaibhav
Copy link
Contributor Author

Hi @DSaurus, thanks for your approval! I have created a separate yaml config for this, so you just have to run:

python launch.py --config configs/dreamfusion-sd-eff.yaml --train  system.prompt_processor.prompt="a zoomed out DSLR photo of a baby bunny sitting on top of a stack of pancakes"

Here are the videos I generated, although they are not good quality... I am still investigating where the issue with generation quality is, and if this method can be extended to other generative systems.

it10000-test.mp4
it10000-test.mp4

@DSaurus
Copy link
Collaborator

DSaurus commented Aug 26, 2024

Hi @jadevaibhav ,

Perhaps you could try to cache the rendering images without gradient first. Then, you sample some rays of this complete rendering image and update the corresponding pixels to do the SDS process. I think it is more robust for 3D generation.

@jadevaibhav
Copy link
Contributor Author

@DSaurus, could you please explain what you mean here?
If I understand correctly, caching multiple images before updating through SDS would be equivalent to directly generating bigger-resolution images. This defeats the purpose of generating a sub-sampled grid... My idea is essentially to take advantage of the continuous representation of 3D space learned through MLP. So at each iteration, we randomly sub-sample a set of ray directions, and over the complete optimization process, we learn at the original (bigger) resolution.

Here's my code of sub-sampling for clarity:

def mask_ray_directions(
    H: int,
    W:int,
    s_H:int,
    s_W:int
    ) -> Float[Tensor, "s_H s_W"]:
    """
    Masking the (H,W) image to (s_H,s_W), for efficient training at higher resolution image.
    pixels from (s_H,s_W) are sampled more (1-aspect_ratio) than outside pixels(aspect_ratio).
    the masking is deferred to before calling get_rays().
    """
    indices_all = torch.meshgrid(
        torch.arange(W, dtype=torch.float32) ,
        torch.arange(H, dtype=torch.float32) ,
        indexing="xy",
    )
    
    mask = torch.zeros(H,W, dtype=torch.bool)
    mask[(H-s_H)//2 : H - math.ceil((H-s_H)/2),(W-s_W)//2 : W - math.ceil((W-s_W)/2)] = True

    in_ind_1d = (indices_all[0]+H*indices_all[1])[mask]
    out_ind_1d = (indices_all[0]+H*indices_all[1])[torch.logical_not(mask)]
    ### tried using 0.5 p ratio of sampling inside vs outside, as smaller area already 
    ### leads to more samples inside anyways

    p = 0.5#(s_H*s_W)/(H*W)
    select_ind = in_ind_1d[
        torch.multinomial(
        torch.ones_like(in_ind_1d)*(1-p),int((1-p)*(s_H*s_W)),replacement=False)]
    select_ind = torch.concatenate(
        [select_ind, out_ind_1d[torch.multinomial(
            torch.ones_like(out_ind_1d)*(p),int((p)*(s_H*s_W)),replacement=False)]
        ],
        dim=0).to(dtype=torch.int).view(s_H,s_W)

    
    return select_ind

@DSaurus
Copy link
Collaborator

DSaurus commented Aug 27, 2024

@jadevaibhav Sure, my idea is to use these cached images multiple times, and each time you can apply your sub-sampler to update these images. If my understanding is correct, the current mask sub-sampler will render images that are not complete. However, diffusion models like Stable Diffusion are not designed to recover these incomplete images. I think this is the reason why the current mask sub-sampler leads to unstable results.

@jadevaibhav
Copy link
Contributor Author

@DSaurus the sub-sampler is used on generated directions, so we only pass selected directions to NeRF. And while calculating SDS loss, I pass the original resolution image with rendered color filled at given indices, and 0 elsewhere. I also believe that diffusion is unable to recover the incomplete image.
Rather than creating an incomplete image, I am thinking of doing an interpolation using these rendered colors. This way, even the gradients are not being wasted. What do you think?
I will be happy to continue the caching discussion on Discord if you want. Also, should we merge the current version in the meantime?

@jadevaibhav
Copy link
Contributor Author

Hi @DSaurus thanks for approving the PR! I don't have the write access, so could you please merge?

I looked into the "interpolation", but currently there is no way to do it with randomly sampled positions. I was looking into the grid_sample() method, but I can't define a transformation or mapping from the original resolution coordinate system to the sampled grid coordinates. I am now experimenting with uniform subsampling, with a random offset for the top-left grid corner.

@jadevaibhav
Copy link
Contributor Author

I finished the new experiment, and it works better than before! The training time is still the same (~33 mins)!

Screenshot 2024-09-22 at 8 53 12 PM

it10000-test-new.mp4

@DSaurus
Copy link
Collaborator

DSaurus commented Sep 29, 2024

@jadevaibhav LGTM! Could you please create a file named eff_dreamfusion.py in the system folder and put your current code into this file?

@jadevaibhav
Copy link
Contributor Author

Sure!

@jadevaibhav
Copy link
Contributor Author

Done! Please review @DSaurus

DSaurus
DSaurus previously approved these changes Oct 2, 2024
Copy link
Collaborator

@DSaurus DSaurus left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jadevaibhav Thanks!

@jadevaibhav jadevaibhav requested a review from DSaurus October 2, 2024 01:42
@DSaurus DSaurus merged commit bdd6db0 into threestudio-project:main Oct 2, 2024
1 check passed
@jadevaibhav
Copy link
Contributor Author

Thanks! I would like to contribute more, is there any new papers/ implementations we're looking at?

@DSaurus
Copy link
Collaborator

DSaurus commented Oct 2, 2024

@jadevaibhav I think it would be great if you are interested in implementing Wonder3D and its following papers, which could generate 3D objects in seconds.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants