Skip to content

Commit

Permalink
Use explicit subpaths in io for exporting a checkpoint (#11352)
Browse files Browse the repository at this point in the history
* Fix llm.export_ckpt

Signed-off-by: Hemil Desai <[email protected]>

* fix

Signed-off-by: Hemil Desai <[email protected]>

---------

Signed-off-by: Hemil Desai <[email protected]>
  • Loading branch information
hemildesai authored Dec 17, 2024
1 parent 06a1491 commit f2169a1
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def import_ckpt(


def load_connector_from_trainer_ckpt(path: Path, target: str) -> io.ModelConnector:
return io.load_context(path).model.exporter(target, path)
return io.load_context(path, subpath="model").exporter(target, path)


@run.cli.entrypoint(name="export", namespace="llm")
Expand Down
9 changes: 6 additions & 3 deletions nemo/collections/llm/gpt/model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,10 @@ def apply(self, output_path: Path) -> Path:

target = target.cpu()
target.save_pretrained(output_path)
self.tokenizer.save_pretrained(output_path)
try:
self.tokenizer.save_pretrained(output_path)
except Exception:
logging.warning("Failed to save tokenizer")

return output_path

Expand All @@ -366,11 +369,11 @@ def convert_state(self, source, target):

@property
def tokenizer(self):
return io.load_context(str(self)).model.tokenizer.tokenizer
return io.load_context(str(self), subpath="model").tokenizer.tokenizer

@property
def config(self) -> "HFLlamaConfig":
source: LlamaConfig = io.load_context(str(self)).model.config
source: LlamaConfig = io.load_context(str(self), subpath="model.config")

from transformers import LlamaConfig as HFLlamaConfig

Expand Down
2 changes: 1 addition & 1 deletion nemo/lightning/io/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def nemo_load(
from nemo.lightning import MegatronStrategy, Trainer, _strategy_lib
from nemo.lightning.io.api import load_context

model = load_context(path).model
model = load_context(path, subpath="model")
_trainer = trainer or Trainer(
devices=1,
accelerator="cpu" if cpu else "gpu",
Expand Down

0 comments on commit f2169a1

Please sign in to comment.