Skip to content
This repository has been archived by the owner on Apr 9, 2024. It is now read-only.

Commit

Permalink
open sourcing SAVi++ model
Browse files Browse the repository at this point in the history
  • Loading branch information
gamaleldin committed Oct 14, 2022
1 parent 229263c commit 7024884
Show file tree
Hide file tree
Showing 13 changed files with 930 additions and 32 deletions.
59 changes: 46 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
## Slot Attention for Video (SAVi)
## Slot Attention for Video (SAVi and SAVi++)

This repository contains the code release for "Conditional Object-Centric
Learning from Video" (ICLR 2022).
Learning from Video" (ICLR 2022) and "SAVi++: Towards End-to-End Object-Centric
Learning from Real-World Videos" (NeurIPS 2022)


<img src="savi.gif" alt="SAVi animation" width="400"/>

Paper: https://arxiv.org/abs/2111.12594
<br />

<img src="savi++_1.gif" alt="SAVi++ animation 1" width="400"/>

Project website: https://slot-attention-video.github.io/
<img src="savi++_2.gif" alt="SAVi++ animation 2" width="400"/>

Papers:
https://arxiv.org/abs/2111.12594
https://arxiv.org/abs/2206.07764

Project websites:
https://slot-attention-video.github.io/
https://slot-attention-video.github.io/savi++/

