Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax/array-api): DPA-2 #4294

Merged
merged 7 commits into from
Nov 2, 2024
Merged

Conversation

njzjz
Copy link
Member

@njzjz njzjz commented Oct 31, 2024

Summary by CodeRabbit

  • New Features

    • Introduced new classes for enhanced descriptor functionality, including DescrptDPA2, DescrptBlockRepformers, and DescrptBlockSeTTebd.
    • Added serialization and deserialization methods for better state management of descriptor objects.
  • Improvements

    • Enhanced compatibility with various array backends through the integration of array_api_compat.
    • Refactored existing methods to utilize new array API functions for improved performance.
    • Updated documentation to reflect JAX as a supported backend alongside PyTorch.
  • Bug Fixes

    • Updated handling of attributes in several classes to ensure correct deserialization and type safety.
  • Tests

    • Enhanced testing capabilities for JAX and Array API Strict backend integration, including conditional imports and new evaluation methods.

Signed-off-by: Jinzhe Zeng <[email protected]>
data.pop("type")
variables = data.pop("@variables")
embeddings = data.pop("embeddings")
env_mat = data.pop("env_mat")

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable env_mat is not used.
Copy link
Contributor

coderabbitai bot commented Oct 31, 2024

📝 Walkthrough
📝 Walkthrough

Walkthrough

The changes in this pull request primarily enhance the functionality of several descriptor classes within the deepmd library by integrating the array_api_compat library. Key modifications include refactoring methods to utilize the array API for array operations, adding serialization and deserialization methods, and introducing new classes that extend existing functionality. The updates aim to improve compatibility with various array backends and streamline attribute handling in descriptor classes.

Changes

File Path Change Summary
deepmd/dpmodel/descriptor/dpa2.py Enhanced DescrptDPA2 class with updated import statements and modified call and serialize methods to use array_api_compat.
deepmd/dpmodel/descriptor/repformers.py Updated DescrptBlockRepformers class to use array_api_compat for array operations, added serialize and deserialize methods.
deepmd/dpmodel/descriptor/se_t_tebd.py Refactored DescrptSeTTebd and DescrptBlockSeTTebd classes to utilize array_api_compat, added serialize and deserialize methods.
deepmd/dpmodel/utils/nlist.py Modified build_multiple_neighbor_list function to replace NumPy operations with array_api_compat equivalents.
deepmd/jax/descriptor/dpa2.py Introduced DescrptDPA2 class with custom __setattr__ method for attribute management and deserialization.
deepmd/jax/descriptor/repformers.py Added classes extending existing ones from deepmd.dpmodel.descriptor.repformers, implementing custom __setattr__ methods for attribute processing.
deepmd/jax/descriptor/se_t_tebd.py Created DescrptBlockSeTTebd and DescrptSeTTebd classes with custom __setattr__ methods for attribute handling and deserialization.
source/tests/consistent/descriptor/test_dpa2.py Updated test class to conditionally import DescrptDPA2JAX, added properties and methods to handle JAX backend testing.
source/tests/array_api_strict/descriptor/dpa2.py Introduced DescrptDPA2 class with custom __setattr__ method for attribute management and deserialization.
source/tests/array_api_strict/descriptor/repformers.py Added classes extending existing ones from deepmd.dpmodel.descriptor.repformers, implementing custom __setattr__ methods for attribute processing.
source/tests/array_api_strict/descriptor/se_t_tebd.py Created DescrptBlockSeTTebd and DescrptSeTTebd classes with custom __setattr__ methods for attribute handling and deserialization.
doc/model/dpa2.md Updated documentation to reflect the addition of JAX as a supported backend alongside PyTorch.

