-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 25f83ee
Showing
148 changed files
with
9,595 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
recursive-include lavis/configs *.yaml *.json | ||
recursive-include lavis/projects *.yaml *.json | ||
|
||
recursive-exclude lavis/datasets/download_scripts * | ||
recursive-exclude lavis/output * | ||
|
||
include requirements.txt | ||
include lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
<div align="center"> | ||
<h1> | ||
<b> | ||
Data-Efficient Multimodal Fusion on a Single GPU | ||
</b> | ||
</h1> | ||
|
||
<p align="center"> | ||
<a href='https://arxiv.org/abs/2312.10144'><img src='https://img.shields.io/badge/arXiv-2312.10144-b31b1b.svg' /></a> | ||
</p> | ||
|
||
<h4> | ||
<b> | ||
<a href="https://www.cs.toronto.edu/~nvouitsis/">Noël Vouitsis*</a>, <a href="https://www.linkedin.com/in/zhaoyan-liu-9309aa180/">Zhaoyan Liu*</a>, <a href="https://www.cs.toronto.edu/~satyag/">Satya Krishna Gorti*</a>, <a href="http://linkedin.com/in/valentin-villecroze">Valentin Villecroze</a>, <a href="http://jescresswell.github.io/">Jesse C. Cresswell</a>, <a href="http://www.cs.toronto.edu/~guangweiyu/">Guangwei Yu</a>, <a href="https://sites.google.com/view/gabriel-loaiza-ganem/">Gabriel Loaiza-Ganem</a>, <a href="https://www.cs.toronto.edu/~mvolkovs/">Maksims Volkovs</a> | ||
</b> | ||
</h4> | ||
</div> | ||
|
||
|
||
## Introduction | ||
This repository contains the official implementation of our <b>CVPR 2024</b> paper <a href='https://arxiv.org/abs/2312.10144'>Data-Efficient Multimodal Fusion on a Single GPU</a>. We release code for the image-text setting, including code for dataset downloading, feature extraction, fusion training and evaluation. We note that our code is based on the [LAVIS](https://github.com/salesforce/LAVIS) library. | ||
|
||
## Installation | ||
|
||
1. (Optional) Creating conda environment | ||
|
||
```bash | ||
conda create -n fusemix python=3.8 | ||
conda activate fusemix | ||
``` | ||
|
||
2. Build from source | ||
|
||
```bash | ||
git clone https://github.com/layer6ai-labs/fusemix | ||
cd fusemix | ||
pip install -e . | ||
``` | ||
|
||
## Getting Started | ||
### Model Zoo | ||
Model zoo summarizes supported models, to view: | ||
```python | ||
from lavis.models import model_zoo | ||
print(model_zoo) | ||
# ====================================================================== | ||
# Architectures Types | ||
# ====================================================================== | ||
# dinov2_feature_extractor vits14, vitb14, vitl14, vitg14 | ||
# bge_feature_extractor large | ||
# cohere_feature_extractor v3 | ||
# mlp_contrastive_fusion base | ||
``` | ||
|
||
### Dataset Zoo | ||
Dataset zoo summarizes supported datasets, to view: | ||
|
||
```python | ||
from lavis.datasets.builders import dataset_zoo | ||
dataset_names = dataset_zoo.get_names() | ||
print(dataset_names) | ||
``` | ||
|
||
### Dataset Downloading | ||
Please refer to `lavis/datasets/download_scripts` for scripts to download the required datasets. | ||
|
||
|
||
### Feature Extraction | ||
|
||
```bash | ||
bash run_scripts/feature_extract/feat_extract_bge_large_coco_cap.sh | ||
``` | ||
|
||
|
||
### FuseMix Training | ||
|
||
```bash | ||
bash run_scripts/fusion/mlp_contrastive_fusion_pretrain_dinov2_vitg14_bge_large_coco_vg_sbu_cap_cc3m.sh | ||
``` | ||
|
||
### Evaluation | ||
|
||
```bash | ||
bash run_scripts/fusion/mlp_contrastive_fusion_retrieval_dinov2_vitg14_bge_large_coco.sh | ||
``` | ||
|
||
## Citation | ||
If you find this work useful in your research, please cite the following paper: | ||
``` | ||
@inproceedings{vouitsis2024dataefficient, | ||
title={Data-Efficient Multimodal Fusion on a Single GPU}, | ||
author={Noël Vouitsis and Zhaoyan Liu and Satya Krishna Gorti and Valentin Villecroze and Jesse C. Cresswell and Guangwei Yu and Gabriel Loaiza-Ganem and Maksims Volkovs}, | ||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, | ||
year={2024}, | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
""" | ||
Copyright (c) 2022, salesforce.com, inc. | ||
All rights reserved. | ||
SPDX-License-Identifier: BSD-3-Clause | ||
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | ||
""" | ||
|
||
import argparse | ||
import random | ||
|
||
import numpy as np | ||
import torch | ||
import torch.backends.cudnn as cudnn | ||
|
||
import lavis.tasks as tasks | ||
from lavis.common.config import Config | ||
from lavis.common.dist_utils import get_rank, init_distributed_mode | ||
from lavis.common.logger import setup_logger | ||
from lavis.common.optims import ( | ||
LinearWarmupCosineLRScheduler, | ||
LinearWarmupStepLRScheduler, | ||
) | ||
from lavis.common.utils import now | ||
|
||
# imports modules for registration | ||
from lavis.datasets.builders import * | ||
from lavis.models import * | ||
from lavis.processors import * | ||
from lavis.runners.runner_base import RunnerBase | ||
from lavis.tasks import * | ||
from lavis.models.fusion_models.mixup import ( | ||
CosineMixupAlphaScheduler, | ||
ExpMixupAlphaScheduler, | ||
) | ||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description="Training") | ||
|
||
parser.add_argument("--cfg-path", required=True, help="path to configuration file.") | ||
parser.add_argument( | ||
"--options", | ||
nargs="+", | ||
help="override some settings in the used config, the key-value pair " | ||
"in xxx=yyy format will be merged into config file (deprecate), " | ||
"change to --cfg-options instead.", | ||
) | ||
|
||
args = parser.parse_args() | ||
# if 'LOCAL_RANK' not in os.environ: | ||
# os.environ['LOCAL_RANK'] = str(args.local_rank) | ||
|
||
return args | ||
|
||
|
||
def setup_seeds(config): | ||
seed = config.run_cfg.seed + get_rank() | ||
|
||
random.seed(seed) | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
|
||
cudnn.benchmark = False | ||
cudnn.deterministic = True | ||
|
||
|
||
def main(): | ||
# allow auto-dl completes on main process without timeout when using NCCL backend. | ||
# os.environ["NCCL_BLOCKING_WAIT"] = "1" | ||
|
||
# set before init_distributed_mode() to ensure the same job_id shared across all ranks. | ||
job_id = now() | ||
|
||
cfg = Config(parse_args()) | ||
|
||
init_distributed_mode(cfg.run_cfg) | ||
|
||
setup_seeds(cfg) | ||
|
||
# set after init_distributed_mode() to only log on master. | ||
setup_logger() | ||
|
||
cfg.pretty_print() | ||
|
||
task = tasks.setup_task(cfg) | ||
datasets = task.build_datasets(cfg) | ||
model = task.build_model(cfg) | ||
|
||
runner = RunnerBase( | ||
cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets | ||
) | ||
runner.evaluate(skip_reload=True) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
""" | ||
Copyright (c) 2022, salesforce.com, inc. | ||
All rights reserved. | ||
SPDX-License-Identifier: BSD-3-Clause | ||
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | ||
""" | ||
|
||
import os | ||
import sys | ||
|
||
from omegaconf import OmegaConf | ||
|
||
from lavis.common.registry import registry | ||
|
||
from lavis.datasets.builders import * | ||
from lavis.models import * | ||
from lavis.processors import * | ||
from lavis.tasks import * | ||
|
||
|
||
root_dir = os.path.dirname(os.path.abspath(__file__)) | ||
default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml")) | ||
|
||
registry.register_path("library_root", root_dir) | ||
repo_root = os.path.join(root_dir, "..") | ||
registry.register_path("repo_root", repo_root) | ||
cache_root = os.path.join(repo_root, default_cfg.env.cache_root) | ||
registry.register_path("cache_root", cache_root) | ||
|
||
registry.register("MAX_INT", sys.maxsize) | ||
registry.register("SPLIT_NAMES", ["train", "val", "test"]) |
Oops, something went wrong.