diff --git a/README.md b/README.md index d7d1ae7..5ef174e 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,25 @@ -## Slot Attention for Video (SAVi) +## Slot Attention for Video (SAVi and SAVi++) This repository contains the code release for "Conditional Object-Centric -Learning from Video" (ICLR 2022). +Learning from Video" (ICLR 2022) and "SAVi++: Towards End-to-End Object-Centric +Learning from Real-World Videos" (NeurIPS 2022) + SAVi animation -Paper: https://arxiv.org/abs/2111.12594 +
+ +SAVi++ animation 1 -Project website: https://slot-attention-video.github.io/ +SAVi++ animation 2 +Papers: +https://arxiv.org/abs/2111.12594 +https://arxiv.org/abs/2206.07764 + +Project websites: +https://slot-attention-video.github.io/ +https://slot-attention-video.github.io/savi++/ ## Instructions > ℹ️ The following instructions assume that you are using JAX on GPUs and have CUDA and CuDNN installed. For more details on how to use JAX with accelerators, including requirements and TPUs, please read the [JAX installation instructions](https://github.com/google/jax#installation). @@ -28,10 +39,17 @@ python -m savi.main --config savi/configs/movi/savi_conditional_small.py --workd ``` to train the smallest SAVi model (SAVi-S) on the [MOVi-A](https://github.com/google-research/kubric/blob/main/challenges/movi/README.md) dataset. +or + +```sh +python -m savi.main --config savi/configs/movi/savi++_conditional.py --workdir tmp/ +``` +to train the more capable SAVi++ model on the [MOVi-E](https://github.com/google-research/kubric/blob/main/challenges/movi/README.md) dataset. + The MOVi datasets are stored in a [Google Cloud Storage (GCS) bucket](https://console.cloud.google.com/storage/browser/kubric-public/tfds) and can be downloaded to local disk prior to training for improved efficiency. -To use a local copy of MOVi-A, for example, please copy the relevant folder to your local disk and set `data_dir` in the config file (`configs/movi/savi_conditional.py`) to point to it. In more detail, first copy using commands such as +To use a local copy of MOVi-A, for example, please copy the relevant folder to your local disk and set `data_dir` in the config file (e.g., `configs/movi/savi_conditional.py`) to point to it. In more detail, first copy using commands such as ``` gsutil -m cp -r gs://kubric-public/tfds/movi_a/128x128/1.0.0 ./movi_a_128x128/ @@ -54,21 +72,25 @@ The resulting directory structure will be as follows: In order to use the local copy simply set `data_dir = "./"` in the config file `configs/movi/savi_conditional_small.py`. You can also copy it into a different location and set the `data_dir` accordingly. -To run SAVi on other MOVi dataset variants, follow the instructions above while replacing `movi_a` with, e.g. `movi_b` or `movi_c`. +To run SAVi or SAVi++ on other MOVi dataset variants, follow the instructions above while replacing `movi_a` with, e.g. `movi_b` or `movi_c`. ## Expected results -At present, this repository only contains the SAVi model configurations without ResNet backbone from our [ICLR 2022 paper](https://arxiv.org/abs/2111.12594). We here refer to these models as SAVi-S and SAVi-M. SAVi-S is trained and evaluated on downscaled 64x64 frames, whereas SAVi-M uses 128x128 frames and a larger CNN backbone. Expected results and a configuration file for the largest SAVi model variant with ResNet backbone (SAVi-L) will be added shortly. +This repository contains the SAVi model configurations from our [ICLR 2022 paper](https://arxiv.org/abs/2111.12594). We here refer to these models as SAVi-S, SAVi-M, and SAVi-L. SAVi-S is trained and evaluated on downscaled 64x64 frames, whereas SAVi-M uses 128x128 frames and a larger CNN backbone. SAVi-L is similar to SAVi-M except that it uses larger ResNet34 encoder and slot embedding. + +This repository contains also the SAVi++ model configurations from our [NeurIPS 2022 paper](https://arxiv.org/abs/2206.07764). SAVi++ uses a more powerful encoder than SAVi-L that adds transformer blocks to the ResNet34. SAVi++ also adds data augmentation and training on depth targets. SAVi++ is able to better handle real world videos with more complexities such as camera movements and complex object shapes and textures. The released MOVi datasets as part of [Kubric](https://github.com/google-research/kubric/) differ slightly from the ones used in our [ICLR 2022 paper](https://arxiv.org/abs/2111.12594) and are of slightly higher complexity (e.g., more variation in backgrounds), results are therefore not directly comparable. MOVi-A is approximately comparable to the "MOVi" dataset used in our [ICLR 2022 paper](https://arxiv.org/abs/2111.12594), whereas MOVi-C is approximately comparable to "MOVi++". We provide updated results for our released configs and the MOVi datasets with version `1.0.0` below. -| Model | MOVi-A | MOVi-B | MOVi-C | MOVi-D | MOVi-E | -|------------|------------|------------|------------|------------|------------| -| **SAVi-S** | 92.1 ± 0.1 | 72.2 ± 0.5 | 64.7 ± 0.3 | 33.8 ± 7.7 | 8.3 ± 0.9 | -| **SAVi-M** | 93.4 ± 1.0 | 75.1 ± 0.5 | 67.4 ± 0.5 | 20.8 ± 2.2 | 12.2 ± 1.1 | -| **SAVi-L** | TBA | TBA | TBA | TBA | TBA | +| Model | MOVi-A | MOVi-B | MOVi-C | MOVi-D | MOVi-E | +|------------|-------------|-------------|------------|------------|-----------| +| **SAVi-S** | 92.1 ± 0.1 | 72.2 ± 0.5 | 64.7 ± 0.3 | 33.8 ± 7.7 | 8.3 ± 0.9 | +| **SAVi-M** | 93.4 ± 1.0 | 75.1 ± 0.5 | 67.4 ± 0.5 | 20.8 ± 2.2 | 12.2 ± 1.1| +| **SAVi-L** | 95.1 ± 0.6 | 64.8 ± 8.9 | 71.3 ± 1.6 | 59.7 ± 6.0 | 34.1 ± 1.2| +| **SAVi++** | 85.3 ± 9.8 | 72.5 ± 11.2 | 79.1 ± 2.1 | 84.8 ± 1.4 | 85.1 ± 0.9| + -All results are in terms of **FG-ARI** (in %) on validation splits. Mean ± standard error over 5 seeds. All SAVi models reported above use bounding boxes of the first video frame as conditioning signal. +All results are in terms of **FG-ARI** (in %) on validation splits. Mean ± standard error over 5 seeds. All SAVi and SAVi++ models reported above use bounding boxes of the first video frame as conditioning signal. ## Cite @@ -83,5 +105,16 @@ All results are in terms of **FG-ARI** (in %) on validation splits. Mean ± stan } ``` +``` +@inproceedings{elsayed2022savi++, + author={Elsayed, Gamaleldin F. and Mahendran, Aravindh + and van Steenkiste, Sjoerd and Greff, Klaus and Mozer, Michael C. + and Kipf, Thomas}, + title = {{SAVi++: Towards end-to-end object-centric learning from real-world videos}}, + booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, + year = {2022} +} +``` + ## Disclaimer This is not an official Google product. diff --git a/savi++_1.gif b/savi++_1.gif new file mode 100644 index 0000000..d026a0d Binary files /dev/null and b/savi++_1.gif differ diff --git a/savi++_2.gif b/savi++_2.gif new file mode 100644 index 0000000..df55840 Binary files /dev/null and b/savi++_2.gif differ diff --git a/savi/configs/movi/savi++_conditional.py b/savi/configs/movi/savi++_conditional.py new file mode 100644 index 0000000..bd4835a --- /dev/null +++ b/savi/configs/movi/savi++_conditional.py @@ -0,0 +1,220 @@ +# Copyright 2022 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config for a conditional SAVi++ model. + +SAVi++ operates on 128x128 video frames and uses a ResNet-34 backbone. This +model is comparable to the SAVi++ model evaluated on MOVi in the SAVi++ +NeurIPS 2022 paper: +https://arxiv.org/abs/2206.07764 +""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 + + # Adam optimizer config. + config.learning_rate = 2e-4 + config.warmup_steps = 2500 + config.max_grad_norm = 0.05 + + config.log_loss_every_steps = 50 + config.eval_every_steps = 1000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "tfds_name": "movi_e/128x128:1.0.0", # Dataset for training/eval. + "data_dir": "gs://kubric-public/tfds", + "shuffle_buffer_size": config.batch_size * 8, + }) + + # NOTE: MOVi-A, MOVi-B, and MOVi-C only contain up to 10 instances (objects), + # i.e. it is safe to reduce config.max_instances to 10 for these datasets, + # resulting in more efficient training/evaluation. We set this default to 23, + # since MOVi-D and MOVi-E contain up to 23 objects per video. Setting + # config.max_instances to a smaller number than the maximum number of objects + # in a dataset will discard objects, ultimately giving different results. + config.max_instances = 23 + config.num_slots = config.max_instances + 1 # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "video_from_tfds", + f"sparse_to_dense_annotation(max_instances={config.max_instances})", + "temporal_random_strided_window(length=6)", + "random_resized_crop" + + "(height=128, width=128, min_object_covered=0.75)", + "transform_depth(transform='log_plus')", + "flow_to_rgb()" # NOTE: This only uses the first two flow dimensions. + ] + + config.preproc_eval = [ + "video_from_tfds", + f"sparse_to_dense_annotation(max_instances={config.max_instances})", + "temporal_crop_or_pad(length=24)", + "resize_small(128)", + "transform_depth(transform='log_plus')", + "flow_to_rgb()" # NOTE: This only uses the first two flow dimensions. + ] + + config.eval_slice_size = 6 + config.eval_slice_keys = [ + "video", "segmentations", "flow", "boxes", "depth" + ] + + # Dictionary of targets and corresponding channels. Losses need to match. + config.targets = {"flow": 3, "depth": 1} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in config.targets}) + + config.conditioning_key = "boxes" + + config.model = ml_collections.ConfigDict({ + "module": "savi.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "savi.modules.FrameEncoder", + "reduction": "spatial_flatten", + + "backbone": ml_collections.ConfigDict({ + "module": "savi.modules.ResNet34", + "num_classes": None, + "axis_name": "time", + "norm_type": "group", + "small_inputs": True + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "savi.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "project_add", + "output_transform": ml_collections.ConfigDict({ + "module": "savi.modules.MLP", + "hidden_size": 64, + "layernorm": "pre" + }), + }), + # Transformer. + "output_transform": ml_collections.ConfigDict({ + "module": "savi.modules.Transformer", + "num_layers": 4, + "num_heads": 4, + "qkv_size": 16 * 4, + "mlp_size": 1024, + "pre_norm": True, + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "savi.modules.SlotAttention", + "num_iterations": 1, + "qkv_size": 256, + }), + + # Predictor. + "predictor": ml_collections.ConfigDict({ + "module": "savi.modules.TransformerBlock", + "num_heads": 4, + "qkv_size": 256, + "mlp_size": 1024 + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": "savi.modules.CoordinateEncoderStateInit", + "prepend_background": True, + "center_of_mass": False, + "embedding_transform": ml_collections.ConfigDict({ + "module": "savi.modules.MLP", + "hidden_size": 256, + "output_size": 128, + "layernorm": None + }), + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "savi.modules.SpatialBroadcastDecoder", + "resolution": (8, 8), # Update if data resol. or strides change. + "early_fusion": True, + "backbone": ml_collections.ConfigDict({ + "module": "savi.modules.CNN", + "features": [64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (2, 2)], + "layer_transpose": [True, True, True, True] + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "savi.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "project_add" + }), + "target_readout": ml_collections.ConfigDict({ + "module": "savi.modules.Readout", + "keys": list(config.targets), + "readout_modules": [ml_collections.ConfigDict({ + "module": "savi.modules.MLP", + "num_hidden_layers": 0, + "hidden_size": 0, "output_size": config.targets[k]}) + for k in config.targets], + }), + }), + "decode_corrected": True, + "decode_predicted": False, # Disable prediction decoder to save memory. + }) + + # Define which video-shaped variables to log/visualize. + config.debug_var_video_paths = { + "recon_masks": "SpatialBroadcastDecoder_0/alphas", + } + for k in config.targets: + config.debug_var_video_paths.update({ + f"{k}_recon": f"SpatialBroadcastDecoder_0/{k}_combined"}) + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "SlotAttention_0/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/savi/configs/movi/savi_conditional_large.py b/savi/configs/movi/savi_conditional_large.py new file mode 100644 index 0000000..d1fc44c --- /dev/null +++ b/savi/configs/movi/savi_conditional_large.py @@ -0,0 +1,209 @@ +# Copyright 2022 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Config for a large conditional SAVi model (SAVi-L). + +SAVi (ResNet) operates on 128x128 video frames and uses a ResNet34 backbone. +This model is comparable to the SAVi Large model evaluated on MOVi++ in the SAVi +ICLR paper: +https://arxiv.org/abs/2111.12594 + +By default, this config uses bounding box coordinates as conditioning signal. +Set `center_of_mass` to `True` to condition on center-of-mass coords instead. +""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.seed = 42 + config.seed_data = True + + config.batch_size = 64 + config.num_train_steps = 500000 + + # Adam optimizer config. + config.learning_rate = 2e-4 + config.warmup_steps = 2500 + config.max_grad_norm = 0.05 + + config.log_loss_every_steps = 50 + config.eval_every_steps = 1000 + config.checkpoint_every_steps = 5000 + + config.train_metrics_spec = { + "loss": "loss", + "ari": "ari", + "ari_nobg": "ari_nobg", + } + config.eval_metrics_spec = { + "eval_loss": "loss", + "eval_ari": "ari", + "eval_ari_nobg": "ari_nobg", + } + + config.data = ml_collections.ConfigDict({ + "tfds_name": "movi_a/128x128:1.0.0", # Dataset for training/eval. + "data_dir": "gs://kubric-public/tfds", + "shuffle_buffer_size": config.batch_size * 8, + }) + + # NOTE: MOVi-A, MOVi-B, and MOVi-C only contain up to 10 instances (objects), + # i.e. it is safe to reduce config.max_instances to 10 for these datasets, + # resulting in more efficient training/evaluation. We set this default to 23, + # since MOVi-D and MOVi-E contain up to 23 objects per video. Setting + # config.max_instances to a smaller number than the maximum number of objects + # in a dataset will discard objects, ultimately giving different results. + config.max_instances = 23 + config.num_slots = config.max_instances + 1 # Only used for metrics. + config.logging_min_n_colors = config.max_instances + + config.preproc_train = [ + "video_from_tfds", + f"sparse_to_dense_annotation(max_instances={config.max_instances})", + "temporal_random_strided_window(length=6)", + "resize_small(128)", + "flow_to_rgb()" # NOTE: This only uses the first two flow dimensions. + ] + + config.preproc_eval = [ + "video_from_tfds", + f"sparse_to_dense_annotation(max_instances={config.max_instances})", + "temporal_crop_or_pad(length=24)", + "resize_small(128)", + "flow_to_rgb()" # NOTE: This only uses the first two flow dimensions. + ] + + # Evaluate on full video sequence by iterating over smaller chunks. + config.eval_slice_size = 6 + config.eval_slice_keys = ["video", "segmentations", "flow", "boxes"] + + # Dictionary of targets and corresponding channels. Losses need to match. + config.targets = {"flow": 3} + config.losses = ml_collections.ConfigDict({ + f"recon_{target}": {"loss_type": "recon", "key": target} + for target in config.targets}) + + config.conditioning_key = "boxes" + + config.model = ml_collections.ConfigDict({ + "module": "savi.modules.SAVi", + + # Encoder. + "encoder": ml_collections.ConfigDict({ + "module": "savi.modules.FrameEncoder", + "reduction": "spatial_flatten", + "backbone": ml_collections.ConfigDict({ + "module": "savi.modules.ResNet34", + "num_classes": None, + "axis_name": "time", + "norm_type": "group", + "small_inputs": True + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "savi.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "project_add", + "output_transform": ml_collections.ConfigDict({ + "module": "savi.modules.MLP", + "hidden_size": 64, + "layernorm": "pre" + }), + }), + }), + + # Corrector. + "corrector": ml_collections.ConfigDict({ + "module": "savi.modules.SlotAttention", + "num_iterations": 1, + "qkv_size": 128, + }), + + # Predictor. + "predictor": ml_collections.ConfigDict({ + "module": "savi.modules.TransformerBlock", + "num_heads": 4, + "qkv_size": 128, + "mlp_size": 256 + }), + + # Initializer. + "initializer": ml_collections.ConfigDict({ + "module": "savi.modules.CoordinateEncoderStateInit", + "prepend_background": True, + "center_of_mass": False, + "embedding_transform": ml_collections.ConfigDict({ + "module": "savi.modules.MLP", + "hidden_size": 256, + "output_size": 128, + "layernorm": None + }), + }), + + # Decoder. + "decoder": ml_collections.ConfigDict({ + "module": + "savi.modules.SpatialBroadcastDecoder", + "resolution": (8, 8), # Update if data resolution or strides change. + "backbone": ml_collections.ConfigDict({ + "module": "savi.modules.CNN", + "features": [64, 64, 64, 64], + "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)], + "strides": [(2, 2), (2, 2), (2, 2), (2, 2)], + "layer_transpose": [True, True, True, True] + }), + "pos_emb": ml_collections.ConfigDict({ + "module": "savi.modules.PositionEmbedding", + "embedding_type": "linear", + "update_type": "project_add" + }), + "target_readout": ml_collections.ConfigDict({ + "module": "savi.modules.Readout", + "keys": list(config.targets), + "readout_modules": [ + ml_collections.ConfigDict({ + "module": "savi.modules.Dense", + "features": config.targets[k] + }) for k in config.targets + ], + }), + }), + "decode_corrected": True, + "decode_predicted": False, # Disable prediction decoder to save memory. + }) + + # Define which video-shaped variables to log/visualize. + config.debug_var_video_paths = { + "recon_masks": "SpatialBroadcastDecoder_0/alphas", + } + for k in config.targets: + config.debug_var_video_paths.update({ + f"{k}_recon": f"SpatialBroadcastDecoder_0/{k}_combined"}) + + # Define which attention matrices to log/visualize. + config.debug_var_attn_paths = { + "corrector_attn": "SlotAttention_0/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" + } + + # Widths of attention matrices (for reshaping to image grid). + config.debug_var_attn_widths = { + "corrector_attn": 16, + } + + return config + + diff --git a/savi/lib/evaluator.py b/savi/lib/evaluator.py index 725df44..77a79cd 100644 --- a/savi/lib/evaluator.py +++ b/savi/lib/evaluator.py @@ -228,7 +228,7 @@ def eval_step( # Join predictions along sequence dimension. concat_fn = lambda _, *x: functools.partial(np.concatenate, axis=2)([*x]) - preds = jax.tree_multimap(concat_fn, preds_per_slice[0], *preds_per_slice) + preds = jax.tree_map(concat_fn, preds_per_slice[0], *preds_per_slice) # Truncate to original sequence length. # NOTE: This op assumes that all predictions have a (complete) time axis. diff --git a/savi/lib/preprocessing.py b/savi/lib/preprocessing.py index 4f0aee9..576e529 100644 --- a/savi/lib/preprocessing.py +++ b/savi/lib/preprocessing.py @@ -910,3 +910,187 @@ def __call__(self, features: Features) -> Features: features[self.flow_key] = tf.image.convert_image_dtype( flow_rgb, tf.float32) return features + + +@dataclasses.dataclass +class TransformDepth: + """Applies one of several possible transformations to depth features.""" + transform: str + depth_key: str = DEPTH + + def __call__(self, features: Features) -> Features: + if self.depth_key in features: + if self.transform == "log": + depth_norm = tf.math.log(features[self.depth_key]) + elif self.transform == "log_plus": + depth_norm = tf.math.log(1. + features[self.depth_key]) + elif self.transform == "invert_plus": + depth_norm = 1. / (1. + features[self.depth_key]) + else: + raise ValueError(f"Unknown depth transformation {self.transform}") + + features[self.depth_key] = depth_norm + return features + + +@dataclasses.dataclass +class RandomResizedCrop(RandomVideoPreprocessOp): + """Random-resized crop for each of the two views. + + Assumption: Height and width are the same for all video-like modalities. + + We randomly crop the input and record the transformation this crop corresponds + to as a new feature. Croped images are resized to (height, width). Boxes are + corrected adjusted and boxes outside the crop are discarded. Flow is rescaled + so as to be pixel accurate after the operation. lidar_points_2d are + transformed using the computed transformation. These points may lie outside + the image after the operation. + + Attr: + height: An integer representing the height to resize to. + width: An integer representing the width to resize to. + min_object_covered, aspect_ratio_range, area_range, max_attempts: See + docstring of `stateless_sample_distorted_bounding_box`. Aspect ratio range + has not been scaled by target aspect ratio. This differs from other + implementations of this data augmentation. + relative_box_area_threshold: If ratio of areas before and after cropping are + lower than this threshold, then the box is discarded (set to NOTRACK_BOX). + """ + # Target size. + height: int + width: int + + # Crop sampling attributes. + min_object_covered: float = 0.1 + aspect_ratio_range: Tuple[float, float] = (3. / 4., 4. / 3.) + area_range: Tuple[float, float] = (0.08, 1.0) + max_attempts: int = 100 + + # Box retention attributes + relative_box_area_threshold: float = 0.0 + + def apply(self, tensor: tf.Tensor, seed: tf.Tensor, key: str, + video_shape: tf.Tensor) -> tf.Tensor: + """Applies the crop operation on tensor.""" + param = self.sample_augmentation_params(video_shape, seed) + si, sj = param[0], param[1] + crop_h, crop_w = param[2], param[3] + + to_float32 = lambda x: tf.cast(x, tf.float32) + + if key == self.boxes_key: + # First crop the boxes. + cropped_boxes = crop_or_pad_boxes( + tensor, si, sj, + crop_h, crop_w, + video_shape[1], video_shape[2]) + # We do not need to scale the boxes because they are in normalized coords. + resized_boxes = cropped_boxes + # Lastly detects NOTRACK_BOX boxes and avoid manipulating those. + no_track_boxes = tf.convert_to_tensor(NOTRACK_BOX) + no_track_boxes = tf.reshape(no_track_boxes, [1, 4]) + resized_boxes = tf.where( + tf.reduce_all(tensor == no_track_boxes, axis=-1, keepdims=True), + tensor, resized_boxes) + + if self.relative_box_area_threshold > 0: + # Thresholds boxes that have been cropped too much, as in their area is + # lower, in relative terms, than `relative_box_area_threshold`. + area_before_crop = tf.reduce_prod(tensor[..., 2:] - tensor[..., :2], + axis=-1) + # Sets minimum area_before_crop to 1e-8 we avoid divisions by 0. + area_before_crop = tf.maximum(area_before_crop, + tf.zeros_like(area_before_crop) + 1e-8) + area_after_crop = tf.reduce_prod( + resized_boxes[..., 2:] - resized_boxes[..., :2], axis=-1) + # As the boxes have normalized coordinates, they need to be rescaled to + # be compared against the original uncropped boxes. + scale_x = to_float32(crop_w) / to_float32(self.width) + scale_y = to_float32(crop_h) / to_float32(self.height) + area_after_crop *= scale_x * scale_y + + ratio = area_after_crop / area_before_crop + return tf.where( + tf.expand_dims(ratio > self.relative_box_area_threshold, -1), + resized_boxes, no_track_boxes) + + else: + return resized_boxes + + else: + if key in (self.padding_mask_key, self.segmentations_key): + tensor = tensor[..., tf.newaxis] + + # Crop. + seq_len, n_channels = tensor.get_shape()[0], tensor.get_shape()[3] + crop_size = (seq_len, crop_h, crop_w, n_channels) + tensor = tf.slice(tensor, tf.stack([0, si, sj, 0]), crop_size) + + # Resize. + resize_method = tf.image.ResizeMethod.BILINEAR + if (tensor.dtype == tf.int32 or tensor.dtype == tf.int64 or + tensor.dtype == tf.uint8): + resize_method = tf.image.ResizeMethod.NEAREST_NEIGHBOR + tensor = tf.image.resize(tensor, [self.height, self.width], + method=resize_method) + out_size = (seq_len, self.height, self.width, n_channels) + tensor = tf.ensure_shape(tensor, out_size) + + if key == self.flow_key: + # Rescales optical flow. + scale_x = to_float32(self.width) / to_float32(crop_w) + scale_y = to_float32(self.height) / to_float32(crop_h) + tensor = tf.stack( + [tensor[..., 0] * scale_y, tensor[..., 1] * scale_x], axis=-1) + + if key in (self.padding_mask_key, self.segmentations_key): + tensor = tensor[..., 0] + return tensor + + def sample_augmentation_params(self, video_shape: tf.Tensor, rng: tf.Tensor): + """Sample a random bounding box for the crop.""" + sample_bbox = tf.image.stateless_sample_distorted_bounding_box( + video_shape[1:], + bounding_boxes=tf.constant([0.0, 0.0, 1.0, 1.0], + dtype=tf.float32, shape=[1, 1, 4]), + seed=rng, + min_object_covered=self.min_object_covered, + aspect_ratio_range=self.aspect_ratio_range, + area_range=self.area_range, + max_attempts=self.max_attempts, + use_image_if_no_bounding_boxes=True) + bbox_begin, bbox_size, _ = sample_bbox + + # The specified bounding box provides crop coordinates. + offset_y, offset_x, _ = tf.unstack(bbox_begin) + target_height, target_width, _ = tf.unstack(bbox_size) + + return tf.stack([offset_y, offset_x, target_height, target_width]) + + def estimate_transformation(self, param: tf.Tensor, video_shape: tf.Tensor + ) -> tf.Tensor: + """Computes the affine transformation for crop params. + + Args: + param: Crop parameters in the [y, x, h, w] format of shape [4,]. + video_shape: Unused. + + Returns: + Affine transformation of shape [3, 3] corresponding to cropping the image + at [y, x] of size [h, w] and resizing it into [self.height, self.width]. + """ + del video_shape + crop = tf.cast(param, tf.float32) + si, sj = crop[0], crop[1] + crop_h, crop_w = crop[2], crop[3] + ei, ej = si + crop_h - 1.0, sj + crop_w - 1.0 + h, w = float(self.height), float(self.width) + + a1 = (ei - si + 1.)/h + a2 = 0. + a3 = si - 0.5 + a1 / 2. + a4 = 0. + a5 = (ej - sj + 1.)/w + a6 = sj - 0.5 + a5 / 2. + affine = tf.stack([a1, a2, a3, a4, a5, a6, 0., 0., 1.]) + return tf.reshape(affine, [3, 3]) diff --git a/savi/lib/trainer.py b/savi/lib/trainer.py index eb04ebd..c31e21b 100644 --- a/savi/lib/trainer.py +++ b/savi/lib/trainer.py @@ -48,7 +48,7 @@ def train_step( rng: PRNGKey, step: int, state_vars: flax.core.FrozenDict, - opt: flax.optim.Optimizer, + opt: flax.optim.Optimizer, # pytype: disable=module-attr batch: Dict[str, ArrayTree], loss_fn: losses.LossFn, learning_rate_fn: Callable[[Array], Array], @@ -57,7 +57,7 @@ def train_step( ground_truth_max_num_instances: int, conditioning_key: Optional[str] = None, max_grad_norm: Optional[float] = None, - ) -> Tuple[flax.optim.Optimizer, flax.core.FrozenDict, PRNGKey, + ) -> Tuple[flax.optim.Optimizer, flax.core.FrozenDict, PRNGKey, # pytype: disable=module-attr metrics.Collection, int]: """Perform a single training step. @@ -166,7 +166,7 @@ def train_and_evaluate(config: ml_collections.ConfigDict, peak_value=config.learning_rate, warmup_steps=config.warmup_steps, decay_steps=config.num_train_steps) - optimizer_def = flax.optim.Adam(learning_rate=config.learning_rate) + optimizer_def = flax.optim.Adam(learning_rate=config.learning_rate) # pytype: disable=module-attr # Construct TrainMetrics and EvalMetrics, metrics collections. train_metrics_cls = utils.make_metrics_collection("TrainMetrics", diff --git a/savi/lib/utils.py b/savi/lib/utils.py index 62b3d41..262cfb4 100644 --- a/savi/lib/utils.py +++ b/savi/lib/utils.py @@ -24,7 +24,7 @@ from flax import linen as nn from flax import traverse_util import jax -from jax.experimental import optimizers as jax_optim +from jax.example_libraries import optimizers as jax_optim import jax.numpy as jnp import jax.ops import matplotlib @@ -47,7 +47,7 @@ class TrainState: """Data structure for checkpointing the model.""" step: int - optimizer: flax.optim.Optimizer + optimizer: flax.optim.Optimizer # pytype: disable=module-attr variables: flax.core.FrozenDict rng: PRNGKey @@ -85,7 +85,7 @@ def flatten_named_dicttree(metrics_res: DictTree, sep: str = "/"): def clip_grads(grad_tree: ArrayTree, max_norm: float, epsilon: float = 1e-6): """Gradient clipping with epsilon. - Adapted from jax.experimental.optimizers.clip_grads. + Adapted from jax.example_libraries.optimizers.clip_grads. Args: grad_tree: ArrayTree of gradients. diff --git a/savi/modules/__init__.py b/savi/modules/__init__.py index 8b85856..1eb3327 100644 --- a/savi/modules/__init__.py +++ b/savi/modules/__init__.py @@ -27,5 +27,7 @@ CoordinateEncoderStateInit) from .misc import (Dense, GRU, Identity, MLP, PositionEmbedding, Readout) from .video import (CorrectorPredictorTuple, FrameEncoder, Processor, SAVi) +from .resnet import (ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, + ResNet200) diff --git a/savi/modules/decoders.py b/savi/modules/decoders.py index 3f5504f..02cbf48 100644 --- a/savi/modules/decoders.py +++ b/savi/modules/decoders.py @@ -38,6 +38,7 @@ class SpatialBroadcastDecoder(nn.Module): resolution: Sequence[int] backbone: Callable[[], nn.Module] pos_emb: Callable[[], nn.Module] + early_fusion: bool = False # Fuse slot features before constructing targets. target_readout: Optional[Callable[[], nn.Module]] = None # Vmapped application of module, consumes time axis (axis=1). @@ -68,26 +69,39 @@ def __call__(self, slots: Array, train: bool = False) -> Array: # Define intermediates for logging / visualization. self.sow("intermediates", "alphas", alphas) + if self.early_fusion: + # To save memory, fuse the slot features before predicting targets. + # The final target output should be equivalent to the late fusion when + # using linear prediction. + bb_features = jnp.reshape( + bb_features, (batch_size, n_slots) + spatial_dims + (-1,)) + # Combine backbone features by alpha masks. + bb_features = jnp.sum(bb_features * alphas, axis=1) + targets_dict = self.target_readout()(bb_features, train) preds_dict = dict() for target_key, channels in targets_dict.items(): - - # channels.shape = (batch_size, n_slots, h, w, c) - channels = jnp.reshape( - channels, (batch_size, n_slots) + (spatial_dims) + (-1,)) - - # masked_channels.shape = (batch_size, n_slots, h, w, c) - masked_channels = channels * alphas - - # decoded_target.shape = (batch_size, h, w, c) - decoded_target = jnp.sum(masked_channels, axis=1) # Combine target. + if self.early_fusion: + # decoded_target.shape = (batch_size, h, w, c) after next line. + decoded_target = channels + else: + # channels.shape = (batch_size, n_slots, h, w, c) + channels = jnp.reshape( + channels, (batch_size, n_slots) + (spatial_dims) + (-1,)) + + # masked_channels.shape = (batch_size, n_slots, h, w, c) + masked_channels = channels * alphas + + # decoded_target.shape = (batch_size, h, w, c) + decoded_target = jnp.sum(masked_channels, axis=1) # Combine target. preds_dict[target_key] = decoded_target if not train: # Define intermediates for logging / visualization. self.sow("intermediates", f"{target_key}_slots", channels) - self.sow("intermediates", f"{target_key}_masked", masked_channels) + if not self.early_fusion: + self.sow("intermediates", f"{target_key}_masked", masked_channels) self.sow("intermediates", f"{target_key}_combined", decoded_target) preds_dict["segmentations"] = jnp.argmax(alpha_logits, axis=1) diff --git a/savi/modules/resnet.py b/savi/modules/resnet.py new file mode 100644 index 0000000..86cb41f --- /dev/null +++ b/savi/modules/resnet.py @@ -0,0 +1,230 @@ +# Copyright 2022 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementation of ResNet V1 in Flax. + +"Deep Residual Learning for Image Recognition" +He et al., 2015, [https://arxiv.org/abs/1512.03385] +""" + +import functools + +from typing import Any, Tuple, Type, List, Optional, Callable, Iterable +import flax.linen as nn +import jax.numpy as jnp + + +Conv1x1 = functools.partial(nn.Conv, kernel_size=(1, 1), use_bias=False) +Conv3x3 = functools.partial(nn.Conv, kernel_size=(3, 3), use_bias=False) + + +class ResNetBlock(nn.Module): + """ResNet block without bottleneck used in ResNet-18 and ResNet-34.""" + + filters: int + norm: Any + kernel_dilation: Tuple[int, int] = (1, 1) + strides: Tuple[int, int] = (1, 1) + + @nn.compact + def __call__(self, x): + residual = x + + x = Conv3x3( + self.filters, + strides=self.strides, + kernel_dilation=self.kernel_dilation, + name="conv1")(x) + x = self.norm(name="bn1")(x) + x = nn.relu(x) + x = Conv3x3(self.filters, name="conv2")(x) + # Initializing the scale to 0 has been common practice since "Fixup + # Initialization: Residual Learning Without Normalization" Tengyu et al, + # 2019, [https://openreview.net/forum?id=H1gsz30cKX]. + x = self.norm(scale_init=nn.initializers.zeros, name="bn2")(x) + + if residual.shape != x.shape: + residual = Conv1x1( + self.filters, strides=self.strides, name="proj_conv")( + residual) + residual = self.norm(name="proj_bn")(residual) + + x = nn.relu(residual + x) + return x + + +class BottleneckResNetBlock(ResNetBlock): + """Bottleneck ResNet block used in ResNet-50 and larger.""" + + @nn.compact + def __call__(self, x): + residual = x + + x = Conv1x1(self.filters, name="conv1")(x) + x = self.norm(name="bn1")(x) + x = nn.relu(x) + x = Conv3x3( + self.filters, + strides=self.strides, + kernel_dilation=self.kernel_dilation, + name="conv2")(x) + x = self.norm(name="bn2")(x) + x = nn.relu(x) + x = Conv1x1(4 * self.filters, name="conv3")(x) + # Initializing the scale to 0 has been common practice since "Fixup + # Initialization: Residual Learning Without Normalization" Tengyu et al, + # 2019, [https://openreview.net/forum?id=H1gsz30cKX]. + x = self.norm(name="bn3")(x) + + if residual.shape != x.shape: + residual = Conv1x1( + 4 * self.filters, strides=self.strides, name="proj_conv")( + residual) + residual = self.norm(name="proj_bn")(residual) + + x = nn.relu(residual + x) + return x + + +class ResNetStage(nn.Module): + """ResNet stage consistent of multiple ResNet blocks.""" + + stage_size: int + filters: int + block_cls: Type[ResNetBlock] + norm: Any + first_block_strides: Tuple[int, int] + + @nn.compact + def __call__(self, x): + for i in range(self.stage_size): + x = self.block_cls( + filters=self.filters, + norm=self.norm, + strides=self.first_block_strides if i == 0 else (1, 1), + name=f"block{i + 1}")( + x) + return x + + +class ResNet(nn.Module): + """Construct ResNet V1 with `num_classes` outputs. + + Attributes: + num_classes: Number of nodes in the final layer. + block_cls: Class for the blocks. ResNet-50 and larger use + `BottleneckResNetBlock` (convolutions: 1x1, 3x3, 1x1), ResNet-18 and + ResNet-34 use `ResNetBlock` without bottleneck (two 3x3 convolutions). + stage_sizes: List with the number of ResNet blocks in each stage. Number of + stages can be varied. + norm_type: Which type of normalization layer to apply. Options are: + "batch": BatchNorm, "group": GroupNorm, "layer": LayerNorm. Defaults to + BatchNorm. + width_factor: Factor applied to the number of filters. The 64 * width_factor + is the number of filters in the first stage, every consecutive stage + doubles the number of filters. + small_inputs: Bool, if True, ignore strides and skip max pooling in the root + block and use smaller filter size. + stage_strides: Stride per stage. This overrides all other arguments. + include_top: Whether to include the fully-connected layer at the top + of the network. + axis_name: Axis name over which to aggregate batchnorm statistics. + """ + num_classes: int + block_cls: Type[ResNetBlock] + stage_sizes: List[int] + norm_type: str = "batch" + width_factor: int = 1 + small_inputs: bool = False + stage_strides: Optional[List[Tuple[int, int]]] = None + include_top: bool = False + axis_name: Optional[str] = None + output_initializer: Callable[[Any, Iterable[int], Any], Any] = ( + nn.initializers.zeros) + + @nn.compact + def __call__(self, x, *, train: bool): + """Apply the ResNet to the inputs `x`. + + Args: + x: Inputs. + train: Whether to use BatchNorm in training or inference mode. + + Returns: + The output head with `num_classes` entries. + """ + width = 64 * self.width_factor + + if self.norm_type == "batch": + norm = functools.partial( + nn.BatchNorm, use_running_average=not train, momentum=0.9, + axis_name=self.axis_name) + elif self.norm_type == "layer": + norm = nn.LayerNorm + elif self.norm_type == "group": + norm = nn.GroupNorm + else: + raise ValueError(f"Invalid norm_type: {self.norm_type}") + + # Root block. + x = nn.Conv( + features=width, + kernel_size=(7, 7) if not self.small_inputs else (3, 3), + strides=(2, 2) if not self.small_inputs else (1, 1), + use_bias=False, + name="init_conv")( + x) + x = norm(name="init_bn")(x) + + if not self.small_inputs: + x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") + + # Stages. + for i, stage_size in enumerate(self.stage_sizes): + if i == 0: + first_block_strides = ( + 1, 1) if self.stage_strides is None else self.stage_strides[i] + else: + first_block_strides = ( + 2, 2) if self.stage_strides is None else self.stage_strides[i] + + x = ResNetStage( + stage_size, + filters=width * 2**i, + block_cls=self.block_cls, + norm=norm, + first_block_strides=first_block_strides, + name=f"stage{i + 1}")(x) + + # Head. + if self.include_top: + x = jnp.mean(x, axis=(1, 2)) + x = nn.Dense( + self.num_classes, kernel_init=self.output_initializer, name="head")(x) + return x + + +ResNetWithBasicBlk = functools.partial(ResNet, block_cls=ResNetBlock) +ResNetWithBottleneckBlk = functools.partial(ResNet, + block_cls=BottleneckResNetBlock) + +ResNet18 = functools.partial(ResNetWithBasicBlk, stage_sizes=[2, 2, 2, 2]) +ResNet34 = functools.partial(ResNetWithBasicBlk, stage_sizes=[3, 4, 6, 3]) +ResNet50 = functools.partial(ResNetWithBottleneckBlk, stage_sizes=[3, 4, 6, 3]) +ResNet101 = functools.partial(ResNetWithBottleneckBlk, + stage_sizes=[3, 4, 23, 3]) +ResNet152 = functools.partial(ResNetWithBottleneckBlk, + stage_sizes=[3, 8, 36, 3]) +ResNet200 = functools.partial(ResNetWithBottleneckBlk, + stage_sizes=[3, 24, 36, 3]) diff --git a/savi/modules/video.py b/savi/modules/video.py index d7fa961..6d0e71a 100644 --- a/savi/modules/video.py +++ b/savi/modules/video.py @@ -184,5 +184,11 @@ def __call__(self, inputs: Array, padding_mask: Optional[Array] = None, elif self.reduction is not None: raise ValueError("Unknown reduction type: {}.".format(self.reduction)) - x = self.output_transform()(x, train=train) + output_block = self.output_transform() + + if hasattr(output_block, "qkv_size"): + # Project to qkv_size if used transformer. + x = nn.relu(nn.Dense(output_block.qkv_size)(x)) + + x = output_block(x, train=train) return x