Possibly related PRs

  • feat(dp/pt): refactor se_e3 descriptor #3813: The changes in the DescrptSeT class in se_t.py involve significant updates to array handling, which aligns with the modifications made to the DescrptDPA2 class in dpa2.py regarding array operations and compatibility with the array API.
  • feat(jax/array-api): hybrid descriptor #4275: The introduction of the DescrptHybrid class in hybrid.py includes updates to array operations, similar to the enhancements made in the main PR to the DescrptDPA2 class, focusing on compatibility with array API standards.
  • feat(jax/array-api): dipole/polarizability fitting #4278: The modifications in the DipoleFitting and PolarFitting classes to utilize array_api_compat for array operations reflect the same focus on enhancing compatibility with array operations as seen in the main PR's changes to the DescrptDPA2 class.

Suggested reviewers

  • wanghan-iapcm
  • iProzd

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 14

🧹 Outside diff range and nitpick comments (6)
deepmd/dpmodel/utils/nlist.py (1)

218-241: Consider adding backend-specific optimizations

While the current implementation using array_api_compat provides good backend compatibility, consider adding backend-specific optimizations (e.g., JAX's vmap or jit) through optional imports. This could improve performance while maintaining the current clean abstraction.

🧰 Tools
🪛 Ruff

228-228: Local variable nall is assigned to but never used

Remove assignment to unused variable nall

(F841)

deepmd/jax/descriptor/se_t_tebd.py (2)

29-46: Consider adding docstrings to DescrptBlockSeTTebd and its methods

Adding docstrings to the DescrptBlockSeTTebd class and its methods enhances code readability and maintainability by providing context and explanations for future developers and users of the class.


50-56: Consider adding docstrings to DescrptSeTTebd and its methods

For better clarity and documentation, consider adding docstrings to the DescrptSeTTebd class and its methods. This practice improves understanding and ease of use for other developers interacting with your code.

deepmd/dpmodel/descriptor/se_t_tebd.py (2)

381-381: Consistency in variable naming: Replace to_numpy_array with to_np_array

In line with the project's coding conventions, consider using to_np_array for consistency, assuming that other parts of the codebase use this naming.

Apply this diff if applicable:

- "davg": to_numpy_array(obj["davg"]),
+ "davg": to_np_array(obj["davg"]),

Line range hint 12-46: Add docstrings for class methods

Several methods in the DescrptBlockSeTTebd class lack docstrings. Adding docstrings enhances code readability and maintainability.

For example, add a docstring to the __init__ method:

def __init__(...):
    """Initialize the DescrptBlockSeTTebd with the given parameters."""
    ...

Also applies to: 49-87

deepmd/dpmodel/descriptor/repformers.py (1)

399-399: Rename unused loop variable idx to _

The loop control variable idx in for idx, ll in enumerate(self.layers): is not used within the loop body. To indicate that it is intentionally unused, consider renaming it to _ or _idx.

Apply this diff to rename idx to _:

- for idx, ll in enumerate(self.layers):
+ for _, ll in enumerate(self.layers):
🧰 Tools
🪛 Ruff

399-399: Loop control variable idx not used within loop body

Rename unused idx to _idx

(B007)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between ff04d8b and 1cac90b.

📒 Files selected for processing (8)
  • deepmd/dpmodel/descriptor/dpa2.py (7 hunks)
  • deepmd/dpmodel/descriptor/repformers.py (31 hunks)
  • deepmd/dpmodel/descriptor/se_t_tebd.py (7 hunks)
  • deepmd/dpmodel/utils/nlist.py (1 hunks)
  • deepmd/jax/descriptor/dpa2.py (1 hunks)
  • deepmd/jax/descriptor/repformers.py (1 hunks)
  • deepmd/jax/descriptor/se_t_tebd.py (1 hunks)
  • source/tests/consistent/descriptor/test_dpa2.py (4 hunks)
🧰 Additional context used
🪛 Ruff
deepmd/dpmodel/descriptor/repformers.py

399-399: Loop control variable idx not used within loop body

Rename unused idx to _idx

(B007)

deepmd/dpmodel/descriptor/se_t_tebd.py

712-712: Local variable ng is assigned to but never used

Remove assignment to unused variable ng

(F841)


803-803: Local variable env_mat is assigned to but never used

Remove assignment to unused variable env_mat

(F841)

deepmd/dpmodel/utils/nlist.py

228-228: Local variable nall is assigned to but never used

Remove assignment to unused variable nall

(F841)

deepmd/jax/descriptor/repformers.py

97-98: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)

