-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(jax): checkpoint I/O #4236
Conversation
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
📝 Walkthrough📝 WalkthroughWalkthroughThe pull request introduces significant updates to the Changes
Possibly related PRs
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 8
🧹 Outside diff range and nitpick comments (19)
deepmd/jax/descriptor/base_descriptor.py (1)
2-7
: Consider using absolute imports for better maintainability.While the current relative imports work, consider using absolute imports for improved maintainability and to avoid potential issues with circular imports in the future.
Here's a suggested change:
-from deepmd.dpmodel.descriptor.make_base_descriptor import ( +from deepmd.dpmodel.descriptor.make_base_descriptor import ( make_base_descriptor, ) -from deepmd.jax.env import ( +from deepmd.jax.env import ( jnp, )deepmd/jax/atomic_model/base_atomic_model.py (1)
11-18
: LGTM: Well-implemented attribute processing function.The
base_atomic_model_set_attr
function effectively processes attribute values based on their names. It handles special cases appropriately, including conversion to JAX arrays and creation of exclusion mask instances.Consider adding a docstring to improve the function's documentation. For example:
def base_atomic_model_set_attr(name: str, value: Any) -> Any: """ Process attribute values based on their names for atomic models. Args: name (str): The name of the attribute. value (Any): The value to be processed. Returns: Any: The processed value. """ # ... existing implementation ...This would enhance the function's self-documentation and make it easier for other developers to understand its purpose and usage.
deepmd/jax/model/ener_model.py (1)
21-24
: LGTM:__setattr__
implementation ensures consistent atomic model initialization.The method correctly handles the "atomic_model" attribute by serializing and deserializing it using
DPAtomicModel
. This approach ensures consistency and proper initialization.Consider the performance impact of serializing and deserializing for large models. If performance becomes an issue, you might want to explore more efficient ways to ensure proper initialization without the full serialization cycle.
deepmd/jax/descriptor/se_e2_a.py (1)
Line range hint
25-38
: LGTM: Custom setattr method with a minor suggestion.The custom
__setattr__
method is well-implemented, providing type-specific processing for various attributes. It ensures JAX compatibility, handles complex data structures, and implements specific logic for different attributes.Consider adding a brief explanation for why "env_mat" doesn't store any value, as this might not be immediately clear to other developers.
deepmd/backend/jax.py (4)
35-40
: LGTM. Consider adding a comment explaining the feature change.The addition of
Backend.Feature.IO
to thefeatures
class variable is appropriate. This change indicates that the JAX backend now supports I/O operations.Consider adding a brief comment explaining why this feature was added and the implications for the JAX backend's capabilities.
96-100
: LGTM. Consider adding a docstring forserialize_from_file
.The implementation of the
serialize_hook
property usingserialize_from_file
fromdeepmd.jax.utils.serialization
is appropriate and consistent with the newly added I/O feature support.Consider adding a brief docstring for the
serialize_from_file
function, explaining its purpose and any important details about its usage.
111-115
: LGTM. Consider adding a docstring fordeserialize_to_file
.The implementation of the
deserialize_hook
property usingdeserialize_to_file
fromdeepmd.jax.utils.serialization
is appropriate and consistent with the newly added I/O feature support and theserialize_hook
implementation.Consider adding a brief docstring for the
deserialize_to_file
function, explaining its purpose and any important details about its usage.
Line range hint
35-115
: Overall implementation looks good. Consider adding more documentation.The changes to
JAXBackend
successfully implement I/O support for the JAX backend. The additions ofBackend.Feature.IO
tofeatures
,".jax"
tosuffixes
, and the implementation ofserialize_hook
anddeserialize_hook
properties are consistent and well-structured.To improve code readability and maintainability:
- Add a comment explaining the addition of the I/O feature and its implications.
- Include brief docstrings for the imported
serialize_from_file
anddeserialize_to_file
functions.- Consider adding a class-level docstring or updating the existing one to reflect the new I/O capabilities of the
JAXBackend
.source/tests/consistent/io/test_io.py (1)
71-71
: LGTM: Added JAX backend support in test_data_equalThe changes appropriately include JAX in the backend testing:
- Adding "jax" to the backend list ensures JAX models are tested alongside other backends.
- Including "jax_version" in the excluded keys maintains consistency in version-specific data handling across backends.
These additions improve the test coverage by including JAX support.
Consider adding a comment explaining why these specific keys are excluded from comparison, to improve code readability and maintainability.
Also applies to: 86-86
deepmd/dpmodel/atomic_model/dp_atomic_model.py (2)
172-176
: LGTM! Consider enhancing docstrings for clarity.The addition of
base_descriptor_cls
andbase_fitting_cls
as class attributes is a good approach to enhance the extensibility of theDPAtomicModel
class. This allows subclasses to override the base descriptor and fitting classes easily.Consider expanding the docstrings slightly to provide more context:
base_descriptor_cls = BaseDescriptor """The base descriptor class. Can be overridden by subclasses to use custom descriptors.""" base_fitting_cls = BaseFitting """The base fitting class. Can be overridden by subclasses to use custom fitting methods."""
184-185
: LGTM! Consider a minor adjustment for consistency.The changes to use
cls.base_descriptor_cls
andcls.base_fitting_cls
in thedeserialize
method align well with the new class attributes. This modification enhances the flexibility of the deserialization process, allowing subclasses to control which descriptor and fitting classes are used.For consistency with the class attribute names, consider renaming the variables:
descriptor_obj = cls.base_descriptor_cls.deserialize(data.pop("descriptor")) fitting_obj = cls.base_fitting_cls.deserialize(data.pop("fitting"))to:
base_descriptor = cls.base_descriptor_cls.deserialize(data.pop("descriptor")) base_fitting = cls.base_fitting_cls.deserialize(data.pop("fitting"))This naming would more closely reflect their relationship to the class attributes.
deepmd/jax/atomic_model/dp_atomic_model.py (1)
22-22
: Add a class docstring toDPAtomicModel
The class
DPAtomicModel
lacks a class-level docstring. Including a docstring will improve code readability and provide valuable context about the class's purpose and usage.deepmd/jax/model/model.py (2)
17-24
: Enhance the docstring for better clarity and guidanceThe docstring for
get_standard_model
provides a basic description but could be expanded for improved clarity. This enhancement would assist users in understanding the function's purpose and how to use it effectively.Consider including:
- A brief explanation of what a "standard model" is within the context of the project.
- Detailed descriptions of the expected structure and required keys in the
data
dictionary.- Information about supported descriptor and fitting types.
- An example illustrating how to call the function with sample data.
48-55
: Expand the docstring to provide comprehensive usage informationSimilarly, the docstring for
get_model
could be made more informative. Providing additional details would help users navigate different model types and understand how to extend or customize models.Suggestions:
- Explain the purpose of the
get_model
function and how it differentiates between model types.- Outline the expected contents of the
data
dictionary, including optional and required keys.- Describe how the
"type"
key influences the model construction.- Provide examples for creating both standard and custom models.
source/tests/consistent/model/common.py (2)
9-11
: Consider aliasingto_numpy_array
to avoid confusionThe function
to_numpy_array
is imported fromdeepmd.dpmodel.common
. Since a similar function is imported from PyTorch and aliased astorch_to_numpy
, consider aliasing this import for consistency and to prevent potential confusion.
75-87
: Remove unused parameternatoms
fromeval_jax_model
The parameter
natoms
ineval_jax_model
is not used within the method. Consider removing it to simplify the function signature and avoid confusion.deepmd/dpmodel/model/transform_output.py (2)
36-36
: Typo in comment: Correct 'brefore' to 'before'There is a typographical error in the comment. The word 'brefore' should be corrected to 'before'.
Apply this diff to fix the typo:
- # cast to energy prec brefore reduction + # cast to energy prec before reduction
Line range hint
27-36
: Consider adding unit tests to validate array backend compatibilitySince the code now utilizes
array_api_compat
to support different array backends, it would be beneficial to add unit tests that ensure correct functionality across the supported array libraries.deepmd/jax/fitting/fitting.py (1)
39-44
: Add docstrings to fitting classes for improved clarityIncluding docstrings for
EnergyFittingNet
andDOSFittingNet
will enhance code readability and provide valuable information about their purpose and usage to other developers.Also applies to: 47-52
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (25)
- deepmd/backend/jax.py (3 hunks)
- deepmd/dpmodel/atomic_model/base_atomic_model.py (4 hunks)
- deepmd/dpmodel/atomic_model/dp_atomic_model.py (1 hunks)
- deepmd/dpmodel/model/make_model.py (7 hunks)
- deepmd/dpmodel/model/transform_output.py (3 hunks)
- deepmd/jax/atomic_model/init.py (1 hunks)
- deepmd/jax/atomic_model/base_atomic_model.py (1 hunks)
- deepmd/jax/atomic_model/dp_atomic_model.py (1 hunks)
- deepmd/jax/descriptor/init.py (1 hunks)
- deepmd/jax/descriptor/base_descriptor.py (1 hunks)
- deepmd/jax/descriptor/dpa1.py (2 hunks)
- deepmd/jax/descriptor/se_e2_a.py (2 hunks)
- deepmd/jax/env.py (1 hunks)
- deepmd/jax/fitting/init.py (1 hunks)
- deepmd/jax/fitting/base_fitting.py (1 hunks)
- deepmd/jax/fitting/fitting.py (2 hunks)
- deepmd/jax/model/init.py (1 hunks)
- deepmd/jax/model/base_model.py (1 hunks)
- deepmd/jax/model/ener_model.py (1 hunks)
- deepmd/jax/model/model.py (1 hunks)
- deepmd/jax/utils/serialization.py (1 hunks)
- pyproject.toml (2 hunks)
- source/tests/consistent/io/test_io.py (3 hunks)
- source/tests/consistent/model/common.py (3 hunks)
- source/tests/consistent/model/test_ener.py (5 hunks)
✅ Files skipped from review due to trivial changes (5)
- deepmd/jax/atomic_model/init.py
- deepmd/jax/descriptor/init.py
- deepmd/jax/fitting/base_fitting.py
- deepmd/jax/model/init.py
- deepmd/jax/model/base_model.py
🧰 Additional context used
🪛 Ruff
deepmd/dpmodel/atomic_model/base_atomic_model.py
202-202: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
deepmd/dpmodel/model/make_model.py
368-368: Local variable
nall
is assigned to but never usedRemove assignment to unused variable
nall
(F841)
deepmd/jax/utils/serialization.py
73-73: Local variable
state
is assigned to but never usedRemove assignment to unused variable
state
(F841)
🔇 Additional comments (46)
deepmd/jax/fitting/__init__.py (3)
2-5
: LGTM: Imports are clear and specific.The imports are well-structured, importing specific classes from the correct module. This approach helps maintain a clean namespace and improves code readability.
7-10
: LGTM: all list correctly updated.The all list has been properly updated to include the newly imported classes. This ensures that these classes are explicitly part of the public API, which is a good practice for controlling what gets exported when using
from module import *
.
2-10
: Summary: Appropriate updates to module exports.The changes to this
__init__.py
file are minimal but important. By adding imports forDOSFittingNet
andEnergyFittingNet
and including them in the__all__
list, the module now explicitly exposes these classes as part of its public API. This update aligns well with the pull request's objective of enhancing the JAX-based functionality in the DeepMD framework.These changes improve the module's usability and make it clear which components are intended for external use. Good job on maintaining a clean and explicit public interface!
deepmd/jax/descriptor/base_descriptor.py (2)
1-9
: LGTM! Well-structured file for JAX-based descriptor.The file is concise and well-organized, effectively setting up a JAX-compatible base descriptor. It demonstrates good separation of concerns by importing the necessary components and utilizing a factory function for flexibility.
9-9
: Verify the integration of BaseDescriptor with other components.The
BaseDescriptor
is correctly defined using themake_base_descriptor
factory function withjnp.ndarray
. This setup allows for JAX-optimized operations.To ensure proper integration, let's verify its usage:
✅ Verification successful
BaseDescriptor is properly integrated with other components.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify the usage of BaseDescriptor in other files # Test: Search for BaseDescriptor usage rg -A 5 "BaseDescriptor" # Test: Check for any potential circular imports rg -A 5 "from deepmd.jax.descriptor.base_descriptor import BaseDescriptor"Length of output: 28747
deepmd/jax/env.py (2)
21-21
: Consistent export of newly imported module.The addition of
jax2tf
to__all__
is consistent with its import, making it accessible when usingfrom deepmd.jax.env import *
. This change is appropriate and maintains the module's interface consistency.
11-13
: Consider the implications of using experimental JAX features.The addition of
jax2tf
fromjax.experimental
suggests plans to use JAX-to-TensorFlow conversion capabilities. While this can be powerful, be aware that experimental features may be subject to changes or instability in future JAX releases.To ensure this import is used elsewhere in the project, run:
deepmd/jax/atomic_model/base_atomic_model.py (2)
1-8
: LGTM: Imports are appropriate and well-organized.The imports are relevant to the function's implementation and follow Python best practices for relative imports.
1-18
: Summary: Solid implementation of attribute processing for JAX compatibility.This new file introduces a utility function
base_atomic_model_set_attr
that plays a crucial role in processing model attributes for JAX compatibility. It handles special cases for certain attributes, converting them to JAX arrays or creating appropriate mask instances.The function appears to be part of a broader effort to enhance the DeepMD framework's compatibility with JAX, as mentioned in the PR summary. It's likely utilized by other parts of the codebase, such as the
DPAtomicModel
class, to ensure proper attribute handling in JAX-based models.Overall, the implementation is clean, efficient, and well-aligned with the PR's objectives.
deepmd/jax/model/ener_model.py (5)
1-5
: LGTM: Imports and license identifier look good.The SPDX license identifier is correctly placed at the top of the file, and the
Any
import fromtyping
is appropriately used in the code.
6-9
: LGTM: Imports are appropriate and align with usage.The imports of
EnergyModelDP
andDPAtomicModel
are correctly used in the class definition and method implementation.
10-15
: LGTM: Imports are correctly used in the code.The
flax_module
decorator andBaseModel
for registration are properly imported and utilized in the class definition.
18-20
: LGTM: Class definition is well-structured and properly decorated.The
EnergyModel
class is correctly registered withBaseModel
, decorated withflax_module
, and inherits fromEnergyModelDP
. This structure aligns well with the JAX framework and the overall design of the DeepMD-kit.
1-24
: Overall, the implementation ofEnergyModel
is well-designed and aligns with JAX integration objectives.The new
EnergyModel
class successfully extends the existingEnergyModelDP
class, integrating seamlessly with the JAX framework through appropriate decorators and registrations. The custom__setattr__
method ensures consistent initialization of the atomic model, which is crucial for maintaining the integrity of the energy calculations.The code is clean, well-structured, and follows good practices in terms of imports, class definition, and method implementation. It effectively achieves the goal of enhancing the DeepMD-kit's compatibility with JAX-based models.
deepmd/jax/descriptor/se_e2_a.py (2)
11-13
: LGTM: Import statement for BaseDescriptor.The import statement for
BaseDescriptor
is correctly added and follows Python conventions.
22-23
: LGTM: Class registration decorators.The
DescrptSeA
class is correctly registered with two identifiers using the@BaseDescriptor.register
decorators. This approach provides flexibility in referencing the class and aligns with the class naming convention.deepmd/backend/jax.py (1)
41-41
: LGTM. Suffix addition is consistent with new I/O feature.The addition of
".jax"
to thesuffixes
class variable is appropriate and consistent with the newly added I/O feature support.source/tests/consistent/io/test_io.py (3)
3-3
: LGTM: Import shutil for enhanced file operationsThe addition of the
shutil
import is appropriate, as it provides high-level file operations that will be used in the updatedtearDown
method.
64-67
: LGTM: Improved cleanup process in tearDown methodThe changes in the
tearDown
method enhance the cleanup process:
- Using
is_file()
is more explicit and safer than the previous existence check.- The addition of
is_dir()
check allows for proper handling of directories.- Utilizing
shutil.rmtree()
ensures complete removal of directory contents.These improvements make the cleanup process more robust and comprehensive.
Line range hint
1-190
: Summary: Enhancements to IO testing with JAX supportThe changes in this file successfully integrate JAX support into the IO testing framework:
- The cleanup process in
tearDown
has been improved to handle both files and directories more robustly.- JAX has been added as a backend for testing in the
test_data_equal
method, expanding the test coverage.- Appropriate exclusions for JAX-specific version information have been added to maintain consistency with other backends.
These modifications align well with the PR objectives of enhancing JAX support in the DeepMD framework. The changes are well-implemented and improve the overall testing capabilities of the IO module.
deepmd/dpmodel/atomic_model/dp_atomic_model.py (1)
172-185
: Overall, excellent changes that enhance extensibility.The modifications to the
DPAtomicModel
class, including the addition ofbase_descriptor_cls
andbase_fitting_cls
attributes and the corresponding changes in thedeserialize
method, significantly improve the extensibility and flexibility of the class. These changes allow for easier subclassing and customization, which is a valuable improvement to the codebase.The implementation is clean and well-thought-out. The minor suggestions provided earlier for docstring enhancement and variable naming are purely for further clarity and consistency, but do not detract from the quality of the changes.
Great job on improving the design of this class!
pyproject.toml (3)
141-142
: Summary of changes to pyproject.tomlThe modifications to
pyproject.toml
are focused on enhancing JAX support:
- Two new JAX-related dependencies added: 'orbax-checkpoint' and 'jax-ai-stack'.
- A new linting rule exception for the JAX-related code directory.
These changes are consistent with the existing structure and patterns in the file. They suggest a significant addition of JAX-related functionality to the project.
Please ensure that these changes are thoroughly tested and documented, especially considering their potential impact on the project's dependencies and coding standards.
Also applies to: 407-407
407-407
: Verify the necessity of the new linting rule exception for JAX.A new path has been added to ignore the TID253 linting rule (banned module-level imports) for the JAX-related code:
"deepmd/jax/**" = ["TID253"]This is consistent with existing exceptions for other backend-specific directories (tf, pt). However, please confirm:
- Is this exception necessary for the JAX implementation?
- Have you considered alternatives to avoid module-level imports in the JAX code?
- Does this align with the project's coding standards and best practices?
To check for module-level imports in the JAX code that might trigger this rule, you can run:
#!/bin/bash # Search for module-level imports in the JAX-related code rg "^import\s+|^from\s+" deepmd/jaxIf possible, consider refactoring to avoid module-level imports, or document why they are necessary in this case.
141-142
: Verify the necessity of new JAX-related dependencies.Two new dependencies have been added to the
jax
optional dependencies:
- 'orbax-checkpoint'
- 'jax-ai-stack'
Both are conditional on Python version 3.10 or higher, which aligns with JAX requirements. However, please confirm:
- Are these dependencies essential for the new JAX-related features introduced in this PR?
- Have you tested the functionality with these new dependencies?
- Are there any specific version constraints needed for these packages?
To ensure these dependencies are used in the codebase, you can run the following command:
deepmd/jax/atomic_model/dp_atomic_model.py (1)
28-30
:⚠️ Potential issueVerify that overriding
__setattr__
does not interfere with Flax module behaviorOverriding
__setattr__
in a class that inherits from a Flax module may affect parameter management, serialization, and other internal mechanisms of Flax. Ensure that this override is necessary and does not introduce unintended side effects on Flax's functionality.deepmd/jax/model/model.py (2)
56-63
: Correctly handles unimplemented features with clear exceptionsThe
get_model
function appropriately checks for unimplemented features, such as the presence of"spin"
in thedata
, and raises aNotImplementedError
with an explicit message.
1-63
: Overall implementation is clean and follows best practicesThe module is well-structured, and the use of class methods to instantiate descriptors, fittings, and models based on dynamic types is effective. The code is readable and aligns with the project's design patterns.
source/tests/consistent/model/common.py (2)
14-14
: LGTMAdding
INSTALLED_JAX
to the imported variables enables conditional JAX functionality as intended.
27-31
: LGTMThe addition of JAX-specific imports within the
if INSTALLED_JAX
block ensures that JAX dependencies are only imported when available.deepmd/dpmodel/model/transform_output.py (3)
3-3
: Importingarray_api_compat
enhances compatibilityThe addition of
import array_api_compat
allows the code to be compatible with different array libraries, such as NumPy, JAX, or others that conform to the array API standard.
27-27
: Usingget_namespace
to obtain array namespaceAssigning
xp = array_api_compat.get_namespace(coord_ext)
ensures that subsequent array operations use the appropriate array namespace, enhancing flexibility and compatibility with different array backends.
36-36
: Utilizingxp.sum
for array operationsReplacing
np.sum
withxp.sum
enables the summation to be performed using the array namespacexp
, supporting various array libraries and improving the code's adaptability.deepmd/jax/utils/serialization.py (1)
21-47
: LGTM!The
deserialize_to_file
function correctly handles model deserialization and file saving for the JAX backend.deepmd/jax/descriptor/dpa1.py (2)
19-21
: ImportingBaseDescriptor
for class registrationThe addition of the import statement for
BaseDescriptor
is appropriate, ensuring that descriptors can be properly registered.
82-83
: RegisteringDescrptDPA1
under multiple identifiersRegistering
DescrptDPA1
with both"dpa1"
and"se_atten"
identifiers allows for flexibility in accessing the descriptor using different names. This is acceptable if intentional.source/tests/consistent/model/test_ener.py (7)
Line range hint
16-21
: Approved: ImportingINSTALLED_JAX
The addition of
INSTALLED_JAX
to the imports ensures that the JAX installation status is correctly handled.
40-45
: Approved: Conditional Import of JAX ModulesThe conditional import statements for JAX modules ensure compatibility when JAX is installed.
94-95
: Approved: Settingjax_class
and Initializingargs
Assigning
jax_class
toEnergyModelJAX
and initializingargs
integrates the JAX backend into the test class.
104-107
: Approved: Addingskip_jax
PropertyThe
skip_jax
property correctly determines whether to skip JAX tests based on the installation status.
115-116
: Approved: HandlingEnergyModelJAX
inpass_data_to_cls
Adding support for
EnergyModelJAX
inpass_data_to_cls
allows constructing JAX models appropriately.
186-194
: Approved: Implementingeval_jax
MethodThe
eval_jax
method enables evaluation of JAX models, consistent with other backends.
203-204
: Approved: Updatingextract_ret
for JAX BackendEnsures that outputs from the JAX backend are properly extracted and formatted.
deepmd/dpmodel/atomic_model/base_atomic_model.py (2)
204-204
:⚠️ Potential issueSimplify iteration over dictionary keys
When iterating over a dictionary, you can iterate directly over it instead of using
dict.keys()
. This is more Pythonic and slightly more efficient.Apply this diff to simplify the loop:
- for kk in ret_dict.keys(): + for kk in ret_dict:Likely invalid or redundant comment.
3-3
:⚠️ Potential issuePotential Python version compatibility issue with
math.prod
The
math.prod
function is available only in Python 3.8 and above. If your project needs to support earlier Python versions, consider usingnumpy.prod
instead to maintain compatibility.Apply this diff to replace
math.prod
withnumpy.prod
:- import math + # Remove import math if not used elsewhere.And modify the usage:
- out_shape2 = math.prod(out_shape[2:]) + out_shape2 = np.prod(out_shape[2:])Likely invalid or redundant comment.
deepmd/jax/fitting/fitting.py (2)
12-14
: ImportBaseFitting
to enable class registrationThe addition of
BaseFitting
import facilitates the registration of fitting classes, enhancing the extensibility of the fitting framework.
39-44
: Verify the correctness of attribute handling in__setattr__
Ensure that all necessary attributes are appropriately handled in
setattr_for_general_fitting
. Attributes not explicitly processed may lead to unexpected behavior if not correctly managed.Run the following script to check for attributes set elsewhere in the codebase:
Also applies to: 47-52
✅ Verification successful
Attribute handling in
__setattr__
is correctly implemented.All attribute assignments within subclasses of
BaseFitting
are appropriately managed throughsetattr_for_general_fitting
, ensuring consistent and expected behavior.
- Location:
deepmd/jax/fitting/fitting.py
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Find all attribute assignments in subclasses of BaseFitting. # Test: Search for `__setattr__` usages in subclasses. Expect: Attributes are consistently handled. ast-grep --pattern $'class $_ extends BaseFitting { $$$ def __setattr__($_, $_): $$$ $$$ }'Length of output: 102
Script:
#!/bin/bash # Description: Find all attribute assignments in subclasses of BaseFitting. # Test: Search for `__setattr__` usages in subclasses. Expect: Attributes are consistently handled. ast-grep --pattern 'class $_(BaseFitting): $$$ def __setattr__($_, $_): $$$ $$$'Length of output: 92
Script:
#!/bin/bash # Description: Find all usages of 'setattr_for_general_fitting' in the codebase. rg 'setattr_for_general_fitting' --context 2Length of output: 2193
Signed-off-by: Jinzhe Zeng <[email protected]>
for more information, see https://pre-commit.ci
One way is to save the model parameters, opt_state and all the runtime status of the training to compressed file. |
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
We can manually call |
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (10)
deepmd/jax/atomic_model/base_atomic_model.py (1)
15-16
: LGTM: Robust null-value handling addedThe addition of the null check before wrapping values in
ArrayAPIVariable
improves robustness and is consistent with the existing pattern used for other attributes.Consider these improvements for better code clarity:
-def base_atomic_model_set_attr(name, value): +def base_atomic_model_set_attr(name: str, value: Any) -> Any: + """Set attribute for base atomic model with proper type conversion. + + Args: + name: Attribute name to set + value: Value to set, which may be wrapped in ArrayAPIVariable for JAX serialization + Returns: + Processed value, potentially wrapped in appropriate container + """deepmd/jax/utils/type_embed.py (1)
22-23
: Add documentation for attribute handling.The wrapping of
econf_tebd
inArrayAPIVariable
looks correct and aligns with JAX checkpoint implementation. However, consider adding documentation to explain:
- The expected type/shape of
econf_tebd
- Why wrapping in
ArrayAPIVariable
is necessary- The implications for serialization/deserialization
Add docstring to explain the attribute handling:
def __setattr__(self, name: str, value: Any) -> None: + """Set attributes with special handling for JAX arrays. + + Args: + name: Attribute name + value: Attribute value. For 'econf_tebd', expects JAX-compatible array + which will be wrapped in ArrayAPIVariable for consistent + serialization in JAX checkpoint format. + """ if name in {"econf_tebd"}:deepmd/jax/descriptor/se_e2_a.py (1)
30-31
: Consider adding type hints and documentation.The wrapping of
dstd
anddavg
withArrayAPIVariable
is a good practice for JAX array handling. However, consider these improvements:
- Add type hints for better code clarity:
- def __setattr__(self, name: str, value: Any) -> None: + def __setattr__(self, name: str, value: Any | ArrayAPIVariable) -> None:
- Add a docstring explaining the purpose of wrapping these attributes with
ArrayAPIVariable
.deepmd/jax/fitting/fitting.py (2)
33-34
: Add documentation for ArrayAPIVariable usageWhile the null check is a good addition, it would be helpful to document why certain attributes need to be wrapped in
ArrayAPIVariable
.Add docstring explaining the purpose:
def setattr_for_general_fitting(name: str, value: Any) -> Any: + """Handle attribute setting for fitting networks. + + Args: + name: Attribute name + value: Attribute value + + Returns: + Processed value based on attribute type: + - Model parameters (bias_atom_e, fparam_avg, etc.): Wrapped in ArrayAPIVariable for JAX compatibility + - emask: Converted to AtomExcludeMask + - nets: Deserialized NetworkCollection + """
Line range hint
22-40
: Consider splitting attribute handling into separate functionsThe function handles multiple different types of attribute processing. Consider splitting it into more focused functions for better maintainability.
+def _wrap_model_param(value: Any) -> Any: + """Wrap model parameters in ArrayAPIVariable.""" + value = to_jax_array(value) + return ArrayAPIVariable(value) if value is not None else value + +def _create_exclude_mask(value: Any) -> AtomExcludeMask: + """Create an atom exclude mask.""" + return AtomExcludeMask(value.ntypes, value.exclude_types) + +def _deserialize_networks(value: Any) -> NetworkCollection: + """Deserialize network collection.""" + return NetworkCollection.deserialize(value.serialize()) + def setattr_for_general_fitting(name: str, value: Any) -> Any: if name in { "bias_atom_e", @@ -29,13 +42,9 @@ "aparam_avg", "aparam_inv_std", }: - value = to_jax_array(value) - if value is not None: - value = ArrayAPIVariable(value) + return _wrap_model_param(value) elif name == "emask": - value = AtomExcludeMask(value.ntypes, value.exclude_types) + return _create_exclude_mask(value) elif name == "nets": - value = NetworkCollection.deserialize(value.serialize()) - return value + return _deserialize_networks(value) + return valuedeepmd/jax/common.py (2)
86-97
: Add documentation and type hints to improve code clarity.The
ArrayAPIVariable
class implementation looks correct but could benefit from improved documentation and type safety:
- Add a class docstring explaining:
- Purpose of the class
- Usage examples
- Requirements for the
value
attribute- Add type hints for method parameters and return types
Here's the suggested improvement:
class ArrayAPIVariable(nnx.Variable): + """A Variable that implements Array API and DLPack protocols. + + This class wraps a value that supports Array API and DLPack protocols, + delegating all array-related operations to the underlying value. + + Examples + -------- + >>> var = ArrayAPIVariable(jnp.array([1, 2, 3])) + >>> np.asarray(var) # converts to numpy via __array__ + """ - def __array__(self, *args, **kwargs): + def __array__(self, dtype: Optional[np.dtype] = None) -> np.ndarray: return self.value.__array__(*args, **kwargs) - def __array_namespace__(self, *args, **kwargs): + def __array_namespace__(self) -> Any: return self.value.__array_namespace__(*args, **kwargs) - def __dlpack__(self, *args, **kwargs): + def __dlpack__(self, stream: Optional[int] = None) -> Any: return self.value.__dlpack__(*args, **kwargs) - def __dlpack_device__(self, *args, **kwargs): + def __dlpack_device__(self) -> tuple[str, int]: return self.value.__dlpack_device__(*args, **kwargs)
86-97
: Consider adding validation to ensure value compatibility.To improve robustness, consider validating that the wrapped value supports the required protocols during initialization. This would prevent runtime errors when array operations are attempted on incompatible values.
Example implementation:
def __init__(self, value: Any): """Initialize with validation for required protocols. Parameters ---------- value : Any A value that supports Array API and DLPack protocols Raises ------ TypeError If value doesn't support required protocols """ required_methods = ['__array__', '__array_namespace__', '__dlpack__', '__dlpack_device__'] missing = [method for method in required_methods if not hasattr(value, method)] if missing: raise TypeError(f"Value must support {', '.join(missing)} methods") super().__init__(value)deepmd/jax/descriptor/dpa1.py (3)
Line range hint
32-36
: Consider adding type validation before deserialization.While the deserialization logic is correct, it might be safer to verify that the value has a
serialize
method before attempting to use it. This could prevent cryptic errors if an invalid value is passed.def __setattr__(self, name: str, value: Any) -> None: if name in {"in_proj", "out_proj"}: + if not hasattr(value, 'serialize'): + raise ValueError(f"Expected serializable object for {name}, got {type(value)}") value = NativeLayer.deserialize(value.serialize()) return super().__setattr__(name, value)
Line range hint
50-56
: Consider optimizing list comprehension and adding validation.The current implementation could be improved for better efficiency and safety:
- Validate that value is a list before processing
- Use generator expression instead of list comprehension for memory efficiency with large lists
def __setattr__(self, name: str, value: Any) -> None: if name == "attention_layers": + if not isinstance(value, (list, tuple)): + raise ValueError(f"Expected list for {name}, got {type(value)}") - value = [ - NeighborGatedAttentionLayer.deserialize(ii.serialize()) for ii in value - ] + value = list(NeighborGatedAttentionLayer.deserialize(ii.serialize()) + for ii in value if hasattr(ii, 'serialize')) return super().__setattr__(name, value)
Line range hint
68-81
: LGTM: Comprehensive attribute handling with proper JAX integration.The implementation correctly handles various attribute types and ensures proper JAX compatibility. A few suggestions for improvement:
- Consider grouping related attributes in dictionaries for cleaner code
- Add docstring explaining the purpose of each attribute type
def __setattr__(self, name: str, value: Any) -> None: + # Group related attributes + ARRAY_ATTRS = {"mean", "stddev"} + NETWORK_ATTRS = {"embeddings", "embeddings_strip"} + - if name in {"mean", "stddev"}: + if name in ARRAY_ATTRS: value = to_jax_array(value) if value is not None: value = ArrayAPIVariable(value) - elif name in {"embeddings", "embeddings_strip"}: + elif name in NETWORK_ATTRS:
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (9)
- deepmd/jax/atomic_model/base_atomic_model.py (2 hunks)
- deepmd/jax/common.py (1 hunks)
- deepmd/jax/descriptor/dpa1.py (2 hunks)
- deepmd/jax/descriptor/se_e2_a.py (2 hunks)
- deepmd/jax/fitting/fitting.py (2 hunks)
- deepmd/jax/utils/exclude_mask.py (3 hunks)
- deepmd/jax/utils/serialization.py (1 hunks)
- deepmd/jax/utils/type_embed.py (2 hunks)
- pyproject.toml (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- deepmd/jax/utils/serialization.py
- pyproject.toml
🧰 Additional context used
🔇 Additional comments (16)
deepmd/jax/atomic_model/base_atomic_model.py (2)
3-3
: LGTM: Import addition is well-organizedThe addition of
ArrayAPIVariable
import is properly grouped with related imports and aligns with the new JAX checkpoint functionality.
Line range hint
11-21
: Verify ArrayAPIVariable serialization supportSince this function processes attributes for the new JAX checkpoint format, we should verify that
ArrayAPIVariable
properly supports serialization.✅ Verification successful
Let me gather more specific information about serialization in the JAX implementation.
This will help us understand:
- The complete implementation of ArrayAPIVariable
- The JAX-specific serialization utilities
- The base Variable class that ArrayAPIVariable extends
Let me check one more aspect to understand how the JAX model handles serialization through the base model class.
Based on the gathered information, I can now provide the final response:
ArrayAPIVariable provides sufficient serialization support through JAX ecosystem
The implementation shows that:
ArrayAPIVariable
extendsnnx.Variable
and implements the necessary array interface methods (__array__
,__array_namespace__
,__dlpack__
)- JAX backend has dedicated serialization utilities in
deepmd/jax/utils/serialization.py
that handle model state serialization using the Orbax Checkpoint (ocp) system- The serialization is handled at the model level where variables are automatically included in the state dictionary through the JAX/Orbax checkpointing system
The reason for verification is that:
ArrayAPIVariable
is properly integrated into JAX's array ecosystem through standard array protocol methods- The serialization is handled by JAX's robust checkpointing system (Orbax) which is designed to handle all JAX array types
- The implementation follows JAX's standard practices for model serialization, ensuring proper handling of all model variables
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Search for ArrayAPIVariable implementation to verify serialization support echo "Checking ArrayAPIVariable implementation:" rg -A 10 "class ArrayAPIVariable" # Look for serialization-related methods echo "Checking serialization methods:" rg -l "(de)?serialize|to_bytes|from_bytes" --type pyLength of output: 6296
Script:
#!/bin/bash # Check JAX serialization implementation echo "Checking JAX serialization implementation:" rg -A 10 "class ArrayAPIVariable" deepmd/jax/common.py echo "\nChecking JAX serialization utils:" rg -A 10 "serialize|deserialize" deepmd/jax/utils/serialization.py echo "\nChecking Variable base class implementation:" rg -A 10 "class Variable" deepmd/jax/common.pyLength of output: 2724
Script:
#!/bin/bash # Check BaseModel serialization implementation echo "Checking BaseModel implementation:" rg -A 10 "class BaseModel" deepmd/jax # Check if nnx.Variable (parent of ArrayAPIVariable) handles serialization echo "\nChecking nnx imports and usage:" rg "^import.*nnx" deepmd/jax rg "^from.*nnx" deepmd/jax # Check model serialization methods echo "\nChecking model serialization methods:" rg -A 5 "def (de)?serialize" deepmd/jax/modelLength of output: 392
deepmd/jax/utils/type_embed.py (2)
8-8
: LGTM: Import addition is well-organized.The
ArrayAPIVariable
import is properly grouped with related imports from the same module.
22-23
: Verify consistent ArrayAPIVariable usage across the codebase.Let's ensure this pattern of wrapping values in
ArrayAPIVariable
is consistently applied across the codebase for JAX checkpoint compatibility.✅ Verification successful
Based on the results, I can now generate the final response as we have sufficient information about the
ArrayAPIVariable
usage patterns across the codebase.ArrayAPIVariable usage is consistent across JAX implementation
The pattern of wrapping values in
ArrayAPIVariable
is consistently implemented across the JAX codebase:
deepmd/jax/fitting/fitting.py
deepmd/jax/utils/type_embed.py
deepmd/jax/utils/exclude_mask.py
deepmd/jax/descriptor/se_e2_a.py
deepmd/jax/descriptor/dpa1.py
deepmd/jax/atomic_model/base_atomic_model.py
All these files follow the same pattern of wrapping values with
ArrayAPIVariable
when handling attributes, ensuring consistent serialization behavior across the JAX implementation.🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check for consistent ArrayAPIVariable usage patterns # Expected: Similar wrapping patterns in other JAX-related files # Check for other occurrences of ArrayAPIVariable wrapping rg -A 3 "ArrayAPIVariable\(" --type py # Check for similar attribute handling patterns ast-grep --pattern 'if $value is not None: $value = ArrayAPIVariable($value)'Length of output: 2408
deepmd/jax/utils/exclude_mask.py (2)
9-11
: LGTM: Import changes are appropriate.The addition of
ArrayAPIVariable
import aligns with the new functionality being added.
20-21
: Verify compatibility with parent class implementation.The null check before wrapping in
ArrayAPIVariable
looks correct, but we should verify that this maintains compatibility with the parentAtomExcludeMaskDP
class's expectations.✅ Verification successful
Let me gather more information about the JAX implementation to ensure full compatibility.
Let me check one more thing to understand the inheritance hierarchy and the ArrayAPIVariable usage.
Let me check one final thing about the class inheritance.
The null check is compatible with parent class implementation
The verification confirms that:
- The JAX implementation (
AtomExcludeMask
) inherits fromAtomExcludeMaskDP
- The parent class uses
type_mask
as a numpy array- The null check in
__setattr__
safely handles the conversion:
- First converts to JAX array using
to_jax_array
- Only wraps non-null values in
ArrayAPIVariable
- Uses
super().__setattr__
to maintain parent class behaviorThe change maintains compatibility while adding type safety for JAX-specific array handling.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check parent class implementation to ensure compatibility # Test: Look for type_mask usage in parent class rg -A 5 "type_mask" "deepmd/dpmodel/utils/exclude_mask.py"Length of output: 1215
Script:
#!/bin/bash # Check JAX implementation and usage of type_mask rg -A 5 "type_mask" "deepmd/jax/utils/exclude_mask.py"Length of output: 464
Script:
#!/bin/bash # Check class definition and inheritance ast-grep --pattern 'class AtomExcludeMask$_' deepmd/jax/utils/exclude_mask.py # Check ArrayAPIVariable usage rg -B 2 "ArrayAPIVariable" deepmd/jax/utils/exclude_mask.pyLength of output: 526
Script:
#!/bin/bash # Get the full class definition to understand inheritance cat deepmd/jax/utils/exclude_mask.pyLength of output: 1011
deepmd/jax/descriptor/se_e2_a.py (2)
8-8
: LGTM: Import addition is appropriate.The addition of
ArrayAPIVariable
import is well-placed and necessary for the new attribute wrapping functionality.
30-31
: Verify serialization compatibility.Since this is part of the checkpoint I/O feature, let's verify that
ArrayAPIVariable
-wrapped attributes are correctly handled during serialization/deserialization.✅ Verification successful
Let me try different search paths since the previous attempt didn't find the test directory.
Let's examine the ArrayAPIVariable class definition and the serialization implementation in JAX.
Based on the code analysis, I can now provide a final response:
ArrayAPIVariable
wrapping is safe for serializationThe review concern can be resolved because:
ArrayAPIVariable
is a thin wrapper around JAX arrays that only implements array protocol methods (__array__
,__array_namespace__
,__dlpack__
, etc.)- The JAX serialization system (in
deepmd/jax/utils/serialization.py
) handles model state through Orbax checkpointing, which works at a lower level with pure dictionaries- The test file
se_e2_a.py
shows that array values are properly converted usingto_array_api_strict_array
before assignment, ensuring compatibility🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Search for serialization tests involving ArrayAPIVariable rg -l "ArrayAPIVariable.*serialize" test/ # Look for existing checkpoint tests rg -l "checkpoint.*jax" test/ # Check if there are any serialization-related functions in ArrayAPIVariable ast-grep --pattern 'class ArrayAPIVariable { $$$ serialize($$$) { $$$ } $$$ }'Length of output: 260
Script:
#!/bin/bash # First, let's find where the tests might be located fd -t d "test" # Search for ArrayAPIVariable definition and usage rg -l "class ArrayAPIVariable" # Look for checkpoint-related code rg -l "checkpoint.*jax" # Search for serialization/deserialization related code rg -l "serialize|deserialize"Length of output: 5702
Script:
#!/bin/bash # Check ArrayAPIVariable implementation cat deepmd/jax/common.py # Check JAX serialization implementation cat deepmd/jax/utils/serialization.py # Look at the test for se_e2_a descriptor cat source/tests/array_api_strict/descriptor/se_e2_a.pyLength of output: 6120
deepmd/jax/fitting/fitting.py (3)
9-9
: LGTM: Import is properly organizedThe new import is correctly grouped with other imports from the same module.
Line range hint
42-54
: Skip: Duplicate setattr methodsA previous review comment already addressed the code duplication in
__setattr__
methods.
33-34
: Verify consistent ArrayAPIVariable usageLet's verify that the
ArrayAPIVariable
wrapping is consistently applied across the codebase for similar model parameters.✅ Verification successful
Let me gather more information to verify the consistency of ArrayAPIVariable usage, particularly focusing on the initialization and handling of model parameters.
This will help us understand:
- The overall usage pattern of ArrayAPIVariable
- If there's a consistent pattern for parameter initialization
- If there are any inconsistencies in how parameters are wrapped
Based on the comprehensive search results, I can now generate the final response:
ArrayAPIVariable usage is consistent across the codebase
The wrapping of values with
ArrayAPIVariable
follows a consistent pattern throughout the codebase:
- All numerical parameters are first converted using
to_jax_array
and then wrapped withArrayAPIVariable
- The pattern
if value is not None: value = ArrayAPIVariable(value)
is consistently used in:
- Fitting parameters (bias_atom_e, fparam_avg, etc.)
- Descriptor parameters (mean, stddev, dstd, davg)
- Type embedding parameters (econf_tebd)
- Atomic model parameters (out_bias, out_std)
- Exclude mask parameters (type_mask)
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check for consistent ArrayAPIVariable usage across the codebase # Test: Search for similar parameter handling in other files rg -l "bias_atom_e|fparam_avg|fparam_inv_std|aparam_avg|aparam_inv_std" | \ xargs rg "ArrayAPIVariable" -C 5Length of output: 1458
Script:
#!/bin/bash # Search for all ArrayAPIVariable usages to understand the pattern rg "ArrayAPIVariable" -B 3 -A 3 # Also check for similar parameter initialization patterns ast-grep --pattern 'if $_ is not None: $_ = ArrayAPIVariable($_)'Length of output: 7543
deepmd/jax/common.py (1)
84-85
: LGTM! Clean integration with existing code.The placement of
ArrayAPIVariable
and its integration with the existing utilities maintains good separation of concerns.deepmd/jax/descriptor/dpa1.py (4)
16-18
: LGTM: Import additions are appropriate.The new imports support the JAX checkpoint functionality by providing necessary utilities for array handling and module decoration.
Line range hint
40-46
: Similar type validation could be added here.The implementation is correct but could benefit from the same type validation suggested above.
Line range hint
89-97
: LGTM: Clean implementation with proper registration.The implementation correctly handles both attributes and the class is properly registered for both "dpa1" and "se_atten" descriptors.
Line range hint
1-97
: Verify checkpoint serialization/deserialization functionality.The implementation looks solid, but we should verify that the checkpoint I/O works correctly across all components.
✅ Verification successful
Let me gather more specific information about the serialization implementation to verify the checkpoint functionality properly.
Based on the gathered information, I can now provide a final verification response.
Checkpoint serialization/deserialization implementation is properly structured and secure.
The implementation shows:
- A robust serialization utility (
serialization.py
) using Orbax checkpoint handler for JAX models- Consistent serialization/deserialization pattern across the descriptor hierarchy
- Proper type conversion and error handling for JAX-specific formats
- Clean inheritance from DP implementations with appropriate JAX adaptations
The code in
dpa1.py
correctly implements the checkpoint I/O by:
- Converting all nested components through their respective serialization methods
- Maintaining the component hierarchy during deserialization
- Properly handling special cases like environment matrices and masks
- Using type-safe conversions for arrays and network components
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify that all necessary components for checkpoint I/O are present # and properly implemented across the codebase. # Check for consistent serialization/deserialization implementation echo "Checking for consistent serialize/deserialize implementation..." rg -l "serialize|deserialize" | grep "jax" # Look for potential checkpoint-related test files echo "Looking for checkpoint-related tests..." fd -e py "test.*checkpoint" "tests/" # Check for any TODO/FIXME comments related to serialization echo "Checking for outstanding TODOs..." rg -i "todo|fixme|xxx" -A 2 "serialize|deserialize"Length of output: 803
Script:
#!/bin/bash # Check the implementation of serialization in the core utility file echo "Checking serialization utility implementation..." cat deepmd/jax/utils/serialization.py # Check how serialization is used in this specific descriptor echo -e "\nChecking descriptor serialization usage..." rg -A 5 "serialize|deserialize" deepmd/jax/descriptor/dpa1.py # Check for any test files related to JAX serialization echo -e "\nChecking for JAX-related tests..." fd -e py "test.*jax" . # Check the base descriptor implementation echo -e "\nChecking base descriptor implementation..." cat deepmd/jax/descriptor/base_descriptor.pyLength of output: 4965
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4236 +/- ##
==========================================
- Coverage 84.58% 84.23% -0.35%
==========================================
Files 547 548 +1
Lines 51327 51392 +65
Branches 3047 3051 +4
==========================================
- Hits 43413 43292 -121
- Misses 6967 7148 +181
- Partials 947 952 +5 ☔ View full report in Codecov by Sentry. |
Implement a JAX checkpoint format. I name it
*.jax
as I don't find existing conventions.Summary by CodeRabbit
Release Notes
New Features
.jax
file suffix in the backend configuration.Bug Fixes
Chores