Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Evo2 merge 20250214 #12263

Merged
merged 67 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
995dc9c
Initial commit of Hyena model needed for Evo2
JRD971000 Feb 14, 2025
55d6548
Delete attention.py
jstjohn Feb 14, 2025
b2a4e19
Add missing imports and update forward of gpt model
jstjohn Feb 14, 2025
e1b8b20
Add in blended dataset config test for evo2
jstjohn Feb 14, 2025
c0c4bbd
Add ability to change dataset class
jstjohn Feb 15, 2025
688e8ce
Alit/evo2 merge 20250214
JRD971000 Feb 17, 2025
c28efaf
Merge branch 'alit/evo2-merge-20250214' into 'evo2-merge-20250214'
JRD971000 Feb 17, 2025
733a79d
Performance improvement and fix for masking test
jstjohn Feb 18, 2025
363c015
Remove no grad decorator
jstjohn Feb 18, 2025
54a361b
Fixup and simplify token mask logic
jstjohn Feb 18, 2025
a32ac16
Update tests and code for non-dna safety
jstjohn Feb 18, 2025
33d8957
More tests on masking logic
jstjohn Feb 18, 2025
9816ff1
Switch renormlization to be per row rather than per micro-batch
jstjohn Feb 18, 2025
fadacc3
Safe handling of divide by zero and handle control chars in phylo tag…
jstjohn Feb 19, 2025
020a508
Apply isort and black reformatting
JRD971000 Feb 19, 2025
a7a5092
Apply isort and black reformatting
artbataev Feb 19, 2025
9c3fb74
Add profiling benchmarking to our evo2 dataset tests
jstjohn Feb 19, 2025
3df19a7
Address formatting and linting errors
jstjohn Feb 21, 2025
7d52494
Merge branch 'main' of github.com:NVIDIA/NeMo into evo2-merge-20250214
jstjohn Feb 21, 2025
dd14d63
Address more pylint warnings
jstjohn Feb 21, 2025
02a4e35
Address pylance errors
jstjohn Feb 21, 2025
d50f267
Merge branch 'main' of github.com:NVIDIA/NeMo into evo2-merge-20250214
jstjohn Feb 21, 2025
25e0ce0
More pylint fixes
jstjohn Feb 21, 2025
e83d7bb
DCO Sign-off for previous commits
jstjohn Feb 21, 2025
fc0bf3b
Address PR feedback
jstjohn Feb 21, 2025
c081ba3
Address copilot PR feedback
jstjohn Feb 21, 2025
4480311
Adding hyena L2 test to CI/CD
jstjohn Feb 21, 2025
04eba8e
Address import error exception issue
jstjohn Feb 21, 2025
4b55408
Update kingdom -> domain in evo2 taxonomy token string
jstjohn Feb 21, 2025
8b9fdb2
Merge branch 'main' of github.com:NVIDIA/NeMo into evo2-merge-20250214
jstjohn Feb 21, 2025
4d1fe01
Create dictionary of standard string representations of hyena models …
jstjohn Feb 21, 2025
266c722
Merge branch 'main' into evo2-merge-20250214
ko3n1g Feb 22, 2025
c83ef09
Adding a hugging face importer and 1b model configs for lighter testing
jstjohn Feb 22, 2025
21300b0
Merge branch 'evo2-merge-20250214' of github.com:NVIDIA/NeMo into evo…
jstjohn Feb 22, 2025
a9867cc
Fix missing import
jstjohn Feb 22, 2025
9111382
Merge branch 'main' into evo2-merge-20250214
jstjohn Feb 22, 2025
471c9bf
fix dist sampler
JRD971000 Feb 22, 2025
5d2b612
revert dist sampler true
JRD971000 Feb 24, 2025
d280df1
Merge branch 'main' of github.com:NVIDIA/NeMo into evo2-merge-20250214
jstjohn Feb 24, 2025
fe99a49
Fix the multi-part download naming in savanna
jstjohn Feb 24, 2025
fa3fac0
Adding 1b models to main llm import
jstjohn Feb 24, 2025
d8ac417
Merge branch 'main' of github.com:NVIDIA/NeMo into evo2-merge-20250214
jstjohn Feb 25, 2025
30a7824
Fix failing LLM CPU unit tests
jstjohn Feb 25, 2025
2be3af5
Addressing PR feedback
jstjohn Feb 25, 2025
7921c44
bug fixing
dorotat-nv Feb 26, 2025
081ae40
Apply isort and black reformatting
JRD971000 Feb 26, 2025
8f32cbf
Merge branch 'main' into evo2-merge-20250214
JRD971000 Feb 26, 2025
8b8b515
add header to test_flops_callback.py
JRD971000 Feb 26, 2025
ae7a387
Add hyena stage to CI/CD requirements and address flops callback unus…
jstjohn Feb 27, 2025
8ca01eb
Merge branch 'main' of github.com:NVIDIA/NeMo into evo2-merge-20250214
jstjohn Feb 27, 2025
44d08b5
Use the custom eos/bos tokens passed to the bytelevel tokenizer
jstjohn Feb 27, 2025
3c1b74e
Add back TE import guards so CI passes
jstjohn Feb 27, 2025
5ced598
Merge branch 'main' of github.com:NVIDIA/NeMo into evo2-merge-20250214
jstjohn Feb 27, 2025
353d346
use cache
akoumpa Mar 2, 2025
8a7fa12
Apply isort and black reformatting
akoumpa Mar 2, 2025
a6971ca
fix no te
akoumpa Mar 2, 2025
9b4737b
Apply isort and black reformatting
akoumpa Mar 2, 2025
bc69b7a
Update test_save_restore.py
akoumpa Mar 2, 2025
3ea659b
David Guzman review of Evo2 (#12440)
jstjohn Mar 2, 2025
1db1d3b
Fix pylint and flake8 issues
jstjohn Mar 2, 2025
c77f262
fix cli args for pretraining test
jstjohn Mar 2, 2025
7fc0637
Move conv init into rng tracker
jstjohn Mar 2, 2025
faf8f3f
Address flake8 issues and remove unused function
jstjohn Mar 2, 2025
b0e1d74
Merge branch 'main' into evo2-merge-20250214
jstjohn Mar 2, 2025
917271c
Handle TE not being installed test a bit better
jstjohn Mar 2, 2025
aee348c
define even if TE missing
akoumpa Mar 2, 2025
31c132f
fix
akoumpa Mar 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions nemo/collections/common/tokenizers/bytelevel_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,31 @@ def normalize(self, text) -> str:


class ByteLevelTokenizer(TokenizerSpec):
def __init__(self, special_tokens: Optional[Union[Dict[str, str], List[str]]] = None):
self.vocab_size = 259
self.special_start = 256
def __init__(
self,
special_tokens: Optional[Union[Dict[str, str], List[str]]] = None,
vocab_size: int = 512,
_eos_id: int = 0,
_pad_id: int = 1,
_bos_id: int = None,
):
self.vocab_size = vocab_size if special_tokens is None else vocab_size + len(special_tokens)
self.special_start = vocab_size
self._eos_id = _eos_id
self._pad_id = _pad_id
self._bos_id = _bos_id
self.special_token_to_id = {
self.pad_id: self.pad_id,
self.bos_id: self.bos_id,
self.eos_id: self.eos_id,
}
# Track special byte-tokens at end of vocabulary.
self.vocab_size = vocab_size if special_tokens is None else vocab_size + len(special_tokens)
self.special_start = self.vocab_size
special_tokens = {} if special_tokens is None else special_tokens
for tok in special_tokens:
self.special_start -= 1
self.special_token_to_id[tok] = self.special_start

self.id_to_special_token = {v: k for k, v in self.special_token_to_id.items()}

# no distinction between tokens and ids.
Expand All @@ -61,10 +73,16 @@ def tokens_to_text(self, tokens):
def text_to_ids(self, text):
return list(text.encode('utf-8'))

def decode_token(self, token: int):
return str(chr(self.clamp(token)))

def clamp(self, n):
return max(32, min(n, self.vocab_size))

def ids_to_text(self, ids):
# remove special tokens.
ids = [x for x in ids if x < self.special_start]
return bytes(ids).decode('utf-8', errors='ignore').rstrip()
return "".join(list(map(self.decode_token, ids)))

def tokens_to_ids(self, tokens):
if isinstance(tokens, str):
Expand All @@ -89,23 +107,19 @@ def token_to_id(self, token):
return token

def id_to_token(self, id):
if id < self.special_start:
if id not in self.id_to_special_token:
return id
else:
return self.id_to_special_token[id]

@property
def pad_id(self):
return 256

@property
def bos_id(self):
return 257
return self._pad_id

@property
def eos_id(self):
return 258
return self._eos_id

@property
def unk_id(self):
return 259 # unused
def bos_id(self):
return self._bos_id
20 changes: 20 additions & 0 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@
GPTConfig175B,
GPTModel,
HFAutoModelForCausalLM,
Hyena7bARCLongContextConfig,
Hyena7bConfig,
Hyena40bARCLongContextConfig,
Hyena40bConfig,
HyenaConfig,
HyenaModel,
HyenaNV7bConfig,
HyenaNV40bConfig,
HyenaNVTestConfig,
HyenaTestConfig,
Llama2Config7B,
Llama2Config13B,
Llama2Config70B,
Expand Down Expand Up @@ -156,6 +166,16 @@
"CustomRetrievalDataModule",
"GPTModel",
"GPTConfig",
"HyenaTestConfig",
"Hyena7bConfig",
"Hyena40bConfig",
"Hyena7bARCLongContextConfig",
"Hyena40bARCLongContextConfig",
"HyenaNVTestConfig",
"HyenaNV40bConfig",
"HyenaNV7bConfig",
"HyenaConfig",
"HyenaModel",
"gpt_data_step",
"gpt_forward_step",
"T5Model",
Expand Down
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the reason for this megatron folder in llm/gpt/data?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are sub-classes of megatron data modules that were decided shouldn't go into the megatron-lm repo.

Empty file.
2 changes: 2 additions & 0 deletions nemo/collections/llm/gpt/data/megatron/hyena/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .config import parse_dataset_config
from .evo2_dataset import Evo2Dataset
164 changes: 164 additions & 0 deletions nemo/collections/llm/gpt/data/megatron/hyena/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024 Arc Institute. All rights reserved.
# Copyright (c) 2024 Michael Poli. All rights reserved.
# Copyright (c) 2024 Stanford University. All rights reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
from pathlib import Path
from typing import Literal, Optional

import yaml
from pydantic import BaseModel, model_validator


def infer_global_batch_size(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this method is in this file? we used to have it in bionemo, do you need it in NeMo? if yes, then it shouldn't be under megatron/hyena

micro_batch_size: int,
num_nodes: int,
devices: int,
accumulate_grad_batches: int = 1,
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
context_model_parallel_size: int = 1,
) -> int:
"""Infers the global batch size based on the micro batch size, number of nodes, devices, accumulation of gradient batches, and model parallel sizes.

Args:
micro_batch_size (int): The micro batch size.
num_nodes (int): The number of nodes.
devices (int): The number of devices.
accumulate_grad_batches (int): The accumulation of gradient batches. Defaults to 1.
tensor_model_parallel_size (int): The tensor model parallel size. Defaults to 1.
pipeline_model_parallel_size (int): The pipeline model parallel size. Defaults to 1.
context_model_parallel_size (int): The context model parallel size. Defaults to 1.

Returns:
int: The global batch size.
"""
if not all(
isinstance(arg, int)
for arg in [
micro_batch_size,
num_nodes,
devices,
accumulate_grad_batches,
tensor_model_parallel_size,
pipeline_model_parallel_size,
context_model_parallel_size,
]
):
raise ValueError(
f"All arguments must be of type int, got {type(micro_batch_size)}, {type(num_nodes)}, {type(devices)}, "
f"{type(accumulate_grad_batches)}, {type(tensor_model_parallel_size)}, {type(pipeline_model_parallel_size)}, and {type(context_model_parallel_size)}"
)
if micro_batch_size <= 0:
raise ValueError(f"micro_batch_size must be greater than 0, got {micro_batch_size}")
if num_nodes <= 0:
raise ValueError(f"num_nodes must be greater than 0, got {num_nodes}")
if devices <= 0:
raise ValueError(f"devices must be greater than 0, got {devices}")
if accumulate_grad_batches <= 0:
raise ValueError(f"accumulate_grad_batches must be greater than 0, got {accumulate_grad_batches}")
if tensor_model_parallel_size <= 0:
raise ValueError(f"tensor_model_parallel_size must be greater than 0, got {tensor_model_parallel_size}")
if pipeline_model_parallel_size <= 0:
raise ValueError(f"pipeline_model_parallel_size must be greater than 0, got {pipeline_model_parallel_size}")
if context_model_parallel_size <= 0:
raise ValueError(f"context_model_parallel_size must be greater than 0, got {context_model_parallel_size}")

world_size = num_nodes * devices
if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size * context_model_parallel_size) != 0:
raise ValueError(
f"world_size must be divisible by tensor_model_parallel_size * pipeline_model_parallel_size * context_model_parallel_size, "
f"got {world_size} and TP{tensor_model_parallel_size} * PP{pipeline_model_parallel_size} * CP{context_model_parallel_size}"
)