🪛 GitHub Check: CodeQL
deepmd/dpmodel/descriptor/se_t_tebd.py

[notice] 803-803: Unused local variable
Variable env_mat is not used.

🔇 Additional comments (16)
deepmd/dpmodel/utils/nlist.py (3)

218-225: LGTM: Array namespace initialization and padding implementation

The array namespace initialization and padding logic is well-implemented, ensuring compatibility across different array backends while maintaining correct neighbor list dimensions.


231-236: LGTM: Robust coordinate manipulation and distance calculation

The implementation correctly handles coordinate transformations and distance calculations while maintaining backend compatibility. The use of float("inf") for masked values is an elegant solution.


240-241: LGTM: Efficient neighbor list filtering

The neighbor list filtering implementation is well-vectorized and correctly handles the cutoff distance check while maintaining backend compatibility.

source/tests/consistent/descriptor/test_dpa2.py (3)

18-18: LGTM! Clean implementation of conditional JAX support.

The conditional import pattern is consistent with other backends and ensures graceful handling when JAX is not installed.

Also applies to: 32-35


278-279: LGTM! Properties follow established patterns.

The skip_jax and jax_class properties are well-integrated with the existing backend properties.

Also applies to: 283-283


379-388: LGTM! Consistent implementation of JAX evaluation method.

The eval_jax method follows the same pattern as other backend evaluation methods and correctly handles all required parameters.

deepmd/jax/descriptor/se_t_tebd.py (1)

48-49: Verify the order of decorators for DescrptSeTTebd

The class DescrptSeTTebd is decorated with @BaseDescriptor.register("se_e3_tebd") followed by @flax_module. The order of decorators affects how the class is registered and initialized. Ensure that this order is intentional and that DescrptSeTTebd is correctly registered and behaves as expected.

deepmd/dpmodel/descriptor/se_t_tebd.py (1)

329-329: Potential division by zero in calculation of nall

When computing nall, there's a division operation that could potentially result in a division by zero if coord_ext is improperly shaped.

Please ensure that coord_ext is correctly shaped and that nf and nall are properly computed in all scenarios.

deepmd/dpmodel/descriptor/dpa2.py (7)

Line range hint 794-798: Proper initialization of array namespace for backend compatibility

The use of xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) correctly obtains a backend-agnostic array namespace, ensuring compatibility across different array backends for subsequent array operations.


831-831: Ensure shapes are compatible for concatenation

At line 831, g1 = xp.concatenate([g1, g1_three_body], axis=-1), please verify that g1 and g1_three_body have compatible shapes along the last axis. This ensures that concatenation proceeds without errors when use_three_body is True.


839-840: Confirm correct usage of xp.tile and xp.take_along_axis

The operations using xp.tile and xp.take_along_axis are crucial for aligning g1_ext with the extended mapping. Double-check that mapping.reshape(nframes, nall, 1) and the subsequent tiling align correctly with g1's dimensions to prevent indexing errors.


854-854: Validate concatenation of g1 and g1_inp

At line 854, the concatenation g1 = xp.concatenate([g1, g1_inp], axis=-1) combines the updated representations with the original input features. Ensure that this operation aligns with the intended architecture and that the dimensions of g1 and g1_inp match appropriately along the concatenation axis.


891-892: Convert davg and dstd to NumPy arrays for serialization

Converting repinit["davg"] and repinit["dstd"] to NumPy arrays using to_numpy_array ensures that these statistics are properly serialized. This change enhances compatibility with different array backends during the serialization process.


904-905: Ensure proper serialization of repformers statistics

Similarly, converting repformers["davg"] and repformers["dstd"] to NumPy arrays is necessary for consistent serialization of the repformers component. This ensures that the statistics are correctly stored and retrievable.


921-922: Handle serialization of three-body repinit statistics