## Instructions
> ℹ️ The following instructions assume that you are using JAX on GPUs and have CUDA and CuDNN installed. For more details on how to use JAX with accelerators, including requirements and TPUs, please read the [JAX installation instructions](https://github.com/google/jax#installation).
Expand All @@ -28,10 +39,17 @@ python -m savi.main --config savi/configs/movi/savi_conditional_small.py --workd
```
to train the smallest SAVi model (SAVi-S) on the [MOVi-A](https://github.com/google-research/kubric/blob/main/challenges/movi/README.md) dataset.

or

```sh
python -m savi.main --config savi/configs/movi/savi++_conditional.py --workdir tmp/
```
to train the more capable SAVi++ model on the [MOVi-E](https://github.com/google-research/kubric/blob/main/challenges/movi/README.md) dataset.

The MOVi datasets are stored in a [Google Cloud Storage (GCS) bucket](https://console.cloud.google.com/storage/browser/kubric-public/tfds)
and can be downloaded to local disk prior to training for improved efficiency.

To use a local copy of MOVi-A, for example, please copy the relevant folder to your local disk and set `data_dir` in the config file (`configs/movi/savi_conditional.py`) to point to it. In more detail, first copy using commands such as
To use a local copy of MOVi-A, for example, please copy the relevant folder to your local disk and set `data_dir` in the config file (e.g., `configs/movi/savi_conditional.py`) to point to it. In more detail, first copy using commands such as

```
gsutil -m cp -r gs://kubric-public/tfds/movi_a/128x128/1.0.0 ./movi_a_128x128/
Expand All @@ -54,21 +72,25 @@ The resulting directory structure will be as follows:

In order to use the local copy simply set `data_dir = "./"` in the config file `configs/movi/savi_conditional_small.py`. You can also copy it into a different location and set the `data_dir` accordingly.

To run SAVi on other MOVi dataset variants, follow the instructions above while replacing `movi_a` with, e.g. `movi_b` or `movi_c`.
To run SAVi or SAVi++ on other MOVi dataset variants, follow the instructions above while replacing `movi_a` with, e.g. `movi_b` or `movi_c`.

## Expected results

At present, this repository only contains the SAVi model configurations without ResNet backbone from our [ICLR 2022 paper](https://arxiv.org/abs/2111.12594). We here refer to these models as SAVi-S and SAVi-M. SAVi-S is trained and evaluated on downscaled 64x64 frames, whereas SAVi-M uses 128x128 frames and a larger CNN backbone. Expected results and a configuration file for the largest SAVi model variant with ResNet backbone (SAVi-L) will be added shortly.
This repository contains the SAVi model configurations from our [ICLR 2022 paper](https://arxiv.org/abs/2111.12594). We here refer to these models as SAVi-S, SAVi-M, and SAVi-L. SAVi-S is trained and evaluated on downscaled 64x64 frames, whereas SAVi-M uses 128x128 frames and a larger CNN backbone. SAVi-L is similar to SAVi-M except that it uses larger ResNet34 encoder and slot embedding.

This repository contains also the SAVi++ model configurations from our [NeurIPS 2022 paper](https://arxiv.org/abs/2206.07764). SAVi++ uses a more powerful encoder than SAVi-L that adds transformer blocks to the ResNet34. SAVi++ also adds data augmentation and training on depth targets. SAVi++ is able to better handle real world videos with more complexities such as camera movements and complex object shapes and textures.

The released MOVi datasets as part of [Kubric](https://github.com/google-research/kubric/) differ slightly from the ones used in our [ICLR 2022 paper](https://arxiv.org/abs/2111.12594) and are of slightly higher complexity (e.g., more variation in backgrounds), results are therefore not directly comparable. MOVi-A is approximately comparable to the "MOVi" dataset used in our [ICLR 2022 paper](https://arxiv.org/abs/2111.12594), whereas MOVi-C is approximately comparable to "MOVi++". We provide updated results for our released configs and the MOVi datasets with version `1.0.0` below.

| Model | MOVi-A | MOVi-B | MOVi-C | MOVi-D | MOVi-E |
|------------|------------|------------|------------|------------|------------|
| **SAVi-S** | 92.1 ± 0.1 | 72.2 ± 0.5 | 64.7 ± 0.3 | 33.8 ± 7.7 | 8.3 ± 0.9 |
| **SAVi-M** | 93.4 ± 1.0 | 75.1 ± 0.5 | 67.4 ± 0.5 | 20.8 ± 2.2 | 12.2 ± 1.1 |
| **SAVi-L** | TBA | TBA | TBA | TBA | TBA |
| Model | MOVi-A | MOVi-B | MOVi-C | MOVi-D | MOVi-E |
|------------|-------------|-------------|------------|------------|-----------|
| **SAVi-S** | 92.1 ± 0.1 | 72.2 ± 0.5 | 64.7 ± 0.3 | 33.8 ± 7.7 | 8.3 ± 0.9 |
| **SAVi-M** | 93.4 ± 1.0 | 75.1 ± 0.5 | 67.4 ± 0.5 | 20.8 ± 2.2 | 12.2 ± 1.1|
| **SAVi-L** | 95.1 ± 0.6 | 64.8 ± 8.9 | 71.3 ± 1.6 | 59.7 ± 6.0 | 34.1 ± 1.2|
| **SAVi++** | 85.3 ± 9.8 | 72.5 ± 11.2 | 79.1 ± 2.1 | 84.8 ± 1.4 | 85.1 ± 0.9|


All results are in terms of **FG-ARI** (in %) on validation splits. Mean ± standard error over 5 seeds. All SAVi models reported above use bounding boxes of the first video frame as conditioning signal.
All results are in terms of **FG-ARI** (in %) on validation splits. Mean ± standard error over 5 seeds. All SAVi and SAVi++ models reported above use bounding boxes of the first video frame as conditioning signal.

## Cite

Expand All @@ -83,5 +105,16 @@ All results are in terms of **FG-ARI** (in %) on validation splits. Mean ± stan
}
```

```
@inproceedings{elsayed2022savi++,
author={Elsayed, Gamaleldin F. and Mahendran, Aravindh
and van Steenkiste, Sjoerd and Greff, Klaus and Mozer, Michael C.
and Kipf, Thomas},
title = {{SAVi++: Towards end-to-end object-centric learning from real-world videos}},
booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
year = {2022}
}
```

## Disclaimer
This is not an official Google product.
Binary file added savi++_1.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added savi++_2.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
220 changes: 220 additions & 0 deletions savi/configs/movi/savi++_conditional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Config for a conditional SAVi++ model.
SAVi++ operates on 128x128 video frames and uses a ResNet-34 backbone. This
model is comparable to the SAVi++ model evaluated on MOVi in the SAVi++
NeurIPS 2022 paper:
https://arxiv.org/abs/2206.07764
"""

import ml_collections


def get_config():
"""Get the default hyperparameter configuration."""
config = ml_collections.ConfigDict()

config.seed = 42
config.seed_data = True

config.batch_size = 64
config.num_train_steps = 500000

# Adam optimizer config.
config.learning_rate = 2e-4
config.warmup_steps = 2500
config.max_grad_norm = 0.05

config.log_loss_every_steps = 50
config.eval_every_steps = 1000
config.checkpoint_every_steps = 5000

config.train_metrics_spec = {
"loss": "loss",
"ari": "ari",
"ari_nobg": "ari_nobg",
}
config.eval_metrics_spec = {
"eval_loss": "loss",
"eval_ari": "ari",
"eval_ari_nobg": "ari_nobg",
}

config.data = ml_collections.ConfigDict({
"tfds_name": "movi_e/128x128:1.0.0", # Dataset for training/eval.
"data_dir": "gs://kubric-public/tfds",
"shuffle_buffer_size": config.batch_size * 8,
})

# NOTE: MOVi-A, MOVi-B, and MOVi-C only contain up to 10 instances (objects),
# i.e. it is safe to reduce config.max_instances to 10 for these datasets,
# resulting in more efficient training/evaluation. We set this default to 23,
# since MOVi-D and MOVi-E contain up to 23 objects per video. Setting
# config.max_instances to a smaller number than the maximum number of objects
# in a dataset will discard objects, ultimately giving different results.
config.max_instances = 23
config.num_slots = config.max_instances + 1 # Only used for metrics.
config.logging_min_n_colors = config.max_instances

config.preproc_train = [
"video_from_tfds",
f"sparse_to_dense_annotation(max_instances={config.max_instances})",
"temporal_random_strided_window(length=6)",
"random_resized_crop" +
"(height=128, width=128, min_object_covered=0.75)",
"transform_depth(transform='log_plus')",
"flow_to_rgb()" # NOTE: This only uses the first two flow dimensions.
]

config.preproc_eval = [
"video_from_tfds",
f"sparse_to_dense_annotation(max_instances={config.max_instances})",
"temporal_crop_or_pad(length=24)",
"resize_small(128)",
"transform_depth(transform='log_plus')",
"flow_to_rgb()" # NOTE: This only uses the first two flow dimensions.
]

config.eval_slice_size = 6
config.eval_slice_keys = [
"video", "segmentations", "flow", "boxes", "depth"
]

# Dictionary of targets and corresponding channels. Losses need to match.
config.targets = {"flow": 3, "depth": 1}
config.losses = ml_collections.ConfigDict({
f"recon_{target}": {"loss_type": "recon", "key": target}
for target in config.targets})

config.conditioning_key = "boxes"

config.model = ml_collections.ConfigDict({
"module": "savi.modules.SAVi",

# Encoder.
"encoder": ml_collections.ConfigDict({
"module": "savi.modules.FrameEncoder",
"reduction": "spatial_flatten",

"backbone": ml_collections.ConfigDict({
"module": "savi.modules.ResNet34",
"num_classes": None,
"axis_name": "time",
"norm_type": "group",
"small_inputs": True
}),
"pos_emb": ml_collections.ConfigDict({
"module": "savi.modules.PositionEmbedding",
"embedding_type": "linear",
"update_type": "project_add",
"output_transform": ml_collections.ConfigDict({
"module": "savi.modules.MLP",
"hidden_size": 64,
"layernorm": "pre"
}),
}),
# Transformer.
"output_transform": ml_collections.ConfigDict({
"module": "savi.modules.Transformer",
"num_layers": 4,
"num_heads": 4,
"qkv_size": 16 * 4,
"mlp_size": 1024,
"pre_norm": True,
}),
}),

# Corrector.
"corrector": ml_collections.ConfigDict({
"module": "savi.modules.SlotAttention",
"num_iterations": 1,
"qkv_size": 256,
}),

# Predictor.
"predictor": ml_collections.ConfigDict({
"module": "savi.modules.TransformerBlock",
"num_heads": 4,
"qkv_size": 256,
"mlp_size": 1024
}),

# Initializer.
"initializer": ml_collections.ConfigDict({
"module": "savi.modules.CoordinateEncoderStateInit",
"prepend_background": True,
"center_of_mass": False,
"embedding_transform": ml_collections.ConfigDict({
"module": "savi.modules.MLP",
"hidden_size": 256,
"output_size": 128,
"layernorm": None
}),
}),

# Decoder.
"decoder": ml_collections.ConfigDict({
"module":
"savi.modules.SpatialBroadcastDecoder",
"resolution": (8, 8), # Update if data resol. or strides change.
"early_fusion": True,
"backbone": ml_collections.ConfigDict({
"module": "savi.modules.CNN",
"features": [64, 64, 64, 64],
"kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
"strides": [(2, 2), (2, 2), (2, 2), (2, 2)],
"layer_transpose": [True, True, True, True]
}),
"pos_emb": ml_collections.ConfigDict({
"module": "savi.modules.PositionEmbedding",
"embedding_type": "linear",
"update_type": "project_add"
}),
"target_readout": ml_collections.ConfigDict({
"module": "savi.modules.Readout",
"keys": list(config.targets),
"readout_modules": [ml_collections.ConfigDict({
"module": "savi.modules.MLP",
"num_hidden_layers": 0,
"hidden_size": 0, "output_size": config.targets[k]})
for k in config.targets],
}),
}),
"decode_corrected": True,
"decode_predicted": False, # Disable prediction decoder to save memory.
})

# Define which video-shaped variables to log/visualize.
config.debug_var_video_paths = {
"recon_masks": "SpatialBroadcastDecoder_0/alphas",
}
for k in config.targets:
config.debug_var_video_paths.update({
f"{k}_recon": f"SpatialBroadcastDecoder_0/{k}_combined"})

# Define which attention matrices to log/visualize.
config.debug_var_attn_paths = {
"corrector_attn": "SlotAttention_0/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn"
}

# Widths of attention matrices (for reshaping to image grid).
config.debug_var_attn_widths = {
"corrector_attn": 16,
}

return config


Loading

0 comments on commit 7024884

Please sign in to comment.