Skip to content

Commit

Permalink
Cherry pick llava next changes that were missed for the release (#11783)
Browse files Browse the repository at this point in the history
* have micro_batch_size and global_batch_size as class attributes in mock datamodule (#11563)

* Bug fix - Inheritance of LLaVATemplateConfig and MLlamaTemplateConfig dataclasses (#11661)

* change LLaVATemplateConfig variables to class variables

* change to use field with default attributes

* Apply isort and black reformatting

Signed-off-by: yashaswikarnati <[email protected]>

* added type info for MLLamaTemplateConfig

* Apply isort and black reformatting

Signed-off-by: yashaswikarnati <[email protected]>

---------

Signed-off-by: yashaswikarnati <[email protected]>
Co-authored-by: yashaswikarnati <[email protected]>

* Github Actions tests for Llava Next and modify pretrain recipe to have language model path (#11424)

* modified pretrain recipe to have language_model_from_pretrained

* ci test for llava next

* fixed indent/lint issue in cicd yml file

* fix lint issues

* Apply isort and black reformatting

Signed-off-by: yashaswikarnati <[email protected]>

* Update .github/workflows/cicd-main.yml

Co-authored-by: oliver könig <[email protected]>
Signed-off-by: Yashaswi Karnati <[email protected]>

* Update .github/workflows/cicd-main.yml

Co-authored-by: oliver könig <[email protected]>
Signed-off-by: Yashaswi Karnati <[email protected]>

---------

Signed-off-by: yashaswikarnati <[email protected]>
Signed-off-by: Yashaswi Karnati <[email protected]>
Co-authored-by: yashaswikarnati <[email protected]>
Co-authored-by: oliver könig <[email protected]>

---------

Signed-off-by: yashaswikarnati <[email protected]>
Signed-off-by: Yashaswi Karnati <[email protected]>
Co-authored-by: yashaswikarnati <[email protected]>
Co-authored-by: oliver könig <[email protected]>
  • Loading branch information
3 people authored Jan 8, 2025
1 parent 053531d commit 151a362
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 24 deletions.
16 changes: 16 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4557,6 +4557,21 @@ jobs:
AFTER_SCRIPT: |
rm -rf /tmp/nemo2_ckpt
rm -rf /tmp/nemo2_ptq_engine
L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure-gpus-1
SCRIPT: |
python tests/collections/vlm/test_llava_next_train.py \
--devices=1 \
--max-steps=5 \
--experiment-dir=/tmp/nemo2_llava_next_results/${{ github.run_id }}
AFTER_SCRIPT: |
rm -rf /tmp/nemo2_llava_next_results
Nemo_CICD_Test:
needs:
Expand Down Expand Up @@ -4716,6 +4731,7 @@ jobs:
- L2_Megatron_GPT_Reranker
- L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact
- L2_NeMo_2_PTQ_Llama2_FP8
- L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING
if: always()
runs-on: ubuntu-latest
steps:
Expand Down
28 changes: 17 additions & 11 deletions nemo/collections/multimodal/data/energon/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,24 @@
class BaseConversationTemplateConfig:
"""Conversation template config related parameters"""

system: Optional[str] = "".format() # fmt: off
system: Optional[str] = ""
roles: List[str] = field(default_factory=lambda: ['user', 'assistant'])
stop_string: Optional[str] = None
chat_template = None


@dataclass
class LLaVATemplateConfig(BaseConversationTemplateConfig):
"""LLava specific template configuration which extends the base config"""
"""LLava-specific template configuration which extends the base config"""

system: Optional[str] = (
"A chat between a curious user and artificial assistant agent. The assistant gives helpful, detailed and polite answers to user's questions.".format()
) # fmt: off
system: str = field(
default="A chat between a curious user and artificial assistant agent. "
"The assistant gives helpful, detailed and polite answers to user's questions."
)
roles: List[str] = field(default_factory=lambda: ['user', 'assistant'])
stop_string: str = "</s>"
chat_template = """
stop_string: str = field(default="</s>")
chat_template: str = field(
default="""
{%- for message in messages %}
{%- if message['role'] == 'system' %}
{{- message['content'].strip() + ' ' -}}
Expand All @@ -45,14 +48,17 @@ class LLaVATemplateConfig(BaseConversationTemplateConfig):
{%- endif %}
{%- endfor -%}
"""
)


class MLlamaTemplateConfig(BaseConversationTemplateConfig):
"""LLava specific template configuration which extends the base config"""
"""MLlama specific template configuration which extends the base config"""

system: Optional[str] = None
system: str = field(default=None)
roles: List[str] = field(default_factory=lambda: ['user', 'assistant'])
stop_string: str = None
chat_template = """
stop_string: str = field(default=None)
chat_template: str = field(
default="""
'{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now("%d %b %Y") %}\n {%- else %}\n {%- set date_string = "26 Jul 2024" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0][\'role\'] == \'system\' %}\n {%- set system_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = "" %}\n{%- endif %}\n\n{#- Find out if there are any images #}\n{% set image_ns = namespace(has_images=false) %} \n{%- for message in messages %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {%- set image_ns.has_images = true %}\n {%- endif %}\n {%- endfor %}\n{%- endfor %}\n\n{#- Error out if there are images and system message #}\n{%- if image_ns.has_images and not system_message == "" %}\n {{- raise_exception("Prompting with images is incompatible with system messages.") }}\n{%- endif %}\n\n{#- System message if there are no images #}\n{%- if not image_ns.has_images %}\n {{- "<|start_header_id|>system<|end_header_id|>\\n\\n" }}\n {%- if tools is not none %}\n {{- "Environment: ipython\\n" }}\n {%- endif %}\n {{- "Cutting Knowledge Date: December 2023\\n" }}\n {{- "Today Date: " + date_string + "\\n\\n" }}\n {%- if tools is not none and not tools_in_user_message %}\n {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- "<|eot_id|>" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception("Cannot put tools in the first user message when there\'s no first user message!") }}\n{%- endif %}\n {{- \'<|start_header_id|>user<|end_header_id|>\\n\\n\' -}}\n {{- "Given the following functions, please respond with a JSON for a function call " }}\n {{- "with its proper arguments that best answers the given prompt.\\n\\n" }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {{- first_user_message + "<|eot_id|>"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == \'ipython\' or message.role == \'tool\' or \'tool_calls\' in message) %}\n {{- \'<|start_header_id|>\' + message[\'role\'] + \'<|end_header_id|>\\n\\n\' }}\n {%- if message[\'content\'] is string %}\n {{- message[\'content\'] }}\n {%- else %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {{- \'<|image|>\' }}\n {%- elif content[\'type\'] == \'text\' %}\n {{- content[\'text\'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- \'<|eot_id|>\' }}\n {%- elif \'tool_calls\' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception("This model only supports single tool-calls at once!") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' -}}\n {{- \'{"name": "\' + tool_call.name + \'", \' }}\n {{- \'"parameters": \' }}\n {{- tool_call.arguments | tojson }}\n {{- "}" }}\n {{- "<|eot_id|>" }}\n {%- elif message.role == "tool" or message.role == "ipython" %}\n {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' }}\n{%- endif %}\n'
"""
)
14 changes: 8 additions & 6 deletions nemo/collections/vlm/llava_next/data/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils import data
from torch.utils.data import DataLoader, Dataset
from transformers import AutoProcessor

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.vlm.neva.data.multimodal_tokens import IMAGE_TOKEN_INDEX
from nemo.lightning.pytorch.plugins import MegatronDataSampler
from nemo.utils import logging
Expand Down Expand Up @@ -77,20 +79,20 @@ def __init__(
self.num_workers = num_workers
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers

self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
model_name = ''
processor = None
if tokenizer is None or image_processor is None:
logging.warning(
f"Processor or tokenizer are not provided! Fall back to `llava-hf/llava-v1.6-vicuna-7b-hf`."
)
from transformers import AutoProcessor

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer

model_name = "llava-hf/llava-v1.6-vicuna-7b-hf"

processor = AutoProcessor.from_pretrained(model_name)
self.tokenizer = tokenizer or AutoTokenizer(model_name)
self.image_processor = image_processor or processor.image_processor
self.tokenizer = tokenizer or AutoTokenizer(model_name)
self.image_processor = image_processor or processor.image_processor
self.data_sampler = MegatronDataSampler(
seq_len=self.seq_length,
decoder_seq_len=self.decoder_seq_len,
Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/vlm/recipes/llava_next_7b.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def pretrain_recipe(
name: str = "default",
num_nodes: int = 1,
num_gpus_per_node: int = 8,
peft_scheme: Optional[str] = 'none',
language_model_from_pretrained: Optional[str] = None,
) -> run.Partial:
"""
Create a Pre-training recipe for Llava1.6 7B model.
Expand Down Expand Up @@ -223,6 +223,7 @@ def pretrain_recipe(
freeze_language_model=True,
freeze_vision_model=True,
freeze_vision_projection=False,
language_model_from_pretrained=language_model_from_pretrained,
)
),
trainer=trainer,
Expand Down
13 changes: 7 additions & 6 deletions scripts/vlm/llava_next_nemo_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
from nemo.collections import vlm


def configure_recipe(nodes: int = 1, gpus_per_node: int = 8, pretrain=False):
def configure_recipe(nodes: int = 1, gpus_per_node: int = 8, pretrain=False, language_model_from_pretrained=None):
"""Configure the recipe"""
if pretrain:
recipe = vlm.llava_next_7b.pretrain_recipe(
dir="./outputs/checkpoints/llava", # Path to store checkpoints
name="llava_pretrain",
num_nodes=nodes,
num_gpus_per_node=gpus_per_node,
language_model_from_pretrained=language_model_from_pretrained,
)
else:
recipe = vlm.llava_next_7b.finetune_recipe(
Expand All @@ -33,8 +34,8 @@ def configure_recipe(nodes: int = 1, gpus_per_node: int = 8, pretrain=False):
num_nodes=nodes,
num_gpus_per_node=gpus_per_node,
)
recipe.trainer.max_steps = 100
recipe.trainer.val_check_interval = 100
recipe.trainer.max_steps = 20
recipe.trainer.val_check_interval = 20
recipe.model.config.freeze_vision_model = True
return recipe

Expand All @@ -49,9 +50,9 @@ def local_executor_torchrun(nodes: int = 1, devices: int = 8) -> run.LocalExecut
return executor


def run_pretraining():
def run_pretraining(language_model_from_pretrained=None):
# pylint: disable=C0115,C0116
recipe = configure_recipe(pretrain=True)
recipe = configure_recipe(pretrain=True, language_model_from_pretrained=language_model_from_pretrained)
executor = local_executor_torchrun(nodes=recipe.trainer.num_nodes, devices=recipe.trainer.devices)

run.run(recipe, executor=executor)
Expand All @@ -67,5 +68,5 @@ def run_finetuning():

# This condition is necessary for the script to be compatible with Python's multiprocessing module.
if __name__ == "__main__":
run_pretraining()
run_pretraining(language_model_from_pretrained='/root/.cache/nemo/models/lmsys/vicuna-7b-v1.5/')
# run_finetuning()
157 changes: 157 additions & 0 deletions tests/collections/vlm/test_llava_next_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## NOTE: This script is present for github-actions testing only.
## There are no guarantees that this script is up-to-date with latest NeMo.

import argparse

import torch
from megatron.core.optimizer import OptimizerConfig
from pytorch_lightning.loggers import TensorBoardLogger
from transformers import AutoProcessor

from nemo import lightning as nl
from nemo.collections import llm, vlm
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.llm.api import train
from nemo.lightning import AutoResume, NeMoLogger
from nemo.lightning.pytorch.callbacks import ModelCheckpoint, ParameterDebugger
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule


def get_args():
# pylint: disable=C0115,C0116
parser = argparse.ArgumentParser(description='Train a small Llava Next model using NeMo 2.0')
parser.add_argument('--devices', type=int, default=1, help="Number of devices to use for training")
parser.add_argument('--max-steps', type=int, default=5, help="Number of steps to train for")
parser.add_argument(
'--experiment-dir', type=str, default=None, help="directory to write results and checkpoints to"
)

return parser.parse_args()


if __name__ == '__main__':

args = get_args()

gbs = 2
mbs = 2
decoder_seq_length = 1024
processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
tokenizer = AutoTokenizer("llava-hf/llava-v1.6-vicuna-7b-hf")

data = vlm.LlavaNextMockDataModule(
seq_length=decoder_seq_length,
tokenizer=tokenizer,
image_processor=processor.image_processor,
global_batch_size=gbs,
micro_batch_size=mbs,
num_workers=0,
)

# Transformer configurations
language_transformer_config = llm.Llama2Config7B(seq_length=decoder_seq_length, num_layers=2)

vision_transformer_config = vlm.HFCLIPVisionConfig(
pretrained_model_name_or_path="openai/clip-vit-large-patch14-336"
)
vision_projection_config = vlm.MultimodalProjectorConfig(
projector_type="mlp2x_gelu",
input_size=1024,
hidden_size=4096,
ffn_hidden_size=4096,
)

# Llava Next model configuration
neva_config = vlm.LlavaNextConfig(
language_transformer_config=language_transformer_config,
vision_transformer_config=vision_transformer_config,
vision_projection_config=vision_projection_config,
freeze_language_model=True,
freeze_vision_model=True,
)

model = vlm.LlavaNextModel(neva_config, tokenizer=data.tokenizer)

strategy = nl.MegatronStrategy(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
encoder_pipeline_model_parallel_size=0,
pipeline_dtype=torch.bfloat16,
)
checkpoint_callback = ModelCheckpoint(
every_n_train_steps=5000,
save_optim_on_train_end=True,
)

def create_verify_precision(precision: torch.dtype):
def verify_precision(tensor: torch.Tensor) -> None:
assert tensor.dtype == precision

return verify_precision

debugger = ParameterDebugger(
param_fn=create_verify_precision(torch.bfloat16),
grad_fn=create_verify_precision(torch.float32),
log_on_hooks=["on_train_start", "on_train_end"],
)
callbacks = [checkpoint_callback, debugger]

loggers = []
tensorboard_logger = TensorBoardLogger(
save_dir='dummy', ## NOTE: this gets overwritten by default
)
loggers.append(tensorboard_logger)

opt_config = OptimizerConfig(
optimizer='adam',
lr=6e-4,
min_lr=6e-5,
use_distributed_optimizer=False,
bf16=True,
)
opt = MegatronOptimizerModule(config=opt_config)

trainer = nl.Trainer(
devices=args.devices,
max_steps=args.max_steps,
accelerator="gpu",
strategy=strategy,
logger=loggers,
callbacks=callbacks,
log_every_n_steps=1,
limit_val_batches=2,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
)

nemo_logger = NeMoLogger(
log_dir=args.experiment_dir,
)

resume = AutoResume(
resume_if_exists=True,
resume_ignore_no_checkpoint=True,
)

train(
model=model,
data=data,
trainer=trainer,
log=nemo_logger,
resume=resume,
tokenizer='data',
optim=opt,
)

0 comments on commit 151a362

Please sign in to comment.