For the three-body repinit, converting repinit_three_body["davg"] and repinit_three_body["dstd"] to NumPy arrays is essential. This modification ensures that all statistical data is serialized uniformly, maintaining consistency across different descriptor components.

deepmd/dpmodel/descriptor/repformers.py (1)

396-396: Verify the replacement of -1 with zeros in nlist

Replacing -1 values in nlist with zeros may cause invalid neighbor indices to reference the first atom (index 0). Ensure that this replacement does not introduce unintended behavior in subsequent computations, especially if the first atom is a valid index.

deepmd/dpmodel/utils/nlist.py Show resolved Hide resolved
deepmd/jax/descriptor/se_t_tebd.py Show resolved Hide resolved
deepmd/jax/descriptor/se_t_tebd.py Show resolved Hide resolved
deepmd/jax/descriptor/dpa2.py Show resolved Hide resolved
deepmd/jax/descriptor/dpa2.py Outdated Show resolved Hide resolved
deepmd/dpmodel/descriptor/se_t_tebd.py Show resolved Hide resolved
deepmd/dpmodel/descriptor/se_t_tebd.py Show resolved Hide resolved
deepmd/dpmodel/descriptor/dpa2.py Outdated Show resolved Hide resolved
deepmd/dpmodel/descriptor/repformers.py Show resolved Hide resolved
deepmd/dpmodel/descriptor/repformers.py Outdated Show resolved Hide resolved
Copy link

codecov bot commented Nov 1, 2024

Codecov Report

Attention: Patch coverage is 93.26146% with 25 lines in your changes missing coverage. Please review.

Project coverage is 84.38%. Comparing base (8355947) to head (0941a0f).
Report is 6 commits behind head on devel.

Files with missing lines Patch % Lines
deepmd/dpmodel/descriptor/repformers.py 93.19% 10 Missing ⚠️
deepmd/jax/descriptor/dpa2.py 83.78% 6 Missing ⚠️
deepmd/jax/descriptor/se_t_tebd.py 84.37% 5 Missing ⚠️
deepmd/jax/descriptor/repformers.py 94.59% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4294      +/-   ##
==========================================
+ Coverage   84.35%   84.38%   +0.03%     
==========================================
  Files         559      562       +3     
  Lines       52509    52758     +249     
  Branches     3054     3054              
==========================================
+ Hits        44292    44520     +228     
- Misses       7260     7279      +19     
- Partials      957      959       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

njzjz and others added 3 commits November 1, 2024 00:50
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

🧹 Outside diff range and nitpick comments (3)
source/tests/array_api_strict/descriptor/se_t_tebd.py (1)

25-25: Add docstring to explain attribute transformations

The class would benefit from documentation explaining the special handling of different attributes.

Add a docstring:

 class DescrptBlockSeTTebd(DescrptBlockSeTTebdDP):
+    """Descriptor block for SE-T-TEBD with array API support.
+
+    This class extends DescrptBlockSeTTebdDP with custom attribute handling:
+    - mean, stddev: Converted to array API arrays
+    - embeddings, embeddings_strip: Deserialized as NetworkCollection
+    - env_mat: No value storage
+    - emask: Created as PairExcludeMask
+    """
deepmd/dpmodel/descriptor/dpa2.py (1)

844-848: Consider moving the mapping assertion earlier.

The assertion assert mapping is not None should be moved before using the mapping variable to fail fast and provide clearer error messages.

 def call(self, coord_ext: np.ndarray, atype_ext: np.ndarray, nlist: np.ndarray, mapping: Optional[np.ndarray] = None):
+    assert mapping is not None, "mapping must be provided"
     xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
     use_three_body = self.use_three_body
     nframes, nloc, nnei = nlist.shape
     # ...
-    assert mapping is not None
     mapping_ext = xp.tile(xp.reshape(mapping, (nframes, nall, 1)), (1, 1, g1.shape[-1]))
deepmd/dpmodel/descriptor/repformers.py (1)

448-450: Remove commented code

The old transpose implementation is left as a comment but is no longer needed since the code has been migrated to use array_api_compat.

Apply this diff to clean up the code:

