-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(pt): set weights_only=True
for torch.load
#4147
Conversation
Fix deepmodeling#4143. Signed-off-by: Jinzhe Zeng <[email protected]>
📝 Walkthrough📝 WalkthroughWalkthroughThe changes in this pull request involve modifying multiple files to update the Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Model
participant Torch
User->>Model: Request to load model
Model->>Torch: Load model with weights_only=True
Torch-->>Model: Return model weights
Model-->>User: Provide loaded model
Assessment against linked issues
Possibly related PRs
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
weights_only=True
for torch.load
weights_only=True
for torch.load
Surprisingly, in some place, NumPy arrays are saved to the state dict. cc @iProzd @wanghan-iapcm |
>>> type(torch.load("model.ckpt.pt")["model"]["_extra_state"]["train_infos"]["lr"])
<class 'numpy.float64'> |
See #4147 and #4143. We can first make `state_dict` safe for `weights_only`, then make a breaking change when loading `state_dict` in the future. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced model saving functionality by ensuring learning rates are consistently stored as floats, improving type consistency. - **Bug Fixes** - Updated model loading behavior in tests to focus solely on model weights, which may resolve issues related to state dictionary loading. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4147 +/- ##
==========================================
- Coverage 84.55% 84.55% -0.01%
==========================================
Files 537 537
Lines 51237 51238 +1
Branches 3047 3047
==========================================
- Hits 43324 43323 -1
- Misses 6965 6969 +4
+ Partials 948 946 -2 ☔ View full report in Codecov by Sentry. |
Signed-off-by: Jinzhe Zeng <[email protected]>
for more information, see https://pre-commit.ci
It's almost one month passed. Let's merge this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (3)
deepmd/pt/entrypoints/main.py (3)
286-288
: Approve changes and suggest error handlingThe addition of
weights_only=True
totorch.load
is a good security practice and addresses the warnings mentioned in the PR objectives.Consider adding error handling to catch potential exceptions that might occur if the loaded file doesn't contain the expected structure. For example:
try: init_state_dict = torch.load( init_model, map_location=DEVICE, weights_only=True ) if "model" in init_state_dict: init_state_dict = init_state_dict["model"] config["model"] = init_state_dict["_extra_state"]["model_params"] except KeyError as e: log.error(f"Failed to load model parameters: {e}") raise
385-387
: Approve changes and suggest consistency improvementThe addition of
weights_only=True
totorch.load
is consistent with the changes in thetrain
function and addresses the security concerns mentioned in the PR objectives.For consistency with the
train
function, consider adding similar error handling here:try: old_state_dict = torch.load( input_file, map_location=env.DEVICE, weights_only=True ) model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict)) model_params = model_state_dict["_extra_state"]["model_params"] except KeyError as e: log.error(f"Failed to load model parameters: {e}") raise
Line range hint
1-587
: Summary of changes and their impactThe changes in this file effectively address the PR objectives by adding
weights_only=True
totorch.load
calls in both thetrain
andchange_bias
functions. This modification enhances security by preventing potential arbitrary code execution during model loading.The changes are consistent and focused, minimizing the risk of introducing new issues. They align well with PyTorch's recommendations for secure model loading.
Consider implementing a utility function for loading models with error handling, which can be reused across different parts of the codebase. This would ensure consistent behavior and error handling when loading models. For example:
def load_model_safely(file_path, device): try: state_dict = torch.load(file_path, map_location=device, weights_only=True) if "model" in state_dict: state_dict = state_dict["model"] model_params = state_dict["_extra_state"]["model_params"] return state_dict, model_params except KeyError as e: log.error(f"Failed to load model parameters from {file_path}: {e}") raiseThis function could then be used in both
train
andchange_bias
functions, promoting code reuse and consistency.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
- deepmd/pt/entrypoints/main.py (2 hunks)
- deepmd/pt/infer/deep_eval.py (1 hunks)
- deepmd/pt/train/training.py (1 hunks)
🧰 Additional context used
🔇 Additional comments (2)
deepmd/pt/infer/deep_eval.py (1)
106-108
: Approved: Security improvement fortorch.load
The addition of
weights_only=True
to thetorch.load
function call is a positive change that addresses the security concerns raised in the linked issue #4143. This modification aligns with PyTorch's recommendations and helps prevent potential security risks associated with loading untrusted pickle data.To ensure this change doesn't introduce compatibility issues, please run the following verification script:
If the script returns any results, consider updating those instances as well for consistency across the codebase.
Consider adding a comment explaining the security implications of
weights_only=True
for future maintainers. Additionally, you may want to update the documentation to reflect this change in behavior, especially if it affects how users should prepare or load their model files.deepmd/pt/train/training.py (1)
403-405
: LGTM! Verify impact on existing functionality.The addition of
weights_only=True
to thetorch.load
call is correct and addresses the security concern raised in the issue. This change will prevent loading arbitrary objects during unpickling.Please verify that this change doesn't break any existing functionality that might depend on non-weight data in the checkpoint. Run the following script to check for any uses of loaded data that might be affected:
If the script returns any results, carefully review those occurrences to ensure they're not relying on non-weight data that will no longer be loaded.
Fix #4143.
Summary by CodeRabbit
New Features
Bug Fixes
Tests