Skip to content

Latest commit

 

History

History
112 lines (102 loc) · 4.6 KB

train.org

File metadata and controls

112 lines (102 loc) · 4.6 KB

Training RAFT for MFT

The official checkpoint was trained on Sintel, FlyingThings3D a dataset generated by Kubric, but the procedure to generate the data is a bit complicated, so we also share a config without training on Kubric (results not much worse).

Setting up the datasets

Sintel

Place or symlink the Sintel dataset into datasets/Sintel-complete.

The datasets/Sintel-complete should have structure something like:

.
`-- training
    |-- clean
    |   |-- cave_4
    |   |   |-- frame_0001.png
    |   |   `-- frame_0002.png
    |   `-- alley_1
    |       |-- frame_0001.png
    |       `-- frame_0002.png
    |-- final
    |   `-- alley_1
    |       |-- frame_0001.png
    |       `-- frame_0002.png
    |-- flow
    |   `-- ambush_6
    |       |-- frame_0001.flo
    |       `-- frame_0002.flo
    `-- occlusions_rev
        `-- market_2
            |-- frame_0001.png
            `-- frame_0002.png

The occlusions_rev is a revised occlusion annotation marking out-of-view pixels as occluded and can be downloaded here.

FlyingThings3D

Place or symlink the FlyingThings3D dataset into datasets/FlyingThings3D. Additionally, we have used optical flow occlusion masks, available for download from google drive or MFT website (2.3GB). The masks were generated a very long time ago by a script which we include in generate_occlusion_maps_FlyingThings3D.py for documentation. It may still be working, but we didn’t test it recently. Note that the FlyingThings3D authors also provide occlusion maps in the newer FlyingThings3D dataset subset (“DispNet/FlowNet2.0 dataset subsets” on the download page). Feel free to create a pull request if you implement the dataloader for the FT3D subset.

The datasets/FlyingThings3D should have structure something like:

.
|-- frames_cleanpass
|   `-- TRAIN
|       `-- A
|           `-- 0003
|               `-- left
|                   |-- 0006.png
|                   |-- 0007.png
|                   `-- 0008.png
|-- optical_flow
|   `-- TRAIN
|       `-- A
|           `-- 0003
|	        |-- into_future
|               |   `-- left
|               |       `-- OpticalFlowIntoFuture_0006_L.pfm
|	        `-- into_past
|                   `-- left
|                       `-- OpticalFlowIntoPast_0006_L.pfm
`-- optical_flow_occlusion_png
    `-- TRAIN
        `-- A
            `-- 0003
	        |-- into_future
                |   `-- left
                |       `-- OpticalFlowIntoFuture_0006_L.png
	        `-- into_past
                    `-- left
                        `-- OpticalFlowIntoPast_0006_L.png

Kubric LongFlow

The generated dataset can be downloaded from google drive (42.9GB). Place it into datasets/kubric_movi_e_longterm, it should have structure like:

.
`-- train
    |-- 00000
    |   |-- images
    |   |   |-- 0000.png
    |   |   |-- 0001.png
    |   |   `-- 0002.png
    |   `-- flowou
    |       |-- 0000_to_0000.flowou.png
    |       |-- 0000_to_0001.flowou.png
    |       `-- 0000_to_0002.flowou.png
    `-- 05794
        |-- images ...
        `-- flowou ...

Generating the LongFlow dataset

We generated the dataset from the kubric MOVi-E by computing dense flow between the frame 0000 and all the other frames in each sequence. See the multiflow_from_kubric.py script that was used to generate the dataset (based on Kubric Point-Tracking Dataset.

Training

We start the training from the raft-sintel.pth checkpoint provided by RAFT authors here. Place it into the checkpoints/ directory.

Install the dependencies:

pip install torchvision tensorboard

And run the training:

python -m MFT.RAFT.train @train_params.txt
# or python -m MFT.RAFT.train @train_params_no_kubric.txt