Skip to content

Commit

Permalink
Pass omegaconf object to trainer in nlp_checkpoint_port.py
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishree <[email protected]>
  • Loading branch information
athitten committed Oct 18, 2023
1 parent 0149f52 commit 6fdca01
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
6 changes: 3 additions & 3 deletions nemo/core/config/pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class TrainerConfig:
limit_test_batches: Any = 1.0
val_check_interval: Any = 1.0
log_every_n_steps: int = 50
accelerator: Optional[str] = None
accelerator: Optional[str] = 'auto'
sync_batchnorm: bool = False
precision: Any = 32
num_sanity_val_steps: int = 2
Expand All @@ -68,8 +68,8 @@ class TrainerConfig:
gradient_clip_algorithm: str = 'norm'
max_time: Optional[Any] = None # can be one of Union[str, timedelta, Dict[str, int], None]
reload_dataloaders_every_n_epochs: int = 0
devices: Any = None
strategy: Any = None
devices: Any = 'auto'
strategy: Any = 'auto'
enable_checkpointing: bool = False
enable_model_summary: bool = True
inference_mode: bool = True
Expand Down
3 changes: 2 additions & 1 deletion scripts/nemo_legacy_import/nlp_checkpoint_port.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def nemo_convert(argv):
logger=False,
enable_checkpointing=False,
)
trainer = pl.Trainer(cfg_trainer)
cfg_trainer = OmegaConf.to_container(OmegaConf.create(cfg_trainer))
trainer = pl.Trainer(**cfg_trainer)

logging.info("Restoring NeMo model from '{}'".format(nemo_in))
try:
Expand Down

0 comments on commit 6fdca01

Please sign in to comment.