diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 8dcbb5afccc0..8196d750e9a3 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -2103,6 +2103,35 @@ jobs: rm -rf tests/collections/llm/gpt_pretrain_results rm -rf tests/collections/llm/gpt_index_mappings + L2_NeMo_2_Hyena_DDP_Pretraining_Test: + needs: [pre-flight, cicd-test-container-build] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Hyena_DDP_Pretraining_Test') + with: + RUNNER: self-hosted-azure # Assume runner has 2 GPUs + SCRIPT: | + python tests/collections/llm/gpt/model/test_hyena.py \ + --mock-data \ + --experiment-dir=tests/collections/llm/hyena_pretrain_results/${{ github.run_id }} \ + --model-size=7b_nv \ + --num-layers=4 \ + --hybrid-override-pattern=SDH* \ + --no-activation-checkpointing \ + --add-bias-output \ + --max-steps=5 \ + --warmup-steps=1 \ + --micro-batch-size=2 \ + --global-batch-size=4 \ + --no-wandb \ + --seq-length=128 \ + --hidden-dropout=0.01 \ + --attention-dropout=0.01 \ + --devices=2 \ + --debug-ddp-parity-freq=1 + + AFTER_SCRIPT: | + rm -rf tests/collections/llm/hyena_pretrain_results/${{ github.run_id }} + L2_NeMo_2_SSM_Pretraining: needs: [pre-flight, cicd-test-container-build] uses: ./.github/workflows/_test_template.yml @@ -3102,6 +3131,7 @@ jobs: - L2_HF_Transformer_SpeechLM_SFT_2gpu - L2_NeMo_2_SSM_Pretraining - L2_NeMo_2_SSM_Finetuning + - L2_NeMo_2_Hyena_DDP_Pretraining_Test - L2_NeMo_2_T5_Pretraining - L2_NeMo_2_T5_Finetuning - L2_NeMo_2_T5_LoRA diff --git a/nemo/collections/common/tokenizers/bytelevel_tokenizers.py b/nemo/collections/common/tokenizers/bytelevel_tokenizers.py index eb965b082815..f850fc9ca5ab 100644 --- a/nemo/collections/common/tokenizers/bytelevel_tokenizers.py +++ b/nemo/collections/common/tokenizers/bytelevel_tokenizers.py @@ -26,47 +26,106 @@ class ByteLevelProcessor: """ def detokenize(self, tokens: List[str]) -> str: + """ + Detokenize a list of tokens into a string. + """ return ' '.join(tokens) - def tokenize(self, text) -> str: - return text + def tokenize(self, text: str) -> List[str]: + """ + Tokenize a string into a list of tokens. + """ + return list(text) - def normalize(self, text) -> str: + def normalize(self, text: str) -> str: + """ + Normalize a string. + """ return text class ByteLevelTokenizer(TokenizerSpec): - def __init__(self, special_tokens: Optional[Union[Dict[str, str], List[str]]] = None): - self.vocab_size = 259 - self.special_start = 256 + """ + A byte-level tokenizer that encodes text as UTF-8 bytes with user control over the EOS, BOS, and PAD + tokens as well as the vocabulary size and a mapping of other special tokens to their IDs. + """ + + def __init__( + self, + special_tokens: Optional[Union[Dict[str, str], List[str]]] = None, + vocab_size: int = 512, + _eos_id: int = 0, + _pad_id: int = 1, + _bos_id: int = None, + ): + """A byte-level tokenizer that encodes text as UTF-8 bytes. + + This tokenizer treats each byte as a token, with a default vocabulary size of 512 to accommodate + UTF-8 byte values (0-255) plus special tokens. It can handle arbitrary text input by encoding + it into bytes. + + Args: + special_tokens: Dictionary or list of special tokens to add to the vocabulary. + These tokens will be assigned IDs at the end of the vocabulary. + Defaults to None. + vocab_size: Size of the vocabulary, should be at least 256 to handle all byte values. + Special tokens will be added after this size. + Defaults to 512. + _eos_id: ID to use for the end-of-sequence token. + Defaults to 0. + _pad_id: ID to use for the padding token. + Defaults to 1. + _bos_id: ID to use for the beginning-of-sequence token. + Defaults to None. + """ + self._eos_id = _eos_id + self._pad_id = _pad_id + self._bos_id = _bos_id self.special_token_to_id = { self.pad_id: self.pad_id, self.bos_id: self.bos_id, self.eos_id: self.eos_id, } + # Track special byte-tokens at end of vocabulary. + self.vocab_size = vocab_size if special_tokens is None else vocab_size + len(special_tokens) + self.special_start = self.vocab_size special_tokens = {} if special_tokens is None else special_tokens for tok in special_tokens: self.special_start -= 1 self.special_token_to_id[tok] = self.special_start - self.id_to_special_token = {v: k for k, v in self.special_token_to_id.items()} # no distinction between tokens and ids. def text_to_tokens(self, text): + """ + Convert a text to a list of tokens. + """ return self.text_to_ids(text) def tokens_to_text(self, tokens): + """ + Convert a list of tokens to a text. + """ return self.ids_to_text(tokens) def text_to_ids(self, text): + """ + Convert a text to a list of IDs. + """ return list(text.encode('utf-8')) def ids_to_text(self, ids): + """ + Convert a list of IDs to a text. + """ # remove special tokens. ids = [x for x in ids if x < self.special_start] return bytes(ids).decode('utf-8', errors='ignore').rstrip() def tokens_to_ids(self, tokens): + """ + Convert a list of tokens to a list of IDs. + """ if isinstance(tokens, str): tokens = [tokens] ids = [] @@ -75,6 +134,9 @@ def tokens_to_ids(self, tokens): return ids def ids_to_tokens(self, ids): + """ + Convert a list of IDs to a list of tokens. + """ if isinstance(ids, int): ids = [ids] tokens = [] @@ -83,12 +145,18 @@ def ids_to_tokens(self, ids): return tokens def token_to_id(self, token): + """ + Convert a token to its corresponding ID. + """ if token in self.special_token_to_id: return self.special_token_to_id[token] else: return token def id_to_token(self, id): + """ + Convert an ID to its corresponding token. + """ if id < self.special_start: return id else: @@ -96,16 +164,28 @@ def id_to_token(self, id): @property def pad_id(self): - return 256 + """ + Get the padding ID. + """ + return self._pad_id @property def bos_id(self): - return 257 + """ + Get the beginning-of-sequence ID. + """ + return self._bos_id @property def eos_id(self): - return 258 + """ + Get the end-of-sequence ID. + """ + return self._eos_id @property def unk_id(self): + """ + Get the unknown ID. + """ return 259 # unused diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 6fc787266f5d..87d1970f208a 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -85,6 +85,18 @@ GPTConfig175B, GPTModel, HFAutoModelForCausalLM, + Hyena1bConfig, + Hyena7bARCLongContextConfig, + Hyena7bConfig, + Hyena40bARCLongContextConfig, + Hyena40bConfig, + HyenaConfig, + HyenaModel, + HyenaNV1bConfig, + HyenaNV7bConfig, + HyenaNV40bConfig, + HyenaNVTestConfig, + HyenaTestConfig, Llama2Config7B, Llama2Config13B, Llama2Config70B, @@ -159,6 +171,18 @@ "CustomRetrievalDataModule", "GPTModel", "GPTConfig", + "HyenaTestConfig", + "Hyena7bConfig", + "Hyena40bConfig", + "Hyena7bARCLongContextConfig", + "Hyena40bARCLongContextConfig", + "HyenaNVTestConfig", + "HyenaNV40bConfig", + "HyenaNV7bConfig", + "HyenaConfig", + "HyenaModel", + "Hyena1bConfig", + "HyenaNV1bConfig", "gpt_data_step", "gpt_forward_step", "T5Model", diff --git a/nemo/collections/llm/gpt/data/megatron/__init__.py b/nemo/collections/llm/gpt/data/megatron/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nemo/collections/llm/gpt/data/megatron/hyena/__init__.py b/nemo/collections/llm/gpt/data/megatron/hyena/__init__.py new file mode 100644 index 000000000000..b937377f5a92 --- /dev/null +++ b/nemo/collections/llm/gpt/data/megatron/hyena/__init__.py @@ -0,0 +1,2 @@ +from .config import parse_dataset_config # noqa: F401 +from .evo2_dataset import Evo2Dataset # noqa: F401 diff --git a/nemo/collections/llm/gpt/data/megatron/hyena/config.py b/nemo/collections/llm/gpt/data/megatron/hyena/config.py new file mode 100644 index 000000000000..84231d0fb558 --- /dev/null +++ b/nemo/collections/llm/gpt/data/megatron/hyena/config.py @@ -0,0 +1,175 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from pathlib import Path +from typing import Literal, Optional + +import yaml +from pydantic import BaseModel, model_validator + + +def infer_global_batch_size( + micro_batch_size: int, + num_nodes: int, + devices: int, + accumulate_grad_batches: int = 1, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + context_model_parallel_size: int = 1, +) -> int: + """Infers the global batch size based on the micro batch size, number of nodes, devices, accumulation of gradient + batches, and model parallel sizes. + + Args: + micro_batch_size (int): The micro batch size. + num_nodes (int): The number of nodes. + devices (int): The number of devices. + accumulate_grad_batches (int): The accumulation of gradient batches. Defaults to 1. + tensor_model_parallel_size (int): The tensor model parallel size. Defaults to 1. + pipeline_model_parallel_size (int): The pipeline model parallel size. Defaults to 1. + context_model_parallel_size (int): The context model parallel size. Defaults to 1. + + Returns: + int: The global batch size. + """ + if not all( + isinstance(arg, int) + for arg in [ + micro_batch_size, + num_nodes, + devices, + accumulate_grad_batches, + tensor_model_parallel_size, + pipeline_model_parallel_size, + context_model_parallel_size, + ] + ): + raise ValueError( + f"All arguments must be of type int, got {type(micro_batch_size)}, {type(num_nodes)}, {type(devices)}, " + f"{type(accumulate_grad_batches)}, {type(tensor_model_parallel_size)}, " + f"{type(pipeline_model_parallel_size)}, and {type(context_model_parallel_size)}" + ) + if micro_batch_size <= 0: + raise ValueError(f"micro_batch_size must be greater than 0, got {micro_batch_size}") + if num_nodes <= 0: + raise ValueError(f"num_nodes must be greater than 0, got {num_nodes}") + if devices <= 0: + raise ValueError(f"devices must be greater than 0, got {devices}") + if accumulate_grad_batches <= 0: + raise ValueError(f"accumulate_grad_batches must be greater than 0, got {accumulate_grad_batches}") + if tensor_model_parallel_size <= 0: + raise ValueError(f"tensor_model_parallel_size must be greater than 0, got {tensor_model_parallel_size}") + if pipeline_model_parallel_size <= 0: + raise ValueError(f"pipeline_model_parallel_size must be greater than 0, got {pipeline_model_parallel_size}") + if context_model_parallel_size <= 0: + raise ValueError(f"context_model_parallel_size must be greater than 0, got {context_model_parallel_size}") + + world_size = num_nodes * devices + if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size * context_model_parallel_size) != 0: + raise ValueError( + f"world_size must be divisible by tensor_model_parallel_size * pipeline_model_parallel_size *" + f" context_model_parallel_size, got {world_size} and TP{tensor_model_parallel_size} * " + f"PP{pipeline_model_parallel_size} * CP{context_model_parallel_size}" + ) + + model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_model_parallel_size + data_parallel_size = world_size // model_parallel_size + global_batch_size = micro_batch_size * data_parallel_size * accumulate_grad_batches + return global_batch_size + + +class Evo2BlendedDatasetConfig(BaseModel): + """Configuration for blended dataset specifications. + + Validates and constructs dataset paths, weights and splits configuration. + Ensures dataset paths exist and are properly resolved relative to base data path. + + Attributes: + dataset_path: Base directory path for datasets. Used to resolve relative dataset prefixes. + dataset_prefix: Path prefix for dataset files. Can be absolute or relative to dataset_path. + dataset_weight: Weight factor for this dataset during blending (0-1). + dataset_split: Dataset partition - 'train', 'validation' or 'test'. + + Raises: + ValueError: If dataset path doesn't exist or prefix can't be resolved. + """ + + dataset_path: str | None = None + dataset_prefix: str + dataset_weight: float + dataset_split: Literal["train", "validation", "test"] + + @model_validator(mode="before") + @classmethod + def validate_dataset_prefix(cls, values: dict) -> dict: + """Ensure dataset_prefix paths exist and are properly resolved or are relative to base dataset_path if + provided. + + Args: + values (dict): Dictionary containing dataset_path and dataset_prefix. + + Returns: + dict: Dictionary containing validated dataset_path and dataset_prefix. + """ + dataset_path = Path(values.get("dataset_path")) if values.get("dataset_path") else None + prefix = Path(values.get("dataset_prefix")) + + if not prefix.is_absolute(): + if dataset_path: + prefix = dataset_path / prefix + else: + prefix = Path(prefix).resolve() + parent = prefix.parent + stem = prefix.stem + if not parent.exists(): + raise ValueError(f"dataset_prefix parent path does not exist: {parent}") + matching_files = list(parent.glob(f"{stem}.*")) + if not matching_files: + raise ValueError(f"dataset_prefix file does not exist: {prefix}") + values["dataset_prefix"] = str(prefix) + return values + + +def parse_dataset_config(dataset_config_path: str, dataset_path: Optional[str] = None): + """Parse the blended training datasplit configuration and renormalize data split weights for training Hyena. + + Args: + dataset_config_path (str): Path to the dataset configuration YAML file. + dataset_path (str): Path to the dataset directory. Defaults to None. + + Returns: + defaultdict: A dictionary where keys are dataset splits and values are lists containing the normalized weight + and dataset prefix for each split. + """ + blended_dataset_config = defaultdict(list) + weight_sums = defaultdict(float) + with open(dataset_config_path, "r") as config_file: + dataset_config_batch = yaml.safe_load(config_file) + for dataset_config in dataset_config_batch: + # Validate. + config_model = Evo2BlendedDatasetConfig(dataset_path=dataset_path, **dataset_config) + # Integrate the weights for renormalization. + weight_sums[config_model.dataset_split] += abs(config_model.dataset_weight) + for dataset_config in dataset_config_batch: + # Validate. + config_model = Evo2BlendedDatasetConfig(dataset_path=dataset_path, **dataset_config) + # Add indexed dataset to split and associate with blended training weight. + blended_dataset_config[config_model.dataset_split].extend( + [config_model.dataset_weight / weight_sums[config_model.dataset_split], config_model.dataset_prefix] + ) + return blended_dataset_config diff --git a/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py new file mode 100644 index 000000000000..4bb4e4e9dc81 --- /dev/null +++ b/nemo/collections/llm/gpt/data/megatron/hyena/evo2_dataset.py @@ -0,0 +1,239 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import ClassVar, Dict, Optional + +import torch +from megatron.core.datasets.gpt_dataset import GPTDataset +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils import make_upper_case + + +class Evo2Dataset(GPTDataset): + """Dataset for training Evo2.""" + + CONTROL_TAGS: ClassVar[list[int]] = [64, 35] # '@' tag for splice splits/windows, '#' for contig splits + TAG_BOUNDS = 124 # start and end delim: '|' + TAG_CHARS: ClassVar[set[int]] = {95, 59, 32} # chars only found in control tags: _, ;, space + DEFAULT_EOD = 0 + TO_UPPER_TOKENS: bool = True # If set, do an in-place transform to make all tokens capital letters + RESET_PAD_EOD_MASK: bool = True # If set, unset the mask for [pad] and [eod] tokens (matches Evo2 paper). + + def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]: + """Get data at the specified index.""" + # 1. Call the default gpt dataset object + databatch: dict = super().__getitem__(idx) + loss_mask = databatch.get("loss_mask", None) + if self.RESET_PAD_EOD_MASK and loss_mask is not None: + # Reset the mask for 'pad', '[eod]', '[pad token]', which will lower the loss, but matches Evo2 pub. + loss_mask = torch.ones_like(loss_mask) + labels = databatch.get("labels", None) + if labels is None or loss_mask is None: + # No next-token labels or loss to mask. + return databatch + + # Mask special label tags in loss. + control_mask = torch.isin(labels, torch.tensor(self.CONTROL_TAGS, device=labels.device)) + loss_mask[control_mask] = 0 + phylotag_mask = self.mask_phylogenetic_tags( + labels, + self.TAG_BOUNDS, + self.TAG_CHARS, + self.config.tokenizer.eod if self.config.tokenizer is not None else self.DEFAULT_EOD, + ) + databatch["loss_mask"] = loss_mask * phylotag_mask + if self.TO_UPPER_TOKENS: + # When making tokens uppercase, make sure this is done after the mask_phylogenetic_tags function which + # relies in part on the original case of the tag tokens. + databatch["tokens"], _ = make_upper_case(databatch["tokens"]) + return databatch + + @staticmethod + def mask_phylogenetic_tags( + tokenized_sequence: torch.Tensor, + terminal_tag_char: int, # e.g. ASCII for '|' + other_tag_chars: set[int], # e.g. {95, 59, 32} for '_', ';', space + eod_token_id: int, # e.g. 0 + ) -> torch.Tensor: + """ + Creates a binary mask for sequences containing phylogenetic tags and DNA. + The rules are as follows (applied per contiguous sub‐sequence between EOD tokens): + + - Any token equal to the terminal_tag_char (the pipe, '|') is masked. + - For the region *before* the first pipe (the “prefix”): + * If the first token is in taxonomy_prefixes (d, p, c, o, f, g, s), + or if the prefix is exactly one lowercase letter, + or if any token in the prefix is one of other_tag_chars, + or if not every token is a valid DNA base, + then mask the entire prefix. + - For the region between pipes: + * If any token is in other_tag_chars or not all tokens are valid DNA, mask that region. + - For the region *after* the last pipe (the “suffix”): + * If the first token is the letter 'd' (ASCII 100) or if the region contains + any other tag characters or any EOD tokens or non‐DNA, mask the suffix. + + Finally, any token equal to eod_token_id is forced to remain unmasked. + (EOD tokens “break” a sequence so that tags never span across them.) + + Args: + tokenized_sequence (torch.Tensor): shape (seq_len,) or (batch_size, seq_len) + containing ASCII values. + terminal_tag_char (int): ASCII value for the pipe character. + other_tag_chars (set[int]): Set of ASCII values that appear only in tags. + eod_token_id (int): The token ID for EOD. + + Notes: + - The tag token is constructed as follows: So note that one way to know you have a tag is if you look + at the first token after the pipe and it is a 'd' character. Make sure implementation handles this. + ``` + return ( + "|d__{};p__{};c__{};o__{};f__{};g__{};s__{}|".format( + lineage.domain if random.random() >= dropout else None, + lineage.phylum if random.random() >= dropout else None, + lineage.clazz if random.random() >= dropout else None, + lineage.order if random.random() >= dropout else None, + lineage.family if random.random() >= dropout else None, + lineage.genus if random.random() >= dropout else None, + lineage.species if random.random() >= dropout else None, + ) + if lineage is not None + else None + ) + ``` + Returns: + torch.Tensor: A mask of the same shape as input where 1 = keep (DNA) and 0 = mask (tag). + """ + device = tokenized_sequence.device + dtype = tokenized_sequence.dtype + # Handle empty or single-token sequences. + if tokenized_sequence.numel() == 0: + return torch.ones(0, device=device, dtype=torch.int) + if tokenized_sequence.numel() == 1: + mask = torch.ones(1, device=device, dtype=torch.int) + token = tokenized_sequence.item() + if token == terminal_tag_char or token in other_tag_chars: + mask[0] = 0 + return mask + + # Ensure input is 2D (batch, seq_len) + batched = tokenized_sequence.ndim == 2 + if not batched: + tokenized_sequence = tokenized_sequence.unsqueeze(0) + batch_size, seq_len = tokenized_sequence.shape + first_taxonomy_prefix_token: int = 100 + + # Valid DNA tokens: A, C, G, T, N (both uppercase and lowercase) + valid_dna = {65, 67, 71, 84, 78, 97, 99, 103, 116, 110} + valid_dna_or_control_tensor = torch.tensor( + list(valid_dna | set(Evo2Dataset.CONTROL_TAGS)), device=device, dtype=dtype + ) + + # Initialize output mask to all ones. + out_mask = torch.ones_like(tokenized_sequence, dtype=torch.int) + + # Helper: Check if all tokens in a region are valid DNA. + def region_all_valid_or_control(region: torch.Tensor) -> bool: + if region.numel() == 0: + return True + # Using torch's all() over the token values. + return bool(torch.all(torch.isin(region, valid_dna_or_control_tensor)).cpu().item()) + + # Process one EOD-free segment using the O1 logic. + def process_segment(seg_seq: torch.Tensor) -> torch.Tensor: + seg_len = seg_seq.size(0) + seg_mask = torch.ones(seg_len, device=device, dtype=torch.int) + # Identify positions of terminal tag (pipe) + pipe_pos = (seg_seq == terminal_tag_char).nonzero(as_tuple=True)[0].cpu().tolist() + if len(pipe_pos) == 0: + # If no pipe exists and any token is a known tag char or not valid DNA, + # mask the entire segment. + if not region_all_valid_or_control(seg_seq): + seg_mask.zero_() + return seg_mask + + # Always mask the pipe positions. + seg_mask[pipe_pos] = 0 + + # Does tag start before the first pipe? This determines the starting state of our state machine. + first_pipe = pipe_pos[0] + if first_pipe >= 0 and first_pipe < seg_len - 1: + # fastest check is to look at the first token after the pipe, if it is a 'd' then the + # tag starts _after_ the pipe, otherwise it starts before. + next_tok = seg_seq[first_pipe + 1].item() + if next_tok == first_taxonomy_prefix_token: + # 'd' character for domain, which is the first part of a phylo tag. + # tag starts after the pipe. + is_tag = False + else: + # tag starts before the pipe. + is_tag = True + else: + # The sequence ends with a pipe, so just check everything before the pipe and return the seg mask + assert first_pipe == seg_len - 1 + # The sequence ends with a pipe, so just check everything before the pipe. + if region_all_valid_or_control(seg_seq[:first_pipe]): + return seg_mask # Pipe pos has already been masked + else: + seg_mask[:first_pipe] = 0 + return seg_mask + start = 0 + for end in pipe_pos: + if is_tag: + seg_mask[start:end] = 0 + else: + pass + is_tag = not is_tag # Flip the state machine. + start = end + 1 # position after the pipe + # Process the last segment after the last pipe. + if is_tag: + seg_mask[start:] = 0 + return seg_mask + + # Process each row by splitting on EOD tokens. + for b in range(batch_size): + row = tokenized_sequence[b] + # Get indices of EOD tokens. + eod_positions = (row == eod_token_id).nonzero(as_tuple=True)[0].cpu().tolist() + start_idx = 0 + for pos in eod_positions: + if pos > start_idx: + seg = row[start_idx:pos] + seg_mask = process_segment(seg) + out_mask[b, start_idx:pos] = seg_mask + # Leave the EOD token itself unmasked. + start_idx = pos + 1 + # Process any remaining tokens after the last EOD. + if start_idx < seq_len: + seg = row[start_idx:] + seg_mask = process_segment(seg) + out_mask[b, start_idx:] = seg_mask + + # Just to make sure we do not allow any non-DNA tokens to be unmasked, even if something went wrong with our + # mask logic. + out_mask[~torch.isin(tokenized_sequence, valid_dna_or_control_tensor)] = 0 + # Finally, force every EOD token to be unmasked. User decides outside of this function if they want EOD mask. + out_mask[tokenized_sequence == eod_token_id] = 1 + + if not batched: + out_mask = out_mask.squeeze(0) + return out_mask + + +class Evo2DatasetPadEodLossMask(Evo2Dataset): + """Dataset for training Evo2 with pad and eod loss mask (more standard approach than the Evo2 paper).""" + + TO_UPPER_TOKENS: bool = True + RESET_PAD_EOD_MASK: bool = False diff --git a/nemo/collections/llm/gpt/data/pre_training.py b/nemo/collections/llm/gpt/data/pre_training.py index 34075a569500..063e2812cddf 100644 --- a/nemo/collections/llm/gpt/data/pre_training.py +++ b/nemo/collections/llm/gpt/data/pre_training.py @@ -16,10 +16,12 @@ import os import warnings from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union import lightning.pytorch as pl from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from megatron.core.datasets.gpt_dataset import GPTDataset +from megatron.core.datasets.megatron_dataset import MegatronDataset from torch.utils import data from nemo.lightning.data import WrappedDataLoader @@ -47,6 +49,9 @@ def is_number_tryexcept(s): def is_zipped_list(paths): + """ + Check if the paths are zipped. + """ # ["30", "path/to/dataset_1_prefix", "70", "path/to/dataset_2_prefix"] even = paths[::2] if len(even) == 0: @@ -58,6 +63,9 @@ def is_zipped_list(paths): def validate_dataset_asset_accessibility(paths): + """ + Validate the accessibility of the dataset assets. + """ if paths is None: raise ValueError("Expected path to have a value.") @@ -73,7 +81,7 @@ def validate_dataset_asset_accessibility(paths): validate_dataset_asset_accessibility(p) return - if not isinstance(paths, str) and not isisntance(paths, Path): + if not isinstance(paths, str) and not isinstance(paths, Path): raise ValueError("Expected path to be of string or Path type.") path = Path(paths) @@ -133,9 +141,13 @@ class PreTrainingDataModule(pl.LightningDataModule, IOMixin): to allocate to train, validation, and test sets, respectively. Unused if ``paths`` is a dict. index_mapping_dir (Optional[str]): Path to a directory to write index mapping files. num_dataset_builder_threads (int): The number of threads to use for dataset building. - num_train_samples (Optional[int]): The number of samples to use for training, defaults to total train steps times global batch size. - num_val_samples (Optional[int]): The number of samples to use for validation, defaults to total validation steps times global batch size. - num_test_samples (Optional[int]): The number of samples to use for testing, defaults to total test steps times global batch size. + num_train_samples (Optional[int]): The number of samples to use for training, defaults to total + train steps times global batch size. + num_val_samples (Optional[int]): The number of samples to use for validation, defaults to total + validation steps times global batch size. + num_test_samples (Optional[int]): The number of samples to use for testing, defaults to total + test steps times global batch size. + dataset_cls (Optional[Type[MegatronDataset]]): The dataset class to use for the data module. """ def __init__( @@ -160,6 +172,7 @@ def __init__( num_train_samples: Optional[int] = None, num_val_samples: Optional[int] = None, num_test_samples: Optional[int] = None, + dataset_cls: Type[MegatronDataset] = GPTDataset, ) -> None: super().__init__() if not isinstance(paths, (list, tuple, dict)): @@ -167,6 +180,8 @@ def __init__( from megatron.core.datasets.utils import get_blend_from_list + self.dataset_cls = dataset_cls + validate_dataset_asset_accessibility(paths) build_kwargs = {} @@ -225,8 +240,10 @@ def build( trainer_limit_val_batches: Union[int, float], trainer_limit_test_batches: Union[int, float], ): + """ + Build the datasets. + """ from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder - from megatron.core.datasets.gpt_dataset import GPTDataset train_iters = trainer_max_steps assert train_iters > 0, f"max_steps {train_iters} should be greater than 0" @@ -275,13 +292,16 @@ def build( train_valid_test_num_samples = [num_train_samples, num_val_samples, num_test_samples] self._train_ds, self._validation_ds, self._test_ds = BlendedMegatronDatasetBuilder( - GPTDataset, + self.dataset_cls, train_valid_test_num_samples, is_built_on_rank=lambda: True, config=self.gpt_dataset_config, ).build() def setup(self, stage: str = "") -> None: + """ + Setup the data module. + """ assert ( hasattr(self, "trainer") and self.trainer is not None ), "Setup should be completed when trainer and config are attached." @@ -311,12 +331,21 @@ def setup(self, stage: str = "") -> None: # ).build() def train_dataloader(self) -> TRAIN_DATALOADERS: + """ + Get the train dataloader. + """ return self._create_dataloader(self._train_ds, mode="train") def val_dataloader(self) -> EVAL_DATALOADERS: + """ + Get the validation dataloader. + """ return self._create_dataloader(self._validation_ds, mode="validation") def test_dataloader(self) -> EVAL_DATALOADERS: + """ + Get the test dataloader. + """ return self._create_dataloader(self._test_ds, mode="test") def _create_dataloader(self, dataset, mode, **kwargs) -> WrappedDataLoader: @@ -335,6 +364,9 @@ def _create_dataloader(self, dataset, mode, **kwargs) -> WrappedDataLoader: @property def gpt_dataset_config(self) -> "GPTDatasetConfig": + """ + Get the GPT dataset configuration. + """ from megatron.core.datasets.gpt_dataset import GPTDatasetConfig return GPTDatasetConfig( @@ -385,16 +417,21 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.data_sampler.if_first_step = 1 def reconfigure_limit_batches(self): + """ + Reconfigure trainer.limit_train_batches and trainer.limit_val_batches in terms of num of microbatches. + """ # Override limit_train_batches in terms of num of microbatches self._reconfigure_limit_batches(self.trainer.limit_train_batches, self._train_ds, "train") - # Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step + # Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting + # in between a step self._reconfigure_limit_batches(self.trainer.limit_val_batches, self._validation_ds, "val") def _reconfigure_limit_batches(self, limit_batches, dataloader, mode): """ Reconfigure trainer.limit_val_batches for pretraining """ - # Override limit_batches in terms of num microbatches and so there are limit_batches//num_micro_batches num of global batches + # Override limit_batches in terms of num microbatches and so there are limit_batches//num_micro_batches + # num of global batches try: from megatron.core.num_microbatches_calculator import get_num_microbatches @@ -418,7 +455,7 @@ def _reconfigure_limit_batches(self, limit_batches, dataloader, mode): limit_micro_batches = int(dl_len_in_micro_batches * limit_batches) if limit_micro_batches == 0 and limit_batches > 0.0: min_percentage = 1.0 / len(dataloader) - raise MisconfigurationException( + raise ValueError( f"You requested to check {limit_batches} of the val_dataloader but" f" {limit_batches} * {len(dataloader)} < 1. Please increase the" f" `limit_val_batches` argument. Try at least" diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index 8d66f6f031af..b50bb991767e 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -47,6 +47,20 @@ Gemma2Model, ) from nemo.collections.llm.gpt.model.hf_auto_model_for_causal_lm import HFAutoModelForCausalLM +from nemo.collections.llm.gpt.model.hyena import ( + Hyena1bConfig, + Hyena7bARCLongContextConfig, + Hyena7bConfig, + Hyena40bARCLongContextConfig, + Hyena40bConfig, + HyenaConfig, + HyenaModel, + HyenaNV1bConfig, + HyenaNV7bConfig, + HyenaNV40bConfig, + HyenaNVTestConfig, + HyenaTestConfig, +) from nemo.collections.llm.gpt.model.llama import ( CodeLlamaConfig7B, CodeLlamaConfig13B, @@ -208,4 +222,16 @@ "transformer_engine_full_layer_spec", "local_layer_spec", "HFAutoModelForCausalLM", + "HyenaTestConfig", + "Hyena1bConfig", + "HyenaNV1bConfig", + "Hyena7bConfig", + "Hyena40bConfig", + "Hyena7bARCLongContextConfig", + "Hyena40bARCLongContextConfig", + "HyenaNVTestConfig", + "HyenaNV40bConfig", + "HyenaNV7bConfig", + "HyenaConfig", + "HyenaModel", ] diff --git a/nemo/collections/llm/gpt/model/hyena.py b/nemo/collections/llm/gpt/model/hyena.py new file mode 100644 index 000000000000..fc72f28d05b5 --- /dev/null +++ b/nemo/collections/llm/gpt/model/hyena.py @@ -0,0 +1,823 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Literal, Optional, Type + +import torch +from megatron.core import parallel_state +from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig +from megatron.core.transformer.enums import AttnBackend +from megatron.core.transformer.transformer_config import TransformerConfig + +from nemo.collections.llm.gpt.model.base import GPTModel, gpt_data_step +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_layer_specs import hyena_stack_spec, hyena_stack_spec_no_te +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_model import HyenaModel as MCoreHyenaModel +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils import hyena_no_weight_decay_cond +from nemo.lightning import get_vocab_size, io, teardown +from nemo.lightning.base import NEMO_MODELS_CACHE +from nemo.lightning.io.state import TransformFns +from nemo.utils import logging + + +class HyenaModel(GPTModel): + """ + This is a wrapper around the MCoreHyenaModel to allow for inference. Our model follows the same API + as the GPTModel, but the megatron model class is different so we need to handle the inference wrapper + slightly differently. + """ + + def get_inference_wrapper(self, params_dtype, inference_batch_times_seqlen_threshold) -> torch.Tensor: + """ + Gets the inference wrapper for the Hyena model. + + Args: + params_dtype: The data type for model parameters + inference_batch_times_seqlen_threshold: Threshold for batch size * sequence length during inference + + Returns: + GPTInferenceWrapper: The inference wrapper for the model + + Raises: + ValueError: If MCoreHyenaModel instance not found or vocab size cannot be determined + """ + # This is to get the MCore model required in GPTInferenceWrapper. + mcore_model = self.module + while mcore_model: + if type(mcore_model) is MCoreHyenaModel: + break + mcore_model = getattr(mcore_model, "module", None) + if mcore_model is None or type(mcore_model) is not MCoreHyenaModel: + raise ValueError("Exact MCoreHyenaModel instance not found in the model structure.") + + vocab_size = None + if self.tokenizer is not None: + vocab_size = self.tokenizer.vocab_size + elif hasattr(self.config, 'vocab_size'): + vocab_size = self.config.vocab_size + else: + raise ValueError( + 'Unable to find vocab size.' + ' Either pass in a tokenizer with vocab size, or set vocab size in the model config' + ) + + inference_wrapper_config = InferenceWrapperConfig( + hidden_size=mcore_model.config.hidden_size, + params_dtype=params_dtype, + inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold, + padded_vocab_size=vocab_size, + ) + + model_inference_wrapper = GPTInferenceWrapper(mcore_model, inference_wrapper_config) + return model_inference_wrapper + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + decoder_input: Optional[torch.Tensor] = None, + loss_mask: Optional[torch.Tensor] = None, + inference_params=None, + packed_seq_params=None, + ) -> torch.Tensor: + """ + Forward pass of the Hyena model. + + Args: + input_ids: Input token IDs + position_ids: Position IDs for input tokens + attention_mask: Optional attention mask + labels: Optional labels for loss computation + decoder_input: Optional decoder input + loss_mask: Optional loss mask + inference_params: Optional inference parameters + packed_seq_params: Optional parameters for packed sequences + + Returns: + torch.Tensor: Output tensor from the model + """ + extra_kwargs = {'packed_seq_params': packed_seq_params} if packed_seq_params is not None else {} + output_tensor = self.module( + input_ids, + position_ids, + attention_mask, + decoder_input=decoder_input, + labels=labels, + inference_params=inference_params, + loss_mask=loss_mask, + **extra_kwargs, + ) + return output_tensor + + +def hyena_forward_step(model, batch) -> torch.Tensor: + """ + Performs a forward step for the Hyena model. + + Args: + model: The Hyena model + batch: Dictionary containing input batch data with keys: + - tokens: Input token IDs + - position_ids: Position IDs + - labels: Labels for loss computation + - loss_mask: Mask for loss computation + + Returns: + torch.Tensor: Output from the model forward pass + """ + forward_args = { + "input_ids": batch["tokens"], + "position_ids": batch["position_ids"], + "labels": batch["labels"], + "loss_mask": batch["loss_mask"], + } + forward_args["attention_mask"] = None + return model(**forward_args) + + +@dataclass +class HyenaConfig(TransformerConfig, io.IOMixin): + """ + Configuration dataclass for Hyena. + + For adjusting ROPE when doing context extension, set seq_len_interpolation_factor relative to 8192. + For example, if your context length is 512k, then set the factor to 512k / 8k = 64. + """ + + # From megatron.core.models.hyena.hyena_model.HyenaModel + fp16_lm_cross_entropy: bool = False + parallel_output: bool = True + params_dtype: torch.dtype = torch.bfloat16 + fp16: bool = False + bf16: bool = True + num_layers: int = 2 + num_attention_heads: int = 8 + num_groups_hyena: int = None + num_groups_hyena_medium: int = None + num_groups_hyena_short: int = None + hybrid_attention_ratio: float = 0.0 + hybrid_mlp_ratio: float = 0.0 + hybrid_override_pattern: str = None + post_process: bool = True + pre_process: bool = True + seq_length: int = 2048 + position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'rope' + rotary_percent: float = 1.0 + rotary_base: int = 10000 + seq_len_interpolation_factor: Optional[float] = None + apply_rope_fusion: bool = True + make_vocab_size_divisible_by: int = 128 + gated_linear_unit: bool = False + fp32_residual_connection: bool = True + normalization: str = 'RMSNorm' + add_bias_linear: bool = False + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + layernorm_epsilon: float = 1e-6 + attention_backend: AttnBackend = AttnBackend.flash + # TODO: Move this to better places? + get_attention_mask_from_fusion: bool = False + recompute_granularity: str = 'full' + recompute_method: str = 'uniform' + recompute_num_layers: int = 4 + forward_step_fn: Callable = hyena_forward_step + data_step_fn: Callable = gpt_data_step + tokenizer_model_path: str = None + hyena_init_method: str = None + hyena_output_layer_init_method: str = None + hyena_filter_no_wd: bool = True + remove_activation_post_first_layer: bool = True + add_attn_proj_bias: bool = True + cross_entropy_loss_fusion: bool = False # Faster but lets default to False for more precision + tp_comm_overlap: bool = False + bias_activation_fusion: bool = True + bias_dropout_add_fusion: bool = True + add_bias_output: bool = False + use_te: bool = True + to_upper: str = "normalized_weighted" # choose between "weighted" and "normalized_weighted" + use_short_conv_bias: bool = False + + def __post_init__(self): + """ + Post-initialization hook that sets up weight decay conditions. + """ + super().__post_init__() + self.hyena_no_weight_decay_cond_fn = hyena_no_weight_decay_cond if self.hyena_filter_no_wd else None + + def configure_model(self, tokenizer) -> "MCoreHyenaModel": + """ + Configures and returns a Hyena model instance based on the config settings. + + Args: + tokenizer: Tokenizer to use for the model + + Returns: + MCoreHyenaModel: Configured Hyena model instance + """ + self.bias_activation_fusion = False if self.remove_activation_post_first_layer else self.bias_activation_fusion + + model = MCoreHyenaModel( + self, + hyena_stack_spec=hyena_stack_spec if self.use_te else hyena_stack_spec_no_te, + vocab_size=get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by), + max_sequence_length=self.seq_length, + num_groups_hyena=self.num_groups_hyena, + num_groups_hyena_medium=self.num_groups_hyena_medium, + num_groups_hyena_short=self.num_groups_hyena_short, + hybrid_override_pattern=self.hybrid_override_pattern, + position_embedding_type=self.position_embedding_type, + rotary_percent=self.rotary_percent, + rotary_base=self.rotary_base, + seq_len_interpolation_factor=self.seq_len_interpolation_factor, + pre_process=parallel_state.is_pipeline_first_stage(), + post_process=parallel_state.is_pipeline_last_stage(), + share_embeddings_and_output_weights=True, + hyena_init_method=self.hyena_init_method, + hyena_output_layer_init_method=self.hyena_output_layer_init_method, + remove_activation_post_first_layer=self.remove_activation_post_first_layer, + add_attn_proj_bias=self.add_attn_proj_bias, + ) + return model + + +@dataclass +class HyenaTestConfig(HyenaConfig): + """Configuration for testing Hyena models.""" + + hybrid_override_pattern: str = "SDH*" + num_layers: int = 4 + seq_length: int = 8192 + hidden_size: int = 4096 + num_groups_hyena: int = 4096 + num_groups_hyena_medium: int = 256 + num_groups_hyena_short: int = 256 + make_vocab_size_divisible_by: int = 8 + tokenizer_library: str = 'byte-level' + mapping_type: str = "base" + ffn_hidden_size: int = 11008 + gated_linear_unit: bool = True + num_attention_heads: int = 32 + use_cpu_initialization: bool = False + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + params_dtype: torch.dtype = torch.bfloat16 + normalization: str = "RMSNorm" + add_qkv_bias: bool = False + add_bias_linear: bool = False + layernorm_epsilon: float = 1e-6 + recompute_granularity: str = 'full' + recompute_method: str = 'uniform' + recompute_num_layers: int = 2 + hyena_init_method: str = 'small_init' + hyena_output_layer_init_method: str = 'wang_init' + hyena_filter_no_wd: bool = True + use_short_conv_bias: bool = False + + +@dataclass +class HyenaNVTestConfig(HyenaTestConfig): + """ + Several unintentional design choices were made to the original Arc implementation that are required to use the + original Arc checkpoints, but may result in less stable model training. If you are training from scratch, + these are the recommended configs. + """ + + remove_activation_post_first_layer: bool = False + add_attn_proj_bias: bool = False + use_short_conv_bias: bool = True + + +@dataclass +class Hyena1bConfig(HyenaConfig): + """Config matching the 1b 8k context Evo2 model""" + + hybrid_override_pattern: str = "SDH*SDHSDH*SDHSDH*SDHSDH*" + num_layers: int = 25 + seq_length: int = 8192 + hidden_size: int = 1920 + num_groups_hyena: int = 1920 + num_groups_hyena_medium: int = 128 + num_groups_hyena_short: int = 128 + make_vocab_size_divisible_by: int = 8 + tokenizer_library: str = 'byte-level' + mapping_type: str = "base" + ffn_hidden_size: int = 5120 + gated_linear_unit: bool = True + num_attention_heads: int = 15 + use_cpu_initialization: bool = False + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + params_dtype: torch.dtype = torch.bfloat16 + normalization: str = "RMSNorm" + add_qkv_bias: bool = False + add_bias_linear: bool = False + layernorm_epsilon: float = 1e-6 + recompute_granularity: str = 'full' + recompute_method: str = 'uniform' + recompute_num_layers: int = 4 + hyena_init_method: str = 'small_init' + hyena_output_layer_init_method: str = 'wang_init' + hyena_filter_no_wd: bool = True + + +@dataclass +class HyenaNV1bConfig(Hyena1bConfig): + """ + Several unintentional design choices were made to the original Arc implementation that are required to use the + original Arc checkpoints, but may result in less stable model training. If you are training from scratch, + these are the recommended configs. + """ + + remove_activation_post_first_layer: bool = False + add_attn_proj_bias: bool = False + use_short_conv_bias: bool = True + + +@dataclass +class Hyena7bConfig(HyenaConfig): + """Config matching the 7b 8k context Evo2 model""" + + hybrid_override_pattern: str = "SDH*SDHSDH*SDHSDH*SDHSDH*SDHSDH*" + num_layers: int = 32 + seq_length: int = 8192 + hidden_size: int = 4096 + num_groups_hyena: int = 4096 + num_groups_hyena_medium: int = 256 + num_groups_hyena_short: int = 256 + make_vocab_size_divisible_by: int = 8 + tokenizer_library: str = 'byte-level' + mapping_type: str = "base" + ffn_hidden_size: int = 11008 + gated_linear_unit: bool = True + num_attention_heads: int = 32 + use_cpu_initialization: bool = False + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + params_dtype: torch.dtype = torch.bfloat16 + normalization: str = "RMSNorm" + add_qkv_bias: bool = False + add_bias_linear: bool = False + layernorm_epsilon: float = 1e-6 + recompute_granularity: str = 'full' + recompute_method: str = 'uniform' + recompute_num_layers: int = 4 + hyena_init_method: str = 'small_init' + hyena_output_layer_init_method: str = 'wang_init' + hyena_filter_no_wd: bool = True + + +@dataclass +class HyenaNV7bConfig(Hyena7bConfig): + """ + Several unintentional design choices were made to the original Arc implementation that are required to use the + original Arc checkpoints, but may result in less stable model training. If you are training from scratch, + these are the recommended configs. + """ + + remove_activation_post_first_layer: bool = False + add_attn_proj_bias: bool = False + use_short_conv_bias: bool = True + + +@dataclass +class Hyena40bConfig(HyenaConfig): + """Config matching the 40b 8k context Evo2 model""" + + hybrid_override_pattern: str = "SDH*SDHSDH*SDHSDH*SDHSDH*SDHSDH*SDH*SDHSDH*SDHSDH*" + num_layers: int = 50 + seq_length: int = 8192 + hidden_size: int = 8192 + num_groups_hyena: int = 8192 + num_groups_hyena_medium: int = 512 + num_groups_hyena_short: int = 512 + make_vocab_size_divisible_by: int = 8 + tokenizer_library: str = 'byte-level' + mapping_type: str = "base" + ffn_hidden_size: int = 21888 + gated_linear_unit: bool = True + num_attention_heads: int = 64 + use_cpu_initialization: bool = False + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + params_dtype: torch.dtype = torch.bfloat16 + normalization: str = "RMSNorm" + add_qkv_bias: bool = False + add_bias_linear: bool = False + layernorm_epsilon: float = 1e-6 + recompute_granularity: str = 'full' + recompute_method: str = 'uniform' + recompute_num_layers: int = 2 + hyena_init_method: str = 'small_init' + hyena_output_layer_init_method: str = 'wang_init' + hyena_filter_no_wd: bool = True + + +@dataclass +class HyenaNV40bConfig(Hyena40bConfig): + """ + Several unintentional design choices were made to the original Arc implementation that are required to use the + original Arc checkpoints, but may result in less stable model training. If you are training from scratch, + these are the recommended configs. + """ + + remove_activation_post_first_layer: bool = False + add_attn_proj_bias: bool = False + use_short_conv_bias: bool = True + + +@dataclass +class Hyena7bARCLongContextConfig(Hyena7bConfig): + """The checkpoint from ARC requires padding to the FFN dim + due to constraintes from large TP size for training.""" + + ffn_hidden_size: int = 11264 + + +@dataclass +class Hyena40bARCLongContextConfig(Hyena40bConfig): + """The checkpoint from ARC requires padding to the FFN dim + due to constraintes from large TP size for training.""" + + ffn_hidden_size: int = 22528 + + +@io.model_importer(HyenaModel, "pytorch") +class PyTorchHyenaImporter(io.ModelConnector["HyenaModel", HyenaModel]): + """ + Importer class for converting PyTorch Hyena models to NeMo format. + """ + + def __new__(cls, path: str, model_config=None): + """ + Creates a new importer instance. + + Args: + path: Path to the PyTorch model + model_config: Optional model configuration + + Returns: + PyTorchHyenaImporter instance + """ + instance = super().__new__(cls, path) + instance.model_config = model_config + return instance + + def init(self) -> HyenaModel: + """ + Initializes a new HyenaModel instance. + + Returns: + HyenaModel: Initialized model + """ + return HyenaModel(self.config, tokenizer=self.tokenizer) + + def get_source_model(self): + """ + Returns the source model. + """ + return torch.load(str(self), map_location='cpu') + + def apply(self, output_path: Path, checkpoint_format: str = 'torch_dist') -> Path: + """ + Applies the model conversion from PyTorch to NeMo format. + + Args: + output_path: Path to save the converted model + checkpoint_format: Format for saving checkpoints + + Returns: + Path: Path to the saved NeMo model + """ + source = self.get_source_model() + + if 'model' in source: + source = source['model'] + + class ModelState: + """Wrapper around the source model state dictionary that also handles some weight transformations.""" + + def __init__(self, state_dict, num_layers, fp32_suffixes): + """Wrapper around the source model state dictionary that also handles some weight transformations. + + Args: + state_dict: original state dictionary from the source model + num_layers: number of layers in the source model + """ + self.num_layers = num_layers + state_dict = self.transform_source_dict(state_dict) + self._state_dict = state_dict + self.fp32_suffixes = fp32_suffixes + + def state_dict(self): + """Return the state dictionary.""" + return self._state_dict + + def to(self, dtype): + """Convert the state dictionary to the target dtype.""" + for k, v in self._state_dict.items(): + if "_extra" not in k: + if v.dtype != dtype: + logging.warning(f"Converting {k} from {v.dtype} (source model) to {dtype} (target model)") + k_suffix = k.split('.')[-1] + if k_suffix in self.fp32_suffixes: + _dtype = torch.float32 + else: + _dtype = dtype + self._state_dict[k] = v.to(_dtype) + + def adjust_medium_filter(self, updated_data): + """Adjust the medium filter.""" + from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig + + for k, v in updated_data.items(): + if "filter.h" in k or "filter.decay" in k: + updated_data[k] = v[:, : HyenaConfig().hyena_medium_conv_len] + return updated_data + + def transform_source_dict(self, source): + """Transform the source state dictionary, applying some challenging layer name re-mappings and + removing extra keys, as well as truncating a filter that didn't need to extend to the full + sequence length dim. + """ + import re + + layer_map = {i + 2: i for i in range(self.num_layers)} + layer_map[self.num_layers + 3] = self.num_layers + updated_data = {} + + for key in list(source['module'].keys()): + if "_extra" in key: + source['module'].pop(key) + else: + match = re.search(r'sequential\.(\d+)', key) + if match: + original_layer_num = int(match.group(1)) + if original_layer_num in layer_map: + # Create the updated key by replacing the layer number + new_key = re.sub(rf'\b{original_layer_num}\b', str(layer_map[original_layer_num]), key) + updated_data[new_key] = source['module'][key] + else: + # Keep the key unchanged if no mapping exists + updated_data[key] = source['module'][key] + else: + updated_data[key] = source['module'][key] + updated_data = self.adjust_medium_filter(updated_data) + return updated_data + + target = self.init() + trainer = self.nemo_setup(target, ckpt_async_save=False, save_ckpt_format=checkpoint_format) + target.to(self.config.params_dtype) + fp32_suffixes = {n.split('.')[-1] for n, p in target.named_parameters() if p.dtype == torch.float32} + source = ModelState(source, self.config.num_layers, fp32_suffixes) + source.to(self.config.params_dtype) + self.convert_state(source, target) + self.nemo_save(output_path, trainer) + + logging.info(f"Converted Hyena model to Nemo, model saved to {output_path}") + + teardown(trainer, target) + del trainer, target + + return output_path + + def convert_state(self, source, target): + """ + Converts the state dictionary from source format to target format. + + Args: + source: Source model state + target: Target model + + Returns: + Result of applying state transforms + """ + mapping = {} + mapping['sequential.0.word_embeddings.weight'] = 'embedding.word_embeddings.weight' + mapping[f'sequential.{len(self.config.hybrid_override_pattern)}.norm.weight'] = 'decoder.final_norm.weight' + te_enabled = self.config.use_te + for i, symbol in enumerate(self.config.hybrid_override_pattern): + if te_enabled: + mapping[f'sequential.{i}.pre_mlp_layernorm.weight'] = ( + f'decoder.layers.{i}.mlp.linear_fc1.layer_norm_weight' + ) + else: + mapping[f'sequential.{i}.pre_mlp_layernorm.weight'] = f'decoder.layers.{i}.pre_mlp_layernorm.weight' + mapping[f'sequential.{i}.mlp.w3.weight'] = f'decoder.layers.{i}.mlp.linear_fc2.weight' + + if symbol != '*': + if te_enabled: + mapping[f'sequential.{i}.input_layernorm.weight'] = ( + f'decoder.layers.{i}.mixer.dense_projection.layer_norm_weight' + ) + else: + mapping[f'sequential.{i}.input_layernorm.weight'] = f'decoder.layers.{i}.norm.weight' + + mapping[f'sequential.{i}.mixer.dense_projection.weight'] = ( + f'decoder.layers.{i}.mixer.dense_projection.weight' + ) + mapping[f'sequential.{i}.mixer.hyena_proj_conv.short_conv_weight'] = ( + f'decoder.layers.{i}.mixer.hyena_proj_conv.short_conv_weight' + ) + mapping[f'sequential.{i}.mixer.dense.weight'] = f'decoder.layers.{i}.mixer.dense.weight' + mapping[f'sequential.{i}.mixer.dense.bias'] = f'decoder.layers.{i}.mixer.dense.bias' + + if symbol == 'S': + mapping[f'sequential.{i}.mixer.mixer.short_conv.short_conv_weight'] = ( + f'decoder.layers.{i}.mixer.mixer.short_conv.short_conv_weight' + ) + + elif symbol == 'D': + mapping[f'sequential.{i}.mixer.mixer.conv_bias'] = f'decoder.layers.{i}.mixer.mixer.conv_bias' + mapping[f'sequential.{i}.mixer.mixer.filter.h'] = f'decoder.layers.{i}.mixer.mixer.filter.h' + mapping[f'sequential.{i}.mixer.mixer.filter.decay'] = ( + f'decoder.layers.{i}.mixer.mixer.filter.decay' + ) + + elif symbol == 'H': + mapping[f'sequential.{i}.mixer.mixer.conv_bias'] = f'decoder.layers.{i}.mixer.mixer.conv_bias' + mapping[f'sequential.{i}.mixer.mixer.filter.gamma'] = ( + f'decoder.layers.{i}.mixer.mixer.filter.gamma' + ) + mapping[f'sequential.{i}.mixer.mixer.filter.R'] = f'decoder.layers.{i}.mixer.mixer.filter.R' + mapping[f'sequential.{i}.mixer.mixer.filter.p'] = f'decoder.layers.{i}.mixer.mixer.filter.p' + + elif symbol == '*': + if te_enabled: + mapping[f'sequential.{i}.input_layernorm.weight'] = ( + f'decoder.layers.{i}.self_attention.linear_qkv.layer_norm_weight' + ) + else: + mapping[f'sequential.{i}.input_layernorm.weight'] = f'decoder.layers.{i}.input_layernorm.weight' + + mapping[f'sequential.{i}.mixer.dense_projection.weight'] = ( + f'decoder.layers.{i}.self_attention.linear_qkv.weight' + ) + mapping[f'sequential.{i}.mixer.dense.weight'] = f'decoder.layers.{i}.self_attention.linear_proj.weight' + mapping[f'sequential.{i}.mixer.dense.bias'] = f'decoder.layers.{i}.self_attention.linear_proj.bias' + else: + raise ValueError(f'Unknown symbol: {symbol}') + + return io.apply_transforms( + source, + target, + mapping=mapping, + transforms=[ + # Transforms that are more complicated than a simple mapping of an old key name to a new one: + io.state_transform( + source_key=("sequential.*.mlp.w1.weight", "sequential.*.mlp.w2.weight"), + target_key="decoder.layers.*.mlp.linear_fc1.weight", + fn=TransformFns.merge_fc1, + ) + ], + ) + + @property + def tokenizer(self): + """ + Gets the tokenizer for the model. + + Returns: + Tokenizer instance + """ + from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer + + tokenizer = get_nmt_tokenizer( + library=self.model_config.tokenizer_library, + ) + + return tokenizer + + @property + def config(self) -> HyenaConfig: + """ + Gets the model configuration. + + Returns: + HyenaConfig: Model configuration + """ + return self.model_config + + +@io.model_importer(HyenaModel, "hf") +class HuggingFaceSavannaHyenaImporter(PyTorchHyenaImporter): + """ + Importer class for converting HuggingFace Savanna Hyena models to NeMo format. + See: https://huggingface.co/arcinstitute/savanna_evo2_7b for an example of a savanna model that this can + import and convert to NeMo format. Any of the Arc models that start with "savanna_" should work. + """ + + def get_source_model(self): + """ + Returns the source model. + """ + import huggingface_hub.errors + from huggingface_hub import hf_hub_download + + if ":" in str(self): + repo_id, revision = str(self).split(":") + else: + repo_id = str(self) + revision = None + # See HF download logic here: + # https://github.com/ArcInstitute/evo2/blob/96ac9d9cd/evo2/models.py#L191-L231 + modelname = repo_id.split("/")[-1] + download_dir = str(NEMO_MODELS_CACHE / repo_id) + weights_filename = f"{modelname}.pt" + try: + weights_path = hf_hub_download( + repo_id=repo_id, local_dir=download_dir, revision=revision, filename=weights_filename + ) + except Exception: + # Try downloading multi-part + # If file is split, download and join parts + logging.warning(f"Single path download failed, try loading checkpoint shards for {modelname}") + # If file is split, get the first part's directory to use the same cache location + weights_path = os.path.join(download_dir, weights_filename) + if os.path.exists(weights_path): + logging.info(f"Found {weights_path}") + else: + # Download and join parts + parts = [] + part_num = 0 + while True: + try: + part_path = hf_hub_download( + repo_id=repo_id, + local_dir=download_dir, + revision=revision, + filename=f"{weights_filename}.part{part_num}", + ) + parts.append(part_path) + part_num += 1 + except huggingface_hub.errors.EntryNotFoundError: + break + + # Join in the same directory + with open(weights_path, 'wb') as outfile: + for part in parts: + with open(part, 'rb') as infile: + while True: + chunk = infile.read(8192 * 1024) + if not chunk: + break + outfile.write(chunk) + + # Cleaning up the parts + for part in parts: + try: + os.remove(part) + except OSError as e: + print(f"Error removing {part}: {e}") + print("Cleaned up shards, final checkpoint saved to", weights_path) + + return torch.load(weights_path, map_location='cpu', weights_only=False) + + +HYENA_MODEL_OPTIONS: dict[str, Type[HyenaConfig]] = { + "1b": Hyena1bConfig, + "1b_nv": HyenaNV1bConfig, + "7b": Hyena7bConfig, + "7b_arc_longcontext": Hyena7bARCLongContextConfig, + "7b_nv": HyenaNV7bConfig, + "40b": Hyena40bConfig, + "40b_arc_longcontext": Hyena40bARCLongContextConfig, + "40b_nv": HyenaNV40bConfig, + "test": HyenaTestConfig, + "test_nv": HyenaNVTestConfig, +} + + +__all__ = [ + "HyenaConfig", + "Hyena7bConfig", + "HyenaNV7bConfig", + "Hyena1bConfig", + "HyenaNV1bConfig", + "Hyena40bConfig", + "HyenaNV40bConfig", + "Hyena7bARCLongContextConfig", + "Hyena40bARCLongContextConfig", + "HyenaTestConfig", + "HyenaNVTestConfig", + "HYENA_MODEL_OPTIONS", +] diff --git a/nemo/collections/llm/gpt/model/megatron/__init__.py b/nemo/collections/llm/gpt/model/megatron/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/__init__.py b/nemo/collections/llm/gpt/model/megatron/hyena/__init__.py new file mode 100644 index 000000000000..b139b030733e --- /dev/null +++ b/nemo/collections/llm/gpt/model/megatron/hyena/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py new file mode 100644 index 000000000000..2d1a80d9f366 --- /dev/null +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py @@ -0,0 +1,362 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext +from dataclasses import dataclass +from typing import Union + +from megatron.core import parallel_state, tensor_parallel +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.utils import sharded_state_dict_default +from megatron.core.utils import make_viewless_tensor +from torch import Tensor, nn + +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_hybrid_layer_allocation import Symbols as LayerSymbols +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_hybrid_layer_allocation import allocate_layers + +try: + from megatron.core.extensions.transformer_engine import TEDelayedScaling, TENorm, te_checkpoint + + HAVE_TE = True + LayerNormImpl = TENorm + +except ImportError: + HAVE_TE = False + + try: + from apex.normalization import FusedLayerNorm + + LayerNormImpl = FusedLayerNorm + + except ImportError: + from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm + + LayerNormImpl = WrappedTorchLayerNorm + + +HYENA_LAYER_MAP = { + LayerSymbols.HYENA_SHORT: "hyena_short_conv", + LayerSymbols.HYENA_MEDIUM: "hyena_medium_conv", + LayerSymbols.HYENA: "hyena", +} + + +@dataclass +class HyenaStackSubmodules: + """ + A class for the module specs for the HyenaStack. + """ + + hyena_layer: Union[ModuleSpec, type] = IdentityOp + attention_layer: Union[ModuleSpec, type] = IdentityOp + + +class HyenaStack(MegatronModule): + """ + A class for the HyenaStack. + """ + + def __init__( + self, + transformer_config: TransformerConfig, + hyena_config: HyenaConfig, + hybrid_override_pattern, + max_sequence_length, + submodules: HyenaStackSubmodules, + pre_process: bool = True, + post_process: bool = True, + post_layer_norm: bool = False, + ) -> None: + + super().__init__(config=transformer_config) + self.transformer_config = transformer_config + self.hyena_config = hyena_config + self.submodules = submodules + self.hybrid_override_pattern = hybrid_override_pattern + self.pre_process = pre_process + self.post_process = post_process + self.post_layer_norm = post_layer_norm + + # Required for pipeline parallel schedules + self.input_tensor = None + + layer_type_list = allocate_layers(self.transformer_config.num_layers, self.hybrid_override_pattern) + + pp_layer_offset = 0 + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + pp_layer_offset, layer_type_list = self._select_layers_for_pipeline_parallel(layer_type_list) + + self.layers = nn.ModuleList() + for i, layer_type in enumerate(layer_type_list): + if layer_type in HYENA_LAYER_MAP: + layer = build_module( + submodules.hyena_layer, + self.transformer_config, + self.hyena_config, + operator_type=HYENA_LAYER_MAP.get(layer_type), + max_sequence_length=max_sequence_length, + layer_number=i + 1 + pp_layer_offset, + ) + elif layer_type == LayerSymbols.ATTENTION: + # Transformer layers apply their own pp_layer_offset + layer = build_module(submodules.attention_layer, config=self.transformer_config, layer_number=i + 1) + else: + assert True, "unexpected layer_type" + self.layers.append(layer) + + if self.post_process and self.post_layer_norm: + # Final layer norm before output. + self.final_norm = TENorm( + config=self.transformer_config, + hidden_size=self.transformer_config.hidden_size, + eps=self.transformer_config.layernorm_epsilon, + ) + # Required for activation recomputation + self.num_layers_per_pipeline_rank = len(self.layers) + + def set_input_tensor(self, input_tensor: Tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + + def _select_layers_for_pipeline_parallel(self, layer_type_list): + pipeline_rank = parallel_state.get_pipeline_model_parallel_rank() + num_layers_per_pipeline_rank = ( + self.transformer_config.num_layers // parallel_state.get_pipeline_model_parallel_world_size() + ) + + assert parallel_state.get_virtual_pipeline_model_parallel_world_size() is None, ( + "The Hyena hybrid model does not currently support " "virtual/interleaved pipeline parallelism" + ) + + offset = pipeline_rank * num_layers_per_pipeline_rank + selected_list = layer_type_list[offset : offset + num_layers_per_pipeline_rank] + + return offset, selected_list + + def _get_layer(self, layer_number: int): + return self.layers[layer_number] + + def _checkpointed_forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + rotary_pos_emb: Tensor, + ): + """Forward method with activation checkpointing.""" + + def custom(start: int, end: int): + def custom_forward(hidden_states, attention_mask, context, context_mask, rotary_pos_emb): + for index in range(start, end): + layer = self._get_layer(index) + hidden_states = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + inference_params=None, + ) + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] + return hidden_states + + return custom_forward + + def checkpoint_handler(forward_func): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + if self.config.fp8: + return te_checkpoint( + forward_func, + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + parallel_state.get_tensor_model_parallel_group(), + hidden_states, + attention_mask, + None, + None, + rotary_pos_emb, + ) + else: + return tensor_parallel.checkpoint( + forward_func, + self.config.distribute_saved_activations, + hidden_states, + attention_mask, + None, + None, + rotary_pos_emb, + ) + + if self.config.recompute_method == 'uniform': + # Uniformly divide the total number of Transformer layers and checkpoint + # the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + layer_idx = 0 + while layer_idx < self.num_layers_per_pipeline_rank: + hidden_states = checkpoint_handler(custom(layer_idx, layer_idx + self.config.recompute_num_layers)) + + layer_idx += self.config.recompute_num_layers + + elif self.config.recompute_method == 'block': + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + recompute_skip_num_layers = 0 + for layer_idx in range(self.num_layers_per_pipeline_rank): + # Skip recomputation when input grad computation is not needed. + # Need to have at least one input tensor with gradient computation + # for re-enterant autograd engine. + if ( + layer_idx >= recompute_skip_num_layers + and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers + ): + hidden_states = checkpoint_handler(custom(layer_idx, layer_idx + 1)) + else: + hidden_states = custom(layer_idx, layer_idx + 1)(hidden_states, attention_mask, rotary_pos_emb) + else: + raise ValueError("Invalid activation recompute method.") + + return hidden_states + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + inference_params=None, + rotary_pos_emb: Tensor = None, + ): + """Forward pass for the HyenaStack.""" + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + if self.transformer_config.fp8: + import transformer_engine # To keep out TE dependency when not training in fp8 + + if self.transformer_config.fp8 == "e4m3": + fp8_format = transformer_engine.common.recipe.Format.E4M3 + elif self.transformer_config.fp8 == "hybrid": + fp8_format = transformer_engine.common.recipe.Format.HYBRID + else: + raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") + + fp8_recipe = TEDelayedScaling( + config=self.transformer_config, + fp8_format=fp8_format, + override_linear_precision=(False, False, not self.transformer_config.fp8_wgrad), + ) + fp8_group = None + if parallel_state.model_parallel_is_initialized(): + fp8_group = parallel_state.get_amax_reduction_group( + with_context_parallel=False, tp_only_amax_red=self.transformer_config.tp_only_amax_red + ) + fp8_context = transformer_engine.pytorch.fp8_autocast( + enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group + ) + else: + fp8_context = nullcontext() + + with fp8_context, rng_context: + + # Forward pass. + if self.config.recompute_granularity == 'full' and self.training: + hidden_states = self._checkpointed_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + ) + else: + for layer in self.layers: + hidden_states = layer( + hidden_states, + attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + ) + + # The attention layer (currently a simplified transformer layer) + # outputs a tuple of (hidden_states, context). Context is intended + # for cross-attention, and is not needed in our model. + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] + + # Final layer norm. + if self.post_process and self.post_layer_norm: + hidden_states = self.final_norm(hidden_states) + return hidden_states + + def sharded_state_dict( + self, prefix: str = '', sharded_offsets: tuple = (), metadata: dict = None + ) -> ShardedStateDict: + """ + Returns a sharded state dictionary for the current object. + + This function constructs a sharded state dictionary by iterating over the layers + in the current object, computing the sharded state dictionary for each layer, + and combining the results into a single dictionary. + + Parameters: + prefix (str): The prefix to use for the state dictionary keys. + sharded_offsets (tuple): The sharded offsets to use for the state dictionary. + metadata (dict): Additional metadata to use when computing the sharded state dictionary. + + Returns: + dict: The sharded state dictionary for the current object. + """ + + sharded_state_dict = {} + layer_prefix = f'{prefix}layers.' + + for local_layer_idx, layer in enumerate(self.layers): + + global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1 + state_dict_prefix = f'{layer_prefix}{local_layer_idx}.' # module list index in HyenaBlock + + sharded_prefix = f'{layer_prefix}{global_layer_offset}.' + sharded_pp_offset = [] + + layer_sharded_state_dict = layer.sharded_state_dict(state_dict_prefix, sharded_pp_offset, metadata) + + replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix) + + sharded_state_dict.update(layer_sharded_state_dict) + + # Add modules other than self.layers + for name, module in self.named_children(): + if not module is self.layers: + sharded_state_dict.update( + sharded_state_dict_default(module, f'{prefix}{name}.', sharded_offsets, metadata) + ) + + return sharded_state_dict diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_config.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_config.py new file mode 100644 index 000000000000..b8e4710c08ed --- /dev/null +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_config.py @@ -0,0 +1,358 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + + +@dataclass +class HyenaConfig: + """Configuration object for Hyena model and operators""" + + tie_projection_weights: bool = False + """ + Tie projection weights between QKV for attn and hyena (will repeat output 3 times). + """ + # + to_upper: str = "normalized_weighted" + """ + "upper" + "weighted" + Whether to convert all text to uppercase. + """ + # + lowercase_loss_reweighting: float = 0.1 + # """ + # If to_upper == "weighted" + # Weight to apply to lowercase tokens in the loss function, 1.0 is no reweighting. + # """ + + use_flashfft: bool = False + """ + Use flashfftconv instead of torch fft kernel (requires installation of flashfftconv)for hyena + """ + + use_cgcg: bool = False + """ + Use cgcg (chunked gate-conv-gate) kernel for hyena + """ + + use_cgcg_short: bool = False + """ + Use cgcg (chunked gate-conv-gate) kernel for hyena short conv + """ + + use_cgcg_mlp: bool = False + """ + Use cgcg (chunked gate-conv-gate) kernel for hyena mlp + """ + + cgcg_dtype: str = "bfloat16" + """ + dtype to use within cgcg kernel + """ + # + # cgcg_fwd_autotune: bool = False + # """ + # Whether to autotune cgcg fwd kernel + # + # @jeromeku: Note autotuning fwd kernel is unstable, + # use pre-tuned config for now. + # """ + + cgcg_medium_fwd_kernel_config_chunk_size: int = 128 + """ + cgcg fwd medium conv kernel config chunk size + """ + cgcg_medium_fwd_kernel_config_block_d: int = 128 + """ + cgcg fwd medium conv kernel config block d tile size + """ + + cgcg_medium_fwd_kernel_config_threadblock_swizzle: str = "row" + """ + cgcg fwd medium conv kernel config threadblock swizzle type + """ + cgcg_medium_fwd_kernel_config_chunk_tiles_per_program: int = 3 + """ + cgcg fwd medium conv kernel config chunk tiles per program + """ + + cgcg_medium_fwd_kernel_config_num_warps: int = 4 + """ + cgcg fwd short conv kernel config num warps + """ + + cgcg_medium_fwd_kernel_config_num_stages: int = 3 + """ + cgcg fwd medium conv kernel config num mma pipeline stages + """ + + cgcg_short_fwd_kernel_config_chunk_size: int = 128 + """ + cgcg fwd short conv kernel config chunk size + """ + cgcg_short_fwd_kernel_config_block_d: int = 128 + """ + cgcg fwd short conv kernel config block d tile size + """ + + cgcg_short_fwd_kernel_config_threadblock_swizzle: str = "row" + """ + cgcg fwd short conv kernel config threadblock swizzle type + """ + cgcg_short_fwd_kernel_config_chunk_tiles_per_program: int = 1 + """ + cgcg fwd short conv kernel config chunk tiles per program + """ + + cgcg_short_fwd_kernel_config_num_warps: int = 4 + """ + cgcg fwd short conv kernel config num warps + """ + + cgcg_short_fwd_kernel_config_num_stages: int = 1 + """ + cgcg fwd short conv kernel config num mma pipeline stages + """ + + cgcg_bwd_autotune: bool = True + """ + Whether to autotune cgcg bwd kernel + """ + + cgcg_fused_bwd: bool = True + """ + Whether to use fused cgcg bwd kernel + """ + + cgcg_bwd_kernel_config_pre_conv_block_x: int = 128 + """ + cgcg bwd pre_conv kernel config block x tile size + """ + + cgcg_bwd_kernel_config_pre_conv_block_y: int = 128 + """ + cgcg bwd pre_conv kernel config block y tile size + """ + + cgcg_bwd_kernel_config_pre_conv_num_warps: int = 8 + """ + cgcg bwd pre_conv kernel config num warps + """ + + cgcg_bwd_kernel_config_post_conv_block_x: int = 32 + """ + cgcg bwd post conv kernel config block x tile size + """ + + cgcg_bwd_kernel_config_post_conv_block_y: int = 128 + """ + cgcg bwd post conv kernel config block y tile size + """ + + cgcg_bwd_kernel_config_post_conv_num_warps: int = 4 + """ + cgcg bwd post conv kernel config num warps + """ + + short_conv_L: int = 3 + """ + For Hyena models, length of the short convolution. + """ + + use_hyena_filter: bool = False + """ + Whether to use the Hyena filter. + """ + + normalize_hyena_filters: bool = False + + conv_proj_bias: bool = True # Maybe this should be false + """ + Use bias in the short conv1D, needed for model parallel for the short conv. + """ + + use_fast_heads: bool = False + """ + Use external fast heads in Hyena mixer (reduce BEFORE fftconv) + """ + + use_slow_heads: bool = False + """ + Use external outer-product heads in Hyena. + """ + + use_long_conv1d: bool = False + + num_groups_hyena: int = None + """ + Determines number of unique filters to have, for the hyena long filter. + """ + + num_groups_hyena_medium: int = None + """ + Determines number of unique filters to have, for the hyena medium filter. + """ + + num_groups_hyena_short: int = None + """ + Determines number of unique filters to have, for the hyena short filter. + """ + + num_groups_hyena_mlp: int = None # TODO: Possibly remove, only used if is_mlp is True + """ + Determines number of unique filters to have, for the hyena mlp (filter). + """ + + use_depthwise_short_conv_grouping: bool = True + """ + Whether to use depthwise convolution grouping for short conv and hyena mlp filters. + """ + + hyena_filter_cls: str = "implicit_modal" + """ + """ + + hyena_width_expansion: float = 1.0 + """ + Factor to expand the projections width within hyena layers. + """ + + hyena_medium_filter_cls: str = 'explicit_single_decay' + """ + For medium hyena filters specifically, None defaults ot same as hyena_filter_cls (long filters). + """ + + hyena_filter_r_max: float = 0.99 # TODO: Possibly remove, only used in ParallelComplexModalFilter + + hyena_filter_r_min: float = 0.5 # TODO: Possibly remove, only used in ParallelComplexModalFilter + + hyena_filter_emb_dim: int = 33 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter + + hyena_filter_fast_decay: float = 0.3 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter + + hyena_filter_slow_decay: float = 1.2 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter + + hyena_filter_order: int = 16 + + hyena_filter_num_inner_mlps: int = 2 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter + + hyena_filter_w: int = 14 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter + + hyena_filter_wd: float = 0.0 # TODO: Where to override WD value for filters? + + hyena_filter_omega_0: float = 1 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter + + hyena_pos_emb: str = "fourier_fixed" # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter + + explicit_filter_decay_preset: str = "weak" + + modal_residue_factors: int = 3 # TODO: Possibly remove, only used in ImplicitRealModelFilter + + modal_pole_factors: int = 3 # TODO: Possibly remove, only used in ImplicitRealModelFilter + + modal_gamma_min: float = 0.01 + + modal_gamma_max: float = 0.1 + + use_custom_hyena_short_kernel: bool = False + """ + Use a custom causal conv layer for the hyena short conv layer. + """ + + use_custom_hyena_mlp_kernel: bool = False # TODO: Possibly remove - only relevant if is_mlp is True + """ + Use a custom causal conv layer for the hyena short conv layer. + """ + + bidirectional: bool = False + """ + A bidirectional version of hyena fftconv + """ + + hyena_short_conv_len: int = 7 + """ + Length of the hyena short conv layer, if using + """ + + fast_conv_proj: bool = True + """ + Use a custom causal conv layer for the hyena projection convs. + """ + + hyena_medium_conv_len: int = 128 + """ + Length of the medium hyena filter. + """ + + fast_conv_mixer: bool = False + """ + Use a custom causal conv layer for the hyena short conv layer. + """ + + hyena_mlp_len: int = 7 # TODO: Possibly remove, only used if is_mlp is True + """ + Length of filter used inside the hyena mlp layer. Defaults to hyena_short_conv_len if not provided. + """ + + fast_hyena_mlp_conv: bool = False # TODO: Possibly remove, only used if is_mlp is True + """ + Use a custom causal conv layer for the hyena MLP layer. + """ + + hyena_mlp_expansion_factor: float = 1.0 # TODO: Possibly remove, only used if is_mlp is True + """ + Factor to expand the projections width within hyena MLP layers only. + """ + + hyena_mlp_pregate: bool = True # TODO: Possibly remove, only used if is_mlp is True + """ + Use a pre-gate in the hyena MLP layer. + """ + + hyena_mlp_postgate: bool = True # TODO: Possibly remove, only used if is_mlp is True + """ + Use a post-gate in the hyena MLP layer. + """ + + hyena_short_conv_pregate: bool = True + """ + Use a pre-gate in the hyena short conv layer. + """ + + hyena_short_conv_postgate: bool = True + """ + Use a post-gate in the hyena short conv layer. + """ + + proj_groups: int = 1 + + grouped_attention: bool = False + + # mlp_type: str = "regular" # TODO: In Savanna setting this to 'short_hyena' uses hyena for MLP (is_mlp == True) + # """ + # Types: + # regular: Megatron implementation + # llama: LLaMA MLP (SiLU-gated MLP) + # short_hyena + # identity + # """ + # + # make_gated_mlp_multiple_of: int = 16 # TODO: Use this or just have user calculate ffn_size themselves? + # """ + # Set the ff_dim to be a multiple of this value for llama mlp. Useful for sharding / using model parallel properly. + # """ diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py new file mode 100644 index 000000000000..2262dff00c43 --- /dev/null +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_hybrid_layer_allocation.py @@ -0,0 +1,112 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + + +if __name__ != "__main__": + from megatron.core.utils import log_single_rank +else: + from typing import Any + + def log_single_rank(logger: logging.Logger, *args: Any, rank: int = 0, **kwargs: Any): + """Log a message from the current rank.""" + print(*args[1:], **kwargs) + + +logger = logging.getLogger(__name__) + + +class Symbols: + """Symbols for the hybrid layer allocation.""" + + HYENA_SHORT = 'S' + HYENA_MEDIUM = 'D' + HYENA = 'H' + ATTENTION = '*' + VALID = {HYENA_SHORT, HYENA_MEDIUM, HYENA, ATTENTION} + + +def _allocate_override(total_layers_count: int, override_pattern: str) -> list: + layer_type_list = list(override_pattern) + override_pattern_length = len(layer_type_list) + if override_pattern_length != total_layers_count: + raise ValueError( + "The hybrid override pattern is the wrong " + f"length: got {override_pattern_length}, expected " + f"{total_layers_count}" + ) + for layer_type in layer_type_list: + if layer_type not in Symbols.VALID: + raise ValueError(f"In hybrid override pattern, '{layer_type}' is not " f"one of {Symbols.VALID}") + + return layer_type_list + + +def allocate_layers( + total_layers_count: int, + override_pattern: str, +) -> list: + """Allocate the layers for the hybrid model.""" + layer_type_list = _allocate_override(total_layers_count, override_pattern) + log_single_rank(logger, logging.INFO, "Using hybrid override pattern") + actual_hyena_short_layers_count = layer_type_list.count(Symbols.HYENA_SHORT) + actual_hyena_medium_layers_count = layer_type_list.count(Symbols.HYENA_MEDIUM) + actual_hyena_layers_count = layer_type_list.count(Symbols.HYENA) + actual_attention_layers_count = layer_type_list.count(Symbols.ATTENTION) + allocation_string = ''.join(layer_type_list) + log_single_rank( + logger, + logging.INFO, + f"Hybrid allocation ({Symbols.HYENA_SHORT} is hyena_short_conv, " + f"{Symbols.HYENA_MEDIUM} is hyena_medium_conv, " + f"{Symbols.HYENA} is hyena, " + f"{Symbols.ATTENTION} is attention, ", + ) + log_single_rank(logger, logging.INFO, allocation_string) + log_single_rank( + logger, + logging.INFO, + f"{actual_hyena_short_layers_count} heyna_short_conv layers in " f"{total_layers_count} total layers.", + ) + log_single_rank( + logger, + logging.INFO, + f"{actual_hyena_medium_layers_count} heyna_medium_conv layers in " f"{total_layers_count} total layers.", + ) + log_single_rank( + logger, + logging.INFO, + f"{actual_hyena_layers_count} heyna layers in " f"{total_layers_count} total layers.", + ) + log_single_rank( + logger, + logging.INFO, + f"{actual_attention_layers_count} attention layers in " f"{total_layers_count} total layers.", + ) + + return layer_type_list + + +if __name__ == "__main__": + test_cases = [ + (4, "SDH*"), + (8, "SSDDH*H*"), + ] + for t in test_cases: + print("") + allocate_layers(*t) diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py new file mode 100644 index 000000000000..9dc63d2d89c7 --- /dev/null +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer.py @@ -0,0 +1,135 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Union + +import torch +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from torch import Tensor + +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig + + +@dataclass +class HyenaLayerSubmodules: + """Submodules for the HyenaLayer.""" + + norm: Union[ModuleSpec, type] = IdentityOp + mixer: Union[ModuleSpec, type] = IdentityOp + hyena_bda: Union[ModuleSpec, type] = IdentityOp + + pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + mlp: Union[ModuleSpec, type] = IdentityOp + mlp_bda: Union[ModuleSpec, type] = IdentityOp + + +class HyenaLayer(MegatronModule): + """Top level Hyena Layer.""" + + def __init__( + self, + transformer_config: TransformerConfig, + hyena_config: HyenaConfig, + operator_type, + max_sequence_length, + submodules: HyenaLayerSubmodules, + layer_number: int = 1, + residual_in_fp32=False, + ): + """ + Top level Hyena Layer + """ + super().__init__(config=transformer_config) + self.transformer_config = transformer_config + self.hyena_config = hyena_config + self.layer_number = layer_number + self.hidden_dropout = transformer_config.hidden_dropout + self.residual_in_fp32 = residual_in_fp32 + self.mixer = build_module( + submodules.mixer, + self.transformer_config, + self.hyena_config, + max_sequence_length, + layer_number=layer_number, + operator_type=operator_type, + ) + self.norm = build_module( + submodules.norm, + self.transformer_config, + self.transformer_config.hidden_size, + eps=self.transformer_config.layernorm_epsilon, + ) + + self.hyena_bda = build_module(submodules.hyena_bda) + self.bias_dropout_add_exec_handler = torch.enable_grad + + self.pre_mlp_layernorm = build_module( + submodules.pre_mlp_layernorm, + config=self.transformer_config, + hidden_size=self.transformer_config.hidden_size, + eps=self.transformer_config.layernorm_epsilon, + ) + + self.mlp = build_module(submodules.mlp, config=self.transformer_config) + if hasattr(self.mlp, 'set_layer_number'): + self.mlp.set_layer_number(self.layer_number) + + self.mlp_bda = build_module(submodules.mlp_bda) + + self.bias_dropout_add_exec_handler = torch.enable_grad + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, # Not used in HyenaLayer + inference_params=None, + rotary_pos_emb: Tensor = None, # Not used in HyenaLayer + ): + """Forward pass for the HyenaLayer.""" + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] + + residual = hidden_states + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + hidden_states = hidden_states.to(dtype=self.transformer_config.params_dtype) + hidden_states = self.norm(hidden_states) + + mixer_out_with_bias = self.mixer(hidden_states, inference_params=inference_params) + + with self.bias_dropout_add_exec_handler(): + hidden_states = self.hyena_bda(self.training, self.transformer_config.bias_dropout_fusion)( + mixer_out_with_bias, residual, self.hidden_dropout + ) + + residual = hidden_states + + pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) + + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) + + with self.bias_dropout_add_exec_handler(): + hidden_states = self.mlp_bda(self.training, self.transformer_config.bias_dropout_fusion)( + mlp_output_with_bias, residual, self.hidden_dropout + ) + + return hidden_states diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py new file mode 100755 index 000000000000..09b3e23f0fe0 --- /dev/null +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_layer_specs.py @@ -0,0 +1,148 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.dot_product_attention import DotProductAttention +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules + +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_block import HyenaStack, HyenaStackSubmodules +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_layer import HyenaLayer, HyenaLayerSubmodules +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_mixer import HyenaMixer, HyenaMixerSubmodules + +try: + from megatron.core.extensions.transformer_engine import ( + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TENorm, + TERowParallelLinear, + ) + + HAVE_TE = True +except ImportError: + HAVE_TE = False + + def _raise_te_import_error(*args, **kwargs): + raise ImportError("Transformer Engine is not installed") + + # NeMo has a number of tests that make sure that you can initialize some modules without TE installed. + TENorm = _raise_te_import_error + TELayerNormColumnParallelLinear = _raise_te_import_error + TERowParallelLinear = _raise_te_import_error + TEDotProductAttention = _raise_te_import_error + +# Layer spec with TE modules +if HAVE_TE: + hyena_stack_spec = ModuleSpec( + module=HyenaStack, + submodules=HyenaStackSubmodules( + hyena_layer=ModuleSpec( + module=HyenaLayer, + submodules=HyenaLayerSubmodules( + mixer=ModuleSpec( + module=HyenaMixer, + submodules=HyenaMixerSubmodules( + dense_projection=TELayerNormColumnParallelLinear, dense=TERowParallelLinear + ), + ), + hyena_bda=get_bias_dropout_add, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + ), + ), + mlp_bda=get_bias_dropout_add, + ), + ), + attention_layer=ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + ), + ), + self_attn_bda=get_bias_dropout_add, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + ), + ), + mlp_bda=get_bias_dropout_add, + ), + ), + ), + ) +else: + hyena_stack_spec = ModuleSpec(module=None) + +# Layer spec without TE modules, for debugging + +hyena_stack_spec_no_te = ModuleSpec( + module=HyenaStack, + submodules=HyenaStackSubmodules( + hyena_layer=ModuleSpec( + module=HyenaLayer, + submodules=HyenaLayerSubmodules( + norm=TENorm, + mixer=ModuleSpec( + module=HyenaMixer, + submodules=HyenaMixerSubmodules(dense_projection=ColumnParallelLinear, dense=RowParallelLinear), + ), + hyena_bda=get_bias_dropout_add, + pre_mlp_layernorm=TENorm, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear), + ), + mlp_bda=get_bias_dropout_add, + ), + ), + attention_layer=ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=TENorm, + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=ColumnParallelLinear, + core_attention=DotProductAttention, + linear_proj=RowParallelLinear, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=TENorm, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear), + ), + mlp_bda=get_bias_dropout_add, + ), + ), + ), +) diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py new file mode 100644 index 000000000000..eceeb09e59d5 --- /dev/null +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py @@ -0,0 +1,258 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import dataclass +from typing import Union + +import torch +import torch.nn as nn +from einops import rearrange +from megatron.core.parallel_state import ( + get_context_parallel_group, + get_context_parallel_world_size, + get_tensor_model_parallel_world_size, +) +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.utils import sharded_state_dict_default + +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils import ( + ParallelCausalDepthwiseConv1d, + ParallelHyenaOperator, + ParallelShortHyenaOperator, + divide, +) + +logger = logging.getLogger(__name__) + +try: + from transformer_engine.common.recipe import DelayedScaling, Format +except ImportError: + logger.warning("WARNING: transformer_engine not installed. Using default recipe.") + + +def set_format_recipe(): + """Set the fp8 format recipe. for Hyena""" + fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") + return fp8_recipe + + +@dataclass +class HyenaMixerSubmodules: + """ + Contains the module specs for the input and output linear layers. + """ + + dense_projection: Union[ModuleSpec, type] = None + dense: Union[ModuleSpec, type] = None + + +class HyenaMixer(MegatronModule): + """ + A class for the HyenaMixer. + """ + + def __init__( + self, + transformer_config: TransformerConfig, + hyena_config: HyenaConfig, + max_sequence_length, + submodules, + layer_number=1, + operator_type="H", + is_mlp=False, # TODO: Check if needed, only used when using Hyena for the MLP block + ): + + super().__init__(transformer_config) + self.transformer_config = transformer_config + self.hyena_config = hyena_config + self.is_mlp = is_mlp + self.operator_type = operator_type + self.layer_number = layer_number + self.grouped_attention = self.hyena_config.grouped_attention + + self.fast_conv_proj = self.hyena_config.fast_conv_proj + self.fast_conv_mixer = self.hyena_config.fast_conv_mixer + + # Per attention head and per partition values. + assert torch.distributed.is_initialized() + self.model_parallel_size = get_tensor_model_parallel_world_size() + world_size: int = get_tensor_model_parallel_world_size() + + # Width expansion for Hyena depending on if it's a mixer of mlp + if self.is_mlp: + self.hyena_width_expansion = self.hyena_config.hyena_mlp_expansion_factor + else: + self.hyena_width_expansion = self.hyena_config.hyena_width_expansion + + # we might expand the hidden size for hyena + self.input_size = self.transformer_config.hidden_size + self.hidden_size = int(self.transformer_config.hidden_size * self.hyena_width_expansion) + + # ensures parallizable + if self.hyena_width_expansion > 1: + multiple_of = 32 + self.hidden_size = int(multiple_of * ((self.hidden_size + multiple_of - 1) // multiple_of)) + + # checks on the hidden size divisibility + assert ( + self.hidden_size % world_size == 0 + ), f"Hidden size {self.hidden_size} is not divisible by the world size {world_size}" + self.hidden_size_per_partition = divide(self.hidden_size, world_size) + self.proj_groups = self.hyena_config.proj_groups + + self.tie_projection_weights = self.hyena_config.tie_projection_weights + + self.grouped_proj_size = self.transformer_config.hidden_size // self.proj_groups + + # Strided linear layer. + if self.tie_projection_weights: + # we'll repeat the output 3 times instead + projections_size = self.hidden_size + else: + projections_size = 3 * self.hidden_size + + # qkv projections + self.dense_projection = build_module( + submodules.dense_projection, + self.input_size, + projections_size, + config=self.transformer_config, + init_method=self.transformer_config.init_method, + gather_output=False, + bias=False, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='fc1', + ) + + hyena_proj_groups = self.proj_groups if not self.grouped_attention else 1 + grouped_proj_size = self.hidden_size_per_partition // hyena_proj_groups + self.hyena_proj_conv = ParallelCausalDepthwiseConv1d( + self.hidden_size_per_partition + 2 * grouped_proj_size, + self.transformer_config, + self.hyena_config, + kernel_size=self.hyena_config.short_conv_L, + init_method=transformer_config.init_method, + bias=self.hyena_config.conv_proj_bias, + use_fast_causal_conv=self.fast_conv_proj, + ) + + if self.operator_type == "hyena_short_conv": + self.num_groups = self.hyena_config.num_groups_hyena_short + self.num_groups_per_tp_rank = self.num_groups // self.model_parallel_size + + self.mixer = ParallelShortHyenaOperator( + self.hidden_size, # pass hidden size here to avoid recalculating + self.transformer_config, + self.hyena_config, + self.transformer_config.init_method, + short_conv_class=ParallelCausalDepthwiseConv1d, + use_fast_causal_conv=self.fast_conv_mixer, + is_mlp=self.is_mlp, + use_conv_bias=self.transformer_config.use_short_conv_bias, + ) + + if self.operator_type in [ + "hyena", + "hyena_medium_conv", + ]: + if self.operator_type == "hyena_medium_conv": + self.num_groups = self.hyena_config.num_groups_hyena_medium + else: + self.num_groups = self.hyena_config.num_groups_hyena + self.num_groups_per_tp_rank = self.num_groups // self.model_parallel_size + + self.mixer = ParallelHyenaOperator( + self.hidden_size, # pass hidden size here to avoid recalculating + self.transformer_config, + self.hyena_config, + self.transformer_config.init_method, + operator_type, + max_sequence_length, + downsample_factor=1, + ) + + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.dropout_p = self.transformer_config.attention_dropout + self.attention_dropout = nn.Dropout(self.dropout_p) + + self.dense = build_module( + submodules.dense, + self.hidden_size, + self.input_size, + config=self.transformer_config, + init_method=self.transformer_config.output_layer_init_method, + bias=True, + input_is_parallel=True, + skip_bias_add=True, + is_expert=False, + tp_comm_buffer_name='fc2', + ) + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """ + Sharded state dictionary for the HyenaMixer. + """ + sharded_state_dict = {} + # Submodules + for name, module in self.named_children(): + if name != 'attention_dropout': + module_sharded_sd = sharded_state_dict_default(module, f'{prefix}{name}.', sharded_offsets, metadata) + + sharded_state_dict.update(module_sharded_sd) + + return sharded_state_dict + + def forward(self, x, layer_past=None, inference_params=None, _hyena_use_cp=True): + """ + Applies sequence mixing to a sequence of 1-dimensional embeddings: batch_size, seq_len, d_model + + Args: + u: input to the operator, in format [B, L, D] + """ + # CP control + if _hyena_use_cp: + cp_group = get_context_parallel_group() + else: + cp_group = None + + if cp_group is not None and get_context_parallel_world_size() > 1: + _proj_use_cp = True + else: + _proj_use_cp = False + + features, _ = self.dense_projection(x) + features = rearrange(features, "l b d -> b l d").contiguous() + features_L_last = features.permute(0, 2, 1) + features_D_last = self.hyena_proj_conv(features_L_last, _use_cp=_proj_use_cp).permute(0, 2, 1) + + x1, x2, v = rearrange( + features_D_last, "b l (g dg p) -> b l g p dg", p=3, g=self.num_groups_per_tp_rank + ).unbind(dim=3) + + z = self.mixer(x1, x2, v) + z = rearrange(z, "b l d -> l b d").contiguous() + + y, bias = self.dense(z) + return y, bias diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_model.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_model.py new file mode 100644 index 000000000000..1895e79ff1b7 --- /dev/null +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_model.py @@ -0,0 +1,292 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import Literal, Optional + +import torch +from megatron.core import InferenceParams, parallel_state, tensor_parallel +from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk +from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from megatron.core.models.common.language_module.language_module import LanguageModule +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import TransformerLayer +from torch import Tensor +from torch.nn.parameter import Parameter + +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils import ( + get_init_method, + make_upper_case, + reweighted_cross_entropy, +) + + +class HyenaModel(LanguageModule): + """ + A class for the HyenaModel. + """ + + def __init__( + self, + transformer_config: TransformerConfig, # Actually a hyena.HyenaConfig but avoid circular import + hyena_stack_spec: ModuleSpec, + vocab_size: int, + max_sequence_length: int, + num_groups_hyena: int, + num_groups_hyena_medium: int, + num_groups_hyena_short: int, + pre_process: bool = True, + hybrid_override_pattern: str = None, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + post_layer_norm: bool = True, + share_embeddings_and_output_weights: bool = True, + position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'rope', + rotary_percent: float = 1.0, + rotary_base: int = 10000, + seq_len_interpolation_factor: Optional[float] = None, + hyena_init_method: str = None, + hyena_output_layer_init_method: str = None, + remove_activation_post_first_layer: bool = True, + add_attn_proj_bias: bool = True, + ) -> None: + super().__init__(config=transformer_config) + + self.transformer_config = transformer_config + self.hyena_config = HyenaConfig() + + # Override HyenaConfig fields with user provided values + self.hyena_config.num_groups_hyena = num_groups_hyena + self.hyena_config.num_groups_hyena_medium = num_groups_hyena_medium + self.hyena_config.num_groups_hyena_short = num_groups_hyena_short + if hyena_init_method: + self.transformer_config.init_method = get_init_method( + hyena_init_method, self.transformer_config.num_layers, self.transformer_config.hidden_size + ) + if hyena_output_layer_init_method: + self.transformer_config.output_layer_init_method = get_init_method( + hyena_output_layer_init_method, self.transformer_config.num_layers, self.transformer_config.hidden_size + ) + + if has_config_logger_enabled(transformer_config): + log_config_to_disk(transformer_config, locals(), prefix=type(self).__name__) + + self.hyena_stack_spec: ModuleSpec = hyena_stack_spec + self.vocab_size = vocab_size + self.max_sequence_length = max_sequence_length + self.pre_process = pre_process + self.hybrid_override_pattern = hybrid_override_pattern + self.post_process = post_process + self.fp16_lm_cross_entropy = fp16_lm_cross_entropy + self.parallel_output = parallel_output + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.position_embedding_type = position_embedding_type + self.post_layer_norm = post_layer_norm + # megatron core pipelining currently depends on model type + # TODO: remove this dependency ? + self.model_type = ModelType.encoder_or_decoder + + if self.pre_process: + self.embedding = LanguageModelEmbedding( + config=self.transformer_config, + vocab_size=self.vocab_size, + max_sequence_length=self.max_sequence_length, + position_embedding_type=position_embedding_type, + ) + + if self.position_embedding_type == 'rope': + self.rotary_pos_emb = RotaryEmbedding( + kv_channels=self.transformer_config.kv_channels, + rotary_percent=rotary_percent, + seq_len_interpolation_factor=seq_len_interpolation_factor, + rotary_base=rotary_base, + use_cpu_initialization=self.transformer_config.use_cpu_initialization, + ) + + self.decoder = build_module( + hyena_stack_spec, + self.transformer_config, + self.hyena_config, + hybrid_override_pattern=self.hybrid_override_pattern, + max_sequence_length=self.max_sequence_length, + pre_process=self.pre_process, + post_process=self.post_process, + post_layer_norm=self.post_layer_norm, + ) + + # In some Hyena species, the published checkpoint has identity activations after the first + # MLP block, so we replicate this behavior in this implementation if remove_activation_post_first_layer. + self.remove_activation_post_first_layer = remove_activation_post_first_layer + if self.remove_activation_post_first_layer: + if parallel_state.is_pipeline_first_stage(): + # Skip the first layer of the global model for this activation patch. + start_idx = 1 + else: + start_idx = 0 + mlp_no_act_config = deepcopy(self.decoder.layers[start_idx].mlp.config) + mlp_no_act_config.activation_func = lambda x: x + for hyena_layer in self.decoder.layers[start_idx:]: + hyena_layer.mlp.activation_func = mlp_no_act_config.activation_func + hyena_layer.mlp.config = mlp_no_act_config + + # In some Hyena species, the published checkpoint always has a bias in the linear projection + # of the self-attention layers regardless of bias in other linear layers. + self.add_attn_proj_bias = add_attn_proj_bias + if self.add_attn_proj_bias and not self.config.add_bias_linear: + for layer in self.decoder.layers: + if isinstance(layer, TransformerLayer): + linear_proj = layer.self_attention.linear_proj + output_size = linear_proj.weight.shape[0] + linear_proj.bias = Parameter( + torch.empty( + output_size, dtype=linear_proj.config.params_dtype, device=linear_proj.weight.device + ) + ) + # Always initialize bias to zero. + with torch.no_grad(): + linear_proj.bias.zero_() + setattr(linear_proj.bias, 'allreduce', True) + setattr(linear_proj, 'te_return_bias', True) + setattr(linear_proj, 'return_bias', True) + setattr(linear_proj, 'use_bias', True) + setattr(linear_proj.bias, 'sequence_parallel', linear_proj.config.sequence_parallel) + + # Output + if post_process: + if self.config.defer_embedding_wgrad_compute: + # The embedding activation buffer preserves a reference to the input activations + # of the final embedding projection layer GEMM. It will hold the activations for + # all the micro-batches of a global batch for the last pipeline stage. Once we are + # done with all the back props for all the microbatches for the last pipeline stage, + # it will be in the pipeline flush stage. During this pipeline flush we use the + # input activations stored in embedding activation buffer and gradient outputs + # stored in gradient buffer to calculate the weight gradients for the embedding + # final linear layer. + self.embedding_activation_buffer = [] + self.grad_output_buffer = [] + else: + self.embedding_activation_buffer = None + self.grad_output_buffer = None + self.output_layer = tensor_parallel.ColumnParallelLinear( + transformer_config.hidden_size, + self.vocab_size, + config=transformer_config, + init_method=transformer_config.init_method, + bias=self.config.add_bias_output, + skip_bias_add=False, + gather_output=not self.parallel_output, + skip_weight_param_allocation=self.pre_process and self.share_embeddings_and_output_weights, + ) + if self.config.add_bias_output: + self.output_layer.bias.data.zero_() + + if self.pre_process or self.post_process: + self.setup_embeddings_and_output_layer() + + def set_input_tensor(self, input_tensor: Tensor) -> None: + """Sets input tensor to the model. + + See megatron.model.transformer.set_input_tensor() + + Args: + input_tensor (Tensor): Sets the input tensor for the model. + """ + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + + assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert' + self.decoder.set_input_tensor(input_tensor[0]) + + def forward( + self, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + decoder_input: Tensor = None, + labels: Tensor = None, + loss_mask: Tensor = None, + inference_params: InferenceParams = None, + ) -> Tensor: + """Forward pass for the HyenaModel.""" + # If decoder_input is provided (not None), then input_ids and position_ids are ignored. + # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. + + # Decoder embedding. + if decoder_input is not None: + pass + elif self.pre_process: + decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + else: + # intermediate stage of pipeline + # decoder will get hidden_states from encoder.input_tensor + decoder_input = None + + rotary_pos_emb = None + if self.position_embedding_type == 'rope': + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_params, self.decoder, decoder_input, self.transformer_config, None + ) + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + + # The following assert will currently fail when running inference. + # Commented out for now. + # TODO (duncan/rwaleffe): (1) confirm that the externally-generated + # attention mask is not needed and is ignored by the model in + # inference mode, (2) reduce the size of the externally-generated + # attention mask to prevent CPU OOM (as we did for training), (3) + # force the attention mask passed to the model in inference mode to + # be None, so this assert will succeed. + # assert attention_mask is None, "The attention mask is ignored and should be set to None" + + # Run decoder. + hidden_states = self.decoder( + hidden_states=decoder_input, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + ) + + if not self.post_process: + return hidden_states + + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + logits, _ = self.output_layer(hidden_states, weight=output_weight) + + if labels is None: + # [s b h] => [b s h] + return logits.transpose(0, 1).contiguous() + + labels, lowercase_mask = make_upper_case(labels) + loss = self.compute_language_model_loss(labels, logits) + normalize_per_batch = True if self.config.to_upper == "normalized_weighted" else False + loss = reweighted_cross_entropy( + loss, + (labels, loss_mask, lowercase_mask), + lowercase_weight=self.hyena_config.lowercase_loss_reweighting, + normalize_per_batch=normalize_per_batch, + ) + return loss diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py new file mode 100644 index 000000000000..1c8870479030 --- /dev/null +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py @@ -0,0 +1,1687 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from megatron.core.parallel_state import ( + get_context_parallel_group, + get_context_parallel_rank, + get_context_parallel_world_size, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from megatron.core.tensor_parallel import get_cuda_rng_tracker +from megatron.core.transformer.transformer_config import TransformerConfig + +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig + +try: + from flashfftconv import FlashFFTConv +except ImportError: + + def FlashFFTConv(*args, **kwargs): + """Not imported: FlashFFTConv. An error will be raised if this is called.""" + raise Exception("Not imported: FlashFFTConv") + + +try: + # Default implementation does not make use of these features but they are included for completeness and + # for future testing. See the savanna repository at https://github.com/Zymrael/savanna/. These functions + # are not currently used in nemo or bionemo tutorials. + from savanna.kernels.triton_src.cgcg.interface import two_pass_chunked_gate_conv_gate + from savanna.kernels.triton_src.cgcg.src.kernel_utils import BwdKernelConfigRefactor, FwdKernelConfigRefactor + from savanna.kernels.triton_src.short_hyena.interface import run_short_hyena + from savanna.kernels.triton_src.short_hyena.src.kernel_utils import ( + PostConvKernelConfig, + PreConvKernelConfig, + ShortHyenaOperatorKernelConfig, + ) +except ImportError: + + def two_pass_chunked_gate_conv_gate(*args, **kwargs): + """Not imported: two_pass_chunked_gate_conv_gate. An error will be raised if this is called.""" + raise Exception("Not imported: two_pass_chunked_gate_conv_gate") + + def run_short_hyena(*args, **kwargs): + """Not imported: run_short_hyena. An error will be raised if this is called.""" + raise Exception("Not imported: run_short_hyena") + + def PreConvKernelConfig(*args, **kwargs): + """Not imported: PreConvKernelConfig. An error will be raised if this is called.""" + raise Exception("Not imported: PreConvKernelConfig") + + def PostConvKernelConfig(*args, **kwargs): + """Not imported: PostConvKernelConfig. An error will be raised if this is called.""" + raise Exception("Not imported: PostConvKernelConfig") + + def ShortHyenaOperatorKernelConfig(*args, **kwargs): + """Not imported: ShortHyenaOperatorKernelConfig. An error will be raised if this is called.""" + raise Exception("Not imported: ShortHyenaOperatorKernelConfig") + + def BwdKernelConfigRefactor(*args, **kwargs): + """Not imported: BwdKernelConfigRefactor. An error will be raised if this is called.""" + raise Exception("Not imported: BwdKernelConfigRefactor") + + def FwdKernelConfigRefactor(*args, **kwargs): + """Not imported: FwdKernelConfigRefactor. An error will be raised if this is called.""" + raise Exception("Not imported: FwdKernelConfigRefactor") + + +try: + from einops import rearrange +except ImportError: + raise ImportError("einops is required by the Hyena model but cannot be imported") + +try: + from causal_conv1d import causal_conv1d_fn +except ImportError: + raise ImportError("causal_conv1d is required by the Hyena model but cannot be imported") + +from typing import Literal + +# CP related utils +import torch.distributed as dist +from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default + + +def _get_zigzag_indices(N, device=None): + """ + Generates the zigzag indices for rearrangement. + Args: + N (int): The total number of chunks. + device (torch.device): The device on which to create tensors. + Returns: + torch.Tensor: The zigzag indices. + """ + half_N = (N + 1) // 2 + idx1 = torch.arange(half_N, device=device) + idx2 = torch.arange(N - 1, half_N - 1, -1, device=device) + zigzag_idx = torch.empty(N, dtype=torch.long, device=device) + zigzag_idx[0::2] = idx1 + zigzag_idx[1::2] = idx2 + return zigzag_idx + + +def _get_inverse_zigzag_indices(N, device=None): + """ + Generates the inverse zigzag indices for rearrangement. + Args: + N (int): The total number of chunks. + device (torch.device): The device on which to create tensors. + Returns: + torch.Tensor: The inverse zigzag indices. + """ + half_N = N // 2 + idx1 = torch.arange(half_N, device=device) + idx2 = torch.arange(N - 1, half_N - 1, -1, device=device) + zigzag_idx = torch.empty(N, dtype=torch.long, device=device) + zigzag_idx[0::2] = idx1 + zigzag_idx[1::2] = idx2 + inverse_zigzag_idx = torch.argsort(zigzag_idx) + return inverse_zigzag_idx + + +def all_to_all_single_fn( + group: dist.ProcessGroup, + type: Literal["split_to_full", "full_to_split"], + input: torch.Tensor, + with_zigzag_splitting: bool = True, +) -> torch.Tensor: + """ + Autograd-aware all_to_all_single communication function. + Args: + group (dist.ProcessGroup): The process group for communication. + type (str): Either 'split_to_full' or 'full_to_split' to specify the communication pattern. + input (torch.Tensor): Input tensor to be communicated. + with_zigzag_splitting (bool, optional): Whether to apply zigzag splitting. Defaults to True. + Returns: + torch.Tensor: Output tensor after communication. + """ + + world_size = dist.get_world_size(group=group) + + if type == "split_to_full": + """Given an split sequence, it gathers the whole sequence, while splitting across the channels dimension.""" + + B, D, local_length = input.shape + L = local_length * world_size + d = D // world_size + + # Reshape and permute input for communication + input_reshaped = rearrange( + input, "B (cp d) l -> cp B d l", cp=world_size + ).contiguous() # [cp_world_size, B, d, l] + + # Perform all_to_all_single communication + output_reshaped = torch.empty_like(input_reshaped) + dist.all_to_all_single(output_reshaped, input_reshaped, group=group) # [cp_world_size, B, d, l] + + # Permute and reshape output back to original form + output = rearrange(output_reshaped, "cp B d l -> B d (cp l)", cp=world_size).contiguous() + + if with_zigzag_splitting: + num_chunks = 2 * world_size + unzigzagged_split_length = L // num_chunks # Length of each small chunk + device = output.device + inverse_zigzag_idx = _get_inverse_zigzag_indices(num_chunks, device=device) + + # Vectorized rearrangement using inverse zigzag indices + output = ( + output.reshape(B, d, num_chunks, unzigzagged_split_length) + .index_select(dim=-2, index=inverse_zigzag_idx) + .reshape(B, d, L) + ) + + return output + + elif type == "full_to_split": + """ + Given a full sequence split across channels, splits across the sequence length while gathering the channels. + """ + + B, d, L = input.shape + + if with_zigzag_splitting: + num_chunks = 2 * world_size + chunk_length = L // num_chunks # Length of each small chunk + device = input.device + zigzag_idx = _get_zigzag_indices(num_chunks, device=device) + + # Ensure L is divisible by num_chunks + if L % num_chunks != 0: + raise ValueError(f"Sequence length {L} is not divisible by num_chunks {num_chunks}") + + # Vectorized rearrangement using zigzag indices + input = ( + input.reshape(B, d, num_chunks, chunk_length).index_select(dim=-2, index=zigzag_idx).reshape(B, d, L) + ) + + # Reshape and permute inputs for communication + input_reshaped = rearrange( + input, "b d (cp l) -> cp b d l", cp=world_size + ).contiguous() # [cp_world_size, b, d, l] + + # Perform all_to_all_single communication + output_reshaped = torch.empty_like(input_reshaped) + dist.all_to_all_single(output_reshaped, input_reshaped, group=group) # [cp_world_size, B, d, l] + + # Permute and reshape outputs back to original form + output = rearrange(output_reshaped, "cp b d l -> b (cp d) l", cp=world_size).contiguous() + + return output + + else: + raise ValueError(f"Unknown type {type}") + + +from torch.autograd.function import Function + + +class AllToAllSingleFunction(Function): + """ + A custom autograd function for performing all_to_all_single communication with optional zigzag splitting. + Attributes: + - ctx: A context object that stores information for the forward and backward passes. + - group: The process group for communication. + - type: The type of communication pattern ('split_to_full' or 'full_to_split'). + - with_zigzag_splitting: A boolean indicating whether to apply zigzag splitting. + """ + + @staticmethod + def forward(ctx, input_tensor, group, type, with_zigzag_splitting): + """ + Forward pass for the AllToAllSingleFunction. + """ + ctx.group = group + ctx.type = type + ctx.with_zigzag_splitting = with_zigzag_splitting + + # Detach input_tensor to prevent PyTorch from tracking operations inside the communication + input_tensor = input_tensor.detach() + + # Perform the communication operation + output = all_to_all_single_fn( + group=ctx.group, type=ctx.type, input=input_tensor, with_zigzag_splitting=ctx.with_zigzag_splitting + ) + + return output + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass for the AllToAllSingleFunction. + """ + # The backward pass will perform the reverse communication + grad_input = all_to_all_single_fn( + group=ctx.group, + type="split_to_full" if ctx.type != "split_to_full" else "full_to_split", + input=grad_output, + with_zigzag_splitting=ctx.with_zigzag_splitting, + ) + # Return the gradient w.r.t. the input_tensor and None for other arguments + return grad_input, None, None, None + + +def zigzag_get_overlapping_patches(data, seq_dim, overlap_size): + """ + Extracts the overlapping patches from data in each rank. + Arguments: + data (torch.Tensor): The concatenated data (chunk_a and chunk_b), e.g., [0, 3] , [1, 2] with zigzag and 2 GPUs. + seq_dim (int): The sequence dimension along which the data is concatenated. + overlap_size (int): The size of the overlapping patch. + Returns: + overlap_a, overlap_b (torch.Tensor): The overlapping chunks from the data. That is the end of the lowest, and + the beginning of the last, e.g., end for 0 and start for 3. + """ + assert seq_dim >= 0, "Negative indexes not supported." + + data_shape = list(data.shape) + modified_shape = list(data.shape) + modified_shape[seq_dim : seq_dim + 1] = [2, data_shape[seq_dim] // 2] + + reshaped_data = torch.reshape(data, modified_shape) + + # Move the dimension of the chunks to the first position + # Create a permutation where seq_dim is moved to position 0 + permute_order = list(range(len(reshaped_data.shape))) + permute_order.insert(0, permute_order.pop(seq_dim)) # Move seq_dim to index 0 + + reshaped_data = reshaped_data.permute(dims=permute_order) + + seq_len = reshaped_data.shape[seq_dim + 1] # Remember that a new dimension was added. + overlapping_patches = reshaped_data.narrow( + dim=seq_dim + 1, start=seq_len - overlap_size, length=overlap_size + ) # Last n elements. + return overlapping_patches[0], overlapping_patches[1] + + +class ExchangeOverlappingRegionsCausal(Function): + """ + A custom autograd function for exchanging overlapping regions between chunks of data in a causal manner. + The data is split across multiple GPUs using a distributed process group. + The forward method handles the exchange of overlapping regions between chunks, while the backward + method computes the gradients. + Attributes: + - ctx: A context object that stores information for the forward and backward passes. + - chunk_a: Chunk to pass to the left. + - chunk_b: Chunk to pass to the right. + - group: The CP group + - group_rank: The rank in the cp_group. + """ + + @staticmethod + def forward(ctx, chunk_a, chunk_b, group, group_rank): + """ + Forward pass for the ExchangeOverlappingRegionsCausal function. + """ + group_ranks = dist.get_process_group_ranks(group) # Get all global ranks in the cp_group + group_world_size = len(group_ranks) # Size of the cp_group + + ctx.group = group + ctx.group_rank = group_rank + ctx.group_world_size = group_world_size + ctx.group_ranks = group_ranks + + # Initialize requests + reqs = [] + + # Exchange overlaps for chunk_a + if group_rank > 0: + # Receive overlap from previous rank + recv_shape = list(chunk_a.shape) + recv_prev_a = torch.empty(recv_shape, dtype=chunk_a.dtype, device=chunk_a.device) + req_recv_a = dist.irecv(recv_prev_a, src=group_ranks[group_rank - 1]) + reqs.append(req_recv_a) + else: + recv_prev_a = None + + if group_rank < group_world_size - 1: + # Send overlap to next rank + req_send_a = dist.isend(chunk_a.contiguous(), dst=group_ranks[group_rank + 1]) + reqs.append(req_send_a) + + # Exchange overlaps for chunk_b + if group_rank < group_world_size - 1: + # Receive overlap from next rank + recv_shape = list(chunk_b.shape) + recv_next_b = torch.empty(recv_shape, dtype=chunk_b.dtype, device=chunk_b.device) + req_recv_b = dist.irecv(recv_next_b, src=group_ranks[group_rank + 1]) + reqs.append(req_recv_b) + else: + recv_next_b = None + + if group_rank > 0: + # Send overlap to previous rank + req_send_b = dist.isend(chunk_b.contiguous(), dst=group_ranks[group_rank - 1]) + reqs.append(req_send_b) + + # Wait for all communication to finish + for req in reqs: + req.wait() + + # If no chunks received, use zeros instead (for consistency) + if recv_prev_a is None: + recv_prev_a = torch.zeros_like(chunk_a, dtype=chunk_a.dtype, device=chunk_a.device) + if recv_next_b is None: + recv_next_b = chunk_a.clone().contiguous() # Got to receive from the same rank, but previous split. + + return recv_prev_a, recv_next_b + + @staticmethod + def backward(ctx, grad_chunk_a, grad_chunk_b): + """ + Backward pass for the ExchangeOverlappingRegionsCausal function. + """ + # chunk_a, chunk_b = ctx.saved_tensors + group_rank = ctx.group_rank + group_world_size = ctx.group_world_size + group_ranks = ctx.group_ranks + + # Initialize gradients with zeros + _grad_chunk_a = torch.zeros_like(grad_chunk_a) + _grad_chunk_b = torch.zeros_like(grad_chunk_b) + + # Initialize requests + reqs = [] + + # Handling grad_chunk_a + + # If rank > 0, send grad_recv_prev_a to rank - 1 + if group_rank > 0: + req_send_a = dist.isend(grad_chunk_a.contiguous(), dst=group_ranks[group_rank - 1]) + reqs.append(req_send_a) + else: + # At rank 0, there's no previous rank to receive from, so we only consider local gradient contributions + pass # No action needed + + # If rank < world_size - 1, receive grad_chunk_a from rank + 1 + if group_rank < group_world_size - 1: + grad_chunk_a_recv = torch.empty_like(grad_chunk_a) + req_recv_a = dist.irecv(grad_chunk_a_recv, src=group_ranks[group_rank + 1]) + reqs.append(req_recv_a) + + # Handling grad_chunk_b + + # If rank < world_size - 1, send grad_recv_next_b to rank + 1 + if group_rank < group_world_size - 1: + req_send_b = dist.isend(grad_chunk_b.contiguous(), dst=group_ranks[group_rank + 1]) + reqs.append(req_send_b) + + # If rank > 0, receive grad_chunk_b from rank - 1 + if group_rank > 0: + grad_chunk_b_recv = torch.empty_like(grad_chunk_b) + req_recv_b = dist.irecv(grad_chunk_b_recv, src=group_ranks[group_rank - 1]) + reqs.append(req_recv_b) + + # Wait for all communication to finish + for req in reqs: + req.wait() + + # Add received gradients + if group_rank < group_world_size - 1: + _grad_chunk_a = grad_chunk_a_recv + + if group_rank > 0: + _grad_chunk_b = grad_chunk_b_recv + + if group_rank == group_world_size - 1: + _grad_chunk_a = grad_chunk_b # In the last split, the chunks are exchanged locally. + + return _grad_chunk_a, _grad_chunk_b, None, None, None + + +# End of CP related functions + + +def hyena_no_weight_decay_cond(name, param): + """ + Condition for no weight decay for Hyena parameters. + """ + # ImplicitModalFilter parameters + if name.endswith('filter.p') or name.endswith('filter.R') or name.endswith('filter.gamma'): + no_wd = True + + # ExplicitSingleDecayFilter parameters + elif name.endswith('filter.h'): + no_wd = True + + # TODO: Add overrides for other filter types if needed + # Alternatively - do something broader, like `if '.filter.' in name` ??? + + # ParallelShortHyenaOperator parameters --> The parameters of the internal ParallelCausalDepthwiseConv1d object + elif name.endswith('short_conv.short_conv_weight'): + no_wd = True + + # All other parameters - use default MCore behavior: + # Do not regularize biases and norm parameters + # (See megatron.core.optimizer._get_pram_groups) + else: + no_wd = name.endswith(".bias") or len(param.shape) == 1 + + return no_wd + + +@torch.jit.script +def _mul_sum(y, q): + """ + Multiply and sum the elements of two tensors along dimension 1. + """ + return (y * q).sum(dim=1) + + +def fftconv_func(u, k, D, dropout_mask, gelu=True, k_rev=None, bidirectional=False): + """Apply a 1D convolution to the input sequence u using the filter k and the shortcut D.""" + seqlen = u.shape[-1] + fft_size = 2 * seqlen + + # check if k is less than seqlen + if k.shape[-1] < seqlen: + # Pad the filter k to the length of the input sequence u + k = torch.nn.functional.pad(k, (0, seqlen - k.shape[-1])) + + # bidirectional + if bidirectional: + u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) + + # split k along the channel dimension + k, k2 = k.split(k.shape[1] // 2, dim=1) + + # get fft of both k's + k_f = torch.fft.rfft(k, n=fft_size) / fft_size + k2_f = torch.fft.rfft(k2, n=fft_size) / fft_size + + if len(u.shape) > 3: + k_f = k_f.unsqueeze(1) + k2_f = k2_f.unsqueeze(1) + + y1 = u_f * k_f + y2 = u_f.conj() * k2_f.conj() + + y = torch.fft.irfft(y1 + y2, n=fft_size, norm="forward")[..., :seqlen] + + # causal + else: + k_f = torch.fft.rfft(k, n=fft_size) / fft_size + if k_rev is not None: + k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size + k_f = k_f + k_rev_f.conj() + + u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) + + if len(u.shape) > 3: + k_f = k_f.unsqueeze(1) + + y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen] + + out = y + u * D.unsqueeze(-1) + if gelu: + out = F.gelu(out) + if dropout_mask is not None: + return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype) + else: + return out.to(dtype=u.dtype) + + +class ImplicitModalFilter(nn.Module): + """ + An implicit modal filter. + """ + + def __init__( + self, + d_model, + order=64, + L_cache=None, + gamma_min=0.01, + gamma_max=0.1, + lr=None, + ): + super().__init__() + self.order = order + self.d_model = d_model + # Do not register into buffer, so it doesn't cast to BF16! + self.t = rearrange(torch.arange(L_cache, dtype=torch.float32), "L -> 1 1 L").to( + device=torch.cuda.current_device() + ) # <- this should be arange + self.use_cached_t = False + with get_cuda_rng_tracker().fork(): + gamma = torch.rand(self.d_model, order, dtype=torch.float32) * (gamma_max - gamma_min) + gamma_min + gamma = gamma.cuda().log() + self.gamma = nn.Parameter(gamma) + + R = 1e-1 * torch.randn(d_model, order, dtype=torch.float32) / math.sqrt(order) + self.R = nn.Parameter(R) + self.p = nn.Parameter(-torch.ones(d_model, order, dtype=torch.float32)) + setattr(self.gamma, 'tensor_model_parallel', True) + setattr(self.R, 'tensor_model_parallel', True) + setattr(self.p, 'tensor_model_parallel', True) + + def get_t(self, L): + """ + Get the t tensor. + """ + # Assumes L <= L_cache + if self.use_cached_t: + return self.t[..., :L] + + t = rearrange(torch.arange(L, dtype=torch.float32, device=self.t.device), "L -> 1 1 L") + self.t = t + self.use_cached_t = True + + return t + + def compute_filter(self, L, t): + """ + Compute the filter for convolution. + """ + assert ( + t.dtype == torch.float32 + ), f"t must be float32. At lower precision, indexes will be merged together. Current dtype: {t.dtype}" + # TODO: See if we can get this kernel to stay FP32. We can but it does not work with the distributed optimizer. + # assert ( + # self.p.dtype == torch.float32 + # ), f"p must be float32. At lower precision, indexes will be merged together. Current dtype: {self.p.dtype}" + # assert ( + # self.gamma.dtype == torch.float32 + # ), ("gamma must be float32. At lower precision, indexes will be merged together. " + # f"Current dtype: {self.gamma.dtype}") + # assert ( + # self.R.dtype == torch.float32 + # ), f"R must be float32. At lower precision, indexes will be merged together. Current dtype: {self.R.dtype}" + + logp = -torch.exp(self.p.to(torch.float32)) + glogp = logp * torch.exp(self.gamma.to(torch.float32)) + h = torch.exp(glogp[..., None] * t) + h = torch.einsum('do,dot->dt', self.R.to(torch.float32), h) + h = h[None] + + return h, None + + def filter(self, L, *args, **kwargs): + """ + Get t and the convolution filter for t and the requested sequence length. + """ + t = self.get_t(L) + h = self.compute_filter(L, t) + return h + + def forward(self, L, **kwargs): + """ + Return the final convolutional filter for the requested sequence length. + """ + return self.filter(L) + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """Sharding along axis 0, bias not sharded""" + state_dict = self.state_dict(prefix='', keep_vars=True) + return make_sharded_tensors_for_checkpoint(state_dict, prefix, {'gamma': 0, 'R': 0, 'p': 0}, sharded_offsets) + + +class ExplicitSingleDecayFilter(nn.Module): + """ + An explicit single decay filter. + """ + + def __init__( + self, + d_model, + L_cache, + log_r_min=0, + log_r_max=2, + unit_passthrough=False, + decay_preset="strong", + small_init=True, + num_decay_repeats=1, + ): + super().__init__() + with get_cuda_rng_tracker().fork(): + h = torch.randn(d_model, L_cache) / math.sqrt(L_cache) + assert decay_preset in ["strong", "normal", "weak"] + if decay_preset == "strong": + log_r_min = 0 + log_r_max = 2 + elif decay_preset == "normal": + log_r_min = -1 + log_r_max = 2 + elif decay_preset == "weak": + log_r_min = -2 + log_r_max = 2 + + if small_init: + h = h * 1e-5 + if unit_passthrough: + h[:, :1] = 1.0 + self.num_decay_repeats = num_decay_repeats + self.h = nn.Parameter(h) + t = torch.linspace(0, 1, L_cache)[None] + self.log_r_min = log_r_min + self.log_r_max = log_r_max + self.model_parallel_rank = get_tensor_model_parallel_rank() + self.model_parallel_size = get_tensor_model_parallel_world_size() + global_d_model = d_model * self.model_parallel_size // self.num_decay_repeats + decay_domain = torch.logspace(log_r_min, log_r_max, global_d_model)[:, None].repeat(self.num_decay_repeats, 1) + decay_domain = decay_domain[self.model_parallel_rank * d_model : (self.model_parallel_rank + 1) * d_model, :] + decay = torch.exp((-decay_domain * t).cuda()) + self.register_buffer("decay", decay) + setattr(self.h, 'tensor_model_parallel', True) + setattr(self.decay, 'tensor_model_parallel', True) + + def forward(self, L, *args, **kwargs): + """ + Forward pass for the explicit single decay filter. This returns the filter for the requested sequence length. + """ + return self.filter(L, *args, **kwargs) + + @torch.compile(mode="max-autotune") + def filter(self, L, *args, **kwargs): + """ + Compute the filter as a function of h and decay for the requested sequence length. + """ + h = self.h[:, :L] + h = h * self.decay[:, :L] + return h + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """Sharding along axis 0, bias not sharded""" + state_dict = self.state_dict(prefix='', keep_vars=True) + return make_sharded_tensors_for_checkpoint( + state_dict, + prefix, + { + 'h': 0, + 'decay': 0, + }, + sharded_offsets, + ) + + +def small_init_init_method(dim): + """Fills the input Tensor with values according to the method described in Transformers without Tears: Improving + the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010), using a normal distribution. + """ + std = math.sqrt(2 / (5 * dim)) + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std) + + return init_ + + +def wang_init_method(n_layers, dim): + """ + Initialize the weights of the model using the Wang initialization method. + """ + std = 2 / n_layers / math.sqrt(dim) + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std) + + return init_ + + +def get_init_method(init_method_name, num_layers, hidden_size): + """ + Gets parameter initialization methods for the linear layers of the model. + """ + if init_method_name == "wang_init": + return wang_init_method(num_layers, hidden_size) + elif init_method_name == "small_init": + return small_init_init_method(hidden_size) + else: + raise NotImplementedError(f"Unknown init method {init_method_name}") + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +def initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1): + """Initialize affine weight for model parallel on GPU.""" + + weight.model_parallel = True + weight.partition_dim = partition_dim + weight.partition_stride = stride + + with get_cuda_rng_tracker().fork(): + init_method(weight.data) # modify the data in place + + +def get_groups_and_group_sizes(hidden_size, num_groups, world_size, expand_factor): + """ + Get the groups and group sizes for the model. + """ + width_per_tp_group = divide(hidden_size, world_size) + num_groups_per_tp = int(divide(num_groups, world_size) * expand_factor) + group_dim = width_per_tp_group // num_groups_per_tp + return width_per_tp_group, num_groups_per_tp, group_dim + + +class ParallelHyenaOperator(nn.Module): + """ + A class for the ParallelHyenaOperator. + """ + + def __init__( + self, + hidden_size, + transformer_config: TransformerConfig, + hyena_config: HyenaConfig, + init_method, + operator_type, + max_sequence_length, + downsample_factor=1, + zigzag=True, + ): + super().__init__() + + self.hidden_size = hidden_size + self.transformer_config = transformer_config + self.hyena_config = hyena_config + self.operator_type = operator_type + self.fp16 = transformer_config.fp16 + self.bf16 = transformer_config.bf16 + self.cgcg_dtype = getattr(torch, hyena_config.cgcg_dtype) # torch.float32 + + if self.operator_type == "hyena_medium_conv" and hyena_config.hyena_medium_filter_cls is not None: + self.hyena_filter_cls = hyena_config.hyena_medium_filter_cls + else: + self.hyena_filter_cls = hyena_config.hyena_filter_cls + + self.downsample_factor = downsample_factor + self.bidirectional = hyena_config.bidirectional + self.use_hyena_filter = hyena_config.use_hyena_filter + self.use_fast_heads = hyena_config.use_fast_heads + self.use_slow_heads = hyena_config.use_slow_heads + + self.zigzag = zigzag + + self.model_parallel_size = get_tensor_model_parallel_world_size() + self.model_parallel_rank = get_tensor_model_parallel_rank() + + self.L = max_sequence_length + + if self.operator_type == "hyena_medium_conv": + self.num_groups = ( + hyena_config.num_groups_hyena_medium + if hyena_config.num_groups_hyena_medium is not None + else hyena_config.num_groups_hyena + ) + elif self.operator_type == "hyena_short_conv": + self.num_groups = ( + hyena_config.num_groups_hyena_short + if hyena_config.num_groups_hyena_short is not None + else hyena_config.num_groups_hyena + ) + else: + # default to the global num_groups_hyena + self.num_groups = hyena_config.num_groups_hyena + + if self.num_groups is None: + self.num_groups = transformer_config.hidden_size + + world_size: int = get_tensor_model_parallel_world_size() + + self.width_per_tp_group, self.num_groups, self.group_dim = get_groups_and_group_sizes( + self.hidden_size, self.num_groups, world_size, hyena_config.hyena_width_expansion + ) + + self.short_conv_L = hyena_config.short_conv_L + self.use_medium_hyena = True if self.operator_type == "hyena_medium_conv" else False + self.hyena_medium_conv_len = hyena_config.hyena_medium_conv_len + + # TODO: Check which if of these use_* is needed, if any + self.use_long_conv1d = hyena_config.use_long_conv1d + self.use_flashfft = hyena_config.use_flashfft + self.use_cgcg = hyena_config.use_cgcg + self.is_medium_cgcg = self.use_cgcg and self.use_medium_hyena + + if self.use_flashfft: + self.fftconv_fn = FlashFFTConv(self.L, dtype=torch.float16 if self.fp16 else torch.bfloat16) + + if self.use_medium_hyena and self.use_cgcg: + if os.environ.get("SAVANNA_DEBUG", "0") == "1": + import pdb + + pdb.set_trace() + self.cgcg_fn = two_pass_chunked_gate_conv_gate + + self.cgcg_fwd_config = FwdKernelConfigRefactor( + CHUNK_SIZE=self.hyena_config.cgcg_medium_fwd_kernel_config_chunk_size, + BLOCK_D=min(self.group_dim, self.hyena_config.cgcg_medium_fwd_kernel_config_block_d), + CHUNK_TILES_PER_PROGRAM=self.hyena_config.cgcg_medium_fwd_kernel_config_chunk_tiles_per_program, + THREADBLOCK_SWIZZLE=self.hyena_config.cgcg_medium_fwd_kernel_config_threadblock_swizzle, + num_warps=self.hyena_config.cgcg_medium_fwd_kernel_config_num_warps, + num_stages=self.hyena_config.cgcg_medium_fwd_kernel_config_num_stages, + ) + + self.cgcg_bwd_config = BwdKernelConfigRefactor( + pre_conv_BLOCK_X=self.hyena_config.cgcg_bwd_kernel_config_pre_conv_block_x, + pre_conv_BLOCK_Y=self.hyena_config.cgcg_bwd_kernel_config_pre_conv_block_y, + pre_conv_num_warps=self.hyena_config.cgcg_bwd_kernel_config_pre_conv_num_warps, + post_conv_BLOCK_X=self.hyena_config.cgcg_bwd_kernel_config_post_conv_block_x, + post_conv_BLOCK_Y=self.hyena_config.cgcg_bwd_kernel_config_post_conv_block_y, + post_conv_num_warps=self.hyena_config.cgcg_bwd_kernel_config_post_conv_num_warps, + ) + + # TODO: Check which of these filters can be removed + # At the moment only "explicit_single_decay" and "implicit_modal" are used + if self.hyena_filter_cls == "explicit_single_decay": + self.filter = ExplicitSingleDecayFilter( + d_model=self.num_groups, + L_cache=self.hyena_medium_conv_len, + decay_preset=hyena_config.explicit_filter_decay_preset, + ) + self.kernel_size = self.hyena_medium_conv_len + elif self.hyena_filter_cls == "implicit_modal": + self.filter = ImplicitModalFilter( + d_model=self.num_groups, + L_cache=self.L, + order=hyena_config.hyena_filter_order, + gamma_min=hyena_config.modal_gamma_min, + gamma_max=hyena_config.modal_gamma_max, + ) + self.kernel_size = self.L + else: + raise ValueError(f"Unknown hyena filter class: {self.hyena_filter_cls}") + + with get_cuda_rng_tracker().fork(): + if self.use_slow_heads: + self.conv_bias = nn.Parameter( + torch.empty( + self.num_groups, + device=torch.cuda.current_device(), + dtype=torch.float32, + ) + ) + else: + self.conv_bias = nn.Parameter( + torch.empty( + self.width_per_tp_group, + device=torch.cuda.current_device(), + dtype=torch.float32, + ) + ) + # Add attribute to prevent automatic casting during model conversion + setattr(self.conv_bias, 'tensor_model_parallel', True) + bounds = math.sqrt(1 / self.kernel_size) + conv_init_method = partial(torch.nn.init.uniform_, a=-bounds, b=bounds) + self.conv_bias.data = conv_init_method(self.conv_bias.data) + self.conv_bias.model_parallel = True + self.conv_bias.partition_dim = 0 + self.conv_bias.stride = 1 + + def multihead_forward(self, q, k, v, h): + """ + Multihead forward pass for the ParallelHyenaOperator. + """ + batch_size = q.shape[0] + group_dim = self.group_dim + num_groups = self.num_groups + + L = v.shape[-1] + fft_size = 2 * L + kv = rearrange(k, "b (h d1) l -> b d1 1 h l", d1=group_dim) * rearrange( + v, "b (h d2) l -> b 1 d2 h l", d2=group_dim + ) + if self.use_flashfft: + # treat mhfftconv as a large batched fftconv + kv_reshape = kv.reshape(-1, num_groups, L) + y = self.fftconv_fn(kv_reshape, h[0]) + y = y.view(batch_size, group_dim, group_dim, num_groups, L) + else: + kv_f = torch.fft.rfft(kv.to(torch.float32), n=fft_size) / fft_size + h_f = torch.fft.rfft(h.to(torch.float32), n=fft_size) # h L+1 + + y = torch.fft.irfft(kv_f * h_f, n=fft_size, norm="forward")[..., :L] + y = y.to(dtype=q.dtype) + + out = y + kv * self.conv_bias.unsqueeze(-1) + q = rearrange(q, "b (h d1) l -> b d1 1 h l", d1=group_dim) + z = _mul_sum(out, q) + z = rearrange(z, "b d2 h l -> b (h d2) l") + + z = z.to(v.dtype) + return z + + def forward(self, x1, x2, v, _hyena_use_cp=True): + """ + Note: + Input shapes: bs, seq_length, (num_groups, group_size) + Output shapes: bs, seq_length, num_groups, group_size + """ + + B, L, G, DG = x1.shape + + # CP control + if _hyena_use_cp: + cp_group = get_context_parallel_group() + else: + cp_group = None + + # downsampled = self.downsample_factor > 1 + + # Only permute if not medium cgcg + if not self.is_medium_cgcg: + x1 = rearrange(x1, "b l g dg -> b (g dg) l", g=self.num_groups, dg=self.group_dim) + x2 = rearrange(x2, "b l g dg -> b (g dg) l", g=self.num_groups, dg=self.group_dim) + v = rearrange(v, "b l g dg -> b (g dg) l", g=self.num_groups, dg=self.group_dim) + + x1, x2, v = x1[..., :L], x2[..., :L], v[..., :L] + + # FIXME: add support post cp refactor + # if self.downsample_factor > 1: + # x1 = x1[..., :: self.downsample_factor] + # x2 = x2[..., :: self.downsample_factor] + # v = v[..., :: self.downsample_factor] + # L = L // self.downsample_factor + + # The kernel length must be adjusted in CP settings + _L_kernel = L if cp_group is None else L * len(torch.distributed.get_process_group_ranks(cp_group)) + if self.use_medium_hyena: + h = self.filter(min(self.hyena_medium_conv_len, _L_kernel)) + else: + h = self.filter(_L_kernel) + + if type(h) == tuple: + h = h[0] + + conv_bias = self.conv_bias + local_size = None + + if cp_group is not None and len(torch.distributed.get_process_group_ranks(cp_group)) > 1: + + x1, x2, v = [ + AllToAllSingleFunction.apply(tensor, cp_group, "split_to_full", True) for tensor in [x1, x2, v] + ] + # the tensors are now split across channels, but have full length. + # [ B, H // num_ranks, L] + + rank = torch.distributed.get_rank(cp_group) + local_size = self.num_groups // get_context_parallel_world_size() + + if isinstance(self.filter, (ImplicitModalFilter)): + h = h[:, rank * local_size : (rank + 1) * local_size] + elif isinstance(self.filter, ExplicitSingleDecayFilter): + h = h[rank * local_size : (rank + 1) * local_size] + else: + raise ValueError(f"Kernels of type {self.filter.__class__} have not been verified with CP.") + + local_bias_size = self.width_per_tp_group // get_context_parallel_world_size() + conv_bias = self.conv_bias[rank * local_bias_size : (rank + 1) * local_bias_size] + + if self.use_slow_heads: + return self.multihead_forward(x1, x2, v, h) + + elif self.use_long_conv1d: + h = h.repeat_interleave(self.group_dim, dim=-2) + z = x2 * v + + z = ( + F.conv1d(z, h[:, None].flip(-1), padding=L - 1, groups=v.shape[1])[..., :L] + + conv_bias.unsqueeze(-1) * z + ) + z = z.to(v.dtype) + z = x1 * z + + elif self.is_medium_cgcg: + # TODO: if the conditions are met, we should not rearrange to l last in the first place + # @jeromeku, done as of 2024-09-28 refactor (see above) + # x1 = rearrange(x1, "b (d g) l -> b l g d", g=self.num_groups) + # x2 = rearrange(x2, "b (d g) l -> b l g d", g=self.num_groups) + # v = rearrange(v, "b (d g) l -> b l g d", g=self.num_groups) + dtype = x1.dtype + if os.environ.get("SAVANNA_DEBUG", "0") == "1": + import pdb + + pdb.set_trace() + # Mapping from x1, x2, and v -> kernel args + # x1 is post-gate (C) + # x2 is pre-gate (B) + # v is x + + if self.cgcg_dtype != dtype: + x = v.to(self.cgcg_dtype) + B = x2.to(self.cgcg_dtype) + C = x1.to(self.cgcg_dtype) + h = h[:, None].to(self.cgcg_dtype) + else: + x = v + B = x2 + C = x1 + h = h[:, None] + + bs, seqlen, g, dg = x.shape + + # @jeromeku: Refactor as of 2024-09-28 + # No more backward kernel config + # default schedule is "default" as other schedules are not supported + # fwd_kernel config is of class FwdKernelConfigRefactor + # Explicitly pass in shape for internal checking + z = self.cgcg_fn( + x=x, # x1.to(self.cgcg_dtype), + B=B, # x2.to(self.cgcg_dtype), + C=C, # v.to(self.cgcg_dtype), + h=h, # h[:, None].to(self.cgcg_dtype), # g, 1, filter_l + bs=bs, + seqlen=seqlen, + g=g, + dg=dg, + fwd_autotune=False, # @jeromeku explicitly set to False for now + bwd_autotune=self.hyena_config.cgcg_bwd_autotune, + fused_bwd=self.hyena_config.cgcg_fused_bwd, + fwd_kernel_cfg=self.cgcg_fwd_config, + bwd_kernel_cfg=None if self.hyena_config.cgcg_bwd_autotune else self.cgcg_bwd_config, + ) + z = z.reshape(bs, seqlen, g * dg) + if self.cgcg_dtype != dtype: + z = z.to(dtype) + return z + else: + h = h.repeat_interleave(self.group_dim, dim=-2) + + if self.hyena_config.use_flashfft: + # squeeze h dim (kernel), to get rid of leading 1 dim + h = h.squeeze(0) + z = self.fftconv_fn(v, h, x2, x1) + else: + z = x2 * v + # with torch.autocast("cuda"): + z = fftconv_func( + u=z.to(torch.float32), + k=h.to(torch.float32), + D=conv_bias.to(torch.float32), + dropout_mask=None, + gelu=False, + bidirectional=self.bidirectional, + ) + z = z.to(v.dtype) + z = x1 * z + + # if downsampled: + # z = z.repeat_interleave(self.downsample_factor, dim=-1) + + # print( + # f"[rank={dist.get_rank()}] shape of z = {z.shape} | " + # f"num_groups = {self.num_groups}, local_size = {local_size}" + # ) # DEBUG + + if cp_group is not None and len(torch.distributed.get_process_group_ranks(cp_group)) > 1: + z = AllToAllSingleFunction.apply(z, cp_group, "full_to_split", True) + # [ B, H, L // num_ranks] + return rearrange(z, "b d l -> b l d") + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """ + Sharded state dictionary for the ParallelHyenaOperator. + """ + sharded_state_dict = {} + # Parameters + self._save_to_state_dict(sharded_state_dict, '', keep_vars=True) + sharded_state_dict = make_sharded_tensors_for_checkpoint( + sharded_state_dict, + prefix, + tensor_parallel_layers_axis_map={ + 'conv_bias': 0, + }, # parameters sharded across TP + sharded_offsets=sharded_offsets, + ) + # Submodules + for name, module in self.named_children(): + module_sharded_sd = sharded_state_dict_default(module, f'{prefix}{name}.', sharded_offsets, metadata) + + sharded_state_dict.update(module_sharded_sd) + return sharded_state_dict + + +class ParallelShortHyenaOperator(nn.Module): + """ + A class for the ParallelShortHyenaOperator. + """ + + def __init__( + self, + hidden_size, + transformer_config: TransformerConfig, + hyena_config: HyenaConfig, + init_method, + short_conv_class, + use_fast_causal_conv=False, + is_mlp=False, # TODO: Check if needed, only used when using Hyena for the MLP block + local_init=False, + use_conv_bias=True, + ): + super().__init__() + self.transformer_config = transformer_config + self.hyena_config = hyena_config + self.is_mlp = is_mlp + self.hidden_size = hidden_size + self.cgcg_dtype = getattr(torch, hyena_config.cgcg_dtype) + self.use_cgcg_mlp = hyena_config.use_cgcg_mlp and self.is_mlp + self.use_cgcg_short = hyena_config.use_cgcg_short and not self.is_mlp + self.use_custom_hyena_mlp_kernel = hyena_config.use_custom_hyena_mlp_kernel + self.use_custom_hyena_short_kernel = hyena_config.use_custom_hyena_short_kernel + self.use_fast_causal_conv = use_fast_causal_conv + + # world_size = mpu.get_model_parallel_world_size() if not local_init else 1 + # world_size: int = torch.distributed.get_world_size() if not local_init else 1 + + world_size: int = get_tensor_model_parallel_world_size() if not local_init else 1 + # assert, if using fast_conv_mixer, then the hyena_short_conv_len must be 3 + if use_fast_causal_conv: + assert hyena_config.hyena_short_conv_len <= 4, "fast_conv_mixer requires hyena_short_conv_len <= 4" + + # for mlp type + if is_mlp: + # option to have a different kernel size for the short conv inside the mlp + if hyena_config.hyena_mlp_len is not None: + kernel_size = hyena_config.hyena_mlp_len + else: + kernel_size = hyena_config.hyena_short_conv_len + + # check for fast causal conv + if hyena_config.fast_hyena_mlp_conv: + assert hyena_config.hyena_mlp_len <= 4, "fast_hyena_mlp_conv requires hyena_mlp_len <= 4" + use_fast_causal_conv = True + + self.pregate = hyena_config.hyena_mlp_pregate + self.postgate = hyena_config.hyena_mlp_postgate + + self.num_groups = ( + hyena_config.num_groups_hyena_mlp + if hyena_config.num_groups_hyena_mlp is not None + else hyena_config.num_groups_hyena + ) + + if self.num_groups is None: + self.num_groups = transformer_config.hidden_size + + self.num_groups = int(self.num_groups * hyena_config.hyena_mlp_expansion_factor) + # handle mixer case + else: + + kernel_size = hyena_config.hyena_short_conv_len + self.pregate = hyena_config.hyena_short_conv_pregate + self.postgate = hyena_config.hyena_short_conv_postgate + self.num_groups = ( + hyena_config.num_groups_hyena_short + if hyena_config.num_groups_hyena_short is not None + else hyena_config.num_groups_hyena + ) + if self.num_groups is None: + self.num_groups = transformer_config.hidden_size + + self.num_groups = int(self.num_groups * hyena_config.hyena_width_expansion) + + self.width_per_tp_group, self.num_groups, self.group_dim = get_groups_and_group_sizes( + self.hidden_size, self.num_groups, world_size, hyena_config.hyena_width_expansion + ) + + self.short_conv = short_conv_class( + self.width_per_tp_group, + transformer_config, + hyena_config=hyena_config, + kernel_size=kernel_size, + init_method=init_method, + bias=hyena_config.conv_proj_bias, + use_fast_causal_conv=use_fast_causal_conv, + num_groups=self.num_groups, + repeat_h_dg=False, + local_init=local_init, + ) + self.kernel_fn, self.fwd_kernel_cfg, self.bwd_kernel_cfg = self.prepare_kernel_configs() + self.use_conv_bias = use_conv_bias + if self.use_conv_bias: + with get_cuda_rng_tracker().fork(): + self.conv_bias = nn.Parameter( + torch.empty( + self.num_groups, + device=torch.cuda.current_device(), + dtype=torch.float32, + ) + ) + setattr(self.conv_bias, 'tensor_model_parallel', True) + bounds = math.sqrt(1 / kernel_size) + conv_init_method = partial(torch.nn.init.uniform_, a=-bounds, b=bounds) + self.conv_bias.data = conv_init_method(self.conv_bias.data) + self.conv_bias.model_parallel = True + self.conv_bias.partition_dim = 0 + self.conv_bias.stride = 1 + + def prepare_kernel_configs(self): + """ + Prepare the kernel configurations for the ParallelShortHyenaOperator. + """ + if self.is_mlp and self.use_cgcg_mlp: + + kernel_fn = two_pass_chunked_gate_conv_gate + fwd_kernel_cfg = FwdKernelConfigRefactor( + CHUNK_SIZE=self.hyena_config.cgcg_short_fwd_kernel_config_chunk_size, + BLOCK_D=min(self.group_dim, self.hyena_config.cgcg_short_fwd_kernel_config_block_d), + CHUNK_TILES_PER_PROGRAM=self.hyena_config.cgcg_short_fwd_kernel_config_chunk_tiles_per_program, + THREADBLOCK_SWIZZLE=self.hyena_config.cgcg_short_fwd_kernel_config_threadblock_swizzle, + num_warps=self.hyena_config.cgcg_short_fwd_kernel_config_num_warps, + num_stages=self.hyena_config.cgcg_short_fwd_kernel_config_num_stages, + ) + bwd_kernel_cfg = BwdKernelConfigRefactor( + pre_conv_BLOCK_X=self.hyena_config.cgcg_bwd_kernel_config_pre_conv_block_x, + pre_conv_BLOCK_Y=self.hyena_config.cgcg_bwd_kernel_config_pre_conv_block_y, + pre_conv_num_warps=self.hyena_config.cgcg_bwd_kernel_config_pre_conv_num_warps, + post_conv_BLOCK_X=self.hyena_config.cgcg_bwd_kernel_config_post_conv_block_x, + post_conv_BLOCK_Y=self.hyena_config.cgcg_bwd_kernel_config_post_conv_block_y, + post_conv_num_warps=self.hyena_config.cgcg_bwd_kernel_config_post_conv_num_warps, + ) + return kernel_fn, fwd_kernel_cfg, bwd_kernel_cfg + elif not self.is_mlp and self.use_cgcg_short: + + kernel_fn = two_pass_chunked_gate_conv_gate + fwd_kernel_cfg = FwdKernelConfigRefactor( + CHUNK_SIZE=self.hyena_config.cgcg_short_fwd_kernel_config_chunk_size, + BLOCK_D=min(self.group_dim, self.hyena_config.cgcg_short_fwd_kernel_config_block_d), + CHUNK_TILES_PER_PROGRAM=self.hyena_config.cgcg_short_fwd_kernel_config_chunk_tiles_per_program, + THREADBLOCK_SWIZZLE=self.hyena_config.cgcg_short_fwd_kernel_config_threadblock_swizzle, + num_warps=self.hyena_config.cgcg_short_fwd_kernel_config_num_warps, + num_stages=self.hyena_config.cgcg_short_fwd_kernel_config_num_stages, + ) + bwd_kernel_cfg = BwdKernelConfigRefactor( + pre_conv_BLOCK_X=self.hyena_config.cgcg_bwd_kernel_config_pre_conv_block_x, + pre_conv_BLOCK_Y=self.hyena_config.cgcg_bwd_kernel_config_pre_conv_block_y, + pre_conv_num_warps=self.hyena_config.cgcg_bwd_kernel_config_pre_conv_num_warps, + post_conv_BLOCK_X=self.hyena_config.cgcg_bwd_kernel_config_post_conv_block_x, + post_conv_BLOCK_Y=self.hyena_config.cgcg_bwd_kernel_config_post_conv_block_y, + post_conv_num_warps=self.hyena_config.cgcg_bwd_kernel_config_post_conv_num_warps, + ) + return kernel_fn, fwd_kernel_cfg, bwd_kernel_cfg + + elif self.is_mlp and self.use_custom_hyena_mlp_kernel: + fn = run_short_hyena + fwd_kernel_cfg = ShortHyenaOperatorKernelConfig( + PreConvKernelConfig( + BLOCK_M=256, + BLOCK_N=256, + NUM_PIPELINE_STAGES=1, + num_warps=4, + num_ctas=1, + ), + PostConvKernelConfig( + BLOCK_M=128, + BLOCK_N=128, + NUM_PIPELINE_STAGES=1, + num_warps=4, + num_ctas=1, + ), + ) + bwd_kernel_cfg = ShortHyenaOperatorKernelConfig( + PreConvKernelConfig( + BLOCK_M=256, + BLOCK_N=256, + NUM_PIPELINE_STAGES=1, + num_warps=4, + num_ctas=1, + ), + PostConvKernelConfig( + BLOCK_M=128, + BLOCK_N=128, + NUM_PIPELINE_STAGES=1, + num_warps=4, + num_ctas=1, + ), + ) + return fn, fwd_kernel_cfg, bwd_kernel_cfg + + elif not self.is_mlp and self.use_custom_hyena_short_kernel: + fn = run_short_hyena + fwd_kernel_cfg = ShortHyenaOperatorKernelConfig( + PreConvKernelConfig( + BLOCK_M=256, + BLOCK_N=256, + NUM_PIPELINE_STAGES=1, + num_warps=4, + num_ctas=1, + ), + PostConvKernelConfig( + BLOCK_M=128, + BLOCK_N=128, + NUM_PIPELINE_STAGES=1, + num_warps=4, + num_ctas=1, + ), + ) + bwd_kernel_cfg = ShortHyenaOperatorKernelConfig( + PreConvKernelConfig( + BLOCK_M=256, + BLOCK_N=256, + NUM_PIPELINE_STAGES=1, + num_warps=4, + num_ctas=1, + ), + PostConvKernelConfig( + BLOCK_M=128, + BLOCK_N=128, + NUM_PIPELINE_STAGES=1, + num_warps=4, + num_ctas=1, + ), + ) + return fn, fwd_kernel_cfg, bwd_kernel_cfg + else: + return None, None, None + + def forward(self, x1, x2, v, _hyena_use_cp=True): + """ + Note: + Input shapes: bs, seq_length, (num_groups, group_size) + Output shapes: bs, seq_length, num_groups, group_size + """ + B, L, G, DG = x1.shape + + if self.use_custom_hyena_mlp_kernel or self.use_custom_hyena_short_kernel: + z = self.kernel_fn( + x1, + x2, + v, + self.short_conv.short_conv_weight, + repeat_interleave=True, + use_causal_conv=self.use_fast_causal_conv, + autotune=False, + fwd_kernel_cfg=self.fwd_kernel_cfg, + bwd_kernel_cfg=self.bwd_kernel_cfg, + ) + return rearrange(z, "b l g dg -> b l (g dg)", g=G) + + elif self.use_cgcg_mlp or self.use_cgcg_short: + dtype = x1.dtype + if os.environ.get("SAVANNA_DEBUG", "0") == "1": + import pdb + + pdb.set_trace() + # @jeromeku: Refactor as of 2024-09-28 + # No more backward kernel config + # default schedule is "default" as other schedules are not supported + # fwd_kernel config is of class FwdKernelConfigRefactor + # Explicitly pass in shape for internal checking + + # Mapping from x1, x2, and v -> kernel args + # x1 is post-gate (C) + # x2 is pre-gate (B) + # v is x + + if self.cgcg_dtype != dtype: + x = v.to(self.cgcg_dtype) + B = x2.to(self.cgcg_dtype) + C = x1.to(self.cgcg_dtype) + h = self.short_conv.short_conv_weight.to(self.cgcg_dtype) # g, 1, filter_l + else: + x = v + B = x2 + C = x1 + h = self.short_conv.short_conv_weight # g, 1, filter_l + + bs, seqlen, g, dg = x.shape + + z = self.kernel_fn( + x, # x1.to(self.cgcg_dtype), + B, # x2.to(self.cgcg_dtype), + C, # v.to(self.cgcg_dtype), + h, # g, 1, filter_l + bs=bs, + seqlen=seqlen, + g=g, + dg=dg, + # Explicitly set fwd autotune to False for now + fwd_autotune=False, + bwd_autotune=self.hyena_config.cgcg_bwd_autotune, + fused_bwd=self.hyena_config.cgcg_fused_bwd, + fwd_kernel_cfg=self.fwd_kernel_cfg, + bwd_kernel_cfg=None if self.hyena_config.cgcg_bwd_autotune else self.bwd_kernel_cfg, + ) + out = rearrange(z, "b l g d -> b l (g d)") + if self.cgcg_dtype != dtype: + out = out.to(dtype) + return out + + else: + x1 = rearrange(x1, "b l g dg -> b (g dg) l") + x2 = rearrange(x2, "b l g dg -> b (g dg) l") + v = rearrange(v, "b l g dg -> b (g dg) l") + + x1, x2, v = x1[..., :L], x2[..., :L], v[..., :L] + + z = x2 * v if self.pregate else v + if not self.use_conv_bias: + z = self.short_conv(z, _use_cp=_hyena_use_cp) + else: + # maybe handle num_groups + bias = self.conv_bias.repeat_interleave(self.group_dim, dim=0) + z = self.short_conv(z, _use_cp=_hyena_use_cp) + rearrange(bias, "h -> 1 h 1") * z # conv(z) + bias * z + + z = x1 * z if self.postgate else z + + return rearrange(z, "b d l -> b l d") + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """ + Sharded state dictionary for the ParallelShortHyenaOperator. + """ + sharded_state_dict = {} + # Parameters + self._save_to_state_dict(sharded_state_dict, '', keep_vars=True) + sharded_state_dict = make_sharded_tensors_for_checkpoint( + sharded_state_dict, + prefix, + tensor_parallel_layers_axis_map={ + 'conv_bias': 0, + }, # parameters sharded across TP + sharded_offsets=sharded_offsets, + ) + # Submodules + for name, module in self.named_children(): + module_sharded_sd = sharded_state_dict_default(module, f'{prefix}{name}.', sharded_offsets, metadata) + + sharded_state_dict.update(module_sharded_sd) + return sharded_state_dict + + +class ParallelCausalDepthwiseConv1d(nn.Module): + """ + A class for the ParallelCausalDepthwiseConv1d. + """ + + def __init__( + self, + d_model, + transformer_config: TransformerConfig, + hyena_config: HyenaConfig, + kernel_size, + init_method, + bias=False, # not currently supported + use_fast_causal_conv=False, + num_groups=None, # enables some weight sharing + repeat_h_dg=True, + local_init=False, + ): + super().__init__() + self.d_model = d_model + self.kernel_size = kernel_size + self.use_bias = bias + self.use_fast_causal_conv = use_fast_causal_conv + self.num_groups = num_groups + + if self.num_groups is None: + self.num_groups = self.d_model + + self.group_dim = self.d_model // self.num_groups + + if self.use_fast_causal_conv: + assert causal_conv1d_fn is not None, "custom causal conv not installed" + weight_shape = [self.num_groups, kernel_size] + # use torch + else: + if hyena_config.use_depthwise_short_conv_grouping: + weight_shape = [self.num_groups, 1, kernel_size] + self.conv_groups = self.d_model + + else: + if repeat_h_dg: + weight_shape = [self.num_groups, self.group_dim, kernel_size] + else: + weight_shape = [self.num_groups, 1, kernel_size] + + self.conv_groups = self.num_groups + + with get_cuda_rng_tracker().fork(): + self.short_conv_weight = nn.Parameter( + torch.empty( + weight_shape, + device=torch.cuda.current_device(), + dtype=transformer_config.params_dtype, + ) + ) + setattr(self.short_conv_weight, 'tensor_model_parallel', True) + + # Use the standard PyTorch Conv1d class init: + # https://pytorch.org/docs/master/generated/torch.nn.Conv1d.html + bounds = math.sqrt(1 / hyena_config.short_conv_L) + conv_init_method = partial(torch.nn.init.uniform_, a=-bounds, b=bounds) + if local_init: + self.short_conv_weight.data = conv_init_method(self.short_conv_weight.data) + else: + # Call this on the module because it also modifies module attributes in addition to the data. + initialize_affine_weight_gpu(self.short_conv_weight, conv_init_method, partition_dim=0) + + def forward(self, x, _use_cp=True): + """ + Forward pass for the ParallelCausalDepthwiseConv1d. + """ + assert x.ndim == 3, "Only 3D tensors supported." + + x_shape = x.shape + weight = self.short_conv_weight + pad_size = self.kernel_size - 1 + + if _use_cp and get_context_parallel_world_size() > 1: + + cp_group = get_context_parallel_group() + cp_rank = get_context_parallel_rank() + + # Transfer patches across ranks. + seq_dim = 2 # Last dimension. + chunk_a, chunk_b = zigzag_get_overlapping_patches(x, seq_dim=seq_dim, overlap_size=pad_size) + received_a, received_b = ExchangeOverlappingRegionsCausal.apply(chunk_a, chunk_b, cp_group, cp_rank) + + # Pad and rearrange + x = rearrange(x, "b h (nc s) -> (nc b) h s", nc=2) + padding = torch.concat([received_a, received_b], dim=0) + + x = torch.concat([padding, x], dim=-1) + + else: + x = F.pad(x, (pad_size, 0)) + + # maybe handle num_groups + weight = weight.repeat_interleave(self.group_dim, dim=0) + + if self.use_fast_causal_conv: + y = causal_conv1d_fn(x, weight, bias=None, activation=None)[..., pad_size:] + else: + + y = F.conv1d( + x, + weight, + bias=None, + stride=1, + padding=0, + groups=self.conv_groups, + ) + + if _use_cp and get_context_parallel_world_size() > 1: + y = rearrange(y, "(nc b) h s -> b h (nc s)", nc=2) + + assert y.shape == x_shape, f"y.shape = {y.shape} | x.shape = {x_shape}" + + return y + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """Sharding along axis 0, bias not sharded""" + state_dict = self.state_dict(prefix='', keep_vars=True) + return make_sharded_tensors_for_checkpoint( + state_dict, + prefix, + { + 'short_conv_weight': 0, + }, + sharded_offsets, + ) + + +def make_upper_case(tokens): + """ + Replace lowercase ASCII characters with uppercase. + """ + # tokens, labels, loss_mask, attention_mask, position_ids = batch + + lowercase_mask = (tokens >= 97) & (tokens <= 122) + uppercase_tensor = tokens.clone() + uppercase_tensor[lowercase_mask] -= 32 + + return uppercase_tensor, lowercase_mask + + +def reweighted_cross_entropy(loss, labels, lowercase_weight=1.0, normalize_per_batch=True): + """ + Modified for lower case loss reweighting, using the cross_entropy function as a base. + + If normalize_per_batch, loss_weights are normalized by the number of tokens in the batch so + the magnitude of the loss is not affected by the number of upper/lower case letters + otherwise, loss_weights are normalized by the number of tokens: combined_loss/len + + performs mean reduction and applies loss_mask + """ + + labels, loss_mask, lowercase_mask = labels[0], labels[1], labels[2] + + upper_loss_mask = loss_mask.bool() & (~lowercase_mask.bool()) + lower_loss_mask = loss_mask.bool() & lowercase_mask.bool() + + loss_weights = torch.zeros_like(loss_mask) + loss_weights[upper_loss_mask] = 1.0 + loss_weights[lower_loss_mask] = lowercase_weight + + if normalize_per_batch: + # Get per-microbatch normalization factor + weight_sum = loss_weights.sum() + mask_sum = loss_mask.sum() + weight_normalizer = torch.maximum(weight_sum, torch.ones_like(weight_sum)) + loss_weights = (mask_sum * loss_weights) / weight_normalizer + + # Apply loss weights and loss mask to the loss + loss = loss * loss_weights * loss_mask + + return loss diff --git a/nemo/lightning/io/registry.py b/nemo/lightning/io/registry.py index fc2257b46bde..24af449b2e13 100644 --- a/nemo/lightning/io/registry.py +++ b/nemo/lightning/io/registry.py @@ -48,11 +48,12 @@ try: - from nemo.collections.common.tokenizers import SentencePieceTokenizer + from nemo.collections.common.tokenizers import ByteLevelTokenizer, SentencePieceTokenizer from nemo.collections.common.tokenizers.tiktoken_tokenizer import TiktokenTokenizer track_io(SentencePieceTokenizer, artifacts=[FileArtifact("model_path")]) track_io(TiktokenTokenizer, artifacts=[FileArtifact("vocab_file")]) + track_io(ByteLevelTokenizer) except ImportError: - # SentencePieceTokenizer is not available, no need to track it + # Tokenizers are not available, no need to track it. pass diff --git a/nemo/lightning/pytorch/callbacks/flops_callback.py b/nemo/lightning/pytorch/callbacks/flops_callback.py index 717157a8436d..220ad087cf29 100644 --- a/nemo/lightning/pytorch/callbacks/flops_callback.py +++ b/nemo/lightning/pytorch/callbacks/flops_callback.py @@ -22,6 +22,8 @@ from nemo.collections.llm.gpt.model.base import GPTConfig from nemo.lightning.pytorch.callbacks import PEFT from nemo.utils import flops_formulas, logging +from nemo.utils.hyena_flops_formulas import hyena + __all__ = ["FLOPsMeasurementCallback", "MM_FLOPsMeasurementCallback"] @@ -32,6 +34,7 @@ "nemotron": flops_formulas.nemotron, "mixtral": flops_formulas.mixtral, "bert": flops_formulas.bert, + "hyena": hyena, } @@ -43,7 +46,7 @@ class FLOPsMeasurementCallback(Callback): model_config (GPTConfig): Model parameters. data_config (pl.LightningDataModule): Data module being used in the experiment. model_name (str): Name of the model being run. The following models are supported: - gpt3, llama2, llama3, nemotron, mixtral, bert. + gpt3, llama2, llama3, nemotron, mixtral, bert, hyena. """ @@ -69,6 +72,8 @@ def __init__( ffn_hs = self.model_cfg.ffn_hidden_size attention_heads = self.model_cfg.num_attention_heads moe_router_topk = self.model_cfg.moe_router_topk + model_pattern = getattr(self.model_cfg, "hybrid_override_pattern", None) + vocab_size = self.data_cfg.tokenizer.vocab_size if hasattr(self.data_cfg, "tokenizer") else None # this handles both- 1. key is present, value is None; 2. key is absent query_groups = self.model_cfg.num_query_groups @@ -84,6 +89,8 @@ def __init__( attention_heads=attention_heads, moe_router_topk=moe_router_topk, query_groups=query_groups, + vocab_size=vocab_size, + model_pattern=model_pattern, ) self.model = self.model.lower() if self.model is not None else self.model @@ -169,7 +176,7 @@ class MM_FLOPsMeasurementCallback(FLOPsMeasurementCallback): """ Calculate and log FLOPs per second after every ``trainer.log_every_n_steps`` steps for multi-modal models. The following models are supported: - hf_clip_vit_l, neva_projection, gpt3, llama2, llama3, nemotron, mixtral, bert. + hf_clip_vit_l, neva_projection, gpt3, llama2, llama3, nemotron, mixtral, bert, hyena Args: model_name_config_dict (dict): diff --git a/nemo/utils/flops_formulas.py b/nemo/utils/flops_formulas.py index c356f9b9476a..31f235f3c573 100644 --- a/nemo/utils/flops_formulas.py +++ b/nemo/utils/flops_formulas.py @@ -38,6 +38,8 @@ class FLOPSConfig: class_token_len: Optional[int] = None projector_type: Optional[str] = None inp_s: Optional[int] = None + model_pattern: Optional[str] = None + vocab_size: Optional[int] = None model_channels: Optional[int] = None vec_in_dim: Optional[int] = None diff --git a/nemo/utils/hyena_flops_formulas.py b/nemo/utils/hyena_flops_formulas.py new file mode 100644 index 000000000000..2b713b465b80 --- /dev/null +++ b/nemo/utils/hyena_flops_formulas.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +# TODO(@cye): Merge MCore HyenaConfig with NeMo HyenaConfig to have all model params in 1 config. +from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig +from nemo.utils.flops_formulas import FLOPSConfig + + +def hyena(config: FLOPSConfig): + """Model FLOPs for Hyena family. FPL = 'flops per layer'.""" + + # TODO(@cye): For now, pull the Hyena defaults directly from a constant dataclass. Merge this config with the NeMo + # model config. + hyena_config = HyenaConfig() + # Hyena Parameters + hyena_short_conv_L = hyena_config.short_conv_L + hyena_short_conv_len = hyena_config.hyena_short_conv_len + hyena_medium_conv_len = hyena_config.hyena_medium_conv_len + + def _hyena_layer_count(model_pattern: Optional[str]): + """Count how many small, medium, and large Hyena layers there are in the model. Also, count the + number of Attention layers. + """ + S, D, H, A = 0, 0, 0, 0 + if model_pattern is None: + return 0, 0, 0, 0 + for layer in model_pattern: + if layer == "S": + S += 1 + elif layer == "D": + D += 1 + elif layer == "H": + H += 1 + elif layer == "*": + A += 1 + return S, D, H, A + + # Count S, D, H, and * layers in HyenaModel. + S, D, H, A = _hyena_layer_count(config.model_pattern) + # Logits FLOPs per batch for a flattened L x H -> V GEMM. + logits_fpl = 2 * config.gbs * config.enc_seq_len * config.hs * config.vocab_size + # Hyena Mixer Common FLOPs - Pre-Attention QKV Projections, Post-Attention Projections, and + # GLU FFN FLOPs per layer. + pre_attn_qkv_proj_fpl = 2 * 3 * config.gbs * config.enc_seq_len * config.hs**2 + post_attn_proj_fpl = 2 * config.gbs * config.enc_seq_len * config.hs**2 + # 3 Batched GEMMs: y = A(gelu(Bx) * Cx) where B,C: H -> F and A: F -> H. + glu_ffn_fpl = 2 * 3 * config.gbs * config.enc_seq_len * config.ffn_hs * config.hs + # Transformer (Self) Attention FLOPs - QK Attention Logits ((L, D) x (D, L)) & Attention-Weighted + # Values FLOPs ((L, L) x (L, D)) + attn_fpl = 2 * 2 * config.gbs * config.hs * config.enc_seq_len**2 + # Hyena Projection + hyena_proj_fpl = 2 * 3 * config.gbs * config.enc_seq_len * hyena_short_conv_L * config.hs + # Hyena Short Conv + hyena_short_conv_fpl = 2 * config.gbs * config.enc_seq_len * hyena_short_conv_len * config.hs + # Hyena Medium Conv + hyena_medium_conv_fpl = 2 * config.gbs * config.enc_seq_len * hyena_medium_conv_len * config.hs + # Hyena Long Conv (FFT) + hyena_long_conv_fft_fpl = config.gbs * 10 * config.enc_seq_len * math.log2(config.enc_seq_len) * config.hs + # Based off of https://gitlab-master.nvidia.com/clara-discovery/savanna/-/blob/main/savanna/mfu.py#L182 + # Assumption: 1x Backwards Pass FLOPS = 2x Forward Pass FLOPS + return 3 * ( + logits_fpl + + config.layers * (pre_attn_qkv_proj_fpl + post_attn_proj_fpl + glu_ffn_fpl) + + A * attn_fpl + + (S + D + H) * hyena_proj_fpl + + S * hyena_short_conv_fpl + + D * hyena_medium_conv_fpl + + H * hyena_long_conv_fft_fpl + ) diff --git a/tests/collections/llm/gpt/data/megatron/hyena/test_config.py b/tests/collections/llm/gpt/data/megatron/hyena/test_config.py new file mode 100644 index 000000000000..cafd12f26ba5 --- /dev/null +++ b/tests/collections/llm/gpt/data/megatron/hyena/test_config.py @@ -0,0 +1,182 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +from collections import defaultdict +from contextlib import contextmanager +from pathlib import Path +from typing import Union + +import pytest +import yaml + +from nemo.collections.llm.gpt.data.megatron.hyena.config import Evo2BlendedDatasetConfig, parse_dataset_config + + +@contextmanager +def change_dir(new_dir: Union[str, Path]): + """ + Context manager for temporarily changing the working directory using os. + + Args: + new_dir (Union[str, Path]): The directory to change to + + Yields: + str: The new working directory path + + Example: + with change_dir('/path/to/dir'): + # Do some work in the new directory + ... + # Original directory is restored + """ + prev_dir = os.getcwd() + new_dir = os.path.expanduser(str(new_dir)) + try: + os.chdir(new_dir) + yield new_dir + finally: + os.chdir(prev_dir) + + +@pytest.fixture +def temp_dataset_config(): + # Create a temporary directory for the dataset path + temp_dir = tempfile.TemporaryDirectory() + dataset_path = temp_dir.name + + # Create a temporary YAML file for the dataset configuration + temp_yaml = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml") + dataset_config_path = temp_yaml.name + + # Define dataset configuration content + dataset_config_content = [ + {"dataset_prefix": "dataset1", "dataset_weight": 0.5, "dataset_split": "train"}, + {"dataset_prefix": "dataset2", "dataset_weight": 0.5, "dataset_split": "train"}, + {"dataset_prefix": "dataset1", "dataset_weight": 0.6, "dataset_split": "validation"}, + {"dataset_prefix": "dataset2", "dataset_weight": 0.6, "dataset_split": "validation"}, + {"dataset_prefix": "dataset2", "dataset_weight": 0.2, "dataset_split": "test"}, + ] + + # Write the dataset configuration content to the YAML file + with open(dataset_config_path, "w") as yaml_file: + yaml.dump(dataset_config_content, yaml_file) + + # Create dummy dataset files in the temporary directory + for dataset in dataset_config_content: + dataset_file = Path(dataset_path) / f"{dataset['dataset_prefix']}.txt" + dataset_file.touch() + + yield dataset_config_path, dataset_path + + # Clean up temporary files and directories + temp_yaml.close() + os.remove(dataset_config_path) + temp_dir.cleanup() + + +@pytest.fixture +def tmp_dataset(tmp_path): + """Create temporary dataset files for testing.""" + dataset_dir = tmp_path / "data" + dataset_dir.mkdir() + (dataset_dir / "dataset.bin").touch() + return dataset_dir + + +def test_valid_absolute_path(tmp_dataset): + """Test configuration with valid absolute path.""" + config = Evo2BlendedDatasetConfig( + dataset_prefix=str(tmp_dataset / "dataset"), dataset_weight=0.5, dataset_split="train" + ) + assert config.dataset_prefix == str(tmp_dataset / "dataset") + assert config.dataset_weight == 0.5 + assert config.dataset_split == "train" + + +def test_valid_relative_path(tmp_dataset): + """Test configuration with valid relative path and base data path.""" + config = Evo2BlendedDatasetConfig( + dataset_path=str(tmp_dataset), dataset_prefix="dataset", dataset_weight=0.5, dataset_split="validation" + ) + assert config.dataset_prefix == str(tmp_dataset / "dataset") + + +def test_invalid_relative_path_without_base(): + """Test relative path fails without base data path.""" + with pytest.raises(ValueError, match=f"dataset_prefix file does not exist: {Path('dataset').resolve()}"): + Evo2BlendedDatasetConfig(dataset_prefix="dataset", dataset_weight=0.5, dataset_split="train") + + +def test_valid_relative_path_without_base(tmp_dataset: Path): + """Test relative path in current workdir does not fail without base data path.""" + # changing temporary cwd since Path(dataset_prefix).resolve() will resolve relative paths to the current working directory + with change_dir(tmp_dataset): + Evo2BlendedDatasetConfig(dataset_prefix="dataset", dataset_weight=0.5, dataset_split="train") + + +def test_nonexistent_parent_path(tmp_path): + """Test configuration fails with nonexistent parent directory.""" + invalid_path = tmp_path / "nonexistent" / "dataset" + with pytest.raises(ValueError, match="parent path does not exist"): + Evo2BlendedDatasetConfig(dataset_prefix=str(invalid_path), dataset_weight=0.5, dataset_split="train") + + +def test_nonexistent_dataset_file(tmp_dataset): + """Test configuration fails with nonexistent dataset file.""" + invalid_path = tmp_dataset / "nonexistent_dataset" + with pytest.raises(ValueError, match="dataset_prefix file does not exist"): + Evo2BlendedDatasetConfig(dataset_prefix=str(invalid_path), dataset_weight=0.5, dataset_split="train") + + +def test_path_resolution(tmp_dataset): + """Test proper path resolution with different input formats.""" + relative_path = Path("dataset") + absolute_path = tmp_dataset / "dataset" + + config1 = Evo2BlendedDatasetConfig( + dataset_path=str(tmp_dataset), dataset_prefix=str(relative_path), dataset_weight=0.5, dataset_split="train" + ) + # changing temporary cwd since Path(dataset_prefix).resolve() will resolve relative paths to the current working directory + with change_dir(tmp_dataset): + config2 = Evo2BlendedDatasetConfig( + dataset_prefix=str(absolute_path), dataset_weight=0.5, dataset_split="train" + ) + + assert config1.dataset_prefix == config2.dataset_prefix + + +def test_parse_dataset_config(temp_dataset_config): + dataset_config_path, dataset_path = temp_dataset_config + + # Call the function to test + result = parse_dataset_config(dataset_config_path, dataset_path) + + print(result) + # Define the expected result + expected_result = defaultdict( + list, + { + "train": [0.5, str(Path(dataset_path) / "dataset1"), 0.5, str(Path(dataset_path) / "dataset2")], + "validation": [0.5, str(Path(dataset_path) / "dataset1"), 0.5, str(Path(dataset_path) / "dataset2")], + "test": [ + 1.0, + str(Path(dataset_path) / "dataset2"), + ], + }, + ) + + # Assert the result matches the expected result + assert result == expected_result diff --git a/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py new file mode 100644 index 000000000000..429c06291f9e --- /dev/null +++ b/tests/collections/llm/gpt/data/megatron/hyena/test_evo2_dataset.py @@ -0,0 +1,1059 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +import timeit +from typing import Tuple + +import pytest +import torch + +from nemo.collections.llm.gpt.data.megatron.hyena.evo2_dataset import Evo2Dataset, Evo2DatasetPadEodLossMask + +""" +The tag token is constructed as follows: So note that one way to know you have a tag is if you look at the first +token after the pipe and it is a 'd' character. Make sure tests are consistent with this simplification. + @staticmethod + def _construct_taxonomy_token( + lineage: Evo2TaxonomyLineage, dropout: float = 0.0, seed: Optional[int] = None + ) -> Optional[str]: + '''Construct a special Taxonomy token for natural language prompting of DNA generation models. + + Args: + lineage (Evo2TaxonomyLineage): The taxonomy lineage information. + dropout (float): The probability of dropping out segments of the lineage. Defaults to 0.0. + seed (Optional[int]): The seed for the random number generator. Defaults to None. + + Returns: + Optional[str]: The constructed taxonomy token or None if lineage is None. + ''' + # If dropout > 0, randomly drop out segments of the lineage for training on incomplete lineages. + with Evo2Preprocessor.preprocessing_context_manager(seed if seed is not None else None): + return ( + "|d__{};p__{};c__{};o__{};f__{};g__{};s__{}|".format( + lineage.domain if random.random() >= dropout else None, + lineage.phylum if random.random() >= dropout else None, + lineage.clazz if random.random() >= dropout else None, + lineage.order if random.random() >= dropout else None, + lineage.family if random.random() >= dropout else None, + lineage.genus if random.random() >= dropout else None, + lineage.species if random.random() >= dropout else None, + ) + if lineage is not None + else None + ) +""" + + +@pytest.fixture +def tag_tokens(): + """Standard tokens for phylogenetic tag tests, defined in Evo2_DataseT: + + CONTROL_TAGS: ClassVar[list[int]] = [64, 35] # '@' tag for splice splits/windows, '#' for contig splits + TAG_BOUNDS = 124 # start and end delim: '|' + TAG_CHARS: ClassVar[set[int]] = {95, 59, 32} # chars only found in control tags: _, ;, space + DEFAULT_EOD = 0 + """ + return { + "terminal": 124, # | + "other_chars": {95, 59, 32}, # _, ;, space + "eod": 0, # end of document token + } + + +def test_mask_phylogenetic_tags_with_eod(tag_tokens): + """ + Tests a sequence where an EOD splits two partial tags. + + Example sequence (ASCII): + 65 124 100 0 124 65 + 'A' '|' 'd' EOD '|' 'A' + + - Segment 1: "A|d" => keep 'A' (DNA), mask '|' and 'd' + - EOD => masked + - Segment 2: "|A" => mask '|', keep 'A' (DNA) + + Expected masking: [1, 0, 0, 1, 0, 1] + """ + sequence = torch.tensor([65, 124, 100, 0, 124, 65]) # "A|d" + EOD + "|A" + + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], # '|' + other_tag_chars=tag_tokens["other_chars"], # { '_',';',' ' } + eod_token_id=tag_tokens["eod"], # 0 + ) + + expected_mask = torch.tensor([1, 0, 0, 1, 0, 1]) + assert torch.equal(mask, expected_mask) + + +def test_mask_phylogenetic_tags_middle(tag_tokens): + """Tests masking a phylogenetic tag that appears in the middle of a DNA sequence. + + The sequence contains: + 1. Normal DNA (ATG) + 2. A phylo tag (|d_|) + 3. More DNA (TCGA) + + Expected behavior: The DNA should be unmasked (1s) while everything between + and including the pipe characters should be masked (0s), as it's a valid phylo tag. + """ + sequence = torch.tensor( + [ + 65, + 84, + 71, # ATG + 124, + 100, + 110, + 102, + 111, + 95, + 116, + 97, + 103, + 124, # |d__tag| + 84, + 67, + 71, + 65, # TCGA + ] + ) + + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], # | + other_tag_chars=tag_tokens["other_chars"], # _, ;, space + eod_token_id=tag_tokens["eod"], + ) + + expected_mask = torch.tensor( + [ + 1, + 1, + 1, # DNA unmasked + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, # phylo tag masked + 1, + 1, + 1, + 1, # DNA unmasked + ] + ) + assert torch.equal(mask, expected_mask) + + +def test_mask_partial_tag_start(tag_tokens): + """Tests handling a sequence that starts with a partial phylogenetic tag. + + The sequence starts with characters that would be inside a phylo tag, + followed by a closing pipe and DNA. Since we want to prevent the model from + learning non-DNA outputs, we mask all potential tag characters even without + complete tag delimiters. + + Sequence: "tag;_|ATG" (starting mid-tag) + Expected: All tag characters and delimiters masked, only DNA unmasked + """ + sequence = torch.tensor( + [ + 116, + 97, + 103, + 59, + 95, # tag;_ + 124, # | + 65, + 84, + 71, # ATG + ] + ) + + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + + expected_mask = torch.tensor( + [ + 0, + 0, + 0, + 0, + 0, # partial tag start masked + 0, # closing pipe masked + 1, + 1, + 1, # DNA unmasked + ] + ) + assert torch.equal(mask, expected_mask) + + +def test_mask_partial_tag_end(tag_tokens): + """Tests handling a sequence that ends with a partial phylogenetic tag. + + The sequence contains DNA followed by an opening pipe and tag characters, + but no closing pipe. Per requirements, we aggressively mask any potential + tag characters to ensure the model only learns DNA bases {A,C,G,T}. + + Sequence: "ATG|info_" (ending mid-tag) + Expected: DNA unmasked, all tag-related characters masked + """ + sequence = torch.tensor( + [ + 65, + 84, + 71, # ATG + 124, # | + 100, + 110, + 102, + 111, + 95, # info_ + ] + ) + + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + + expected_mask = torch.tensor( + [ + 1, + 1, + 1, # DNA unmasked + 0, # opening pipe masked + 0, + 0, + 0, + 0, + 0, # partial tag end masked + ] + ) + assert torch.equal(mask, expected_mask) + + +def test_standalone_tag(tag_tokens): + """Tests masking of a single complete tag with no surrounding sequence. + + Tests that a standalone tag (|tag_|) is fully masked since it contains + non-DNA characters. This ensures the model only learns to output + {A,C,G,T} tokens. + + Sequence: |tag_| + Expected: All tokens masked (all zeros) + """ + sequence = torch.tensor([124, 100, 97, 103, 95, 124]) # |dtag_| + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + expected = torch.tensor([0, 0, 0, 0, 0, 0]) # All masked + assert torch.equal(mask, expected) + + +def test_sequence_starting_with_tag(tag_tokens): + """Tests sequence that begins with a complete tag followed by DNA. + + Verifies that when a sequence starts with a complete tag followed by + DNA bases, the tag portion is masked while the DNA portion remains + unmasked. + + Sequence: |tag_|ATG + Expected: Tag masked (zeros), DNA unmasked (ones) + """ + sequence = torch.tensor( + [ + 124, + 100, # d token for domain + 97, + 103, + 95, + 124, # |tag_| + 65, + 84, + 71, # ATG + ] + ) + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + expected = torch.tensor([0, 0, 0, 0, 0, 0, 1, 1, 1]) # Tag masked, DNA unmasked + assert torch.equal(mask, expected) + + +def test_sequence_ending_with_tag(tag_tokens): + """Tests sequence that ends with a complete tag. + + Verifies that when a sequence ends with a complete tag, the DNA portion + remains unmasked while the entire tag portion is masked. + + Sequence: ATG|tag_| + Expected: DNA unmasked (ones), tag masked (zeros) + """ + sequence = torch.tensor( + [ + 65, + 84, + 71, # ATG + 124, + 100, + 97, + 103, + 95, + 124, # |tag_| + ] + ) + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + expected = torch.tensor([1, 1, 1, 0, 0, 0, 0, 0, 0]) # DNA unmasked, tag masked + assert torch.equal(mask, expected) + + +def test_mask_multiple_tags(tag_tokens): + """Tests handling multiple phylogenetic tags in sequence, demonstrating state transitions. + + This tests how the masking switches states between phylo and non-phylo regions: + 1. Starts in non-phylo state with DNA + 2. Switches to phylo state at first pipe (with tag chars) + 3. Switches back to non-phylo at closing pipe + 4. Pattern repeats for second tag + + Sequence: "ATG|tag_1|CG|tag_2|AT" + Expected: Only DNA sequences should remain unmasked + """ + sequence = torch.tensor( + [ + 65, + 84, + 71, # ATG + 124, + 100, + 97, + 103, + 95, + 49, + 124, # |tag_1| + 67, + 71, # CG + 124, + 100, + 97, + 103, + 95, + 50, + 124, # |tag_2| + 65, + 84, # AT + ] + ) + + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + + expected_mask = torch.tensor( + [ + 1, + 1, + 1, # DNA unmasked + 0, + 0, + 0, + 0, + 0, + 0, + 0, # first tag masked + 1, + 1, # DNA unmasked + 0, + 0, + 0, + 0, + 0, + 0, + 0, # second tag masked + 1, + 1, # DNA unmasked + ] + ) + assert torch.equal(mask, expected_mask) + + +def test_mask_dna_after_pipe(tag_tokens): + """Tests the scenario where we have a pipe followed by DNA sequence. + + This tests the edge case of a pipe character appearing at the start of a sequence. + Even if DNA follows, we mask the pipe character to prevent the model from + learning to output non-DNA tokens. + + Sequence: "|ATG" (pipe followed by DNA) + Expected: Pipe masked, DNA unmasked + """ + sequence = torch.tensor( + [ + 124, # | + 65, + 84, + 71, # ATG + ] + ) + + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + + expected_mask = torch.tensor([0, 1, 1, 1]) # Pipe masked, DNA unmasked + assert torch.equal(mask, expected_mask) + + +def test_ambiguous_dna_char_followed_by_tag_start(tag_tokens): + """Tests handling of an ambiguous DNA character followed by a tag start. + + When we see a character that could be either DNA or the end of a truncated tag + followed by a pipe, we should mask both for safety since we can't disambiguate + whether the character was part of a tag. + + Sequence: "t|AAAT" (t could be DNA or end of tag) + Expected: First t and pipe masked (0), AAAT unmasked (1) + """ + sequence = torch.tensor( + [ + 116, # t (ambiguous - could be DNA or end of tag) + 124, # | + 65, # A + 65, # A + 65, # A + 84, # T + ] + ) + + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + + expected_mask = torch.tensor([0, 0, 1, 1, 1, 1]) # Ambiguous t and pipe masked, DNA unmasked + assert torch.equal(mask, expected_mask) + + +def test_dna_followed_by_unambiguous_tag_start(tag_tokens): + """Tests handling of DNA sequence followed by clear tag start. + + When we see DNA followed by |d, it's unambiguous - the d clearly indicates + the start of a phylogenetic tag (d__), so we can safely unmask the DNA and + mask the tag portion. + + Sequence: "AAAT|d" (AAAT is DNA, |d starts tag) + Expected: AAAT unmasked (1), |d masked (0) + """ + sequence = torch.tensor( + [ + 65, # A + 65, # A + 65, # A + 84, # T + 124, # | + 100, # d (clearly starts d__ tag) + ] + ) + + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + + expected_mask = torch.tensor([1, 1, 1, 1, 0, 0]) # DNA unmasked, tag start masked + assert torch.equal(mask, expected_mask) + + +def test_double_partial_tags_with_dna_middle(tag_tokens): + """Tests a sequence that has partial tags at both ends with DNA in the middle. + + Tests the specific case where a sequence slice cuts through phylogenetic tags + on both ends, with valid DNA sequence in the middle. The behavior we want is: + 1. The partial tag at the start should be masked + 2. The DNA in the middle should be unmasked + 3. The partial tag at the end should be masked + + Sequence: "cacata|acagataaaataTACAGGGAATA|d__" + Expected: First partial tag masked (0s), middle DNA unmasked (1s), end tag masked (0s) + """ + sequence = torch.tensor( + [ + 99, + 97, + 99, + 97, + 116, + 97, # cacata + 124, # | + 97, + 99, + 97, + 103, + 97, + 116, + 97, + 97, + 97, + 97, + 116, + 97, # acagataaaata + 84, + 65, + 67, + 65, + 71, + 71, + 71, + 65, + 65, + 84, + 65, # TACAGGGAATA + 124, # | + 100, + 95, + 95, # d__ + ] + ) + + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + + expected_mask = torch.tensor( + [ + 0, + 0, + 0, + 0, + 0, + 0, # partial start tag masked + 0, # pipe masked + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, # middle DNA unmasked + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, # middle DNA unmasked + 0, # pipe masked + 0, + 0, + 0, # partial end tag masked + ] + ) + + assert torch.equal(mask, expected_mask) + + +def test_packed_partial_tag_subsequence_predna(tag_tokens): + """ + Sequence: "GAATA[EOD]cacata|acagataaaataTACAGGGAATA|d__" + Expected: First partial tag masked (0s), middle DNA unmasked (1s), end tag masked (0s) + + """ + sequence_alpha = "GAATA0cacata|acagataaaataTACAGGGAATA|d__" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor( + len("GAATA0") * [1] + [0] * len("cacata|") + len("acagataaaataTACAGGGAATA") * [1] + [0] * len("|d__"), + dtype=torch.int32, + ) + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, expected_mask) + + +def test_packed_partial_tag_subsequence_pretag(tag_tokens): + """ + Sequence: "cacata|[EOD]acagataaaataTACAGGGAATA|d__" + Expected: First partial tag masked (0s), middle DNA unmasked (1s), end tag masked (0s) + + """ + sequence_alpha = "cacata|0acagataaaataTACAGGGAATA|d__" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor( + len("cacata") * [1] + [0] + [1] * len("0acagataaaataTACAGGGAATA") + len("|d__") * [0], dtype=torch.int32 + ) + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, expected_mask) + + +def test_packed_partial_tag_subsequence_predna_middletag(tag_tokens): + """ + Sequence: "GAATA[EOD]cacata|acagataaaata|d__tag;|TACAGGGAATA|d__" + Expected: First partial tag masked (0s), middle DNA unmasked (1s), end tag masked (0s) + + """ + sequence_alpha = "GAATA0cacata|acagataaaata|d__tag;|TACAGGGAATA|d__" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor( + len("GAATA0") * [1] + + len("cacata|") * [0] + + [1] * len("acagataaaata") + + len("|d__tag;|") * [0] + + len("TACAGGGAATA") * [1] + + len("|d__") * [0], + dtype=torch.int32, + ) + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, expected_mask) + + +def test_packed_partial_tag_subsequence_pretag_middletag(tag_tokens): + """ + Sequence: "cacata|[EOD]acagataaaata|d__tag;|TACAGGGAATA|d__" + Expected: First partial tag masked (0s), middle DNA unmasked (1s), end tag masked (0s) + + """ + sequence_alpha = "cacata|0acagataaaata|d__tag;|TACAGGGAATA|d__" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor( + len("cacata") * [1] + + [0] # masked pipe. + + [1] * len("0acagataaaata") + + len("|d__tag;|") * [0] + + len("TACAGGGAATA") * [1] + + len("|d__") * [0], + dtype=torch.int32, + ) + mask = Evo2DatasetPadEodLossMask.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, expected_mask) + + +def test_packed_partial_tag_subsequence_pretag_middletag_bs2(tag_tokens): + """ + Sequence: "cacata|[EOD]acagataaaata|d__tag;|TACAGGGAATA|d__" + Expected: First partial tag masked (0s), middle DNA unmasked (1s), end tag masked (0s) + + """ + sequence_alpha = "cacata|0acagataaaata|d__tag;|TACAGGGAATA|d__" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor( + len("cacata") * [1] + + [0] + + [1] * len("0acagataaaata") + + len("|d__tag;|") * [0] + + len("TACAGGGAATA") * [1] + + len("|d__") * [0], + dtype=torch.int32, + ) + expected_mask = torch.stack([expected_mask, expected_mask]) + mask = Evo2DatasetPadEodLossMask.mask_phylogenetic_tags( + tokenized_sequence=torch.stack([sequence, sequence]), + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, expected_mask) + + +def test_packed_partial_tag_subsequence_pretag_middletag_bs3(tag_tokens): + """ + Sequence: "cacata|[EOD]acagataaaata|d__tag;|TACAGGGAATA|d__" + Expected: First partial tag masked (0s), middle DNA unmasked (1s), end tag masked (0s) + + """ + sequence_alpha = "cacata|0acagataaaata|d__tag;|TACAGGGAATA|d__somet" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor( + len("cacata") * [1] + + [0] + + [1] * len("0acagataaaata") + + len("|d__tag;|") * [0] + + len("TACAGGGAATA") * [1] + + len("|d__somet") * [0], + dtype=torch.int32, + ) + + sequence_alpha2 = "GAATA0cacata|acagataaaata|d__tag;|TACAGGGAATA|d__" + sequence2 = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha2], dtype=torch.int32) + expected_mask2 = torch.tensor( + len("GAATA0") * [1] + + len("cacata|") * [0] + + [1] * len("acagataaaata") + + len("|d__tag;|") * [0] + + len("TACAGGGAATA") * [1] + + len("|d__") * [0], + dtype=torch.int32, + ) + + expected_mask = torch.stack([expected_mask, expected_mask, expected_mask2]) + + mask = Evo2DatasetPadEodLossMask.mask_phylogenetic_tags( + tokenized_sequence=torch.stack([sequence, sequence, sequence2]), + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, expected_mask) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_packed_partial_tag_subsequence_pretag_middletag_bs3_cuda(tag_tokens): + sequence_alpha = "cacata|0acagataaaata|d__tag;|TACAGGGAATA|d__somet" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor( + len("cacata") * [1] + + [0] + + [1] * len("0acagataaaata") + + len("|d__tag;|") * [0] + + len("TACAGGGAATA") * [1] + + len("|d__somet") * [0], + dtype=torch.int32, + ) + + sequence_alpha2 = "GAATA0cacata|acagataaaata|d__tag;|TACAGGGAATA|d__" + sequence2 = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha2], dtype=torch.int32) + expected_mask2 = torch.tensor( + len("GAATA0") * [1] + + len("cacata|") * [0] + + [1] * len("acagataaaata") + + len("|d__tag;|") * [0] + + len("TACAGGGAATA") * [1] + + len("|d__") * [0], + dtype=torch.int32, + ) + + expected_mask = torch.stack([expected_mask, expected_mask, expected_mask2]) + + mask = Evo2DatasetPadEodLossMask.mask_phylogenetic_tags( + tokenized_sequence=torch.stack([sequence, sequence, sequence2]).cuda(), + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask.cpu(), expected_mask) + + +def test_multiple_packed_tags(tag_tokens): + """ + Tests a sequence with multiple packed tags. + """ + sequence_alpha = "|d__tag;|0|d__tag;|0|d__somet" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor( + len("|d__tag;|") * [0] + len("0") * [1] + len("|d__tag;|") * [0] + len("0") * [1] + len("|d__somet") * [0], + dtype=torch.int32, + ) + mask = Evo2DatasetPadEodLossMask.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, expected_mask) + + +def test_multiple_eods(tag_tokens): + """ + Tests a sequence with multiple EODs. + """ + sequence_alpha = "ACGT0tacg0" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor(len(sequence_alpha) * [1], dtype=torch.int32) + mask = Evo2DatasetPadEodLossMask.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, expected_mask) + + +def test_multiple_eods_prefix_no_suffix(tag_tokens): + """ + Tests a sequence with multiple EODs. + """ + sequence_alpha = "0ACGT0tacg0aa" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor(len(sequence_alpha) * [1], dtype=torch.int32) + mask = Evo2DatasetPadEodLossMask.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, expected_mask) + + +def test_no_eods_with_batch(tag_tokens): + """ + Tests a sequence with multiple EODs. + """ + sequence_alpha = "ACATAGATTT" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor(len(sequence_alpha) * [1], dtype=torch.int32) + mask = Evo2DatasetPadEodLossMask.mask_phylogenetic_tags( + tokenized_sequence=torch.stack([sequence, sequence]), + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, torch.stack([expected_mask, expected_mask])) + + +def test_no_eods_one_tag_with_batch_bs2(tag_tokens): + """ + Tests a sequence with multiple EODs. + """ + sequence_alpha = "ACAT|d__tag;|AGATTT" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor(len("ACAT") * [1] + len("|d__tag;|") * [0] + len("AGATTT") * [1], dtype=torch.int32) + mask = Evo2DatasetPadEodLossMask.mask_phylogenetic_tags( + tokenized_sequence=torch.stack([sequence, sequence]), + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, torch.stack([expected_mask, expected_mask])) + + +def test_packed_partial_tag_subsequence_predna_with_control(tag_tokens): + """ + Sequence: "GAATA[EOD]cacata|acagataaa@ataTACAGGGAATA|d__" + Expected: First partial tag masked (0s), middle DNA unmasked (1s), end tag masked (0s) + + """ + sequence_alpha = "GAATA0cacata|acagataaaa@taTACAGGGAATA|d__" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor( + len("GAATA0") * [1] + [0] * len("cacata|") + len("acagataaaa@taTACAGGGAATA") * [1] + [0] * len("|d__"), + dtype=torch.int32, + ) + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, expected_mask) + + +def test_packed_partial_tag_subsequence_predna_with_control2(tag_tokens): + """ + Sequence: "GAATA[EOD]cacata|acagataaa@ataTACAGGGAATA|d__" + Expected: First partial tag masked (0s), middle DNA unmasked (1s), end tag masked (0s) + + """ + sequence_alpha = "GA#ATA0cacata|acagataaaa@taTACAGGGAATA|d__" + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32) + expected_mask = torch.tensor( + len("GA#ATA0") * [1] + [0] * len("cacata|") + len("acagataaaa@taTACAGGGAATA") * [1] + [0] * len("|d__"), + dtype=torch.int32, + ) + mask = Evo2Dataset.mask_phylogenetic_tags( + tokenized_sequence=sequence, + terminal_tag_char=tag_tokens["terminal"], + other_tag_chars=tag_tokens["other_chars"], + eod_token_id=tag_tokens["eod"], + ) + torch.testing.assert_close(mask, expected_mask) + + +def _construct_taxonomy_token(dropout: float = 0.0) -> str: + """Construct a special Taxonomy token for natural language prompting of DNA generation models. + + Args: + dropout (float): The probability of dropping out segments of the lineage. Defaults to 0.0. + + Returns: + Optional[str]: The constructed taxonomy token or None if lineage is None. + """ + # If dropout > 0, randomly drop out segments of the lineage for training on incomplete lineages. + return "|d__{};p__{};c__{};o__{};f__{};g__{};s__{}|".format( + "somedomain" if random.random() >= dropout else None, + "somephylum" if random.random() >= dropout else None, + "someclass" if random.random() >= dropout else None, + "someorder" if random.random() >= dropout else None, + "somefamily" if random.random() >= dropout else None, + "lineage.genus" if random.random() >= dropout else None, + "lineage.speciescactaca" if random.random() >= dropout else None, + ) + + +def mask_phylogenetic_tags_old(tokenized_sequence, terminal_tag_char, other_tag_chars, eod_token_id): + """ + Optimized version to create a phylonetic tag mask for batched tokenized sequences with correct handling of partial tags. + Args: + - tokenized_sequence (torch.Tensor): A batched tensor of shape (batch_size, seq_length). + - terminal_tag_char (int): The token ID representing the start and end of a phylogenetic tag ('|'). + - other_tag_chars (set of int): A set of token IDs that are uniquely part of the tag ('_', ';', etc.). + - eod_token_id (int): The token ID representing the end-of-document (EOD). + Returns: + - mask_vector (torch.Tensor): A batched mask of the same shape as tokenized_sequence where + 1 represents non-tag tokens and 0 represents tokens within the masked region. + """ + device = tokenized_sequence.device + batch_size, seq_len = tokenized_sequence.shape + mask_vector = torch.ones_like(tokenized_sequence, dtype=torch.int, device=device) + + # To address when unbalanced tags are present + terms = torch.tensor([0, seq_len - 1], device=device) + other_tags = torch.tensor(list(other_tag_chars), device=device) + for batch_idx in range(batch_size): + tag_term_locs = torch.where(tokenized_sequence[batch_idx] == terminal_tag_char)[0] + tag_end_locs = torch.where(tokenized_sequence[batch_idx] == eod_token_id)[0] + + merged_tags = torch.cat((terms, tag_term_locs, tag_end_locs)).sort()[0] + merged_tags = merged_tags.unique() + + start = 0 # First and last locations are always added + for end in merged_tags[1:]: + if torch.isin(tokenized_sequence[batch_idx][start:end], other_tags).sum() > 0: + # end token is not part of the tag + if eod_token_id == tokenized_sequence[batch_idx][end]: + end = end - 1 + if eod_token_id == tokenized_sequence[batch_idx][start]: + start = start + 1 + + mask_vector[batch_idx][start : (end + 1)] = 0 + start = end + return mask_vector + + +def benchmark_phylo_tag_masking(num_iterations: int = 1000) -> Tuple[float, float]: + """Benchmark the performance of phylogenetic tag masking functions. + + Args + num_iterations: Number of iterations to run for timing + """ + tax_token = _construct_taxonomy_token(dropout=0.0) + sequence_alpha = ( + tax_token[2:] + + "".join(random.choice("ACGTacgt") for _ in range(5000)) + + tax_token[:-25] + + "0" + + tax_token[36:] + + "".join(random.choice("ACGTacgt") for _ in range(5000)) + ) + sequence = torch.tensor([ord(t) if t != "0" else 0 for t in sequence_alpha], dtype=torch.int32, device="cpu") + + # Time the new implementation + new_time1 = timeit.timeit( + lambda: Evo2Dataset.mask_phylogenetic_tags(sequence.unsqueeze(0), 124, {95, 59, 32}, 0), + number=num_iterations, + ) + + # Time the old implementation + old_time1 = timeit.timeit( + lambda: mask_phylogenetic_tags_old(sequence.unsqueeze(0), 124, {95, 59, 32}, 0), + number=num_iterations, + ) + + # Time the new implementation + new_time2 = timeit.timeit( + lambda: Evo2Dataset.mask_phylogenetic_tags(sequence.unsqueeze(0), 124, {95, 59, 32}, 0), + number=num_iterations, + ) + + # Time the old implementation + old_time2 = timeit.timeit( + lambda: mask_phylogenetic_tags_old(sequence.unsqueeze(0), 124, {95, 59, 32}, 0), + number=num_iterations, + ) + new_time = (new_time1 + new_time2) / 2 + old_time = (old_time1 + old_time2) / 2 + return old_time, new_time + + +def test_phylo_tag_masking_speed(): + num_iterations = 2000 + old_time, new_time = benchmark_phylo_tag_masking(num_iterations=num_iterations) + # Assert performance equivalent to within 20% or better on a small example. + assert old_time / num_iterations > (new_time / num_iterations) * 0.8 + + +if __name__ == "__main__": + num_iterations = 2000 + old_time, new_time = benchmark_phylo_tag_masking(num_iterations=num_iterations) + print(f"Old implementation average time: {old_time/num_iterations:.6f} seconds") + print(f"New implementation average time: {new_time/num_iterations:.6f} seconds") + print(f"Speed improvement: {(old_time/new_time - 1)*100:.2f}%") diff --git a/tests/collections/llm/gpt/model/test_hyena.py b/tests/collections/llm/gpt/model/test_hyena.py new file mode 100644 index 000000000000..8c41ea9b413d --- /dev/null +++ b/tests/collections/llm/gpt/model/test_hyena.py @@ -0,0 +1,685 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from dataclasses import asdict +from typing import Type + +# TODO add back support for slurm resilience. +# import nvidia_resiliency_ext.ptl_resiliency as res_module +import torch +from lightning.pytorch.callbacks import LearningRateMonitor, RichModelSummary +from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger +from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.optimizer import OptimizerConfig + +from nemo import lightning as nl +from nemo.collections import llm +from nemo.collections.llm.gpt.data import MockDataModule, PreTrainingDataModule +from nemo.collections.llm.gpt.data.megatron.hyena.config import parse_dataset_config +from nemo.collections.llm.gpt.data.megatron.hyena.evo2_dataset import Evo2Dataset, Evo2DatasetPadEodLossMask +from nemo.collections.llm.gpt.model.hyena import HYENA_MODEL_OPTIONS +from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( + userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192, + userbuffers_fp8_h100_h8192_tp4_mbs1_seqlen8192, +) +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +from nemo.lightning import NeMoLogger +from nemo.lightning.pytorch import callbacks as nl_callbacks +from nemo.lightning.pytorch.callbacks import ModelCheckpoint +from nemo.lightning.pytorch.callbacks.flops_callback import FLOPsMeasurementCallback +from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback +from nemo.lightning.pytorch.optim import CosineAnnealingScheduler +from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule +from nemo.lightning.pytorch.strategies.utils import RestoreConfig +from nemo.utils.exp_manager import TimingCallback + +torch._dynamo.config.suppress_errors = True + + +def parse_args(): + """Parse arguments for Evo2 model training.""" + parser = argparse.ArgumentParser( + description="Train a Hyena model using NeMo 2.0.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + data_group = parser.add_mutually_exclusive_group(required=True) + + data_group.add_argument( + "-d", + "--dataset-config", + type=str, + help="Path to the blended / weighted training dataset configuration YAML.", + ) + data_group.add_argument( + "--mock-data", + action="store_true", + help="Train with Mock data (for testing/debugging), either set this or provide a dataset config.", + ) + + parser.add_argument( + "--dataset-dir", + type=str, + help="Absolute path to the dataset directory. Defaults to using the absolute or relative paths (dataset_prefix) specified in the dataset config YAML.", + ) + + parser.add_argument("--num-nodes", type=int, default=1, help="Number of nodes to use for training, defaults to 1.") + parser.add_argument("--devices", type=int, default=1, help="Number of devices to use for training, defaults to 1.") + parser.add_argument("--seq-length", type=int, default=8192, help="Training sequence length") + parser.add_argument( + "--tensor-parallel-size", type=int, default=1, help="Order of tensor parallelism. Defaults to 1." + ) + parser.add_argument( + "--pipeline-model-parallel-size", type=int, default=1, help="Order of pipeline parallelism. Defaults to 1." + ) + parser.add_argument( + "--context-parallel-size", type=int, default=1, help="Order of context parallelism. Defaults to 1." + ) + parser.add_argument("--no-wandb", action="store_true", default=False, help="Disable Wandb logging") + parser.add_argument("--wandb-project", type=str, default="nemo_evo2", help="Wandb project name") + parser.add_argument("--wandb-run-id", type=str, default=None, help="Wandb run identifier") + parser.add_argument( + "--wandb-group", type=str, default=None, help="A unique string shared by all runs in a given group" + ) + parser.add_argument( + "--wandb-job-type", + type=str, + default=None, + help="A unique string representing a type of run, which is useful when you're grouping runs together into larger experiments using group.", + ) + parser.add_argument("--sequence-parallel", action="store_true", help="Set to enable sequence parallelism.") + parser.add_argument("--fp8", action="store_true", help="Set to enable FP8") + parser.add_argument("--micro-batch-size", type=int, default=1, help="Micro-batch size for data-parallel training.") + parser.add_argument( + "--global-batch-size", + type=int, + default=None, + help="Global batch size for training. If set to None, infer it from the TP, CP, and PP parameters.", + ) + parser.add_argument( + "--grad-acc-batches", type=int, default=1, help="Number of batches to accumulate gradients over." + ) + parser.add_argument("--max-steps", type=int, help="Number of training optimizer update steps.") + parser.add_argument( + "--val-check-interval", type=int, help="Number of steps between validation measurements and model checkpoints." + ) + parser.add_argument("--grad-reduce-in-fp32", action="store_true", default=False, help="Gradient reduce in FP32.") + parser.add_argument( + "--fp8-wgrad", + action="store_true", + default=False, + help="Faster option that is maybe less accurate (TBD) when using fp8.", + ) + parser.add_argument( + "--no-aligned-megatron-ddp", action="store_true", default=False, help="Do not do aligned gradient updates etc." + ) + parser.add_argument("--use-megatron-comm-overlap-8k", action="store_true", default=False) + parser.add_argument( + "--tp-comm-overlap-backend", + type=str, + choices=["nccl", "mpi", "gloo"], + default="nccl", + help="TP communication backend to use. Defaults to 'nccl'.", + ) + parser.add_argument("--align-param-gather", action="store_true", default=False) + # parser.add_argument("--straggler-detection", action="store_true", default=False) + parser.add_argument( + "--model-size", + type=str, + choices=sorted(HYENA_MODEL_OPTIONS.keys()), + default="7b", + help="Model architecture to use, choose between 7b, 40b, or test (a sub-model of 4 layers, less than 1B " + "parameters). '_arc_1m' models have GLU / FFN dimensions that support 1M context length when trained " + "with TP<=8.", + ) + parser.add_argument( + "--add-bias-output", + action="store_true", + default=False, + help="Add bias to the output layer to enable learning a simple prior.", + ) + parser.add_argument( + "--experiment-dir", + type=str, + required=True, + help="Directory to write model checkpoints and results to.", + ) + parser.add_argument( + "--limit-val-batches", + type=int, + default=20, + help="Number of validation steps", + ) + parser.add_argument( + "--log-every-n-steps", + type=int, + default=1, + required=False, + help="Number of steps between logging.", + ) + parser.add_argument( + "--ckpt-dir", + type=str, + default=None, + help="Directory to restore an initial checkpoint from. Use this for supervised fine-tuning.", + ) + parser.add_argument("--wd", type=float, default=0.01, help="Weight decay for optimizer.") + parser.add_argument( + "--restore-optimizer-from-ckpt", + action="store_true", + help="Restore optimizer state from initial checkpoint. Defaults to False.", + ) + parser.add_argument( + "--no-average-in-collective", + action="store_true", + default=False, + help="Avaerage optimizer state in collective rather than dividing by dp size and summing.", + ) + parser.add_argument("--seed", type=int, default=1234, help="Set random seed for training.") + parser.add_argument("--workers", type=int, default=8, help="Number of workers to use for data loading.") + parser.add_argument( + "--gc-interval", + type=int, + default=0, + help="Set to a value > 0 if you want to synchronize garbage collection, will do gc every gc-interval steps.", + ) + parser.add_argument( + "--enable-preemption", + action="store_true", + default=False, + help="Enable preemption hooks. If enabled this will save a checkpoint whenver slurm exits.", + ) + parser.add_argument( + "--ckpt-async-save", + action="store_true", + default=False, + ) + parser.add_argument( + "--ckpt-format", + type=str, + choices=["torch_dist", "zarr"], + default="torch_dist", + help="Specify checkpoint format to use. Defaults to 'torch_dist', as 'zarr' is deprecated. Only use if " + "resuming training from a zarr checkpoint.", + ) + parser.add_argument( + "--eod-pad-in-loss-mask", + action="store_true", + default=False, + help="Do not predict EOD/Pad tokens (typical default, but not default in original evo2).", + ) + parser.add_argument( + "--cross-entropy-loss-fusion", + action="store_true", + default=False, + help="Use the faster, but maybe less accurate fused form of cross entropy, " + "which also has bf16 grads internally.", + ) + parser.add_argument( + "--no-fp32-residual-connection", + action="store_true", + default=False, + help="If set, turn off fp32 residual connections which may be faster but may impact accuracy.", + ) + parser.add_argument( + "--debug-ddp-parity-freq", + type=int, + default=0, + help="Set to value > 0 to debug DDP weight parity between ranks.", + ) + parser.add_argument( + "--hybrid-override-pattern", + type=str, + help="Override the hybrid override pattern in the config (specifies hyena layer ordering and type).", + ) + parser.add_argument( + "--num-layers", type=int, help="If set, override the number of layers specified in the requested config." + ) + parser.add_argument( + "--tflops-callback", + action="store_true", + default=False, + help="Enable tflops calculation callback for Hyena / Evo2. Defaults to False.", + ) + parser.add_argument( + "--log-parameters-and-shapes", + action="store_true", + default=False, + help="Log training parameters shapes and dtypes for debugging.", + ) + parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate.") + parser.add_argument("--min-lr", type=float, default=3e-5, help="Min learning rate in cosine annealing.") + parser.add_argument("--warmup-steps", type=int, default=2500, help="Number of warmup steps in cosine annealing") + # NSYS profiling/tooling arguments + parser.add_argument( + "--nsys-profiling", + action="store_true", + default=False, + help="Enable targeted `nsys` profiling on the training loop for a defined step range. To actually get profiling" + " output you must run the whole program with `nsys`. For example: " + " `nsys profile -s none -o output_report_name -t cuda,nvtx --force-overwrite true " + "--capture-range=cudaProfilerApi --capture-range-end=stop [regular python command here]`", + ) + # start, end, rank + parser.add_argument( + "--nsys-start-step", + type=int, + required=False, + default=0, + help="Start nsys profiling after this step.", + ) + parser.add_argument( + "--nsys-end-step", + type=int, + required=False, + help="End nsys profiling after this step.", + ) + parser.add_argument( + "--no-renormalize-loss", + action="store_true", + default=False, + help="Do not renormalize the loss weights.", + ) + # rank as list of integers + parser.add_argument( + "--nsys-ranks", + type=int, + nargs="+", + required=False, + default=[0], + help="Enable nsys profiling for these ranks.", + ) + parser.add_argument( + "--activation-checkpoint-recompute-num-layers", + type=int, + help="If set, override the default value set in the config.", + ) + parser.add_argument( + "--disable-checkpointing", + action="store_false", + default=True, + dest="create_checkpoint_callback", + help="Disable creating a ModelCheckpoint callback.", + ) + parser.add_argument( + "--clip-grad", + type=float, + default=1.0, + help="Grad clip value. Note that when using DDP this may need to be inflated.", + ) + parser.add_argument( + "--seq-len-interpolation-factor", + type=float, + help="Adjusts the linear scaling of ROPE (Rotary Position Embedding) for context extension. " + "Set this factor relative to your base context length e.g., for an original context length of 8192 and " + "an extended context length of 524288, use 524288/8192 = 64.", + ) + parser.add_argument( + "--overlap-param-gather", + action="store_true", + default=False, + help="Overlap the parameter gather with the optimizer step. This is currently disabled due to a NeMo bug " + "when using DDP. Making this an option defaulting to False is a temporary solution until the bug is fixed.", + ) + parser.add_argument( + "--overlap-grad-reduce", + action="store_true", + default=False, + help="Overlap the gradient reduce with the optimizer step.", + ) + parser.add_argument( + "--hidden-dropout", + type=float, + default=0.0, + help="Dropout probability for the hyena layers", + ) + parser.add_argument( + "--attention-dropout", + type=float, + default=0.0, + help="Dropout probability for the attention layers.", + ) + recompute_group = parser.add_mutually_exclusive_group(required=False) + recompute_group.add_argument("--no-activation-checkpointing", action="store_true", default=False) + recompute_group.add_argument("--selective-activation-checkpointing", action="store_true", default=False) + return parser.parse_args() + + +def main(): + """Main function to run Evo2 training.""" + args = parse_args() + + # Parse dataset configuration. + + # Instantiate tokenizer. + tokenizer = get_nmt_tokenizer( + "byte-level", + ) + + # Infer global batch size. + global_batch_size = args.global_batch_size + if args.mock_data: + data = MockDataModule( + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + global_batch_size=global_batch_size, + num_workers=args.workers, + tokenizer=tokenizer, + ) + else: + blended_dataset_config = parse_dataset_config(args.dataset_config, args.dataset_dir) + dataset_cls = Evo2DatasetPadEodLossMask if args.eod_pad_in_loss_mask else Evo2Dataset + # Instantiate pre-training module. + data = PreTrainingDataModule( + paths=blended_dataset_config, + dataset_cls=dataset_cls, + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + global_batch_size=global_batch_size, + seed=args.seed, + num_workers=args.workers, + tokenizer=tokenizer, + eod_mask_loss=args.eod_pad_in_loss_mask, + ) + + if args.no_activation_checkpointing: + activation_checkpointing_args = { + "recompute_granularity": None, + "recompute_method": None, + "recompute_num_layers": None, + } + elif args.selective_activation_checkpointing: + activation_checkpointing_args = { + "recompute_granularity": "selective", + "recompute_method": None, + "recompute_num_layers": None, + } + else: + if args.activation_checkpoint_recompute_num_layers is not None: + activation_checkpointing_args = { + "recompute_num_layers": args.activation_checkpoint_recompute_num_layers, + } + else: + activation_checkpointing_args = {} + + # Retrieve model config. + config_modifiers_init = { + "tp_comm_overlap": args.use_megatron_comm_overlap_8k, + "seq_length": args.seq_length, + "hidden_dropout": args.hidden_dropout, + "attention_dropout": args.attention_dropout, + "to_upper": "weighted" if args.no_renormalize_loss else "normalized_weighted", + "distribute_saved_activations": False if args.sequence_parallel else True, + "cross_entropy_loss_fusion": args.cross_entropy_loss_fusion, + "fp32_residual_connection": not args.no_fp32_residual_connection, + "add_bias_output": args.add_bias_output, + **activation_checkpointing_args, + } + if args.hybrid_override_pattern: + config_modifiers_init["hybrid_override_pattern"] = args.hybrid_override_pattern + if args.num_layers: + config_modifiers_init["num_layers"] = args.num_layers + + if args.model_size not in HYENA_MODEL_OPTIONS: + raise ValueError(f"Invalid model size: {args.model_size}") + evo2_config = HYENA_MODEL_OPTIONS[args.model_size](**config_modifiers_init) + + # Instantiate model. + model = llm.HyenaModel(evo2_config, tokenizer=data.tokenizer) + + # Setup callbacks. + callbacks = [ + RichModelSummary(max_depth=4), + LearningRateMonitor(), + TimingCallback(), + ] + if args.create_checkpoint_callback: + checkpoint_callback = ModelCheckpoint( + every_n_train_steps=args.val_check_interval, + dirpath=args.experiment_dir, + save_top_k=5, + always_save_context=True, + save_optim_on_train_end=True, + save_context_on_train_end=True, + ) + callbacks.append(checkpoint_callback) + + if args.enable_preemption: + callbacks.append(nl_callbacks.PreemptionCallback()) + if args.debug_ddp_parity_freq > 0: + callbacks.append(nl_callbacks.DdpParityChecker(interval=args.debug_ddp_parity_freq)) + if args.log_parameters_and_shapes: + callbacks.append(nl_callbacks.ParameterDebugger()) + if args.tflops_callback: + # Add callback that logs the tera-FLOPS per second per GPU during training. + flop_meas_callback = FLOPsMeasurementCallback( + asdict(evo2_config), + data, + "hyena", + ) + callbacks.append(flop_meas_callback) + + # TODO(@cye): Add this back when it works with 24.12. + # if args.straggler_detection: + # callbacks.append( + # res_module.StragglerDetectionCallback( + # report_time_interval=300, + # calc_relative_gpu_perf=True, + # calc_individual_gpu_perf=True, + # num_gpu_perf_scores_to_print=5, + # gpu_relative_perf_threshold=0.7, + # gpu_individual_perf_threshold=0.7, + # stop_if_detected=True, + # enable_ptl_logging=True, + # ) + # ) + if args.use_megatron_comm_overlap_8k: + # Pick the floating point appropriate config. + if args.fp8: + tp_comm_overlap_cfg = userbuffers_fp8_h100_h8192_tp4_mbs1_seqlen8192 + else: + tp_comm_overlap_cfg = userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192 + callbacks.append( + MegatronCommOverlapCallback( + tp_comm_overlap=evo2_config.tp_comm_overlap, + tp_comm_overlap_cfg=tp_comm_overlap_cfg, + tp_comm_bootstrap_backend=args.tp_comm_overlap_backend, + wgrad_deferral_limit=22, # default from NeMo + overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing. + align_param_gather=args.align_param_gather, + ) + ) + + if args.gc_interval > 0: + callbacks.append( + nl_callbacks.GarbageCollectionCallback( + gc_interval_train=args.gc_interval, gc_interval_val=args.gc_interval + ) + ) + if args.nsys_profiling: + if args.nsys_end_step is None: + nsys_end_step = args.max_steps + else: + nsys_end_step = args.nsys_end_step + callbacks.append( + nl_callbacks.NsysCallback( + start_step=args.nsys_start_step, end_step=nsys_end_step, ranks=args.nsys_ranks, gen_shape=True + ) + ) + + loggers = [] + nemo_logger_kwargs = {} + if (not args.no_wandb) and args.wandb_project: + wandb_logger = WandbLogger( + name=( + f"evo2-size-{args.model_size}-TP{args.tensor_parallel_size}-" + f"PP{args.pipeline_model_parallel_size}-CP{args.context_parallel_size}" + f"-GBS{global_batch_size}-MBS{args.micro_batch_size}-SkipLossRenorm{args.no_renormalize_loss}" + f"-NOAC{args.no_activation_checkpointing}-SELAC{args.selective_activation_checkpointing}" + f"-ACRNL{evo2_config.recompute_num_layers}" + f"-PAT{evo2_config.hybrid_override_pattern}" + f"-F32R{evo2_config.fp32_residual_connection}" + f"-FCE{evo2_config.cross_entropy_loss_fusion}" + f"-AIC{not args.no_average_in_collective}" + f"-PEOD{args.eod_pad_in_loss_mask}" + f"-BO{args.add_bias_output}" + f"-GCLP{args.clip_grad}" + f"-HDO{args.hidden_dropout}" + f"-ADO{args.attention_dropout}" + f"-LR{args.lr}-MINLR{args.min_lr}-WUSTEPS{args.warmup_steps}-WD{args.wd}" + f"-GRFP32{args.grad_reduce_in_fp32}-FP8WG{args.fp8_wgrad and args.fp8}" + f"-OGR{args.overlap_grad_reduce}-OPG{args.overlap_param_gather}" + f"-NODES{args.num_nodes}-FP8{args.fp8}" + ), + group=args.wandb_group, + job_type=args.wandb_job_type, + id=args.wandb_run_id, # set this to use the same curve name for restarts. + project=args.wandb_project, + save_dir=args.experiment_dir, + ) + loggers.append(wandb_logger) + nemo_logger_kwargs["wandb"] = wandb_logger + tb_logger = TensorBoardLogger( + save_dir="dummy", ## NOTE: this gets overwritten by default + ) + nemo_logger_kwargs["tensorboard"] = tb_logger + loggers.append(tb_logger) + + nemo_logger = NeMoLogger(log_dir=args.experiment_dir, **nemo_logger_kwargs) + ddp: DistributedDataParallelConfig = DistributedDataParallelConfig( + check_for_nan_in_grad=True, + overlap_grad_reduce=args.overlap_grad_reduce, + overlap_param_gather=args.overlap_param_gather, # Verify that this works using + grad_reduce_in_fp32=args.grad_reduce_in_fp32, + align_param_gather=args.align_param_gather, + average_in_collective=not args.no_average_in_collective, + ) + # Initialize Megatron Strategy and Trainer. + strategy = nl.MegatronStrategy( + ddp=ddp, + tensor_model_parallel_size=args.tensor_parallel_size, + pipeline_model_parallel_size=args.pipeline_model_parallel_size, + context_parallel_size=args.context_parallel_size, + pipeline_dtype=torch.bfloat16, + sequence_parallel=args.sequence_parallel, + ckpt_load_optimizer=True, + ckpt_save_optimizer=True, + ckpt_async_save=args.ckpt_async_save, + save_ckpt_format=args.ckpt_format, + ckpt_load_strictness="log_all", # or rebasing to https://github.com/NVIDIA/NeMo/pull/11988/files#diff-7667eae242a8ef776bff78cd08e79bc81df4896a450f0a781f6ed317a3dfb7ffR139 + ) + trainer = nl.Trainer( + devices=args.devices, + num_nodes=args.num_nodes, + max_steps=args.max_steps, + accelerator="gpu", + strategy=strategy, + logger=loggers, + callbacks=callbacks, + log_every_n_steps=args.log_every_n_steps, + limit_val_batches=args.limit_val_batches, + num_sanity_val_steps=0, + use_distributed_sampler=False, + plugins=nl.MegatronMixedPrecision( + precision="bf16-mixed", + params_dtype=torch.bfloat16, + grad_reduce_in_fp32=args.grad_reduce_in_fp32, + fp8="hybrid" if args.fp8 else None, + fp8_amax_history_len=16 if args.fp8 else 1, + fp8_amax_compute_algo="max" if args.fp8 else "most_recent", + fp8_wgrad=args.fp8 + and ( + args.fp8_wgrad or args.use_megatron_comm_overlap_8k + ), # faster and less accurate when set to True, and MUST be True if using TP communication overlap + ), + val_check_interval=args.val_check_interval, + enable_checkpointing=args.create_checkpoint_callback, + ) + + # Logger setup + nemo_logger.setup( + trainer, + resume_if_exists=True, + ) + + resume = nl.AutoResume( + resume_if_exists=True, + resume_ignore_no_checkpoint=True, + resume_past_end=False, + resume_from_directory=args.experiment_dir, + restore_config=( + RestoreConfig( + path=args.ckpt_dir, + load_model_state=True, + load_optim_state=args.restore_optimizer_from_ckpt, + ) + if args.ckpt_dir + else None + ), + ) + resume.setup(trainer, model) + + # Optimizer and scheduler setup + opt_config = OptimizerConfig( + optimizer="adam", + lr=args.lr, + adam_beta1=0.9, + adam_beta2=0.95, + weight_decay=args.wd, + clip_grad=args.clip_grad, + use_distributed_optimizer=True, + bf16=True, + ) + sched = CosineAnnealingScheduler( + max_steps=trainer.max_steps, + warmup_steps=args.warmup_steps, + min_lr=args.min_lr, + ) + + opt = MegatronOptimizerModule(opt_config, sched, no_weight_decay_cond=evo2_config.hyena_no_weight_decay_cond_fn) + opt.connect(model) + + # Start training + trainer.fit(model, data) + + +if __name__ == "__main__": + """ Example command to run the script, use --help for more options.: + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc-per-node=8 \ + /opt/NeMo/tests/collections/llm/gpt/model/test_hyena.py \ + --num-nodes=1 \ + --devices=8 \ + --max-steps=500000 \ + --val-check-interval=200 \ + --experiment-dir= \ + --dataset-config= \ + --seq-length=8192 \ + --tensor-parallel-size=1 \ + --pipeline-model-parallel-size=1 \ + --context-parallel-size=1 \ + --global-batch-size=16 \ + --micro-batch-size=2 \ + --model-size=7b \ + --fp8 \ + --clip-grad 0 \ + --overlap-grad-reduce \ + --lr=0.0003 \ + --warmup-steps=2500 \ + --wandb-project=nemo_evo2 + + """ + main() diff --git a/tests/collections/llm/gpt/model/test_hyena_accuracy.py b/tests/collections/llm/gpt/model/test_hyena_accuracy.py new file mode 100644 index 000000000000..0e5c455e7405 --- /dev/null +++ b/tests/collections/llm/gpt/model/test_hyena_accuracy.py @@ -0,0 +1,280 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +########################################################### +# BEGIN COPY/pasted bionemo stuff: +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Iterator, Literal, Optional, Set, TypeVar + +import lightning.pytorch as pl +import megatron.core.num_microbatches_calculator +import pytest +import torch +import torch.distributed +from megatron.core import parallel_state +from megatron.core.dist_checkpointing.mapping import ShardedTensor +from megatron.core.tensor_parallel import random as tp_random +from megatron.core.transformer.enums import AttnBackend +from megatron.core.transformer.module import Float16Module, MegatronModule + +from nemo.collections import llm +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +from nemo.lightning.io.pl import MegatronCheckpointIO + + +def _munge_key_megatron_to_nemo2(k: str) -> str: + return f"module.{k}" + + +def _munge_sharded_tensor_key_megatron_to_nemo2(v: ShardedTensor) -> ShardedTensor: + # This works with PP=1, how do we handle PP>1? + key = v.key + v.key = _munge_key_megatron_to_nemo2(key) + return v + + +def _key_in_filter(k: str, filter: Set[str]) -> bool: + for prefix in filter: + if k.startswith(prefix): + return True + return False + + +MegatronModelType = TypeVar("MegatronModelType", bound=MegatronModule) + + +def _reset_microbatch_calculator(): + """Resets _GLOBAL_NUM_MICROBATCHES_CALCULATOR in megatron which is used in NeMo to initilised model parallel in + nemo.collections.nlp.modules.common.megatron.megatron_init.initialize_model_parallel_for_nemo + """ # noqa: D205, D415 + megatron.core.num_microbatches_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None + + +def _dummy() -> None: + return + + +def _teardown_apex_megatron_cuda(): + """Cleans GPU allocation and model and data parallel settings after usage of a model: + - sets the global variables related to model and data parallelism to None in Apex and Megatron:. + - releases all unoccupied cached GPU memory currently held by the caching CUDA allocator, see torch.cuda.empty_cache + """ # noqa: D205, D415 + torch.cuda.empty_cache() + _reset_microbatch_calculator() + parallel_state.destroy_model_parallel() + + +def _initialize_distributed_parallel_state( + devices: int = 1, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + pipeline_model_parallel_split_rank: int = 0, + context_parallel_size: int = 1, + interactive: bool = False, +) -> None: + # initialize pytorch DDP + # if not interactive and not torch.distributed.is_initialized(): + if not torch.distributed.is_initialized(): + logging.info("pytorch DDP is not initialized. Initializing with pytorch-lightening...") + trainer = pl.Trainer(devices=devices, strategy="ddp" if not interactive else "auto", num_nodes=1) + + if trainer.strategy.launcher is not None: + trainer.strategy.launcher.launch(_dummy, trainer=trainer) + trainer.strategy.setup_environment() + + if not interactive and parallel_state.is_unitialized(): + logging.info("Megatron DDP is not initialized. Initializing...") + parallel_state.initialize_model_parallel( + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + pipeline_model_parallel_split_rank=pipeline_model_parallel_split_rank, + context_parallel_size=context_parallel_size, + ) + + +@contextmanager +def distributed_model_parallel_state( + seed: Optional[int] = 42, + devices: int = 1, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + pipeline_model_parallel_split_rank: int = 0, + context_parallel_size: int = 1, + interactive: bool = False, +) -> Iterator[None]: + """Context manager for handling creating and cleaning up distributed model parallel state for tests. + Use like: + with distributed_model_parallel_state(): + # your test code here + # After the block your state is cleaned up. + """ # noqa: D205 + initial_states: Optional[Any] = None + + try: + _teardown_apex_megatron_cuda() + _initialize_distributed_parallel_state( + devices=devices, + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + pipeline_model_parallel_split_rank=pipeline_model_parallel_split_rank, + context_parallel_size=context_parallel_size, + interactive=interactive, + ) + # Our goal is to set required state on entry, and then restore current state on exit for the RNGs. + # there are two possibilities that are handled below: + # 1. If the RNG state is not initialized, we need to set it up and then + # unset it on exit to restore the current state. We track that this is the case when `initial_states` is `None`. + # 2. If the RNG state is initialized, we need to track this state and reset it on exit to be what it was on entry. + # We track that this is the case when `initial_states` is not `None`. + if tp_random.get_cuda_rng_tracker().is_initialized(): + initial_states = tp_random.get_cuda_rng_tracker().get_states() + if seed is not None: + # Set the seed if provided, this case is valid whether or not the RNG had state previously. + # on exit the RNG state will be restored to what it was on entry. + tp_random.model_parallel_cuda_manual_seed(seed) + else: + # This is the case where the RNG state is not initialized and no seed was provided. + # We need to raise an error in this case, as we cannot restore the RNG state on exit and we need a seed + # to initialize the RNG state to. This only happens if the user overrides the default seed and sets it + # to None, and additionally if the RNG state was not initialized externally, as there is a default seed of 42. + if initial_states is None: + raise ValueError( + "You must provide a seed if the initial parallel state is unset. " + "Either provide a seed or leave the default seed (rather setting to None) " + "or initialize the RNG state externally." + ) + yield + finally: + if initial_states is not None: + tp_random.get_cuda_rng_tracker().set_states(initial_states) + else: + # Reset to the unset state + tp_random.get_cuda_rng_tracker().reset() + _teardown_apex_megatron_cuda() + + +# END COPY/pasted bionemo stuff +############################################################### + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) # Capture all levels in the logger itself + + +def load_weights_sharded_inplace_nemo2_to_mcore( + model: MegatronModelType, + distributed_checkpoint_dir: str | Path, + skip_keys_with_these_prefixes: Set[str], + ckpt_format: Literal["zarr", "torch_dist"] = "zarr", +): + logger.info("Start setting up state dict") + sharded_state_dict = { + _munge_key_megatron_to_nemo2(k): _munge_sharded_tensor_key_megatron_to_nemo2(v) + for k, v in model.sharded_state_dict().items() + if not _key_in_filter( + k, skip_keys_with_these_prefixes + ) # and "_extra_state" not in k # extra state is needed for fp8 sharded states + } + MegatronCheckpointIO(save_ckpt_format=ckpt_format).load_checkpoint( + distributed_checkpoint_dir, sharded_state_dict=sharded_state_dict + ) + + +@pytest.mark.skip(reason="Skipping test due to slow runtime and non-availability of model/test data in CI.") +def test_golden_values(use_te: bool = True): + """Step 1: + # add local .ssh/*.pub key to eos ~/.ssh/authorized_keys + mkdir -p arc_model/checkpoints/ + rsync -avz --progress --partial login-eos01.eos.clusters.nvidia.com:/lustre/fsw/healthcareeng_bionemo/arc_evo2/savanna_outputs/interleaved_hyena_7b arc_model/checkpoints/ + rsync -avz --progress --partial login-eos01.eos.clusters.nvidia.com:/lustre/fsw/healthcareeng_bionemo/arc_evo2/savanna_outputs/interleaved_hyena_7b_no_te arc_model/checkpoints/ + mkdir -p arc_model/gold_standards/ + rsync -avz --progress --partial login-eos01.eos.clusters.nvidia.com:/lustre/fsw/healthcareeng_bionemo/arc_evo2/savanna_outputs/interleaved_7b_golden_value.pt arc_model/gold_standards/ + rsync -avz --progress --partial login-eos01.eos.clusters.nvidia.com:/lustre/fsw/healthcareeng_bionemo/arc_evo2/savanna_outputs/final_7b_no_fp8_golden_value.pt arc_model/gold_standards/ + """ + if use_te: + cfg_path = "arc_model/checkpoints/interleaved_hyena_7b/weights" # TODO interleaved checkpoint + else: + cfg_path = "arc_model/checkpoints/interleaved_hyena_7b_no_te/weights" + + with torch.inference_mode(), distributed_model_parallel_state(): + hyena_config = llm.Hyena7bConfig(use_te=use_te, attention_backend=AttnBackend.fused) + tokenizer = get_nmt_tokenizer( + "byte-level", + ) + raw_megatron_model = hyena_config.configure_model(tokenizer).eval().cuda() + device = raw_megatron_model.parameters().__next__().device + load_weights_sharded_inplace_nemo2_to_mcore(raw_megatron_model, cfg_path, {}, "zarr") + """ + fp8='hybrid', fp8_margin=0, fp8_interval=1, fp8_amax_history_len=16, fp8_amax_compute_algo='max', fp8_wgrad=True, fp8_dot_product_attention=False, fp8_multi_head_attention=False, tp_only_amax_red=False + """ + model = Float16Module(hyena_config, raw_megatron_model) + input_seq = "GAAATTAGCGCGTCCGGAATGATACGAGGGGAAACGAAATTTTGAATTAATGGAGAAAAAAGACGAGAAACCTTAAGCAAAAAAATTTTAGCTTCGAATATTTATTAATTTCTGAGATGTTGTTAAACGATTTTCGATTCCAAGTTGTGCGCACGAACGTTATTGCAAATAAATGCTGCTTATTCGGATGTTTCCACGATCTTTGTTGCAATGGTAGTCGAGTACCCGATAACCCAATTTCGTTACATCGGCCTATCTGTAGAATATCCAATCTATGGTTCATAAAAAATCTGATCGTTTGTTTTTAAGAAATTAAACGCGTTAAATTGAACGAATTTCGAATACCGGTCTTAGCGAAGGACCTCCCCTCTTGCTTGCGTATTGCCCCGCGAAATTTCTTTTCGGCGATGAACGATACAAAAAATTCTATCGAATGTTACTTCTATTCTCTGCCTCGTCTATGACTTGGAGATTGGTCTATGTCGTTCGTTTTCTCGCGAGTTTCCAATATGTCCGTAGTATGTGAACGCTGGTATTCGTGAAGATAAATTATTGTTTTTACAATTTCTTTCAAAAATATATAATTTTAATTTATATAAT" + input_ids = torch.tensor(tokenizer.text_to_ids(input_seq)).int().unsqueeze(0).to(device) + position_ids = torch.arange(len(input_seq)).unsqueeze(0).to(device) + attention_mask = None + outputs = model(input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask) + gold_standard_no_fp8 = torch.load("arc_model/gold_standards/final_7b_no_fp8_golden_value.pt").to( + device=outputs.device, dtype=outputs.dtype + ) + gold_standard_fp8 = torch.load("arc_model/gold_standards/interleaved_7b_golden_value.pt").to( + device=outputs.device, dtype=outputs.dtype + ) + + our_generation_str = "".join( + [chr(idx) for idx in outputs.softmax(dim=-1).argmax(dim=-1).flatten().detach().cpu().numpy().tolist()] + ) + their_generation_str_fp8 = "".join( + [ + chr(idx) + for idx in gold_standard_fp8.softmax(dim=-1).argmax(dim=-1).flatten().detach().cpu().numpy().tolist() + ] + ) + their_generation_str_no_fp8 = "".join( + [ + chr(idx) + for idx in gold_standard_no_fp8.softmax(dim=-1) + .argmax(dim=-1) + .flatten() + .detach() + .cpu() + .numpy() + .tolist() + ] + ) + char_matches_ours_v_theirs_no_fp8 = [ + our_generation_str[i] == their_generation_str_no_fp8[i] for i in range(len(their_generation_str_no_fp8)) + ] + char_matches_ours_v_theirs_fp8 = [ + our_generation_str[i] == their_generation_str_fp8[i] for i in range(len(their_generation_str_fp8)) + ] + char_matches_theirs_v_theirs_fp8_vs_not = [ + their_generation_str_fp8[i] == their_generation_str_no_fp8[i] + for i in range(len(their_generation_str_no_fp8)) + ] + token_similarity_vs_no_fp8 = sum(char_matches_ours_v_theirs_no_fp8) / len(char_matches_ours_v_theirs_no_fp8) + token_similarity_vs_fp8 = sum(char_matches_ours_v_theirs_fp8) / len(char_matches_ours_v_theirs_fp8) + token_similarity_theirs = sum(char_matches_theirs_v_theirs_fp8_vs_not) / len( + char_matches_theirs_v_theirs_fp8_vs_not + ) + assert ( + token_similarity_vs_no_fp8 >= token_similarity_theirs + and token_similarity_vs_fp8 >= token_similarity_theirs + ) + torch.testing.assert_close(outputs, gold_standard_no_fp8) diff --git a/tests/core/test_save_restore.py b/tests/core/test_save_restore.py index 8ac9dfeca1ae..a06018839013 100644 --- a/tests/core/test_save_restore.py +++ b/tests/core/test_save_restore.py @@ -30,6 +30,13 @@ from nemo.utils.exceptions import NeMoBaseException +@pytest.fixture(scope="module", autouse=True) +def set_env(): + os.environ["HF_HOME"] = "/home/TestData/hf_home_test_save_restore" + yield + del os.environ["HF_HOME"] + + def classpath(cls): return f'{cls.__module__}.{cls.__name__}' diff --git a/tests/lightning/pytorch/callbacks/test_flops_callback.py b/tests/lightning/pytorch/callbacks/test_flops_callback.py new file mode 100644 index 000000000000..757ff5e924f3 --- /dev/null +++ b/tests/lightning/pytorch/callbacks/test_flops_callback.py @@ -0,0 +1,59 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024 Arc Institute. All rights reserved. +# Copyright (c) 2024 Michael Poli. All rights reserved. +# Copyright (c) 2024 Stanford University. All rights reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from nemo.collections.llm.gpt.model.base import GPTConfig +from nemo.lightning.pytorch.callbacks.flops_callback import FLOPsMeasurementCallback + + +class MockDataModule: + def __init__(self, global_batch_size, vocab_size): + self.global_batch_size = global_batch_size + self.tokenizer = self + self.vocab_size = vocab_size + + +def test_flops_measurement_callback_bert(): + model_config = GPTConfig( + seq_length=128, + hidden_size=768, + num_layers=12, + ffn_hidden_size=3072, + num_attention_heads=12, + moe_router_topk=0, + num_query_groups=12, + ) + + train_step_time = 1.23 + global_batch_size = 1 + num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + model_name = "bert" + data_module = MockDataModule(global_batch_size=global_batch_size, vocab_size=100) + callback = FLOPsMeasurementCallback(model_config, data_module, model_name) + total_flops, flops_per_gpu = callback.eval_model_flops() + + expected_total_flops = 84146651135.99998 + expected_flops_per_gpu = expected_total_flops / num_devices + + assert total_flops == expected_total_flops + assert flops_per_gpu == expected_flops_per_gpu + + tflops_per_sec_per_gpu = callback.eval_tflops_per_sec_per_gpu(train_step_time) + expected_tflops_per_sec_per_gpu = expected_flops_per_gpu / (1e12 * train_step_time) + assert tflops_per_sec_per_gpu == expected_tflops_per_sec_per_gpu diff --git a/tests/utils/test_flops_formulas.py b/tests/utils/test_flops_formulas.py new file mode 100644 index 000000000000..8176f333f67f --- /dev/null +++ b/tests/utils/test_flops_formulas.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nemo.utils.flops_formulas import FLOPSConfig, bert, gpt3, llama2, llama3, mixtral, nemotron +from nemo.utils.hyena_flops_formulas import hyena + + +@pytest.fixture +def flops_config(): + return FLOPSConfig( + gbs=1, + enc_seq_len=128, + hs=768, + layers=12, + ffn_hs=3072, + attention_heads=12, + moe_router_topk=2, + query_groups=12, + vocab_size=50257, + model_pattern="SDH*", + ) + + +def test_gpt3(flops_config): + expected_flops = 97240743936 + assert gpt3(flops_config) == expected_flops + + +def test_llama2(flops_config): + expected_flops = 107659395072.0 + assert llama2(flops_config) == expected_flops + + +def test_llama3(flops_config): + expected_flops = 164433494016.0 + assert llama3(flops_config) == expected_flops + + +def test_nemotron(flops_config): + expected_flops = 218036699136.0 + assert nemotron(flops_config) == expected_flops + + +def test_mixtral(flops_config): + expected_flops = 172889210880.0 + assert mixtral(flops_config) == expected_flops + + +def test_bert(flops_config): + expected_flops = 84146651135.99998 + assert bert(flops_config) == expected_flops + + +def test_hyena(flops_config): + expected_flops = 116883062784.0 + assert hyena(flops_config) == expected_flops