Skip to content

Commit

Permalink
Merge branch 'r2.0.0rc1' into r2.0.0rc1
Browse files Browse the repository at this point in the history
  • Loading branch information
terrykong authored Jul 11, 2024
2 parents 72c19e1 + cf167e6 commit 9ff422f
Show file tree
Hide file tree
Showing 12 changed files with 95 additions and 43 deletions.
3 changes: 2 additions & 1 deletion docs/source/nlp/nemo_megatron/intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ To learn more about using NeMo to train Large Language Models at scale, please r
peft/landing_page
positional_embeddings
mcore_customization
reset_learning_rate


References
Expand All @@ -28,4 +29,4 @@ References
.. bibliography:: ../nlp_all.bib
:style: plain
:labelprefix: nlp-megatron
:keyprefix: nlp-megatron-
:keyprefix: nlp-megatron-
30 changes: 30 additions & 0 deletions docs/source/nlp/nemo_megatron/reset_learning_rate.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
.. _reset_learning_rate:

Reset Learning Rate
-------------------

The reset learning rate feature provides the ability to reset the learning rate for an existing checkpoint to its initial value (either 0 or ``optim.min_lr`` depending on the warmup steps) when performing continual pretraining.

Parameters
----------

* ``reset_lr`` (boolean): Enables resetting the learning rate to the initial value. This feature is only supported with the distributed optimizer and megatron_amp_O2.
* ``reset_lr_steps`` (boolean): Enables adjusting the learning rate's max_steps and decay_steps by subtracting the number of steps already completed at the checkpoint.

Use Cases
---------

1. ``reset_lr=True, reset_lr_steps=False``
When pretraining an existing checkpoint "from scratch" on a different dataset. The learning rate will be reset to its initial value. This allows the model to start training on a new dataset with the same learning rate dynamics as if it were starting from scratch.

2. ``reset_lr=True, reset_lr_steps=True``
When continuing training from an existing checkpoint with the same configuration. The learning rate will be reset to its initial value, and the ``max_steps`` and ``decay_steps`` for learning rate schedule will be recalculated by subtracting the number of steps already completed at the checkpoint. Specifically:
* ``max_steps`` will be recalculated as ``max_steps -= completed_steps``.
* ``decay_steps`` will be recalculated as ``decay_steps -= completed_steps``.
This ensures that the learning rate reaches the ``min_lr`` value by the end of training without changing the ``trainer.max_steps``:

.. image:: https://github.com/NVIDIA/NeMo/releases/download/v2.0.0rc0/asset-post-reset-learning-rate-example.png
:alt:
:width: 1080px


5 changes: 4 additions & 1 deletion nemo/collections/llm/gpt/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pytorch_lightning as L
import torch
import torch.distributed
from megatron.core.models.gpt import gpt_layer_specs
from megatron.core.optimizer import OptimizerConfig
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
Expand Down Expand Up @@ -66,12 +65,16 @@ def gpt_forward_step(model, batch) -> torch.Tensor:


def transformer_engine_layer_spec(config: "GPTConfig") -> ModuleSpec:
from megatron.core.models.gpt import gpt_layer_specs

return gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec(
num_experts=config.num_moe_experts, moe_grouped_gemm=config.moe_grouped_gemm, qk_layernorm=config.qk_layernorm
)


def local_layer_spec(config: "GPTConfig") -> ModuleSpec:
from megatron.core.models.gpt import gpt_layer_specs

return gpt_layer_specs.get_gpt_layer_local_spec(
num_experts=config.num_moe_experts, moe_grouped_gemm=config.moe_grouped_gemm, qk_layernorm=config.qk_layernorm
)
Expand Down
7 changes: 5 additions & 2 deletions nemo/lightning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ def teardown(trainer: Trainer, model: Optional[nn.Module] = None) -> None:
trainer._teardown() # noqa: SLF001
if model is not None:
for obj in gc.get_objects():
if torch.is_tensor(obj) and obj.is_cuda:
del obj
try:
if torch.is_tensor(obj) and obj.is_cuda:
del obj
except:
pass

gc.collect()
torch.cuda.empty_cache()
Expand Down
3 changes: 2 additions & 1 deletion nemo/lightning/io/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def nemo_setup(self, model: pl.LightningModule, trainer: Optional[pl.Trainer] =
from nemo.lightning import MegatronStrategy, Trainer

_trainer = trainer or Trainer(
devices=1, accelerator="cpu", strategy=MegatronStrategy(store_optimizer_states=False, ddp="pytorch")
devices=1, accelerator="cpu", strategy=MegatronStrategy(store_optimizer_states=False)
)

_trainer.strategy.connect(model)
Expand All @@ -161,6 +161,7 @@ def nemo_save(self, output_path: Path, trainer: pl.Trainer) -> None:
trainer (pl.Trainer): The trainer with the strategy to save the model.
"""
trainer.strategy._setup_optimizers = False
trainer.strategy._init_model_parallel = False
trainer.strategy.setup(trainer)
trainer.save_checkpoint(output_path)

Expand Down
16 changes: 10 additions & 6 deletions nemo/lightning/io/pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,14 @@

import pytorch_lightning as pl
import torch
from lightning_fabric.plugins import CheckpointIO
from lightning_fabric.plugins.io.checkpoint_io import CheckpointIO
from lightning_fabric.utilities.cloud_io import get_filesystem
from lightning_fabric.utilities.types import _PATH
from megatron.core.dist_checkpointing.serialization import (
get_default_load_sharded_strategy,
get_default_save_sharded_strategy,
)

# from nemo.utils.callbacks.torch_dist_async import TorchDistAsyncSaveShardedStrategy
from megatron.core.dist_checkpointing.strategies import tensorstore
from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue, AsyncRequest
from megatron.core.dist_checkpointing.strategies.base import SaveShardedStrategy
from megatron.core.dist_checkpointing.strategies.fully_parallel import (
FullyParallelLoadStrategyWrapper,
Expand All @@ -28,7 +25,12 @@

from nemo.lightning.io.capture import IOProtocol
from nemo.lightning.io.mixin import IOMixin
from nemo.utils.callbacks.dist_ckpt_io import AsyncCompatibleCheckpointIO

try:
from nemo.utils.callbacks.dist_ckpt_io import AsyncCompatibleCheckpointIO
except ImportError:
AsyncCompatibleCheckpointIO = CheckpointIO


log = logging.getLogger(__name__)

Expand Down Expand Up @@ -163,7 +165,9 @@ def load_checkpoint(
raise ValueError(f"Distributed checkpoints should be a directory. Found: {path}.")

if self.save_ckpt_format == 'zarr' and self.load_directly_on_device:
sharded_strategy = tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device=True)
from megatron.core.dist_checkpointing.strategies.tensorstore import TensorStoreLoadShardedStrategy

sharded_strategy = TensorStoreLoadShardedStrategy(load_directly_on_device=True)
else:
sharded_strategy = None

Expand Down
5 changes: 3 additions & 2 deletions nemo/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import shutil
from datetime import timedelta
from pathlib import Path
from typing import Any, Dict, Iterable, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Union

import pytorch_lightning
import torch
Expand All @@ -30,7 +30,6 @@
from nemo.lightning.io.pl import TrainerContext
from nemo.utils import logging
from nemo.utils.app_state import AppState
from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO
from nemo.utils.model_utils import ckpt_to_dir


Expand Down Expand Up @@ -401,6 +400,8 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)
finalize_fn = self._get_finalize_save_checkpoint_callback(trainer, filepath, trainer.global_step)
if self.async_save:
checkpoint_io = trainer.strategy.checkpoint_io
from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO

if not isinstance(checkpoint_io, AsyncFinalizableCheckpointIO):
raise ValueError('Async save requires async compatible CheckpointIO')
storage_options = dict(finalize_fn=finalize_fn)
Expand Down
3 changes: 3 additions & 0 deletions nemo/lightning/pytorch/callbacks/model_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._maybe_apply_transform(trainer)

def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._maybe_apply_transform(trainer)

def _maybe_apply_transform(self, trainer):
if self._needs_to_call:
self.apply_transform(trainer)
Expand Down
25 changes: 18 additions & 7 deletions nemo/lightning/pytorch/callbacks/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn as nn
from lightning_fabric.utilities.types import _PATH
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
from pytorch_lightning.trainer.states import TrainerFn
from typing_extensions import override

from nemo.lightning.io.pl import ckpt_to_dir
Expand Down Expand Up @@ -102,13 +103,26 @@ def apply_transform(self, trainer):
logging.info("Initializing model parallel")
trainer.strategy.init_model_parallel()

logging.info("Setting up optimizers")
trainer.strategy.setup_optimizers(trainer)
if trainer.state.fn == TrainerFn.FITTING:
logging.info("Setting up optimizers")
trainer.strategy.setup_optimizers(trainer)

def on_load_checkpoint(
def on_save_checkpoint(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any]
) -> None:
pl_module.strict_loading = False
# Filter out non-trainable parameters
trainable_params = set(name for name, param in pl_module.named_parameters() if param.requires_grad)
filtered_state_dict = {}
for name, value in checkpoint['state_dict'].items():
if name in trainable_params:
filtered_state_dict[name] = value
elif self.adapter_key_filter(name): # Include all adapter-related parameters
filtered_state_dict[name] = value

checkpoint['state_dict'] = filtered_state_dict

def adapter_key_filter(self, key: str) -> bool:
return ".adapter." in key or key.endswith(".adapters")


class AdapterWrapper(nn.Module):
Expand Down Expand Up @@ -232,9 +246,6 @@ class WrappedAdapterIO(_WrappingCheckpointIO):
def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
assert self.checkpoint_io is not None

key = "sharded_state_dict" if "sharded_state_dict" in checkpoint else "state_dict"
checkpoint[key] = dict(filter(lambda x: ".adapter." in x[0], checkpoint[key].items()))

self.checkpoint_io.save_checkpoint(checkpoint, path, storage_options=storage_options)

from nemo.utils.get_rank import is_global_rank_zero
Expand Down
3 changes: 3 additions & 0 deletions nemo/lightning/pytorch/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
ckpt_parallel_save_within_dp=False,
ckpt_parallel_load=False,
ckpt_parallel_save_optim=True,
ckpt_load_directly_on_device=True,
setup_optimizers: bool = True,
init_model_parallel: bool = True,
**kwargs,
Expand Down Expand Up @@ -147,6 +148,7 @@ def __init__(
self.parallel_save_within_dp = ckpt_parallel_save_within_dp
self.parallel_load = ckpt_parallel_load
self.parallel_save_optim = ckpt_parallel_save_optim
self.load_directly_on_device = ckpt_load_directly_on_device

self._ddp = ddp
if ddp == "megatron":
Expand Down Expand Up @@ -582,6 +584,7 @@ def checkpoint_io(self) -> CheckpointIO:
parallel_save=self.parallel_save,
parallel_save_within_dp=self.parallel_save_within_dp,
parallel_load=self.parallel_load,
load_directly_on_device=self.load_directly_on_device,
)
if async_save:
self._checkpoint_io = AsyncFinalizableCheckpointIO(self._checkpoint_io)
Expand Down
26 changes: 13 additions & 13 deletions nemo/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,12 @@ def __init__(self, weight, bias, skip_bias_add):
self.weight = weight
self.skip_bias_add = skip_bias_add

def forward(self, x):
def forward(self, x, weight=None):
if weight is None:
weight = self.weight
if self.skip_bias_add:
return F.linear(x, self.weight), self.bias
return F.linear(x, self.weight, self.bias), None
return F.linear(x, weight), self.bias
return F.linear(x, weight, self.bias), None


def get_export_format(filename: str):
Expand Down Expand Up @@ -239,7 +241,8 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01):
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
from apex.normalization import MixedFusedRMSNorm
from apex.normalization.fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm
from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm as MCoreFusedLayerNorm
from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear

def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]:
Expand All @@ -255,21 +258,17 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]:

if isinstance(n, FusedLayerNorm) or isinstance(n, MixedFusedLayerNorm):
shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine
n_state = n.state_dict()
elif isinstance(n, MCoreFusedLayerNorm):
shape, eps, affine = n.weight.shape, n.eps, True
elif isinstance(n, FastLayerNorm):
shape, eps, affine = n.weight.shape, n.epsilon, True
n_state = n.state_dict()
elif isinstance(n, MixedFusedRMSNorm):
shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine
tmp_n_state = n.state_dict()
n_state = {'weight': tmp_n_state['weight'], 'bias': torch.zeros_like(tmp_n_state['weight'])}
else:
return None

n_state = n.state_dict()
mod = nn.LayerNorm(shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype)

mod.load_state_dict(n_state)
mod.load_state_dict(n_state, strict=True)

return mod

Expand Down Expand Up @@ -306,7 +305,7 @@ def replace_ParallelLinear(n: nn.Module) -> Optional[nn.Linear]:
mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(dev)

n_state = n.state_dict()
mod.load_state_dict(n_state)
mod.load_state_dict(n_state, strict=False)
return mod

def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
Expand All @@ -318,7 +317,7 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
Equivalent LayerNorm module
"""
if not isinstance(n, FusedScaleMaskSoftmax):
logging.warning("This function can only change the FusedScaleMaskSoftmax module.")
logging.warning(f"This function can only change the FusedScaleMaskSoftmax module, got: {n.__class__}")
return n

# disable the fusion only
Expand All @@ -331,6 +330,7 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
default_Apex_replacements = {
"FusedLayerNorm": replace_FusedLayerNorm,
"MixedFusedLayerNorm": replace_FusedLayerNorm,
"MCoreFusedLayerNorm": replace_FusedLayerNorm,
"FastLayerNorm": replace_FusedLayerNorm,
"RowParallelLinear": replace_ParallelLinear,
"ColumnParallelLinear": replace_ParallelLinear,
Expand Down
12 changes: 2 additions & 10 deletions tests/lightning/pytorch/callbacks/test_peft.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest.mock import MagicMock, call, patch

import torch.nn as nn
from pytorch_lightning.trainer.states import TrainerFn
from nemo.collections.llm import fn
from nemo.lightning.pytorch.callbacks.peft import PEFT, WrappedAdapterIO

Expand Down Expand Up @@ -43,6 +44,7 @@ def test_peft_on_train_epoch_start_with_adapter(self, mock_logging):
trainer = MagicMock()
pl_module = MagicMock()
pl_module.model_transform = peft
trainer.state.fn = TrainerFn.FITTING # Mock the trainer to be in FITTING state

peft.setup(trainer, pl_module, "fit")

Expand Down Expand Up @@ -70,13 +72,3 @@ def test_peft_on_train_epoch_start_with_adapter(self, mock_logging):
trainer.strategy.load_model_state_dict.assert_called_once_with({"dummy_state": "dummy_value"}, strict=False)
trainer.strategy.init_model_parallel.assert_called_once()
trainer.strategy.setup_optimizers.assert_called_once_with(trainer)

def test_peft_on_load_checkpoint(self):
peft = self.DummyPEFT()
trainer = MagicMock()
pl_module = MagicMock()
checkpoint = {}

peft.on_load_checkpoint(trainer, pl_module, checkpoint)

assert pl_module.strict_loading == False

0 comments on commit 9ff422f

Please sign in to comment.