-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
rowanz
committed
Dec 30, 2021
1 parent
bf336c3
commit c73420a
Showing
54 changed files
with
21,995 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,77 @@ | ||
# merlot_reserve | ||
Code release for "MERLOT Reserve: Neural Script Knowledge through Vision and Language and Sound" | ||
|
||
MERLOT Reserve (in submission) is a model for learning joint representations of vision, language, and sound from YouTube. The learned model can be used in a zero-shot or finetuned setting, where it does well on tasks like VCR and TVQA. | ||
|
||
Visit our project page at [rowanzellers.com/merlotreserve](https://rowanzellers.com/merlotreserve) or read the [full paper](#) to learn more. | ||
|
||
![](https://i.imgur.com/Z9iEsLZ.png "MERLOT Reserve Teaser") | ||
|
||
## What's here | ||
|
||
We are releasing the following: | ||
* JAX code, and model checkpoints, for the MERLOT model | ||
* Code for pretraining the model | ||
* Code for finetuning the model on VCR and TVQA | ||
* Code for doing zero-shot inference with the model | ||
|
||
## Environment and setup | ||
|
||
There are two different ways to run MERLOT Reserve: | ||
|
||
* *Pretraining on videos* You'll need a TPU Pod VM for this. This step shouldn't be necessary for most people, as we have released model checkpoints. | ||
* *Finetuning on VCR or TVQA* I've done this on a TPU v3-8 VM. This should be possible on GPU(s), but I haven't tested this on such hardware. | ||
* *Zero-shot inference* I've ran this on a GPU (even an older, Titan X from 2016 works.) | ||
|
||
### Installation on a GPU Machine | ||
Install Cuda 11.4 (I used [this link](https://developer.download.nvidia.com/compute/cuda/11.4.2/local_installers/cuda-repo-ubuntu1804-11-4-local_11.4.2-470.57.02-1_amd64.deb)) and [CUDNN 8.2](https://developer.nvidia.com/rdp/cudnn-download). You might have to add something like this to your `PATH`: | ||
|
||
`export LD_LIBRARY_PATH=/usr/local/cuda/lib64` | ||
|
||
Create the environment: | ||
```bash | ||
conda create --name mreserve python=3.8 && conda activate mreserve | ||
conda install -y python=3.8 tqdm numpy pyyaml scipy ipython cython typing h5py pandas matplotlib | ||
|
||
# Install jax | ||
pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html | ||
# If doing this on TPUs instead of locally... | ||
# pip install "jax[tpu]>=0.2.18" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html | ||
|
||
# This is needed sometimes https://stackoverflow.com/questions/66060487/valueerror-numpy-ndarray-size-changed-may-indicate-binary-incompatibility-exp | ||
pip uninstall numpy | ||
pip install numpy==1.19.5 | ||
|
||
pip install -r requirements.txt | ||
``` | ||
|
||
You can then try out the interactive script at [demo/demo_video.py](demo/demo_video.py). It will handle downloading the model checkpoint for you. | ||
|
||
### Installation on a Cloud TPU VM | ||
|
||
See the instructions in [pretrain/](pretrain/) to set up your environment on a TPU v3-8 VM. | ||
|
||
## Checkpoints | ||
|
||
These should get auto-downloaded if you use `PretrainedMerlotReserve` in [mreserve/modeling.py](mreserve/modeling.py). All are flax checkpoint files: | ||
|
||
```bash | ||
# pretrained checkpoints | ||
gs://merlotreserve/ckpts/base | ||
gs://merlotreserve/ckpts/base_resadapt | ||
gs://merlotreserve/ckpts/large | ||
gs://merlotreserve/ckpts/large_resadapt | ||
|
||
# finetuned checkpoints | ||
gs://merlotreserve/vcr_ckpts/vcr_finetune_base | ||
gs://merlotreserve/vcr_ckpts/vcr_finetune_large | ||
|
||
gs://merlotreserve/tvqa_ckpts/tvqa_finetune_base | ||
gs://merlotreserve/tvqa_ckpts/tvqa_finetune_large | ||
|
||
# TVQA Data | ||
gs://merlotreserve/finetune_data/tvqa/ | ||
|
||
# VCR data | ||
gs://merlotreserve/finetune_data/vcr/ | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# THE DEMO | ||
|
||
Here, `demo_video.py` is a python script for interactive Q/A on videos. | ||
|
||
First download the video | ||
```bash | ||
pip install youtube-dl | ||
youtube-dl -f "best[height<=480,ext=mp4]" https://www.youtube.com/watch?v=pmjPjZZRhNQ -o "%(id)s.%(ext)s" | ||
``` | ||
Then run the demo! `ipython -i demo_video.py` | ||
|
||
Check out [zero_shot_ek](zero_shot_ek) and [zero_shot_qa](zero_shot_qa) for using it on QA tasks. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
""" | ||
Demo for doing interesting things with a video | ||
""" | ||
import sys | ||
sys.path.append('../') | ||
|
||
from mreserve.preprocess import video_to_segments, preprocess_video, encoder, MASK | ||
from mreserve.modeling import PretrainedMerlotReserve | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
# This handles loading the model and getting the checkpoints. | ||
grid_size = (18, 32) | ||
model = PretrainedMerlotReserve.from_pretrained(model_name='large', image_grid_size=grid_size) | ||
|
||
## First open the video and break it up into segments. you can only have 8. | ||
# Each segment is 5 seconds so it corresponds to seconds 15 - 55 of the video | ||
|
||
# Feel free to change the URL! | ||
video_segments = video_to_segments('pmjPjZZRhNQ.mp4') | ||
video_segments = video_segments[3:11] | ||
|
||
# Set up a fake classification task. | ||
video_segments[0]['text'] = 'in this video i\'ll be<|MASK|>' | ||
video_segments[0]['use_text_as_input'] = True | ||
for i in range(1,8): | ||
video_segments[i]['use_text_as_input'] = False | ||
|
||
video_pre = preprocess_video(video_segments, output_grid_size=grid_size, verbose=True) | ||
|
||
# Now we embed the entire video and extract the text. result is [seq_len, H]. we extract a hidden state for every | ||
# MASK token | ||
out_h = model.embed_video(**video_pre) | ||
out_h = out_h[video_pre['tokens'] == MASK] | ||
|
||
options = ['making coffee', 'going backpacking'] | ||
|
||
# the following is all the labels from activitynet. why not! some of them don't make sense grammatically though. | ||
options += ['Applying sunscreen', 'Archery', 'Arm wrestling', 'Assembling bicycle', 'BMX', 'Baking cookies', 'Ballet', 'Bathing dog', 'Baton twirling', 'Beach soccer', 'Beer pong', 'Belly dance', 'Blow-drying hair', 'Blowing leaves', 'Braiding hair', 'Breakdancing', 'Brushing hair', 'Brushing teeth', 'Building sandcastles', 'Bullfighting', 'Bungee jumping', 'Calf roping', 'Camel ride', 'Canoeing', 'Capoeira', 'Carving jack-o-lanterns', 'Changing car wheel', 'Cheerleading', 'Chopping wood', 'Clean and jerk', 'Cleaning shoes', 'Cleaning sink', 'Cleaning windows', 'Clipping cat claws', 'Cricket', 'Croquet', 'Cumbia', 'Curling', 'Cutting the grass', 'Decorating the Christmas tree', 'Disc dog', 'Discus throw', 'Dodgeball', 'Doing a powerbomb', 'Doing crunches', 'Doing fencing', 'Doing karate', 'Doing kickboxing', 'Doing motocross', 'Doing nails', 'Doing step aerobics', 'Drinking beer', 'Drinking coffee', 'Drum corps', 'Elliptical trainer', 'Fixing bicycle', 'Fixing the roof', 'Fun sliding down', 'Futsal', 'Gargling mouthwash', 'Getting a haircut', 'Getting a piercing', 'Getting a tattoo', 'Grooming dog', 'Grooming horse', 'Hammer throw', 'Hand car wash', 'Hand washing clothes', 'Hanging wallpaper', 'Having an ice cream', 'High jump', 'Hitting a pinata', 'Hopscotch', 'Horseback riding', 'Hula hoop', 'Hurling', 'Ice fishing', 'Installing carpet', 'Ironing clothes', 'Javelin throw', 'Kayaking', 'Kite flying', 'Kneeling', 'Knitting', 'Laying tile', 'Layup drill in basketball', 'Long jump', 'Longboarding', 'Making a cake', 'Making a lemonade', 'Making a sandwich', 'Making an omelette', 'Mixing drinks', 'Mooping floor', 'Mowing the lawn', 'Paintball', 'Painting', 'Painting fence', 'Painting furniture', 'Peeling potatoes', 'Ping-pong', 'Plastering', 'Plataform diving', 'Playing accordion', 'Playing badminton', 'Playing bagpipes', 'Playing beach volleyball', 'Playing blackjack', 'Playing congas', 'Playing drums', 'Playing field hockey', 'Playing flauta', 'Playing guitarra', 'Playing harmonica', 'Playing ice hockey', 'Playing kickball', 'Playing lacrosse', 'Playing piano', 'Playing polo', 'Playing pool', 'Playing racquetball', 'Playing rubik cube', 'Playing saxophone', 'Playing squash', 'Playing ten pins', 'Playing violin', 'Playing water polo', 'Pole vault', 'Polishing forniture', 'Polishing shoes', 'Powerbocking', 'Preparing pasta', 'Preparing salad', 'Putting in contact lenses', 'Putting on makeup', 'Putting on shoes', 'Rafting', 'Raking leaves', 'Removing curlers', 'Removing ice from car', 'Riding bumper cars', 'River tubing', 'Rock climbing', 'Rock-paper-scissors', 'Rollerblading', 'Roof shingle removal', 'Rope skipping', 'Running a marathon', 'Sailing', 'Scuba diving', 'Sharpening knives', 'Shaving', 'Shaving legs', 'Shot put', 'Shoveling snow', 'Shuffleboard', 'Skateboarding', 'Skiing', 'Slacklining', 'Smoking a cigarette', 'Smoking hookah', 'Snatch', 'Snow tubing', 'Snowboarding', 'Spinning', 'Spread mulch', 'Springboard diving', 'Starting a campfire', 'Sumo', 'Surfing', 'Swimming', 'Swinging at the playground', 'Table soccer', 'Tai chi', 'Tango', 'Tennis serve with ball bouncing', 'Throwing darts', 'Trimming branches or hedges', 'Triple jump', 'Tug of war', 'Tumbling', 'Using parallel bars', 'Using the balance beam', 'Using the monkey bar', 'Using the pommel horse', 'Using the rowing machine', 'Using uneven bars', 'Vacuuming floor', 'Volleyball', 'Wakeboarding', 'Walking the dog', 'Washing dishes', 'Washing face', 'Washing hands', 'Waterskiing', 'Waxing skis', 'Welding', 'Windsurfing', 'Wrapping presents', 'Zumba'] | ||
label_space = model.get_label_space(options) | ||
|
||
# Dot product the <|MASK|> tokens and the options together | ||
logits = 100.0 * jnp.einsum('bh,lh->bl', out_h, label_space) | ||
|
||
for i, logits_i in enumerate(logits): | ||
print(f"Idx {i}", flush=True) | ||
probs = jax.nn.softmax(logits_i, -1) | ||
for idx_i in jnp.argsort(-probs): | ||
p_i = probs[idx_i] | ||
print("{:.1f} {}".format(p_i * 100.0, options[idx_i], flush=True)) |
Oops, something went wrong.