- # rot_mat = xp.transpose(h2g2, (0, 1, 3, 2))
  rot_mat = xp.matrix_transpose(h2g2)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 1cac90b and 84c1900.

📒 Files selected for processing (9)
  • deepmd/dpmodel/descriptor/dpa2.py (8 hunks)
  • deepmd/dpmodel/descriptor/repformers.py (31 hunks)
  • deepmd/dpmodel/utils/nlist.py (1 hunks)
  • deepmd/jax/descriptor/dpa2.py (1 hunks)
  • deepmd/jax/descriptor/repformers.py (1 hunks)
  • source/tests/array_api_strict/descriptor/dpa2.py (1 hunks)
  • source/tests/array_api_strict/descriptor/repformers.py (1 hunks)
  • source/tests/array_api_strict/descriptor/se_t_tebd.py (1 hunks)
  • source/tests/consistent/descriptor/test_dpa2.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/jax/descriptor/dpa2.py
🧰 Additional context used
🪛 Ruff
deepmd/dpmodel/descriptor/repformers.py

424-424: Loop control variable idx not used within loop body

Rename unused idx to _idx

(B007)

deepmd/dpmodel/utils/nlist.py

228-228: Local variable nall is assigned to but never used

Remove assignment to unused variable nall

(F841)

deepmd/jax/descriptor/repformers.py

104-105: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)

source/tests/array_api_strict/descriptor/repformers.py

95-96: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)

🔇 Additional comments (31)
source/tests/array_api_strict/descriptor/se_t_tebd.py (2)

1-23: LGTM: Well-organized imports with proper type hints

The imports are logically grouped, properly typed, and include all necessary dependencies for the descriptor classes.


11-13: Verify consistent array API usage across the codebase

Let's ensure the array API integration is consistent across related descriptor classes.

✅ Verification successful

Based on the results, I can generate the final response as we have sufficient information about the array API usage across the descriptor classes.

Array API integration is consistently implemented across descriptor classes

The codebase shows uniform implementation of array API strict handling:

  • All descriptor classes (se_e2_r.py, se_t_tebd.py, hybrid.py, se_e2_a.py, repformers.py, dpa1.py, dpa2.py) properly import and use to_array_api_strict_array
  • Consistent attribute handling pattern where values are converted using to_array_api_strict_array
  • Special handling in hybrid.py and repformers.py for list values with list comprehension
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for consistent array API usage in descriptor classes

# Test 1: Check for array_api_compat usage in descriptor files
echo "Checking array_api_compat usage in descriptors..."
rg -l "array_api_compat" "source/tests/array_api_strict/descriptor/"

# Test 2: Check for to_array_api_strict_array usage
echo "Checking to_array_api_strict_array usage..."
rg "to_array_api_strict_array" "source/tests/array_api_strict/descriptor/"

# Test 3: Check for similar __setattr__ implementations
echo "Checking for similar attribute handling..."
ast-grep --pattern 'class $_ {
  $$$
  def __setattr__($_, $_, $_) {
    $$$
    to_array_api_strict_array($$$)
    $$$
  }
}'

Length of output: 2019

deepmd/dpmodel/utils/nlist.py (3)

218-227: LGTM: Padding logic is well-implemented

The padding implementation correctly handles variable-sized neighbor lists using -1 as sentinel values, maintaining compatibility with array_api_compat.


230-237: LGTM: Robust coordinate transformation and distance calculation

The implementation correctly:

  1. Handles coordinate reshaping for vectorized operations
  2. Properly masks invalid neighbor indices
  3. Uses array_api_compat's vector_norm for backend-agnostic distance calculations

228-229: ⚠️ Potential issue

Remove unused variable nall

The variable nall is assigned but never used in this scope.

-    coord1 = xp.reshape(coord, (nb, -1, 3))
-    nall = coord1.shape[1]
+    coord1 = xp.reshape(coord, (nb, -1, 3))

Likely invalid or redundant comment.

🧰 Tools
🪛 Ruff

228-228: Local variable nall is assigned to but never used

Remove assignment to unused variable nall

