-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Fix Optimizer & LR scheduler & Consume Samples when Resuming in PEFT (#11631) * Fix Optimizer & LR scheduler Resume * fix unit test Signed-off-by: Chen Cui <[email protected]> * Apply isort and black reformatting Signed-off-by: cuichenx <[email protected]> * typo Signed-off-by: Chen Cui <[email protected]> * Fix consume samples * Fix unit tests * Apply isort and black reformatting Signed-off-by: suiyoubi <[email protected]> --------- Signed-off-by: Chen Cui <[email protected]> Signed-off-by: cuichenx <[email protected]> Signed-off-by: suiyoubi <[email protected]> Co-authored-by: Chen Cui <[email protected]> Co-authored-by: cuichenx <[email protected]> Co-authored-by: suiyoubi <[email protected]> * Utilities to detect and drop deprecated arguments from NeMo 2.0 checkpoint context io.json (#11648) * Utils to detect and drop deprecated arguments in io.json Signed-off-by: Jan Lasek <[email protected]> * Unit tests for drop_unexpected_params Signed-off-by: Jan Lasek <[email protected]> * Apply isort and black reformatting Signed-off-by: janekl <[email protected]> * Add copyright header Signed-off-by: Jan Lasek <[email protected]> --------- Signed-off-by: Jan Lasek <[email protected]> Signed-off-by: janekl <[email protected]> Co-authored-by: janekl <[email protected]> * NIM supporting changes for nemo.export for NeMo 2.0 (part II) (#11669) * Remove trt_compile from __init__ as it triggers imports from nemo.utils Signed-off-by: Jan Lasek <[email protected]> * Get tokenizer for NeMo 2 from model.yaml using local SP or HF classes Signed-off-by: Jan Lasek <[email protected]> * Apply isort and black reformatting Signed-off-by: janekl <[email protected]> --------- Signed-off-by: Jan Lasek <[email protected]> Signed-off-by: janekl <[email protected]> Co-authored-by: janekl <[email protected]> * Add check for symlink in _safe_extract (#11611) Signed-off-by: Abhishree <[email protected]> --------- Signed-off-by: Chen Cui <[email protected]> Signed-off-by: cuichenx <[email protected]> Signed-off-by: suiyoubi <[email protected]> Signed-off-by: Jan Lasek <[email protected]> Signed-off-by: janekl <[email protected]> Signed-off-by: Abhishree <[email protected]> Co-authored-by: Ao Tang <[email protected]> Co-authored-by: Chen Cui <[email protected]> Co-authored-by: cuichenx <[email protected]> Co-authored-by: suiyoubi <[email protected]> Co-authored-by: Jan Lasek <[email protected]> Co-authored-by: janekl <[email protected]> Co-authored-by: Abhishree Thittenamane <[email protected]>
- Loading branch information
1 parent
526a525
commit 36511af
Showing
10 changed files
with
303 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import sys | ||
from datetime import datetime | ||
from pathlib import Path | ||
|
||
import fiddle as fdl | ||
from fiddle._src.experimental import serialization | ||
|
||
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir | ||
from nemo.lightning.io import drop_unexpected_params, load | ||
from nemo.utils import logging | ||
|
||
IO_FILE = "io.json" | ||
|
||
""" | ||
Script to update NeMo 2.0 model context (stored in io.json) for unexpected | ||
keword arguments for compatibility with the currently running environment. | ||
It accepts path to a NeMo 2.0 checkpoint and optional flag for building | ||
the updated configuration. It performs the following steps: | ||
1. Loads config from the model context directory. | ||
2. Checks the config for unexpected (e.g. deprecated) arguments and drops them. | ||
3. Attempts to build the updated configuration if --build flag is on. | ||
4. Backs up the existing context file and saves the updated configuration. | ||
""" | ||
|
||
|
||
def get_args(): | ||
"""Parses command line arguments.""" | ||
parser = argparse.ArgumentParser( | ||
description="Script to drop unexpected arguments from NeMo 2.0 io.json model context." | ||
) | ||
parser.add_argument("--model_path", type=str, required=True, help="Path to a NeMo 2.0 checkpoint.") | ||
parser.add_argument("--build", action="store_true", help="Whether to test building the updated config.") | ||
return parser.parse_args() | ||
|
||
|
||
def save_io(config: fdl.Config, path: str): | ||
""" | ||
Saves the given configuration object to a specified file path in JSON format. | ||
Args: | ||
config (fdl.Config): The configuration object to be saved. | ||
path (str): The file path where the configuration will be saved. | ||
""" | ||
config_json = serialization.dump_json(config) | ||
with open(path, "w") as f: | ||
f.write(config_json) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = get_args() | ||
|
||
model_path = Path(args.model_path) | ||
context_path = ckpt_to_context_subdir(model_path) | ||
logging.info(f"Path to model context: {context_path}.") | ||
|
||
config = load(context_path, build=False) | ||
updated = drop_unexpected_params(config) | ||
|
||
if not updated: | ||
logging.info("Config does not need any updates.") | ||
sys.exit(0) | ||
|
||
if args.build: | ||
try: | ||
fdl.build(config) | ||
except Exception as e: | ||
logging.error("Build for the updated config failed.") | ||
raise | ||
else: | ||
logging.info("Build for the updated config successful.") | ||
|
||
# Backup the existing context file and save the updated config | ||
io_path = context_path / IO_FILE | ||
io_path_backup = context_path / f"BACKUP_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}_{IO_FILE}" | ||
io_path.rename(io_path_backup) | ||
save_io(config, io_path) | ||
logging.info(f"Config saved to {io_path}.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import fiddle as fdl | ||
|
||
from nemo.lightning.io import drop_unexpected_params | ||
|
||
|
||
class TestDropUnexpectedParams: | ||
|
||
def setup_method(self): | ||
""" | ||
Setup common test resources. | ||
""" | ||
|
||
class MockClassOld: | ||
def __init__(self, x, y, deprecated): | ||
pass | ||
|
||
class MockClassNew: | ||
def __init__(self, x, y): | ||
pass | ||
|
||
class OuterClass: | ||
def __init__(self, z, t): | ||
pass | ||
|
||
self.MockClassOld = MockClassOld | ||
self.MockClassNew = MockClassNew | ||
self.OuterClass = OuterClass | ||
|
||
def test_valid_config_stays_same(self): | ||
""" | ||
Test that a valid config remains unchanged. | ||
""" | ||
|
||
config = fdl.Config(self.MockClassNew, x=1, y=2) | ||
updated = drop_unexpected_params(config) | ||
|
||
assert not updated, "Expected the config to remain unchanged." | ||
assert config.x == 1 | ||
assert config.y == 2 | ||
|
||
def test_config_updates(self): | ||
""" | ||
Test that a config with unexpected parameters gets updated. | ||
""" | ||
config = fdl.Config(self.MockClassOld, x=1, y=2, deprecated=3) | ||
|
||
# Simulate deprecation issue by overriding target class | ||
config.__dict__['__fn_or_cls__'] = self.MockClassNew | ||
|
||
updated = drop_unexpected_params(config) | ||
assert updated, "Expected the config to be updated." | ||
assert config.x == 1 | ||
assert config.y == 2 | ||
assert not hasattr(config, "deprecated"), "Expected 'deprecated' to be removed from the config." | ||
|
||
def test_nested_config_updates(self): | ||
""" | ||
Test that a nested config with unexpected parameters gets updated. | ||
""" | ||
config = fdl.Config(self.OuterClass, z=4, t=fdl.Config(self.MockClassOld, x=1, y=2, deprecated=3)) | ||
|
||
# Simulate deprecation issue by overriding target class | ||
config.t.__dict__["__fn_or_cls__"] = self.MockClassNew | ||
|
||
updated = drop_unexpected_params(config) | ||
assert updated, "Expected the nested config to be updated." | ||
assert config.z == 4 | ||
assert config.t.x == 1 | ||
assert config.t.y == 2 | ||
assert not hasattr(config.t, "deprecated"), "Expected 'deprecated' to be removed from the inner config." |
Oops, something went wrong.