Project Page | Arxiv | Weights
The official implementation of BrightDreamer: Generic 3D Gaussian Generative Framework for Fast Text-to-3D Synthesis.
If you find this work interesting or useful, please give me a ⭐!
User Interactive Demo
gui_demo.mp4
Interpolation Demonstration Demo
interpolation.mp4
If you have any questions about this project, please feel free to open an issue.
git clone https://github.com/lutao2021/BrightDreamer.git
cd BrightDreamer
conda create -n BrightDreamer python=3.9
conda activate BrightDreamer
You need to first install the suitable torch and torchvision according your environment. The version used in our experiments is
torch==1.13.1+cu117 torchvision==0.14.1+cu117
Then you can install other packages by
pip install -r requirements.txt
mkdir submodules
cd submodules
git clone https://github.com/graphdeco-inria/diff-gaussian-rasterization.git --recursive
git clone https://gitlab.inria.fr/bkerbl/simple-knn.git
pip install ./diff-gaussian-rasterization/
pip install ./simple-knn/
cd ..
To use the pre-trained model (provided or trained by yourself) to inference, you can choose one of the following methods. It needs about 16GB GPU VRAM.
python inference.py --model_path /path/to/ckpt_file --prompt "input text prompt" --save_path /folder/to/save/videos --default_radius 3.5 --default_polar 60
# example
CUDA_VISIBLE_DEVICES=0 python inference.py --model_path models/vehicle.pth --prompt "Racing car, cyan, lightweight aero kit, sequential gearbox" --save_path workspace_inference --default_radius 3.5 --default_polar 60
2) Infer by user's input from command (without the requirements to load model at each time). You can input the prompt in command. Input "exit" to quit.
python inference_cmd.py --model_path /path/to/ckpt_file --save_path /folder/to/save/videos --default_radius 3.5 --default_polar 60
# example
CUDA_VISIBLE_DEVICES=0 python inference_cmd.py --model_path models/vehicle.pth --save_path workspace_inference --default_radius 3.5 --default_polar 60
<optional> If you want to start in a server and use in the local pc, you need construct a tunnel to server first.
ssh -L 5000:127.0.0.1:5000 <your server host>
Next, you can start the back-end program.
python inference_gui.py --model_path /path/to/ckpt_file
# example
CUDA_VISIBLE_DEVICES=0 python inference_gui.py --model_path models/vehicle.pth
Then you can open the page in your local browser.
127.0.0.1:5000
-
To accelerate training, we choose to cache the text embeddings in the training prompt set. But this may cost more RAM memory space and disk space. This method can save several minutes (speed up about 15%) for each epoch depending on the size of training prompts. You can also choose to mix the provided prompts into a single txt file for mixing training.
python embedding_cache.py --prompts_set vehicle python embedding_cache.py --prompts_set daily_life python embedding_cache.py --prompts_set animal python embedding_cache.py --prompts_set mix
The cached text embeddings will be saved at ./vehicle.pkl, ./daily_life.pkl, and ./animal.pkl.
-
Train the BrightDreamer generator. We provide the command demo of training in the following scripts.
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh scripts/vehicle.sh CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh scripts/daily_life.sh CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh scripts/animal.sh CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh scripts/mix_training.sh
Key hyper-parameters:
--prompts_set The training prompts set. --cache_file The cached text embeddings. --batch_size The number of prompts in a single iteration on each card. The actual batch size is the number of gpus * batch_size. --c_batch_size The number of cameras for a prompt to render images to calculate the SDS loss. --lr Learning rate. --eval_interval The frequency of outputing test images. --test_interval The frequency of rendering test videos. --guidance The Unet used to calculate the SDS loss. --workspace The folder of output. --ckpt Recover the training process.
The batch_size of 8 and the c_batch_size of 4 may use 65GB GPU memory on a single card. In our experiments, 4 cards can also work well, but more slowly. Larger batch size will result in a better result. We train 36 hours for the vehicle prompts set, 60 hours for the daily life prompts set, and 30 hours for the animal prompts set on a server with 8 80GB GPUs.
- A better and more abundant prompts set will improve the training quality much more.
- The better diffusion model could improve our training quality.
- More training tricks can be introduced to our framework to improve the quality and to alleviate the 'Janus' promblem.
Our code is inspired by stable-dreamfusion, Stable Diffusion, gaussian-splatting and DeepFloyd-IF. Thanks for their outstanding works and open-source!
If you find this work useful, a citation will be appreciated via:
@misc{jiang2024brightdreamer,
title={BrightDreamer: Generic 3D Gaussian Generative Framework for Fast Text-to-3D Synthesis},
author={Lutao Jiang and Lin Wang},
year={2024},
eprint={2403.11273},
archivePrefix={arXiv},
primaryClass={cs.CV}
}