(F841)

deepmd/dpmodel/descriptor/dpa2.py (3)

7-7: LGTM: Array compatibility imports are well-structured.

The new imports enhance array operations compatibility across different backends and provide necessary conversion utilities.

Also applies to: 13-18


Line range hint 797-812: LGTM: Array operations are properly abstracted.

The code effectively uses the array API compatibility layer for backend-agnostic operations:

  • Obtains array namespace from input tensors
  • Uses xp.reshape and xp.concat consistently
  • Handles type embeddings correctly

Also applies to: 837-837


899-900: LGTM: Consistent array serialization.

The code properly converts array variables to numpy format during serialization using to_numpy_array, ensuring compatibility across different array backends.

Also applies to: 912-913, 929-930

deepmd/dpmodel/descriptor/repformers.py (6)

8-8: Well-structured utility functions for tensor operations!

The new transpose utility functions are well-implemented and properly documented. They handle specific transposition patterns needed for tensor operations in a clean and efficient way.

Also applies to: 15-20, 48-67


Line range hint 392-404: Clean migration to array_api_compat!

The array operations have been properly migrated to use array_api_compat, making the code more backend-agnostic while maintaining the same functionality.


412-421: Proper handling of distance calculations with array API!

The direct distance calculations and array operations have been correctly migrated to use the array API while maintaining the original logic.


1528-1561: Well-implemented array operations in _update_g1_conv!

The array operations are properly implemented using array_api_compat with careful handling of shapes, types, and edge cases. The code is well-documented and maintains good structure.


Line range hint 1294-1474: Robust residual initialization and handling!

The residual lists are properly initialized and consistently handled across different update types. The implementation follows good practices for managing neural network residual connections.


1880-1882: Consistent array serialization!

The residual arrays are properly converted to NumPy arrays during serialization, ensuring consistent data storage and compatibility.

source/tests/array_api_strict/descriptor/dpa2.py (1)

47-47: ⚠️ Potential issue

Typo in attribute name 'g1_shape_tranform'

The attribute name 'g1_shape_tranform' seems misspelled. Did you mean 'g1_shape_transform'?

Run the following script to verify the usage of 'g1_shape_tranform' and 'g1_shape_transform' in the codebase:

source/tests/array_api_strict/descriptor/repformers.py (5)

31-46: Good implementation of custom __setattr__ in DescrptBlockRepformers

The method effectively handles attribute assignments with appropriate transformations and maintains consistency.


48-53: Proper customization of __setattr__ in Atten2Map

The method correctly handles the mapqk attribute by deserializing it appropriately.


55-60: Correct handling of attributes in Atten2MultiHeadApply

The __setattr__ method processes mapv and head_map attributes accurately.


62-67: Effective use of __setattr__ in Atten2EquiVarApply

The method properly handles the head_map attribute through deserialization.


69-74: Appropriate attribute management in LocalAtten

The __setattr__ method correctly processes mapq, mapkv, and head_map attributes.

deepmd/jax/descriptor/repformers.py (7)

32-49: Implementation of DescrptBlockRepformers is correct

The __setattr__ method correctly handles attribute assignments, deserialization, and type conversions for attributes like mean, stddev, layers, g2_embd, and emask.


53-58: Implementation of Atten2Map is correct

The __setattr__ method properly deserializes the mapqk attribute into a NativeLayer instance.


60-65: Implementation of Atten2MultiHeadApply is correct

The __setattr__ method accurately deserializes the mapv and head_map attributes into NativeLayer instances.


68-73: Implementation of Atten2EquiVarApply is correct

The __setattr__ method correctly handles the deserialization of the head_map attribute into a NativeLayer instance.


76-81: Implementation of LocalAtten is correct

The __setattr__ method effectively deserializes the mapq, mapkv, and head_map attributes into NativeLayer instances.


Line range hint 85-108: Implementation of RepformerLayer is consistent and correct

The __setattr__ method appropriately handles attribute assignments, deserializations, and type conversions for various attributes, ensuring correct processing within the class.