model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_model_parallel_size
data_parallel_size = world_size // model_parallel_size
global_batch_size = micro_batch_size * data_parallel_size * accumulate_grad_batches
return global_batch_size


class Evo2BlendedDatasetConfig(BaseModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, shouldn't Evo2BlendedDatasetConfig in fact be BlendedDatasetConfig and be located somewhere here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems that in NeMo you pass mostly paths from command line

data = llm.PreTrainingDataModule(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""Configuration for blended dataset specifications.

Validates and constructs dataset paths, weights and splits configuration.
Ensures dataset paths exist and are properly resolved relative to base data path.

Attributes:
dataset_path: Base directory path for datasets. Used to resolve relative dataset prefixes.
dataset_prefix: Path prefix for dataset files. Can be absolute or relative to dataset_path.
dataset_weight: Weight factor for this dataset during blending (0-1).
dataset_split: Dataset partition - 'train', 'validation' or 'test'.

Raises:
ValueError: If dataset path doesn't exist or prefix can't be resolved.
"""

dataset_path: str | None = None
dataset_prefix: str
dataset_weight: float
dataset_split: Literal["train", "validation", "test"]

@model_validator(mode="before")
@classmethod
def validate_dataset_prefix(cls, values: dict) -> dict:
"""Ensure dataset_prefix paths exist and are properly resolved or are relative to base dataset_path if provided."""
dataset_path = Path(values.get("dataset_path")) if values.get("dataset_path") else None
prefix = Path(values.get("dataset_prefix"))

if not prefix.is_absolute():
if dataset_path:
prefix = dataset_path / prefix
else:
prefix = Path(prefix).resolve()
parent = prefix.parent
stem = prefix.stem
if not parent.exists():
raise ValueError(f"dataset_prefix parent path does not exist: {parent}")
matching_files = list(parent.glob(f"{stem}.*"))
if not matching_files:
raise ValueError(f"dataset_prefix file does not exist: {prefix}")
values["dataset_prefix"] = str(prefix)
return values


def parse_dataset_config(dataset_config_path: str, dataset_path: Optional[str] = None):
"""Parse the blended training datasplit configuration and renormalize data split weights for training Hyena.

Args:
dataset_config_path (str): Path to the dataset configuration YAML file.
dataset_path (str): Path to the dataset directory. Defaults to None.

Returns:
defaultdict: A dictionary where keys are dataset splits and values are lists containing the normalized weight
and dataset prefix for each split.
"""
blended_dataset_config = defaultdict(list)
weight_sums = defaultdict(float)
with open(dataset_config_path, "r") as config_file:
dataset_config_batch = yaml.safe_load(config_file)
for dataset_config in dataset_config_batch:
# Validate.
config_model = Evo2BlendedDatasetConfig(dataset_path=dataset_path, **dataset_config)
# Integrate the weights for renormalization.
weight_sums[config_model.dataset_split] += abs(config_model.dataset_weight)
for dataset_config in dataset_config_batch:
# Validate.
config_model = Evo2BlendedDatasetConfig(dataset_path=dataset_path, **dataset_config)
# Add indexed dataset to split and associate with blended training weight.
blended_dataset_config[config_model.dataset_split].extend(
[config_model.dataset_weight / weight_sums[config_model.dataset_split], config_model.dataset_prefix]
)
return blended_dataset_config
Loading
Loading