-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(jax): reformat nlist in the TF model #4336
Conversation
Format the neighbor list in the TF model to convert the dynamic shape to the determined shape, so the TF model can accept the neighbor list with a dynamic shape. Signed-off-by: Jinzhe Zeng <[email protected]>
📝 WalkthroughWalkthroughThe changes introduce a new function Changes
Suggested labels
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Outside diff range and nitpick comments (8)
deepmd/jax/jax2tf/format_nlist.py (4)
2-4
: Consider documenting the experimental numpy dependency.The code relies on TensorFlow's experimental numpy module. While this is necessary for numpy-like operations in TensorFlow, it's worth documenting this dependency and potential stability considerations.
Add a comment explaining the experimental nature:
+# Using experimental numpy for tensor operations - required for numpy-like functionality in TF import tensorflow.experimental.numpy as tnp
13-37
: Enhance docstring with type and constraint information.The docstring is well-structured but could benefit from additional details about types and constraints.
Consider adding:
- The expected dtype of
nlist
(presumably integer)- Valid range for
nsel
(must be positive)- Valid range for
rcut
(must be positive)- Whether negative values in
nlist
(other than -1) are valid
52-66
: Consider optimizing the truncation logic.While the implementation is correct, there are potential improvements for robustness and performance:
- Replace
float("inf")
withtf.float32.max
or the appropriate dtype max:- rr2 = tnp.where(m_real_nei, rr2, float("inf")) + rr2 = tnp.where(m_real_nei, rr2, tf.float32.max)
- Consider fusing operations to reduce memory usage:
- coord1 = tnp.take_along_axis(extended_coord, index, axis=1) - coord1 = coord1.reshape(n_nf, n_nloc, n_nsel, 3) + coord1 = tnp.take_along_axis(extended_coord, index, axis=1).reshape(n_nf, n_nloc, n_nsel, 3)
69-71
: Enhance the XLA-related comment.While the comment explains the purpose, it could be more detailed about the XLA implications.
Consider expanding the comment:
- # do a reshape any way; this will tell the xla the shape without any dynamic shape + # Explicitly reshape to help XLA compiler optimize the graph by providing + # static shape information, even though the shape hasn't changed. This + # eliminates the need for dynamic shape handling in downstream operations.source/jax2tf_tests/test_format_nlist.py (3)
16-44
: Add docstrings and comments to explain the test setup.The test setup uses specific numerical values and configurations that would benefit from documentation:
- Purpose of the test class
- Explanation of the test parameters (nf, nloc, ns, etc.)
- Reasoning behind the chosen coordinate values
- Expected behavior of the neighbor list construction
Consider adding a class docstring like:
class TestFormatNlist(tf.test.TestCase): """Tests the format_nlist function with various neighbor list configurations. The test setup simulates a small molecular system with: - 3 local atoms (nloc) - 5x5x3 periodic images (ns) - 2 atom types - Triclinic cell """
45-48
: Add more assertions to test_format_nlist_equal.While testing for equality is good, consider adding more specific assertions:
- Shape of the output
- Data type consistency
- Range of indices
def test_format_nlist_equal(self): nlist = format_nlist(self.ecoord, self.nlist, sum(self.nsel), self.rcut) + self.assertEqual(nlist.shape, self.nlist.shape) + self.assertEqual(nlist.dtype, self.nlist.dtype) + self.assertTrue(tf.reduce_all(nlist >= 0)) + self.assertTrue(tf.reduce_all(nlist < len(self.ecoord))) self.assertAllEqual(nlist, self.nlist)
77-91
: Add test cases for edge cases and error conditions.The current tests focus on valid inputs, but it would be valuable to test error conditions:
- Empty neighbor list
- Zero cutoff radius
- Invalid coordinates
Consider adding tests like:
def test_format_nlist_empty(self): empty_nlist = tf.zeros((1, self.nloc, 0), dtype=tf.int32) with self.assertRaises(ValueError): format_nlist(self.ecoord, empty_nlist, sum(self.nsel), self.rcut) def test_format_nlist_zero_cutoff(self): with self.assertRaises(ValueError): format_nlist(self.ecoord, self.nlist, sum(self.nsel), 0.0)deepmd/jax/jax2tf/serialization.py (1)
91-91
: Consider adding error handling for format_nlist calls.While the
format_nlist
calls are correctly placed and use appropriate model parameters, consider adding error handling to gracefully handle potential formatting failures. This is especially important as the function deals with shape transformations that could fail with invalid input.Example error handling:
- nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut()) + try: + nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut()) + except ValueError as e: + tf.print("Error formatting neighbor list:", e) + raiseAlso applies to: 116-116
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
deepmd/jax/jax2tf/format_nlist.py
(1 hunks)deepmd/jax/jax2tf/serialization.py
(4 hunks)source/jax2tf_tests/test_format_nlist.py
(1 hunks)
🔇 Additional comments (5)
deepmd/jax/jax2tf/format_nlist.py (2)
42-50
: LGTM! Efficient padding implementation.
The padding implementation correctly handles the case where we need more neighbors, using efficient tensor operations and preserving the original dtype.
67-68
: LGTM! Efficient handling of the no-change case.
The implementation correctly handles the case where no modification is needed.
source/jax2tf_tests/test_format_nlist.py (1)
3-3
: Consider the stability implications of using experimental numpy.
The code uses tensorflow.experimental.numpy
which might be subject to breaking changes. Consider adding a version check or documenting the minimum required TensorFlow version to ensure compatibility.
deepmd/jax/jax2tf/serialization.py (2)
13-15
: LGTM: Import statement is correctly placed.
The import of format_nlist
is properly organized and aligns with the PR's objective of reformatting neighbor lists.
82-82
: LGTM: Input signature changes enable dynamic neighbor lists.
The modification of nlist
parameter spec to accept dynamic shapes ([None, None, None]
) aligns with the PR objective. The change is consistently applied to both atomic virial and non-atomic virial functions.
Let's verify if there are any other functions in the codebase that expect the old fixed-shape neighbor list format:
Also applies to: 109-109
✅ Verification successful
Based on the search results, I can now generate the final response since I have sufficient information to verify the review comment:
LGTM: Dynamic neighbor list shape change is safe and consistent
The change from fixed shape [None, None, model.get_nnei()]
to dynamic shape [None, None, None]
for the nlist
parameter is safe because:
- The
format_nlist
function is called before any processing to ensure the neighbor list has the correct shape usingmodel.get_nnei()
andmodel.get_rcut()
(see jax2tf/serialization.py) - All internal processing functions properly handle the neighbor list shape by extracting dimensions at runtime rather than assuming fixed sizes
- The dynamic shape provides more flexibility while maintaining compatibility with the existing codebase's shape handling
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for TensorSpec definitions with model.get_nnei()
rg "TensorSpec.*model\.get_nnei\(\)" --type py
Length of output: 48
Script:
#!/bin/bash
# Let's search more broadly for neighbor list related tensor specs and shapes
rg "TensorSpec.*\[None,\s*None,.*\].*int64" -A 2 -B 2 --type py
# Also search for any references to neighbor list shapes
rg "nlist.*shape" --type py -A 2 -B 2
# Look for any model.get_nnei() usage
rg "get_nnei\(\)" --type py -A 2 -B 2
Length of output: 65349
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4336 +/- ##
==========================================
- Coverage 84.51% 84.48% -0.04%
==========================================
Files 575 576 +1
Lines 53398 53429 +31
Branches 3059 3059
==========================================
+ Hits 45129 45138 +9
- Misses 7306 7328 +22
Partials 963 963 ☔ View full report in Codecov by Sentry. |
This PR needs more effort to work with the C++ interface. |
The issue doesn't come from this PR, but come from the C++ interface. I've fixed it. |
Including nlist and no nlist interface. The limitation: A SavedModel created on a device cannot be run on another. For example, a CUDA model cannot be run on the CPU. The model is generated using #4336. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Added support for the JAX backend, including specific model and checkpoint file formats. - Introduced a new shell script for model conversion to enhance usability. - Updated installation documentation to clarify JAX support and requirements. - New section in documentation detailing limitations of the JAX backend with LAMMPS. - **Bug Fixes** - Enhanced error handling for model initialization and backend compatibility. - **Documentation** - Updated backend documentation to include JAX details and limitations. - Improved clarity in installation instructions for both TensorFlow and JAX. - **Tests** - Added comprehensive unit tests for JAX integration with the Deep Potential class. - Expanded test coverage for LAMMPS integration with DeepMD. - **Chores** - Updated CMake configurations and workflow files for improved testing and dependency management. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: Your Name <[email protected]>
Reformat the neighbor list in the TF model to convert the dynamic shape to the determined shape so the TF model can accept the neighbor list with a dynamic shape.
Summary by CodeRabbit
New Features
Tests