🧰 Tools
🪛 Ruff

104-105: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


105-107: 🛠️ Refactor suggestion

Simplify nested if statements by combining conditions

To enhance code readability, combine the nested if statements into a single condition using and. This reduces indentation and makes the code more concise.

Apply the following change:

-            elif name in {"loc_attn"}:
-                if value is not None:
-                    value = LocalAtten.deserialize(value.serialize())
+            elif name in {"loc_attn"} and value is not None:
+                value = LocalAtten.deserialize(value.serialize())

Likely invalid or redundant comment.

source/tests/consistent/descriptor/test_dpa2.py (4)

18-19: Import INSTALLED_ARRAY_API_STRICT appropriately

The addition of INSTALLED_ARRAY_API_STRICT to the import list is correct and necessary for conditional feature support.


33-41: Conditional imports for new descriptors are correctly implemented

The use of conditional imports based on INSTALLED_JAX and INSTALLED_ARRAY_API_STRICT effectively handles optional dependencies, ensuring that the code remains robust whether or not these libraries are installed.


289-290: Assigning new descriptor classes aligns with existing patterns

The addition of jax_class and array_api_strict_class follows the established convention for managing different backend classes. This enhances the code's scalability and maintainability.


386-405: New evaluation methods are consistent with existing implementation

The eval_jax and eval_array_api_strict methods are properly defined and consistent with other evaluation methods like eval_dp and eval_pt. They correctly pass the necessary parameters for descriptor evaluation.

@njzjz njzjz marked this pull request as ready for review November 1, 2024 05:00
@njzjz njzjz changed the title feat(jax): DPA-2 feat(jax/array-api): DPA-2 Nov 1, 2024
@njzjz njzjz requested review from iProzd and wanghan-iapcm November 1, 2024 05:02
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (2)
deepmd/dpmodel/descriptor/repformers.py (2)

8-20: Add docstrings to the new transpose functions.

The new transpose functions look good, but they would benefit from docstrings explaining:

  • Input tensor shape and meaning of each dimension
  • Output tensor shape and how dimensions are reordered
  • Example usage

Example docstring:

def xp_transpose_01423(x):
    """Transpose a 5D tensor from (d0,d1,d2,d3,d4) to (d0,d1,d4,d2,d3).
    
    Parameters
    ----------
    x : array_like
        Input tensor of shape (batch, loc, nei1, nei2, feat)
        
    Returns
    -------
    array_like
        Transposed tensor of shape (batch, loc, feat, nei1, nei2)
    """

Also applies to: 48-68


1880-1883: Add comment explaining numpy conversion.

Consider adding a comment explaining why residuals need to be converted to numpy arrays during serialization.

# Convert residuals to numpy arrays for serialization compatibility
"g1_residual": [to_numpy_array(aa) for aa in self.g1_residual],
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 84c1900 and db36432.

📒 Files selected for processing (3)
  • deepmd/dpmodel/descriptor/repformers.py (31 hunks)
  • deepmd/jax/descriptor/dpa2.py (1 hunks)
  • source/tests/array_api_strict/descriptor/dpa2.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/jax/descriptor/dpa2.py
🧰 Additional context used
🪛 Ruff
deepmd/dpmodel/descriptor/repformers.py

424-424: Loop control variable idx not used within loop body

Rename unused idx to _idx

(B007)

🔇 Additional comments (8)
deepmd/dpmodel/descriptor/repformers.py (6)

Line range hint 392-404: LGTM! Clean refactoring to use array API.

The changes successfully migrate numpy operations to use the array API, making the code more backend-agnostic. The new transpose helper functions are used appropriately.

Also applies to: 412-450


460-477: LGTM! Comprehensive deserialization implementation.

The deserialize method properly handles the reconstruction of the complex object hierarchy.


478-525: LGTM! Complete serialization implementation.

The serialize method thoroughly captures the object state, including all configuration parameters and nested objects.


1528-1561: LGTM! Clean array operation refactoring.

The array operations have been properly migrated to use array_api_compat while maintaining the original logic.


