Skip to content

Commit

Permalink
NeVa token fusion (NVIDIA#9245)
Browse files Browse the repository at this point in the history
* token fusion via mlp downsampling + media_type default fix

Signed-off-by: paul-gibbons <[email protected]>

* inference update

Signed-off-by: paul-gibbons <[email protected]>

* adapter fix

Signed-off-by: paul-gibbons <[email protected]>

* config refactor, remove image_token_len dependency, transpose mlp_downsample height and weight

Signed-off-by: paul-gibbons <[email protected]>

* Apply isort and black reformatting

Signed-off-by: paul-gibbons <[email protected]>

* removing image_token_len in text generation strategy

Signed-off-by: paul-gibbons <[email protected]>

* fix patch_dim text generation

Signed-off-by: paul-gibbons <[email protected]>

* crop-size fix

Signed-off-by: paul-gibbons <[email protected]>

* fixing RGB reversal bug

Signed-off-by: paul-gibbons <[email protected]>

* Apply isort and black reformatting

Signed-off-by: paul-gibbons <[email protected]>

* crop_size default -> None in text_generation_strategy

Signed-off-by: paul-gibbons <[email protected]>

* Apply isort and black reformatting

Signed-off-by: paul-gibbons <[email protected]>

* patch_dim padding for mlp_downsample

Signed-off-by: paul-gibbons <[email protected]>

* Apply isort and black reformatting

Signed-off-by: paul-gibbons <[email protected]>

* patch_dim padding update

Signed-off-by: paul-gibbons <[email protected]>

* Apply isort and black reformatting

Signed-off-by: paul-gibbons <[email protected]>

* updating h/w patch_dim naming convention

Signed-off-by: paul-gibbons <[email protected]>

* Apply isort and black reformatting

Signed-off-by: paul-gibbons <[email protected]>

---------

Signed-off-by: paul-gibbons <[email protected]>
Signed-off-by: paul-gibbons <[email protected]>
Co-authored-by: paul-gibbons <[email protected]>
  • Loading branch information
paul-gibbons and paul-gibbons authored Jun 3, 2024
1 parent 9956b54 commit eb411ae
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ model:
from_pretrained: "openai/clip-vit-large-patch14" # path or name
from_hf: True
patch_dim: 14
crop_size: [224, 224]
hidden_size: 1024 # could be found from model but tricky in code
vision_select_layer: -2 # default to the last layer
class_token_length: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ model:
from_pretrained: "" # path or name
from_hf: True
patch_dim: 14
crop_size: [224, 224]
hidden_size: 1024 # could be found from model but tricky in code
vision_select_layer: -2 # default to the last layer
class_token_length: 1
Expand Down Expand Up @@ -189,7 +190,6 @@ model:
is_multimodal: True
media_type: image # currently supported: image
sep_image_conv_front: False
image_token_len: 256
conv_template: ${model.mm_cfg.llm.model_type} # check `nemo/collections/multimodal/data/neva/conversation.py`
image_folder: null
image_aspect_ratio: 'square'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ model:
from_pretrained: "" # path or name
from_hf: True
patch_dim: 14
crop_size: [224, 224]
hidden_size: 1024 # could be found from model but tricky in code
vision_select_layer: -2 # default to the last layer
class_token_length: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ model:
from_pretrained: "" # path or name
from_hf: True
patch_dim: 14
crop_size: [336, 336]
hidden_size: 1024 # could be found from model but tricky in code
vision_select_layer: -2 # default to the last layer
class_token_length: 1
Expand Down Expand Up @@ -194,7 +195,6 @@ model:
num_frames: 8 # selects the number of frames to use from the video
sep_token_between_frames: False # TODO: allow usage of separator tokens between frames
sep_image_conv_front: False
image_token_len: 256
conv_template: ${model.mm_cfg.llm.model_type} # check `nemo/collections/multimodal/data/neva/conversation.py`
image_folder: null
video_folder: null
Expand Down
59 changes: 43 additions & 16 deletions nemo/collections/multimodal/data/neva/neva_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,36 +145,34 @@ def open_video(self, file_name):
cap = decord.VideoReader(f)
return self.flatten_frames(cap)
else:
decord.bridge.set_bridge("torch")
cap = decord.VideoReader(os.path.join(self.video_folder, file_name))
return self.flatten_frames(cap)
return None

def flatten_frames(self, cap):
if self.data_cfg['splice_single_frame'] == 'first':
frame = cap[0].asnumpy()[:, :, ::-1]
frame = cap[0].asnumpy()
return Image.fromarray(frame).convert('RGB')
elif self.data_cfg['splice_single_frame'] == 'middle':
frame = cap[len(cap) // 2].asnumpy()[:, :, ::-1]
frame = cap[len(cap) // 2].asnumpy()
return Image.fromarray(frame).convert('RGB')
elif self.data_cfg['splice_single_frame'] == 'last':
frame = cap[-1].asnumpy()[:, :, ::-1]
frame = cap[-1].asnumpy()
return Image.fromarray(frame).convert('RGB')
else:
if self.data_cfg['num_frames'] == -1:
frames = []
for frame in cap:
rgb_frame = frame.asnumpy()[:, :, ::-1]
rgb_frame = frame.asnumpy()
img = Image.fromarray(rgb_frame).convert('RGB')
frames.append(img)
return frames
else:
num_frames = min(len(cap), self.data_cfg['num_frames'])
indices = np.linspace(0, len(cap) - 1, num_frames, dtype=int)
frames = []
for i in indices:
rgb_frame = cap[i].asnumpy()[:, :, ::-1]
img = Image.fromarray(rgb_frame).convert('RGB')
frames.append(img)
frames = cap.get_batch(indices)

while len(frames) < self.data_cfg['num_frames']:
frames.append(frames[-1])
Expand Down Expand Up @@ -262,9 +260,13 @@ def preprocess_multimodal(sources: dict, multimodal_cfg: dict, cur_token_len: in
return sources

num_patches = image_token_len

if media_type == 'video':
num_patches *= multimodal_cfg['num_frames']

if multimodal_cfg['mm_mlp_adapter_type'] == 'mlp_downsample':
num_patches //= 4

if multimodal_cfg['use_im_start_end']:
replace_token = DEFAULT_IMAGE_PATCH_TOKEN[model_type] * num_patches
else:
Expand Down Expand Up @@ -922,9 +924,19 @@ def expand2square(pil_img, background_color):
media_tensors = torch.tensor([])
if images:
media_tensors = torch.stack(images)
cur_token_len = (media_tensors[0].shape[1] // 14) * (
media_tensors[0].shape[2] // 14
) # FIXME: 14 is hardcoded patch size
patch_dim = self.multimodal_cfg['patch_dim']

height_num_patches = media_tensors[0].shape[1] // patch_dim
width_num_patches = media_tensors[0].shape[2] // patch_dim

if self.multimodal_cfg['mm_mlp_adapter_type'] == 'mlp_downsample':
if height_num_patches % 2 != 0:
height_num_patches += 1
if width_num_patches % 2 != 0:
width_num_patches += 1

cur_token_len = height_num_patches * width_num_patches

sources = preprocess_multimodal(
copy.deepcopy(sources),
self.multimodal_cfg,
Expand Down Expand Up @@ -978,9 +990,19 @@ def expand2square(pil_img, background_color):
media_tensors = frames
if videos:
media_tensors = torch.stack(videos)
cur_token_len = (media_tensors[0].shape[-1] // 14) * (
media_tensors[0].shape[-2] // 14
) # FIXME: 14 is hardcoded patch size
patch_dim = self.multimodal_cfg['patch_dim']

height_num_patches = media_tensors[0].shape[-2] // patch_dim
width_num_patches = media_tensors[0].shape[-1] // patch_dim

if self.multimodal_cfg['mm_mlp_adapter_type'] == 'mlp_downsample':
if height_num_patches % 2 != 0:
height_num_patches += 1
if width_num_patches % 2 != 0:
width_num_patches += 1

cur_token_len = height_num_patches * width_num_patches

sources = preprocess_multimodal(
copy.deepcopy(sources),
self.multimodal_cfg,
Expand Down Expand Up @@ -1190,11 +1212,15 @@ def make_supervised_data_module(tokenizer, model_cfg) -> Dict:
add_extra_token = 1
if getattr(model_cfg, 'no_seqlen_plus_one_input_tokens', False):
add_extra_token = 0
crop_size = data_cfg.get("crop_size", (224, 224))
crop_size = mm_cfg.vision_encoder.get("crop_size", (224, 224))
if mm_cfg.vision_encoder.from_hf:
image_processor = CLIPImageProcessor.from_pretrained(
mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16
)
assert crop_size == (
image_processor.crop_size['height'],
image_processor.crop_size['width'],
), f"Crop size {crop_size} does not match the HuggingFace CLIP model's crop size {(image_processor.crop_size['height'], image_processor.crop_size['width'])}"
else:
# TODO(yuya): Fix this hard-code for our own CLIP
image_processor = image_transform(
Expand All @@ -1212,8 +1238,8 @@ def make_supervised_data_module(tokenizer, model_cfg) -> Dict:
sep_image_conv_front=data_cfg.sep_image_conv_front,
model_type=mm_cfg.llm.get("model_type", "nvgpt"),
conv_template=data_cfg.get("conv_template", "nvgpt"),
patch_dim=model_cfg.mm_cfg.vision_encoder.patch_dim,
crop_size=crop_size,
image_token_len=data_cfg.image_token_len,
image_folder=data_cfg.get('image_folder', None),
video_folder=data_cfg.get('video_folder', None),
image_aspect_ratio=data_cfg.image_aspect_ratio,
Expand All @@ -1223,6 +1249,7 @@ def make_supervised_data_module(tokenizer, model_cfg) -> Dict:
context_length=model_cfg.encoder_seq_length,
media_type=data_cfg.get('media_type', 'image'),
num_frames=data_cfg.get('num_frames', -1),
mm_mlp_adapter_type=model_cfg.mm_cfg.get('mm_mlp_adapter_type', 'linear'),
),
data_cfg=dict(
splice_single_frame=data_cfg.get('splice_single_frame', None),
Expand Down
15 changes: 8 additions & 7 deletions nemo/collections/multimodal/parts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import tempfile
from typing import Any, Callable, Tuple

import decord
import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf, open_dict
Expand Down Expand Up @@ -469,23 +470,23 @@ def expand2square(pil_img, background_color):

# add video processor for video neva
def video_processor(maybe_video_path):
from decord import VideoReader

if isinstance(maybe_video_path, str):
vr = VideoReader(maybe_video_path)
decord.bridge.set_bridge("torch")
vr = decord.VideoReader(maybe_video_path)
if neva_cfg.data.splice_single_frame == 'first':
frames = [Image.fromarray(vr[0].asnumpy()[:, :, ::-1]).convert('RGB')]
frames = [Image.fromarray(vr[0].asnumpy()).convert('RGB')]
elif neva_cfg.data.splice_single_frame == 'middle':
frames = [Image.fromarray(vr[len(vr) // 2].asnumpy()[:, :, ::-1]).convert('RGB')]
frames = [Image.fromarray(vr[len(vr) // 2].asnumpy()).convert('RGB')]
elif neva_cfg.data.splice_single_frame == 'last':
frames = [Image.fromarray(vr[-1].asnumpy()[:, :, ::-1]).convert('RGB')]
frames = [Image.fromarray(vr[-1].asnumpy()).convert('RGB')]
else:
if neva_cfg.data.num_frames == -1:
frames = [Image.fromarray(frame.asnumpy()[:, :, ::-1]).convert('RGB') for frame in vr]
frames = [Image.fromarray(frame.asnumpy()).convert('RGB') for frame in vr]
else:
num_frames = min(len(vr), neva_cfg.data.num_frames)
indices = np.linspace(0, len(vr) - 1, num_frames, dtype=int)
frames = [Image.fromarray(vr[i].asnumpy()[:, :, ::-1]).convert('RGB') for i in indices]
frames = vr.get_batch(indices)

while len(frames) < neva_cfg.data.num_frames:
frames.append(frames[-1])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,9 @@ def _get_init_fn(self, init_method: str):
raise NotImplementedError("out_init_method should be zero, normal, kaiming or xavier")
return init_fn

def adapter_unfreeze(self,):
def adapter_unfreeze(
self,
):
"""
Can be customized to allow for selective training of only some params in the PEFT.
"""
Expand Down Expand Up @@ -402,7 +404,7 @@ class LoraQAdapter(ParallelLinearAdapter):

class LoraDenseAttentionAdapter(ParallelLinearAdapter):
"""
Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes
Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes
and they do not use an bottleneck activation function
"""

Expand All @@ -411,7 +413,7 @@ class LoraDenseAttentionAdapter(ParallelLinearAdapter):

class LoraHto4HAdapter(ParallelLinearAdapter):
"""
Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes
Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes
and they do not use an bottleneck activation function
"""

Expand All @@ -420,7 +422,7 @@ class LoraHto4HAdapter(ParallelLinearAdapter):

class Lora4HtoHAdapter(ParallelLinearAdapter):
"""
Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes
Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes
and they do not use an bottleneck activation function
"""

Expand Down Expand Up @@ -688,14 +690,20 @@ def set_inference_table(self, prompt_representation: torch.Tensor):
self.is_inference_ready = True
return True

def clear_inference_table(self,):
def clear_inference_table(
self,
):
self.inference_table.fill_(0.0)
self.is_inference_ready = False

def get_inference_table(self,):
def get_inference_table(
self,
):
return self.inference_table.data

def inner_forward(self,):
def inner_forward(
self,
):
input_embeds = self.embedding(self.indices).unsqueeze(0)
intermediate_parallel, bias_parallel = self.first(input_embeds)
intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel)
Expand Down Expand Up @@ -890,6 +898,29 @@ class LoraKQVAdapterWeightTyingConfig(ParallelLinearAdapterWeightTyingConfig):
_target_: str = "{0}.{1}".format(LoraKQVAdapterWeightTying.__module__, LoraKQVAdapterWeightTying.__name__)


class DownSampleBlock(nn.Module):
def forward(self, x):
vit_embeds = x
h = w = int(vit_embeds.shape[3] ** 0.5)
vit_embeds = vit_embeds.reshape(*vit_embeds.shape[:3], h, w, -1)
vit_embeds = self.flat_square(vit_embeds)
vit_embeds = vit_embeds.reshape(*vit_embeds.shape[:3], -1, vit_embeds.shape[-1])
return vit_embeds

def flat_square(self, x):
b, T, F, h, w, c = x.size()
if w % 2 == 1:
x = torch.cat([x, torch.zeros((b, T, F, h, 1, c), dtype=x.dtype).to(x.device)], dim=4)
b, T, F, h, w, c = x.size()
if h % 2 == 1:
x = torch.cat([x, torch.zeros((b, T, F, 1, w, c), dtype=x.dtype).to(x.device)], dim=3)
b, T, F, h, w, c = x.size()
x = x.view(b, T, F, h, int(w / 2), int(c * 2))
x = x.permute(0, 1, 2, 4, 3, 5).contiguous()
x = x.view(b, T, F, int(h / 2), int(w / 2), int(c * 4))
return x


class MultimodalProjectorAdapter(nn.Module, AdapterModuleUtil):
def __init__(self, adapter_type: str, in_features: int, out_features: int, bias: bool, **kwargs) -> None:
super().__init__()
Expand All @@ -898,6 +929,14 @@ def __init__(self, adapter_type: str, in_features: int, out_features: int, bias:
self.mm_projector = torch.nn.Linear(in_features, out_features, bias)
elif adapter_type == 'identity':
self.mm_projector = lambda x: x
elif adapter_type == 'mlp_downsample':
self.mm_projector = torch.nn.Sequential(
DownSampleBlock(),
torch.nn.LayerNorm(in_features * 4),
torch.nn.Linear(in_features * 4, out_features, bias),
torch.nn.GELU(),
torch.nn.Linear(out_features, out_features, bias),
)
else:
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', adapter_type)
if mlp_gelu_match:
Expand Down
30 changes: 26 additions & 4 deletions nemo/collections/nlp/modules/common/text_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import List, Set, Tuple

import torch

from transformers import CLIPImageProcessor
from nemo.collections.nlp.modules.common.lm_utils import pad_batch
from nemo.collections.nlp.modules.common.megatron.module import Float16Module
from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids
Expand Down Expand Up @@ -533,7 +533,6 @@ class NevaModelTextGenerationStrategy(TextGenerationStrategy):
def __init__(self, model):
super().__init__(model)
self.forward_model = self.model.model
self.num_media_latents = model.cfg.data.get("image_token_len", 576)
self.tokenizer = self.model.tokenizer
self.image_paths = []
self.cfg = self.model.cfg
Expand All @@ -545,16 +544,39 @@ def __init__(self, model):
sep_image_conv_front=self.data_cfg.sep_image_conv_front,
conv_template=self.data_cfg.get("conv_template", "nvgpt"),
model_type=self.cfg.mm_cfg.llm.get("model_type", "nvgpt"),
image_token_len=self.data_cfg.image_token_len,
image_folder=self.data_cfg.image_folder,
patch_dim=self.cfg.mm_cfg.vision_encoder.patch_dim,
crop_size=self.cfg.mm_cfg.vision_encoder.get("crop_size", None),
image_folder=self.data_cfg.get('image_folder', None),
video_folder=self.data_cfg.get('video_folder', None),
image_aspect_ratio=self.data_cfg.image_aspect_ratio,
use_im_start_end=getattr(self.cfg.mm_cfg, 'use_im_start_end', False),
image_processor=None,
add_extra_token=add_extra_token,
context_length=self.cfg.encoder_seq_length,
media_type=getattr(self.data_cfg, 'media_type', 'image'),
num_frames=getattr(self.data_cfg, 'num_frames', 1),
mm_mlp_adapter_type=getattr(self.cfg.mm_cfg, 'mm_mlp_adapter_type', 'linear'),
)
if self.multimodal_cfg['crop_size'] is None:
image_processor = CLIPImageProcessor.from_pretrained(
self.cfg.mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16
)
self.multimodal_cfg['crop_size'] = (
image_processor.crop_size['height'],
image_processor.crop_size['width'],
)

patch_dim = self.multimodal_cfg['patch_dim']
height_num_patches = self.multimodal_cfg['crop_size'][0] // patch_dim
width_num_patches = self.multimodal_cfg['crop_size'][1] // patch_dim

if self.multimodal_cfg['mm_mlp_adapter_type'] == 'mlp_downsample':
if height_num_patches % 2 != 0:
height_num_patches += 1
if width_num_patches % 2 != 0:
width_num_patches += 1

self.num_media_latents = height_num_patches * width_num_patches

def clip_max_len(self, maxlen: int) -> int:
"""clip the max len based on the LM model max sequence length"""
Expand Down

0 comments on commit eb411ae

Please sign in to comment.