Skip to content

Commit

Permalink
r2.1.0 cherrypick (#11680)
Browse files Browse the repository at this point in the history
* Fix Optimizer & LR scheduler & Consume Samples when Resuming in PEFT (#11631)

* Fix Optimizer & LR scheduler Resume

* fix unit test

Signed-off-by: Chen Cui <[email protected]>

* Apply isort and black reformatting

Signed-off-by: cuichenx <[email protected]>

* typo

Signed-off-by: Chen Cui <[email protected]>

* Fix consume samples

* Fix unit tests

* Apply isort and black reformatting

Signed-off-by: suiyoubi <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: cuichenx <[email protected]>
Signed-off-by: suiyoubi <[email protected]>
Co-authored-by: Chen Cui <[email protected]>
Co-authored-by: cuichenx <[email protected]>
Co-authored-by: suiyoubi <[email protected]>

* Utilities to detect and drop deprecated arguments from NeMo 2.0 checkpoint context io.json (#11648)

* Utils to detect and drop deprecated arguments in io.json

Signed-off-by: Jan Lasek <[email protected]>

* Unit tests for drop_unexpected_params

Signed-off-by: Jan Lasek <[email protected]>

* Apply isort and black reformatting

Signed-off-by: janekl <[email protected]>

* Add copyright header

Signed-off-by: Jan Lasek <[email protected]>

---------

Signed-off-by: Jan Lasek <[email protected]>
Signed-off-by: janekl <[email protected]>
Co-authored-by: janekl <[email protected]>

* NIM supporting changes for nemo.export for NeMo 2.0 (part II) (#11669)

* Remove trt_compile from __init__ as it triggers imports from nemo.utils

Signed-off-by: Jan Lasek <[email protected]>

* Get tokenizer for NeMo 2 from model.yaml using local SP or HF classes

Signed-off-by: Jan Lasek <[email protected]>

* Apply isort and black reformatting

Signed-off-by: janekl <[email protected]>

---------

Signed-off-by: Jan Lasek <[email protected]>
Signed-off-by: janekl <[email protected]>
Co-authored-by: janekl <[email protected]>

* Add check for symlink in _safe_extract (#11611)

Signed-off-by: Abhishree <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: cuichenx <[email protected]>
Signed-off-by: suiyoubi <[email protected]>
Signed-off-by: Jan Lasek <[email protected]>
Signed-off-by: janekl <[email protected]>
Signed-off-by: Abhishree <[email protected]>
Co-authored-by: Ao Tang <[email protected]>
Co-authored-by: Chen Cui <[email protected]>
Co-authored-by: cuichenx <[email protected]>
Co-authored-by: suiyoubi <[email protected]>
Co-authored-by: Jan Lasek <[email protected]>
Co-authored-by: janekl <[email protected]>
Co-authored-by: Abhishree Thittenamane <[email protected]>
  • Loading branch information
8 people authored Dec 20, 2024
1 parent 526a525 commit 36511af
Show file tree
Hide file tree
Showing 10 changed files with 303 additions and 16 deletions.
7 changes: 4 additions & 3 deletions nemo/collections/llm/gpt/data/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
self.packed_sequence_size = -1 if not packed_sequence_specs else packed_sequence_specs.packed_sequence_size
self.validate_batch_size_for_packed_sequence()
self.dataset_kwargs = dataset_kwargs or {}
self.init_global_step = 0

def validate_batch_size_for_packed_sequence(self):
"""
Expand Down Expand Up @@ -163,9 +164,7 @@ def state_dict(self) -> Dict[str, Any]:
A dictionary containing datamodule state.
"""
consumed_samples = self.data_sampler.compute_consumed_samples(
self.trainer.global_step - self.data_sampler.init_global_step
)
consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step)
return {"consumed_samples": consumed_samples}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -240,6 +239,8 @@ def _create_dataset(self, path, is_test=False, **kwargs):

def _create_dataloader(self, dataset, mode, **kwargs) -> DataLoader:
# pylint: disable=C0115,C0116
self.init_global_step = self.trainer.global_step
self.data_sampler.init_global_step = self.init_global_step
return WrappedDataLoader(
mode=mode,
dataset=dataset,
Expand Down
7 changes: 6 additions & 1 deletion nemo/core/connectors/save_restore_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,12 @@ def _is_safe_path(member, extract_to):
# Construct the full path where the member would be extracted
full_path = os.path.join(extract_to, member_path)
# Ensure the member would be extracted within the intended directory
return os.path.commonprefix([full_path, extract_to]) == extract_to
if os.path.commonprefix([full_path, extract_to]) != extract_to:
return False
# Check if the member is a symbolic link
if member.issym() or member.islnk():
return False
return True

@staticmethod
def _safe_extract(tar, out_folder: str, members=None):
Expand Down
2 changes: 0 additions & 2 deletions nemo/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.export.tensorrt_lazy_compiler import trt_compile
57 changes: 52 additions & 5 deletions nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@
from nemo.export.tarutils import TarPath, ZarrPathStore
from nemo.export.tiktoken_tokenizer import TiktokenTokenizer

try:
from nemo.lightning import io

HAVE_NEMO2 = True
except (ImportError, ModuleNotFoundError):
HAVE_NEMO2 = False

LOGGER = logging.getLogger("NeMo")


Expand Down Expand Up @@ -287,14 +294,54 @@ def copy_tokenizer_files(config, out_dir):
return config


def get_tokenizer_from_nemo2_context(model_context_dir: Path):
"""
Retrieve tokenizer configuration from NeMo 2.0 context and instantiate the tokenizer.
Args:
model_context_dir (Path): Path to the model context directory.
Returns:
The instantiated tokenizer (various classes possible).
"""

if HAVE_NEMO2:
# Use NeMo tokenizer loaded from the NeMo 2.0 model context
tokenizer_spec = io.load_context(model_context_dir, subpath="model.tokenizer")
return build_tokenizer(tokenizer_spec)
else:
# Use local nemo.export SentencePieceTokenizer implementation
# or directly a HuggingFace tokenizer based on the model config
with (model_context_dir / "model.yaml").open("r") as stream:
model_config = yaml.safe_load(stream)

tokenizer_config = model_config["tokenizer"]
target_class = tokenizer_config["_target_"]
tokenizer_module = "nemo.collections.common.tokenizers."
assert target_class.startswith(tokenizer_module)
target_class = target_class.removeprefix(tokenizer_module)

if target_class == "sentencepiece_tokenizer.SentencePieceTokenizer":
tokenizer = SentencePieceTokenizer(
model_path=str(model_context_dir / tokenizer_config["model_path"]),
special_tokens=tokenizer_config.get("special_tokens", None),
legacy=tokenizer_config.get("legacy", False),
)
elif target_class == "huggingface.auto_tokenizer.AutoTokenizer":
tokenizer = AutoTokenizer.from_pretrained(
str(model_context_dir / tokenizer_config["pretrained_model_name"])
)
else:
raise ValueError(f"Unsupported tokenizer type: {tokenizer_module}{target_class}.")

return tokenizer


def get_tokenizer(tokenizer_dir_or_path: Union[str, Path]) -> PreTrainedTokenizer:
"""Loads the tokenizer from the decoded NeMo weights dir."""
tokenizer_dir_or_path = Path(tokenizer_dir_or_path)
if (tokenizer_dir_or_path / "nemo_context").exists():
from nemo.lightning import io

tokenizer_spec = io.load_context((tokenizer_dir_or_path / "nemo_context"), subpath="model.tokenizer")
return build_tokenizer(tokenizer_spec)
return get_tokenizer_from_nemo2_context(tokenizer_dir_or_path / "nemo_context")
elif os.path.exists(os.path.join(tokenizer_dir_or_path, "vocab.json")):
vocab_path = tokenizer_dir_or_path / "vocab.json" if tokenizer_dir_or_path.is_dir() else tokenizer_dir_or_path
tokenizer_config = {"library": "tiktoken", "vocab_file": str(vocab_path)}
Expand Down Expand Up @@ -474,7 +521,7 @@ def load_nemo_model(nemo_ckpt: Union[str, Path], nemo_export_dir: Union[str, Pat
elif k == "activation_func":
nemo_model_config["activation"] = v["_target_"].rsplit('.', 1)[-1]
else:
from nemo.lightning import io
assert HAVE_NEMO2, "nemo_toolkit>=2.0.0 is required to load the model context."

config = io.load_context(io_folder, subpath="model.config")

Expand Down
3 changes: 2 additions & 1 deletion nemo/lightning/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from nemo.lightning.io.api import export_ckpt, import_ckpt, load, load_context, model_exporter, model_importer
from nemo.lightning.io.capture import reinit
from nemo.lightning.io.connector import Connector, ModelConnector
from nemo.lightning.io.mixin import ConnectorMixin, IOMixin, track_io
from nemo.lightning.io.mixin import ConnectorMixin, IOMixin, drop_unexpected_params, track_io
from nemo.lightning.io.pl import TrainerContext, is_distributed_ckpt
from nemo.lightning.io.state import TransformCTX, apply_transforms, state_transform

__all__ = [
"apply_transforms",
"Connector",
"ConnectorMixin",
"drop_unexpected_params",
"IOMixin",
"track_io",
"import_ckpt",
Expand Down
39 changes: 39 additions & 0 deletions nemo/lightning/io/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,45 @@ def _artifact_transform_load(cfg: fdl.Config, path: Path):
pass


def drop_unexpected_params(config: fdl.Config) -> bool:
"""
Analyzes config to detect unexpected keyword arguments -- for example, deprecated parameters -- and
updates the config by dropping them. Returns True if the config gets updated and False otherwise.
Args:
config (fdl.Config): The configuration object to analyze.
"""

updated = False

def analyze(config: fdl.Config, prefix: str):

if isinstance(config, fdl.Config):
signature = inspect.signature(config.__fn_or_cls__)

accept_kwargs = any(param.kind is inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values())

if not accept_kwargs:
to_drop = [param for param in config.__arguments__ if param not in signature.parameters]

if to_drop:
nonlocal updated
updated = True
logging.warning(f"Deprecated parameters to drop from {prefix}: {to_drop}")
for param in to_drop:
del config.__arguments__[param]
else:
logging.info(f"Skip analyzing {prefix} as it accepts arbitrary keyword arguments.")

# Proceed recursively for all arguments
for key, value in config.__arguments__.items():
analyze(value, prefix + "." + key)

analyze(config, "<root>")

return updated


def load(path: Path, output_type: Type[CkptType] = Any, subpath: Optional[str] = None, build: bool = True) -> CkptType:
"""
Loads a configuration from a pickle file and constructs an object of the specified type.
Expand Down
7 changes: 6 additions & 1 deletion nemo/lightning/pytorch/callbacks/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,12 @@ def apply_transform(self, trainer):
)
trainer.strategy.load_model_state_dict(adapter_state, strict=False)
if trainer.state.fn == TrainerFn.FITTING:
trainer.strategy.load_optimizer_state_dict(adapter_state, selective_restore=True)
# Load optimizer
trainer.strategy.load_optimizer_state_dict(adapter_state, selective_restore=False)
# Load lr scheduler
if (lr_schedulers := adapter_state.get('lr_schedulers', None)) is not None:
for config, lrs_state in zip(trainer.lr_scheduler_configs, lr_schedulers):
config.scheduler.load_state_dict(lrs_state)

for cb in trainer.callbacks[::-1]:
if isinstance(cb, MegatronOptimizerModule):
Expand Down
94 changes: 94 additions & 0 deletions scripts/llm/update_io_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import sys
from datetime import datetime
from pathlib import Path

import fiddle as fdl
from fiddle._src.experimental import serialization

from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
from nemo.lightning.io import drop_unexpected_params, load
from nemo.utils import logging

IO_FILE = "io.json"

"""
Script to update NeMo 2.0 model context (stored in io.json) for unexpected
keword arguments for compatibility with the currently running environment.
It accepts path to a NeMo 2.0 checkpoint and optional flag for building
the updated configuration. It performs the following steps:
1. Loads config from the model context directory.
2. Checks the config for unexpected (e.g. deprecated) arguments and drops them.
3. Attempts to build the updated configuration if --build flag is on.
4. Backs up the existing context file and saves the updated configuration.
"""


def get_args():
"""Parses command line arguments."""
parser = argparse.ArgumentParser(
description="Script to drop unexpected arguments from NeMo 2.0 io.json model context."
)
parser.add_argument("--model_path", type=str, required=True, help="Path to a NeMo 2.0 checkpoint.")
parser.add_argument("--build", action="store_true", help="Whether to test building the updated config.")
return parser.parse_args()


def save_io(config: fdl.Config, path: str):
"""
Saves the given configuration object to a specified file path in JSON format.
Args:
config (fdl.Config): The configuration object to be saved.
path (str): The file path where the configuration will be saved.
"""
config_json = serialization.dump_json(config)
with open(path, "w") as f:
f.write(config_json)


if __name__ == "__main__":
args = get_args()

model_path = Path(args.model_path)
context_path = ckpt_to_context_subdir(model_path)
logging.info(f"Path to model context: {context_path}.")

config = load(context_path, build=False)
updated = drop_unexpected_params(config)

if not updated:
logging.info("Config does not need any updates.")
sys.exit(0)

if args.build:
try:
fdl.build(config)
except Exception as e:
logging.error("Build for the updated config failed.")
raise
else:
logging.info("Build for the updated config successful.")

# Backup the existing context file and save the updated config
io_path = context_path / IO_FILE
io_path_backup = context_path / f"BACKUP_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}_{IO_FILE}"
io_path.rename(io_path_backup)
save_io(config, io_path)
logging.info(f"Config saved to {io_path}.")
84 changes: 84 additions & 0 deletions tests/collections/llm/io/test_drop_unexpected_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import fiddle as fdl

from nemo.lightning.io import drop_unexpected_params


class TestDropUnexpectedParams:

def setup_method(self):
"""
Setup common test resources.
"""

class MockClassOld:
def __init__(self, x, y, deprecated):
pass

class MockClassNew:
def __init__(self, x, y):
pass

class OuterClass:
def __init__(self, z, t):
pass

self.MockClassOld = MockClassOld
self.MockClassNew = MockClassNew
self.OuterClass = OuterClass

def test_valid_config_stays_same(self):
"""
Test that a valid config remains unchanged.
"""

config = fdl.Config(self.MockClassNew, x=1, y=2)
updated = drop_unexpected_params(config)

assert not updated, "Expected the config to remain unchanged."
assert config.x == 1
assert config.y == 2

def test_config_updates(self):
"""
Test that a config with unexpected parameters gets updated.
"""
config = fdl.Config(self.MockClassOld, x=1, y=2, deprecated=3)

# Simulate deprecation issue by overriding target class
config.__dict__['__fn_or_cls__'] = self.MockClassNew

updated = drop_unexpected_params(config)
assert updated, "Expected the config to be updated."
assert config.x == 1
assert config.y == 2
assert not hasattr(config, "deprecated"), "Expected 'deprecated' to be removed from the config."

def test_nested_config_updates(self):
"""
Test that a nested config with unexpected parameters gets updated.
"""
config = fdl.Config(self.OuterClass, z=4, t=fdl.Config(self.MockClassOld, x=1, y=2, deprecated=3))

# Simulate deprecation issue by overriding target class
config.t.__dict__["__fn_or_cls__"] = self.MockClassNew

updated = drop_unexpected_params(config)
assert updated, "Expected the nested config to be updated."
assert config.z == 4
assert config.t.x == 1
assert config.t.y == 2
assert not hasattr(config.t, "deprecated"), "Expected 'deprecated' to be removed from the inner config."
Loading

0 comments on commit 36511af

Please sign in to comment.