A replication of Rectified Flow paper with PyTorch and U-ViT.
To train a new model, you can modify the yaml file and:
python multi_gpu_trainer.py example
Training data of Oxford Flowers should be split manually, and you can find the numpy version of their labels in this repo.
To run inference, please download my pretrained weight:
python sample_img.py --device "cuda:0" --load "last" --SavedDir tmp/ --ExpConfig example/example.yaml --n_sqrt 16 --steps 200
or use an ODE solver:
pip install torchdiffeq
python sample_img_ODESolver.py --device "cuda:0" --load "last" --SavedDir tmp/ --ExpConfig example/example.yaml --n_sqrt 16 --rtol 0.001
The inference process is controled by 6 parameters :
"device", usually 'cuda:0' ;
"load", best epoch or last epoch;
"SavedDir", where to save images;
"ExpConfig", the yaml file of your experiments;
"n_sqrt", you will get N2 samples for each class;
"steps", n steps for sampling, in my experiment, 200 is a good choice;
"rtol", acceptable relative error per step, 1e-3 is good enough.
The result should looks like the welcoming images.
python image_interpolation.py --device "cuda:0" --load "last" --SavedDir tmp/ --ExpConfig example/example.yaml --input_image images/image1.jpg --target_image images/image2.jpg --rtol 0.0001 --mix_depth -0.02 --spherical True
This function is experimental and currently does not work well!
Enjoy!