Line range hint 1294-1474: LGTM! Well-structured residual initialization.

The residual initialization is properly implemented with:

  • Separate lists for different components (g1, g2, h2)
  • Consistent seeding for reproducibility
  • Proper handling of initialization modes

423-427: ⚠️ Potential issue

Fix unused loop variable.

🧰 Tools
🪛 Ruff

424-424: Loop control variable idx not used within loop body

Rename unused idx to _idx

(B007)

source/tests/array_api_strict/descriptor/dpa2.py (2)

35-57: Implementation of __setattr__ method is correct

The overridden __setattr__ method correctly handles attribute assignments with appropriate deserialization and type checking.


47-47: ⚠️ Potential issue

Typo in attribute name 'g1_shape_tranform'

The attribute 'g1_shape_tranform' appears to be misspelled; it should be 'g1_shape_transform'.

Apply this diff to correct the typo:

-            elif name in {"g1_shape_tranform", "tebd_transform"}:
+            elif name in {"g1_shape_transform", "tebd_transform"}:

Run the following script to check for other occurrences of 'g1_shape_tranform' in the codebase:

Signed-off-by: Jinzhe Zeng <[email protected]>
@github-actions github-actions bot added the Docs label Nov 1, 2024
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (1)
doc/model/dpa2.md (1)

Line range hint 11-24: Consider adding JAX-specific installation requirements.

While the document includes PyTorch-specific installation requirements, it should also include any specific requirements or considerations for using the JAX backend.

Consider adding a new section:

## Requirements of installation {{ jax_icon }}

If you plan to use the JAX backend, ensure you have installed JAX with the appropriate hardware support (CPU/GPU/TPU). See [JAX installation guide](https://github.com/google/jax#installation) for details.
🧰 Tools
🪛 Markdownlint

7-7: null
Bare URL used

(MD034, no-bare-urls)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between db36432 and e0f1572.

📒 Files selected for processing (1)
  • doc/model/dpa2.md (1 hunks)
🔇 Additional comments (2)
doc/model/dpa2.md (2)

1-1: LGTM: Backend support clearly documented.

The addition of JAX to the supported backends is clearly indicated both in the header and the note section.

Also applies to: 4-4


Line range hint 26-26: Verify data format compatibility with JAX backend.

The document states that DPA-2 supports both standard and mixed type data formats. Let's verify this is true for the JAX backend as well.

✅ Verification successful

Let me check the specific data format handling in the JAX implementation and tests.


Let me check one more aspect to verify the actual implementation of data format handling in the JAX backend for DPA2.


Based on the evidence gathered from the shell scripts, I can now provide a final response:

DPA-2 data format compatibility with JAX backend is verified

The JAX backend fully supports mixed type data formats for DPA-2, as evidenced by:

  • JAX model implementation explicitly handles mixed_types parameter in its core functionality (deepmd/jax/model/hlo.py)
  • Test cases in test_dpa2.py confirm mixed type support through consistent testing across backends
  • The JAX descriptor (DescrptDPA2) inherits from DescrptDPA2DP, maintaining format compatibility with the base implementation
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for data format handling in JAX implementation
rg -l "mixed.*type" --type py | xargs rg "jax" -l

Length of output: 566


Script:

#!/bin/bash
# Check JAX implementation for data format handling
rg "mixed.*type" -A 5 -B 5 deepmd/jax/model/model.py deepmd/jax/model/hlo.py

# Check test cases for DPA2 with JAX backend
rg "mixed.*type|data_format" -A 5 -B 5 source/tests/consistent/descriptor/test_dpa2.py

Length of output: 4283


Script:

#!/bin/bash
# Check DPA2 descriptor implementation for JAX
rg "class.*DPA2" -A 10 deepmd/jax/descriptor/dpa2.py

# Also check if there are any format-specific implementations
rg "format|mixed_type" deepmd/jax/descriptor/dpa2.py

Length of output: 621

🧰 Tools
🪛 Markdownlint

7-7: null
Bare URL used

(MD034, no-bare-urls)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants