Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add safetensor option when saving and restoring models #11549

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

stevehuang52
Copy link
Collaborator

@stevehuang52 stevehuang52 commented Dec 11, 2024

What does this PR do ?

Re-do changes in a previously closed PR #7812 to add safetensor options when saving and restoring models. as customers requested.

CC @galv

Collection: [core,asr,nlp,common]

Changelog

  • Re-do changes in Add safetensors #7812 for current main branch
  • add nemo.utils.secure.torch_load and nemo.utils.secure.torch_save
  • torch.load and torch.save removed from save_restore_connector and alternates of it
  • secure.torch_load and torch_save use safetensors when available, however if safe=False torch.load and torch.save are used to allow for backwards compatability
  • safe is a named parameter and passed through from save_to and restore_from. The default is safe=False to preserve backwards compatability
  • added unit tests to double check backwards compatability and secure functions work correctly.

Usage

Following usage is copied from #7812

When using nemo with untrusted .nemo files this will greatly reduce the potential for the user to be attacked.

# Existing usage should be preserved. although the result will include a .safetensor inside the .nemo now as well
model.save_to(save_path=model_save_path)
model.save_to(save_path=model_save_path, safe=False)

# the .nemo will only include the .safetensor.
model.save_to(save_path=model_save_path, safe=True)

# Existing usage should be preserved. although if the .nemo contains a .safetensor, it will be used instead of the pytorch version
model.restore_from(model_save_path)
model.restore_from(save_path=model_save_path, safe=False)

# The restore will fail if a .safetensor isn't available to prevent an untrusted .nemo file from resulting in an exploit
model.restore_from(save_path=model_save_path, safe=True)

PR Type:

  • New Feature
  • Bugfix
  • Documentation

Signed-off-by: stevehuang52 <[email protected]>
@@ -508,6 +511,7 @@ def restore_from(
override_config_path: Optional[str] = None,
map_location: Optional[torch.device] = None,
strict: bool = False,
safe: bool = False,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tango4j this won't work for diarization models, since the vad and speaker models are restored in the _init_vad_model and _init_speaker_model methods that don't take safe as additional param. Do you think we should fix it or maybe plan in the future?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My opinion is to leave it for the future. Enabling safe tensors unblocks an important use case for us.

self.msdd_model.save_to(neural_diar_model)
self.clus_diar._vad_model.save_to(vad_model, safe=safe)
self.clus_diar._speaker_model.save_to(spkr_model, safe=safe)
self.msdd_model.save_to(neural_diar_model, safe=safe)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tango4j Similar issue as ClusterDiarizer, here the _init_msdd_model doesn't take safe param

@@ -1010,7 +1010,7 @@
)
super().__init__()

def save_to(self, model, save_path: str):
def save_to(self, model, save_path: str, safe: bool = False):

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
save_path = save_path.replace(".nemo", "_XYZ.nemo")
super().save_to(model, save_path, safe=safe)

class MockModelV2(MockModel):

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable MockModelV2 is not used.
tests/core/test_save_restore.py Fixed Show fixed Hide fixed
@FredSRichardson
Copy link

I just wanted to chime in that if there is any way to get this pull request merged into the main branch and a release branch that would be hugely helpful. I can hobble along using my clone of this branch, but that's obviously not ideal. Thank you guys for you great work and for this really helpful contribution in particular!

Copy link
Contributor

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

@github-actions github-actions bot added the stale label Dec 28, 2024
Copy link
Contributor

github-actions bot commented Jan 4, 2025

This PR was closed because it has been inactive for 7 days since being marked as stale.

@github-actions github-actions bot closed this Jan 4, 2025
@galv galv reopened this Jan 8, 2025
@galv
Copy link
Collaborator

galv commented Jan 8, 2025

Reopened. It got stale over the holidays.

@github-actions github-actions bot removed the stale label Jan 9, 2025
if safe:
raise e
else:
logging.info(e)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My opinion is that this is going to create a lot of confusingly noise. Basically every time that someone does an unsafe load on a model that wasn't saved with the safe=True option (which is frequent because it's the default), they will get an unhelpful log message of a file not found error, even though the application is otherwise working correctly.

None
"""
try:
storch.save_file(tensors, filename + SAFE_EXTENSION)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use the safetensors format if the user does not request it? You will basically double the size of the model on the disk this way. This does not seem like a good idea to me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old PR enables safetensor by default, but yes I agree that it should be disabled by default, and only enable when required.

@FredSRichardson
Copy link

I just wanted to say I fully endorse this pull request. I tested it and it works well for my use case. It will greatly enable my work if some form of this safetensor PR can make it into the main branch. Thank you!!!

Copy link
Contributor

beep boop 🤖: 🙏 The following files have warnings. In case you are familiar with these, please try helping us to improve the code base.


Your code was analyzed with PyLint. The following annotations have been identified:

************* Module nemo.collections.asr.models.clustering_diarizer
nemo/collections/asr/models/clustering_diarizer.py:238:0: C0301: Line too long (135/119) (line-too-long)
nemo/collections/asr/models/clustering_diarizer.py:242:0: C0301: Line too long (164/119) (line-too-long)
nemo/collections/asr/models/clustering_diarizer.py:328:0: C0301: Line too long (121/119) (line-too-long)
nemo/collections/asr/models/clustering_diarizer.py:549:4: C0116: Missing function or method docstring (missing-function-docstring)
************* Module nemo.collections.asr.modules.audio_preprocessing
nemo/collections/asr/modules/audio_preprocessing.py:98:0: C0301: Line too long (667/119) (line-too-long)
nemo/collections/asr/modules/audio_preprocessing.py:95:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/audio_preprocessing.py:106:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/audio_preprocessing.py:304:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/audio_preprocessing.py:545:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/audio_preprocessing.py:616:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/audio_preprocessing.py:667:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/audio_preprocessing.py:724:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/asr/modules/audio_preprocessing.py:757:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/asr/modules/audio_preprocessing.py:776:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/asr/modules/audio_preprocessing.py:792:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/asr/modules/audio_preprocessing.py:798:0: C0115: Missing class docstring (missing-class-docstring)
************* Module nemo.collections.asr.modules.conv_asr
nemo/collections/asr/modules/conv_asr.py:197:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:239:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:399:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:459:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:503:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:507:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:603:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:677:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:689:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:758:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:858:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:881:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/asr/modules/conv_asr.py:900:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/asr/modules/conv_asr.py:945:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/asr/modules/conv_asr.py:969:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/asr/modules/conv_asr.py:983:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/asr/modules/conv_asr.py:992:0: C0115: Missing class docstring (missing-class-docstring)
************* Module nemo.collections.nlp.models.nlp_model
nemo/collections/nlp/models/nlp_model.py:187:0: C0301: Line too long (121/119) (line-too-long)
nemo/collections/nlp/models/nlp_model.py:191:0: C0301: Line too long (120/119) (line-too-long)
nemo/collections/nlp/models/nlp_model.py:196:0: C0301: Line too long (135/119) (line-too-long)
nemo/collections/nlp/models/nlp_model.py:286:0: C0301: Line too long (132/119) (line-too-long)
nemo/collections/nlp/models/nlp_model.py:423:0: C0301: Line too long (133/119) (line-too-long)
nemo/collections/nlp/models/nlp_model.py:306:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/nlp/models/nlp_model.py:454:4: C0116: Missing function or method docstring (missing-function-docstring)
************* Module nemo.collections.nlp.parts.nlp_overrides
nemo/collections/nlp/parts/nlp_overrides.py:212:0: C0301: Line too long (140/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:217:0: C0301: Line too long (149/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:233:0: C0301: Line too long (123/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:424:0: C0301: Line too long (136/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:510:0: C0301: Line too long (152/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:730:0: C0301: Line too long (140/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:735:0: C0301: Line too long (149/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:1001:0: C0301: Line too long (128/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:1005:0: C0301: Line too long (141/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:1009:0: C0301: Line too long (149/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:1066:0: C0301: Line too long (135/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:1193:0: C0301: Line too long (121/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:1778:0: C0301: Line too long (152/119) (line-too-long)
nemo/collections/nlp/parts/nlp_overrides.py:248:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/nlp/parts/nlp_overrides.py:381:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/nlp/parts/nlp_overrides.py:425:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/nlp/parts/nlp_overrides.py:602:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/nlp/parts/nlp_overrides.py:618:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/nlp/parts/nlp_overrides.py:637:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/nlp/parts/nlp_overrides.py:878:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/nlp/parts/nlp_overrides.py:997:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/nlp/parts/nlp_overrides.py:1689:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/nlp/parts/nlp_overrides.py:1777:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/nlp/parts/nlp_overrides.py:18:0: W0611: Unused import re (unused-import)
nemo/collections/nlp/parts/nlp_overrides.py:106:4: W0611: Unused tensorstore imported from megatron.core.dist_checkpointing.strategies (unused-import)
************* Module nemo.core.classes.common
nemo/core/classes/common.py:693:0: C0301: Line too long (120/119) (line-too-long)
nemo/core/classes/common.py:819:0: C0301: Line too long (124/119) (line-too-long)
nemo/core/classes/common.py:926:0: C0301: Line too long (120/119) (line-too-long)
nemo/core/classes/common.py:471:0: C0115: Missing class docstring (missing-class-docstring)
nemo/core/classes/common.py:567:0: C0115: Missing class docstring (missing-class-docstring)
nemo/core/classes/common.py:647:0: C0115: Missing class docstring (missing-class-docstring)
nemo/core/classes/common.py:1026:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/classes/common.py:1141:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/classes/common.py:29:0: W0611: Unused Iterable imported from typing (unused-import)
************* Module nemo.core.classes.modelPT
nemo/core/classes/modelPT.py:82:0: C0301: Line too long (130/119) (line-too-long)
nemo/core/classes/modelPT.py:177:0: C0301: Line too long (121/119) (line-too-long)
nemo/core/classes/modelPT.py:184:0: C0301: Line too long (154/119) (line-too-long)
nemo/core/classes/modelPT.py:258:0: C0301: Line too long (131/119) (line-too-long)
nemo/core/classes/modelPT.py:262:0: C0301: Line too long (132/119) (line-too-long)
nemo/core/classes/modelPT.py:310:0: C0301: Line too long (160/119) (line-too-long)
nemo/core/classes/modelPT.py:390:0: C0301: Line too long (135/119) (line-too-long)
nemo/core/classes/modelPT.py:1250:0: C0301: Line too long (123/119) (line-too-long)
nemo/core/classes/modelPT.py:1487:0: C0301: Line too long (140/119) (line-too-long)
nemo/core/classes/modelPT.py:1698:0: C0301: Line too long (128/119) (line-too-long)
nemo/core/classes/modelPT.py:1717:0: C0301: Line too long (135/119) (line-too-long)
nemo/core/classes/modelPT.py:1727:0: C0301: Line too long (122/119) (line-too-long)
nemo/core/classes/modelPT.py:1853:0: C0301: Line too long (166/119) (line-too-long)
nemo/core/classes/modelPT.py:1922:0: C0301: Line too long (120/119) (line-too-long)
nemo/core/classes/modelPT.py:2076:0: C0301: Line too long (151/119) (line-too-long)
nemo/core/classes/modelPT.py:223:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/classes/modelPT.py:878:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/classes/modelPT.py:944:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/classes/modelPT.py:948:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/classes/modelPT.py:955:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/classes/modelPT.py:1223:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/classes/modelPT.py:1647:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/classes/modelPT.py:1765:4: C0116: Missing function or method docstring (missing-function-docstring)
************* Module nemo.core.connectors.save_restore_connector
nemo/core/connectors/save_restore_connector.py:51:0: C0301: Line too long (135/119) (line-too-long)
nemo/core/connectors/save_restore_connector.py:208:0: C0301: Line too long (141/119) (line-too-long)
nemo/core/connectors/save_restore_connector.py:315:0: C0301: Line too long (140/119) (line-too-long)
nemo/core/connectors/save_restore_connector.py:430:0: C0301: Line too long (141/119) (line-too-long)
nemo/core/connectors/save_restore_connector.py:435:0: C0301: Line too long (141/119) (line-too-long)
nemo/core/connectors/save_restore_connector.py:38:0: C0115: Missing class docstring (missing-class-docstring)
nemo/core/connectors/save_restore_connector.py:704:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/connectors/save_restore_connector.py:712:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/connectors/save_restore_connector.py:720:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/core/connectors/save_restore_connector.py:728:4: C0116: Missing function or method docstring (missing-function-docstring)

-----------------------------------
Your code has been rated at 9.75/10

Mitigation guide:

  • Add sensible and useful docstrings to functions and methods
  • For trivial methods like getter/setters, consider adding # pylint: disable=C0116 inside the function itself
  • To disable multiple functions/methods at once, put a # pylint: disable=C0116 before the first and a # pylint: enable=C0116 after the last.

By applying these rules, we reduce the occurance of this message in future.

Thank you for improving NeMo's documentation!

Copy link
Contributor

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

@github-actions github-actions bot added the stale label Jan 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants