-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(jax): SavedModel C++ interface (including DPA-2 supports) #4307
Conversation
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
📝 WalkthroughWalkthroughThe pull request introduces support for the JAX backend in DeePMD-kit, updating documentation and codebase to accommodate this addition. Key changes include the inclusion of JAX-specific model and checkpoint file formats, installation instructions, and the implementation of a new Changes
Possibly related PRs
Suggested reviewers
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 10
🧹 Outside diff range and nitpick comments (14)
source/cmake/googletest.cmake.in (1)
14-14
: Document the reason for upgrading Google Test.Consider adding a comment explaining why the upgrade to v1.14.0 was necessary. This helps future maintainers understand the motivation behind version changes.
- GIT_TAG v1.14.0 + # Upgraded to v1.14.0 to support new testing features needed for JAX backend + GIT_TAG v1.14.0source/api_cc/include/common.h (2)
16-16
: Document JAX backend limitations.Given the known limitations mentioned in the PR description:
- Neighbor list only support
- Device compatibility restrictions
- Memory leak concerns
Consider adding documentation comments above the enum to clarify these limitations for API users.
Example addition:
+/** + * @brief Backend types supported by DeePMD-kit + * @note JAX backend has the following limitations: + * - Only supports neighbor list operations + * - Models created on one device (e.g., CUDA) cannot run on different devices (e.g., CPU) + * - Known memory leaks exist in the TensorFlow C API implementation + */ enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown };
16-16
: Consider adding runtime checks for JAX limitations.Given the device compatibility and neighbor list limitations, consider adding runtime validation to prevent misuse.
Would you like help implementing:
- Device compatibility checks
- Neighbor list validation
- Memory leak detection utilities
doc/install/install-from-source.md (2)
300-302
: Consider clarifying the JAX backend dependency.The documentation accurately combines TensorFlow and JAX backends, but it would be helpful to explicitly mention that the JAX backend requires TensorFlow's C++ library as a dependency. This would help users better understand the system requirements.
Consider adding a note like:
The C++ interfaces of both TensorFlow and JAX backends are based on the TensorFlow C++ library. +Note: The JAX backend requires TensorFlow's C++ library as a dependency, even if you're not using TensorFlow directly.
380-380
: Add version compatibility information.The documentation correctly indicates that both TensorFlow and JAX backends use these CMake variables, but it would be beneficial to add information about version compatibility requirements.
Consider adding version compatibility notes:
{{ tensorflow_icon }} {{ jax_icon }} Whether building the TensorFlow backend and the JAX backend. +Note: Ensure that the TensorFlow C++ library version is compatible with both your TensorFlow and JAX Python packages.
Also applies to: 396-396
source/api_cc/include/DeepPotJAX.h (4)
89-92
: Correct the Doxygen comment to match the method signatureThe method
is_aparam_nall()
does not take any parameters, but the comment includes@param[out] aparam_nall
, which is incorrect. Please update the comment to reflect the actual method signature.Apply this diff:
/** - * @brief Get whether the atom dimension of aparam is nall instead of fparam. - * @param[out] aparam_nall whether the atom dimension of aparam is nall - *instead of fparam. + * @brief Check if the atom dimension of `aparam` is `nall` instead of `fparam`. **/
49-49
: Remove unnecessary semicolons after method definitionsThe semicolons after the closing braces in the method definitions are unnecessary and can be removed to maintain consistency and readability.
Apply this diff:
} - }; + }Also applies to: 57-57, 65-65, 73-73, 81-81, 96-96
225-226
: Clarify the Doxygen comments forfparam
andaparam
parametersThe descriptions for the
fparam
andaparam
parameters appear incomplete or unclear. The lines seem to be missing words or have formatting issues.For
fparam
:* dim_fparam. Then all frames are assumed to be provided with the same *fparam.
For
aparam
:* natoms x dim_aparam. Then all frames are assumed to be provided with the *same aparam.
Please revise the comments to provide clear and complete descriptions of the expected parameter formats.
Suggested correction:
* @param[in] fparam The frame parameter. The array can be of size: * - nframes x dim_fparam, or - * dim_fparam. Then all frames are assumed to be provided with the same - *fparam. + * - dim_fparam (if all frames share the same `fparam`), in which case all frames are assumed to use the same `fparam`.* @param[in] aparam The atomic parameter. The array can be of size: * - nframes x natoms x dim_aparam, or - * natoms x dim_aparam. Then all frames are assumed to be provided with the - *same aparam. + * - natoms x dim_aparam (if all frames share the same `aparam`), in which case all frames are assumed to use the same `aparam`.Also applies to: 229-230
192-204
: Consider using RAII wrappers for TensorFlow C API resourcesTo improve resource management and exception safety, consider encapsulating the TensorFlow C API resources (e.g.,
TF_Graph*
,TF_Session*
, etc.) in RAII-style wrapper classes. This ensures that resources are automatically released when they go out of scope, helping to prevent memory leaks and simplifying the destructor implementation.source/api_cc/tests/test_deeppot_jax.cc (2)
100-110
: Refactor repeated variable assignments in test casesTo enhance maintainability and reduce code duplication, consider removing the repeated variable assignments at the beginning of each test case by directly accessing the class member variables.
Apply this diff to each test case:
- std::vector<VALUETYPE>& coord = this->coord; - std::vector<int>& atype = this->atype; - std::vector<VALUETYPE>& box = this->box; - std::vector<VALUETYPE>& expected_e = this->expected_e; - std::vector<VALUETYPE>& expected_f = this->expected_f; - std::vector<VALUETYPE>& expected_v = this->expected_v; - int& natoms = this->natoms; - double& expected_tot_e = this->expected_tot_e; - std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v; - deepmd::DeepPot& dp = this->dp; - float rc = dp.cutoff(); + float rc = this->dp.cutoff();Also applies to: 162-172, 244-255, 306-317, 369-380
76-76
: Add error handling fordp.init(file_name)
Ensure that
dp.init(file_name)
successfully loads the model file and handle any potential exceptions that may occur during initialization.Apply this diff to enhance error handling:
- dp.init(file_name); + try { + dp.init(file_name); + } catch (const std::exception& e) { + FAIL() << "Failed to initialize DeepPot: " << e.what(); + }source/api_cc/src/DeepPotJAX.cc (2)
35-38
: Optimize string truncation infind_function
The use of
substr
to truncatename_
can be replaced with the more efficientresize
method to avoid unnecessary copying.Apply this diff to optimize the string truncation:
- name_ = name_.substr(0, pos + 1); + name_.resize(pos + 1);🧰 Tools
🪛 cppcheck
[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
29-29
: Pass string parameters by const referencePassing string parameters by const reference can improve performance by avoiding unnecessary copies.
Apply this diff to modify the function signatures:
-inline void find_function(TF_Function*& found_func, - const std::vector<TF_Function*>& funcs, - const std::string func_name) { +inline void find_function(TF_Function*& found_func, + const std::vector<TF_Function*>& funcs, + const std::string& func_name) { ... -inline TFE_Op* get_func_op(TFE_Context* ctx, - const std::string func_name, - const std::vector<TF_Function*>& funcs, - const std::string device, - TF_Status* status) { +inline TFE_Op* get_func_op(TFE_Context* ctx, + const std::string& func_name, + const std::vector<TF_Function*>& funcs, + const std::string& device, + TF_Status* status) { ... -template <typename T> -inline T get_scalar(TFE_Context* ctx, - const std::string func_name, - const std::vector<TF_Function*>& funcs, - const std::string device, - TF_Status* status) { +template <typename T> +inline T get_scalar(TFE_Context* ctx, + const std::string& func_name, + const std::vector<TF_Function*>& funcs, + const std::string& device, + TF_Status* status) { ...Repeat similar changes for the functions
get_vector
,get_vector_string
, and any other functions wherestd::string
parameters are passed by value.Also applies to: 64-64, 66-66, 86-86, 88-88, 109-109, 111-111, 130-130, 132-132
🧰 Tools
🪛 cppcheck
[performance] 29-29: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
source/api_cc/src/DeepPot.cc (1)
65-72
: Ensure consistent backend initialization and error handling.The code correctly initializes the JAX backend when
BUILD_TENSORFLOW
is defined and throws an appropriate exception when it is not. However, consider clarifying the exception message for better understanding.Apply this diff to improve the exception message clarity:
} else if (deepmd::DPBackend::JAX == backend) { #ifdef BUILD_TENSORFLOW dp = std::make_shared<deepmd::DeepPotJAX>(model, gpu_rank, file_content); #else throw deepmd::deepmd_exception( - "TensorFlow backend is not built, which is used to load JAX2TF " - "SavedModels"); + "Cannot load JAX2TF SavedModels because TensorFlow backend is not built."); #endif
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (11)
.github/workflows/test_cc.yml
(1 hunks)doc/backend.md
(1 hunks)doc/install/install-from-source.md
(4 hunks)source/api_cc/include/DeepPotJAX.h
(1 hunks)source/api_cc/include/common.h
(1 hunks)source/api_cc/src/DeepPot.cc
(3 hunks)source/api_cc/src/DeepPotJAX.cc
(1 hunks)source/api_cc/tests/test_deeppot_jax.cc
(1 hunks)source/cmake/googletest.cmake.in
(1 hunks)source/tests/infer/deeppot_sea.savedmodel/.gitignore
(1 hunks)source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb
(1 hunks)
✅ Files skipped from review due to trivial changes (2)
- source/tests/infer/deeppot_sea.savedmodel/.gitignore
- source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb
🧰 Additional context used
🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc
[performance] 29-29: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 64-64: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 66-66: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 86-86: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 88-88: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 109-109: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 111-111: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 130-130: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 132-132: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
🔇 Additional comments (5)
source/cmake/googletest.cmake.in (1)
14-14
: Verify test compatibility with Google Test v1.14.0.
The upgrade from release-1.12.1 to v1.14.0 is reasonable, but please ensure all existing tests are compatible with the new version, as it includes breaking changes.
✅ Verification successful
Based on the analysis of the test files, I can now provide a final response regarding the Google Test upgrade compatibility:
Google Test v1.14.0 upgrade is safe to proceed
The codebase uses standard Google Test features that are stable across versions 1.12.1 to 1.14.0:
- Uses standard
TYPED_TEST
macros with consistent syntax - No usage of deprecated features like
TEST_CASE
orFRIEND_TEST
- No death test assertions (
ASSERT_DEATH
/EXPECT_DEATH
) - No custom test event listeners
- No usage of features that had breaking changes in v1.14.0
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for usage of deprecated or changed Google Test features
# that might be affected by the upgrade to v1.14.0
# Look for test files
echo "Searching for test files..."
fd -e cc -e cpp -e h -e hpp --full-path '.*test.*' -x rg -l 'TEST|TEST_F|TEST_P|EXPECT_|ASSERT_'
# Check for specific deprecated features or patterns
echo "Checking for potentially affected test patterns..."
rg -A 2 'GTEST_SKIP|INSTANTIATE_TEST_SUITE|INSTANTIATE_TEST_CASE'
Length of output: 3284
Script:
#!/bin/bash
# Let's check the actual test implementations for any potential compatibility issues
# Focus on specific patterns that changed between 1.12.1 and 1.14.0
# Check for usage of TEST_CASE (deprecated in favor of TEST_SUITE)
echo "Checking for TEST_CASE usage..."
rg "TEST_CASE\(" -A 2
# Check for FRIEND_TEST usage (syntax changed in 1.14.0)
echo "Checking for FRIEND_TEST usage..."
rg "FRIEND_TEST\(" -A 2
# Check for TYPED_TEST usage (behavior changes in 1.14.0)
echo "Checking for TYPED_TEST patterns..."
rg "TYPED_TEST\(" -A 2
# Check for ASSERT_DEATH and EXPECT_DEATH (behavior changes in 1.14.0)
echo "Checking death test patterns..."
rg "(ASSERT_DEATH|EXPECT_DEATH)\(" -A 2
# Check for custom test event listeners (API changes in 1.14.0)
echo "Checking for custom test listeners..."
rg "testing::TestEventListener" -A 2
Length of output: 49962
.github/workflows/test_cc.yml (1)
48-48
: Clarify strategy for memory leak detection
The addition of -fsanitize=leak -DENABLE_SANITIZE=1
enables leak detection, which is valuable for maintaining code quality. However, this seems to conflict with the PR's stated intention to skip memory leak checks due to known issues with the TensorFlow C API.
Could you clarify:
- How should the CI handle known memory leaks from the TensorFlow C API?
- Should we consider conditionally enabling leak sanitizer only for non-TensorFlow components?
- What's the strategy for distinguishing between TensorFlow C API leaks and potential leaks in the new SavedModel interface?
Let's check if there are any existing leak sanitizer suppressions:
✅ Verification successful
Based on the search results, I can now generate the final response:
Memory leak sanitizer configuration appears appropriate
The addition of -fsanitize=leak -DENABLE_SANITIZE=1
in the CI workflow is well-aligned with the codebase's needs:
- The codebase already has sanitizer-aware code, as evidenced by the
ENABLE_SANITIZE
guards in test files - No existing leak sanitizer suppressions were found, suggesting memory leaks should be actively detected rather than suppressed
- The TensorFlow C API usage in the codebase (e.g., in DeepPotJAX) shows careful memory management with proper cleanup
The conditional application (matrix.check_memleak && ...
) provides the flexibility to:
- Run tests with leak detection when investigating memory issues
- Skip leak detection when focusing on other aspects or when known TensorFlow C API leaks would be disruptive
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Look for leak sanitizer suppression configurations
rg -i "leak.*suppress|suppress.*leak"
# Check for any LSAN-related configurations
fd -e h -e cpp -e cc -e hpp | xargs rg "LSAN_OPTIONS|__lsan"
Length of output: 102
Script:
#!/bin/bash
# Look for any memory leak related configurations or handling in CMake files
rg -i "sanitize.*leak|leak.*sanitize" --type cmake
# Check for any TensorFlow-related memory management or leak handling
rg -i "tensorflow.*memory|memory.*leak" --type cpp --type hpp -A 3
# Look for any test configurations or setup files that might handle memory checks
fd "test|Test" -e cmake -e txt | xargs rg -i "sanitize|leak|memory"
# Check if there are any existing memory leak related preprocessor definitions
rg "ENABLE_SANITIZE|DISABLE_SANITIZE" --type cpp --type hpp --type cmake
Length of output: 68638
Script:
#!/bin/bash
# Check for any leak sanitizer related flags or configurations in CMake files
rg -i "sanitize.*leak|leak.*sanitize|ENABLE_SANITIZE" --type cmake
# Look for any memory leak related test configurations
fd -e cpp -e cc | xargs rg "ENABLE_SANITIZE|DISABLE_SANITIZE|fsanitize=leak"
# Check for any TensorFlow C API related memory management code
rg -i "tensorflow.*memory|memory.*leak" --type cpp -A 3
Length of output: 336
Script:
#!/bin/bash
# Check CMakeLists.txt files for sanitizer configurations
fd CMakeLists.txt | xargs rg -i "sanitize|ENABLE_SANITIZE"
# Look for any memory leak handling in C++ source files
rg -i "memory.*leak|leak.*check" --type cpp -B 2 -A 2
# Check for any TensorFlow C API usage and error handling
rg "TF_\w+\s*\(" --type cpp -B 2 -A 2
Length of output: 8410
doc/backend.md (1)
33-34
: LGTM! Documentation accurately reflects the SavedModel C++ interface dependency.
The documentation clearly states that only the .savedmodel
format supports C++ inference and requires the TensorFlow C++ interface, which aligns with the implementation.
source/api_cc/src/DeepPot.cc (2)
10-10
: Including DeepPotJAX.h
conditionally is appropriate.
Including DeepPotJAX.h
when BUILD_TENSORFLOW
is defined is correct since the JAX backend depends on TensorFlow.
45-47
: Correctly identify JAX backend based on model file extension.
The condition properly checks for the '.savedmodel' extension and sets the backend to JAX.
📝 WalkthroughWalkthroughThe pull request introduces significant enhancements across various components of the DeePMD-kit project. Key changes include the addition of JAX as a new backend option in the documentation and codebase, modifications to the workflow configuration for improved memory leak checks, and the introduction of a new class, Changes
Possibly related PRs
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 12
🧹 Outside diff range and nitpick comments (21)
source/cmake/googletest.cmake.in (1)
14-14
: Leverage improved memory leak detectionThis version update is particularly beneficial as Google Test v1.14.0 includes enhanced memory leak detection capabilities, which could help identify potential memory leaks mentioned in the PR objectives regarding the TensorFlow C API usage.
Consider enabling the following Google Test features in your test configurations:
- Use
testing::FLAGS_gtest_death_test_style = "threadsafe"
for better leak detection in death tests- Enable
GTEST_FLAG(detect_leaks)
when available on your platform.github/workflows/test_cc.yml (1)
48-48
: Consider enhancing memory leak detection.Given the PR's focus on C++ interface and known memory leak concerns, consider these improvements:
- Add a timeout specifically for sanitizer runs as they can take longer
- Consider running leak checks on multiple platforms to catch platform-specific memory issues
Example configuration:
strategy: matrix: check_memleak: [true, false] + # Add platform matrix when check_memleak is true + include: + - check_memleak: true + os: macos-latest + - check_memleak: true + os: windows-latest steps: + # Add timeout for sanitizer runs + timeout-minutes: ${{ matrix.check_memleak && 30 || 15 }}source/api_cc/include/common.h (1)
16-16
: Consider adding documentation for backend-specific limitations.Given the PR objectives mentioning device compatibility restrictions and memory leak concerns with the TF C API, it would be helpful to document these limitations in the header file.
Add documentation above the enum:
+/** + * @brief Supported deep learning backends + * @note JAX backend has the following limitations: + * - Models created on one device type cannot be executed on different devices + * - Only supports neighbor list functionality + * - May have memory leaks when using TensorFlow C API + */ enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown };source/api_cc/include/DeepPotJAX.h (5)
25-25
: Avoid passing primitive types by const reference.Passing primitive types like
int
by const reference (const int&
) does not provide performance benefits and can slightly hinder performance. It's recommended to pass these types by value instead.Apply the following diff to update parameter declarations:
-const int& gpu_rank +int gpu_rank -const int& ago +int agoAlso applies to: 30-30, 35-35, 40-40, 131-131, 244-244
62-65
: Clarify the constant return value innumb_types_spin()
.The method
numb_types_spin()
always returns0
, indicating that spin types are not supported. Consider documenting this behavior in the method's description to inform users.
27-27
: Maintain consistent formatting in Doxygen comments for better readability.In the Doxygen comments, ensure that each line begins with
*
(including a space) for consistency and improved readability.Example fix:
- *DP will read from the string instead of the file. + * DP will read from the string instead of the file.Also applies to: 37-37, 91-91, 226-226, 230-230
221-221
: Correct parameter name mismatch between documentation and code.In the documentation for the
compute
method, the parameter is referred to aslmp_list
, but in the code, it is namedinlist
. Please update the documentation to match the code to prevent confusion.Apply the following diff:
- * @param[in] lmp_list The input neighbour list. + * @param[in] inlist The input neighbour list.Also applies to: 243-243
86-86
: Declareget_type_map
as aconst
method.The
get_type_map
method does not modify any member variables and can be declared asconst
to reflect its non-mutating behavior.Apply the following change:
-void get_type_map(std::string& type_map); +void get_type_map(std::string& type_map) const;source/api_cc/tests/test_deeppot_jax.cc (5)
21-34
: Remove unnecessary commented-out codeThe block of commented-out Python code between lines 21-34 is not needed in the C++ test file and can be removed to improve readability.
Apply this diff to remove the commented code:
- // import numpy as np - // from deepmd.infer import DeepPot - // coord = np.array([ - // 12.83, 2.56, 2.18, 12.09, 2.87, 2.74, - // 00.25, 3.32, 1.68, 3.36, 3.00, 1.81, - // 3.51, 2.51, 2.60, 4.27, 3.22, 1.56 - // ]).reshape(1, -1) - // atype = np.array([0, 1, 1, 0, 1, 1]) - // box = np.array([13., 0., 0., 0., 13., 0., 0., 0., 13.]).reshape(1, -1) - // dp = DeepPot("deeppot_sea.savedmodel") - // e, f, v, ae, av = dp.eval(coord, box, atype, atomic=True) - // np.set_printoptions(precision=16) - // print(f"{e.ravel()=} {v.ravel()=} {f.ravel()=} {ae.ravel()=} - // {av.ravel()=}")
36-36
: Correct the leading zero in floating-point literalThe coordinate value
00.25
on line 36 has an unnecessary leading zero, which may cause confusion. It should be written as0.25
for clarity.Apply this diff:
- 00.25, 3.32, 1.68, 3.36, 3.00, 1.81, + 0.25, 3.32, 1.68, 3.36, 3.00, 1.81,
94-94
: Remove unnecessary emptyTearDown()
methodSince the
TearDown()
method is empty, it can be omitted to simplify the code.Apply this diff:
- void TearDown() override {};
121-121
: Use.data()
method for vector pointersWhen obtaining pointers to the underlying data of a
std::vector
, prefer using the.data()
method over&vector[0]
for clarity and safety, especially if the vector could be empty.Apply this diff to update the code:
-deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]); +deepmd::InputNlist inlist(nloc, ilist.data(), numneigh.data(), firstneigh.data());Also applies to: 184-184, 266-266, 345-345, 406-406
99-304
: Refactor test cases to reduce code duplicationThe test cases share significant portions of code, particularly in variable declarations and initializations. Consider extracting common code into helper functions or setting up shared fixtures to improve maintainability and readability.
source/api_cc/src/DeepPotJAX.cc (7)
29-29
: Pass 'func_name' by const reference to improve performanceIn the function
find_function
, the parameterfunc_name
is passed by value. Sincefunc_name
is astd::string
and is not modified within the function, consider passing it byconst
reference to avoid unnecessary copies.🧰 Tools
🪛 cppcheck
[performance] 29-29: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
37-37
: Optimize string manipulation to avoid unnecessary copyingThe call to
substr
assigns a prefix ofname_
back to itself. You can make this more efficient by modifyingname_
in place usingerase
orresize
, which avoids creating a new string object.Apply this diff to optimize the code:
-if (pos != std::string::npos) { - name_ = name_.substr(0, pos + 1); -} +if (pos != std::string::npos) { + name_.erase(pos + 1); +}🧰 Tools
🪛 cppcheck
[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
64-65
: Pass parameters by const reference to improve performanceIn the function
get_func_op
, the parametersfunc_name
anddevice
are passed by value. Since these arestd::string
objects and are not modified within the function, consider passing them byconst
reference to avoid unnecessary copying.🧰 Tools
🪛 cppcheck
[performance] 64-64: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
86-89
: Pass parameters by const reference to improve performanceIn the function
get_scalar
, the parametersfunc_name
anddevice
are passed by value. Passing them byconst
reference can improve performance by avoiding unnecessary string copies.🧰 Tools
🪛 cppcheck
[performance] 86-86: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 88-88: Function parameter 'device' should be passed by const reference.
(passedByValue)
109-112
: Pass parameters by const reference to improve performanceIn the function
get_vector
, the parametersfunc_name
anddevice
are passed by value. Modify them to be passed byconst
reference to enhance performance.🧰 Tools
🪛 cppcheck
[performance] 109-109: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 111-111: Function parameter 'device' should be passed by const reference.
(passedByValue)
130-133
: Pass parameters by const reference to improve performanceIn
get_vector_string
, passingfunc_name
anddevice
byconst
reference will prevent unnecessary copying ofstd::string
objects.🧰 Tools
🪛 cppcheck
[performance] 130-130: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 132-132: Function parameter 'device' should be passed by const reference.
(passedByValue)
305-305
: Remove unused local variable 'nloc'The variable
nloc
is declared but not used in the code. Removing it will clean up the code and eliminate any compiler warnings about unused variables.Apply this diff to remove the unused variable:
-int nloc = nall_real - nghost_real;
🧰 Tools
🪛 GitHub Check: CodeQL
[notice] 305-305: Unused local variable
Variable nloc is not used.source/api_cc/src/DeepPot.cc (1)
65-72
: Clarify the exception message when TensorFlow backend is not builtThe code checks if the TensorFlow backend is built before initializing
DeepPotJAX
, which is necessary because JAX models rely on TensorFlow. However, the exception message can be improved for clarity.Suggested change:
throw deepmd::deepmd_exception( - "TensorFlow backend is not built, which is used to load JAX2TF " - "SavedModels"); + "TensorFlow backend is not built. It is required to load JAX2TF SavedModels.");
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (11)
.github/workflows/test_cc.yml
(1 hunks)doc/backend.md
(1 hunks)doc/install/install-from-source.md
(4 hunks)source/api_cc/include/DeepPotJAX.h
(1 hunks)source/api_cc/include/common.h
(1 hunks)source/api_cc/src/DeepPot.cc
(3 hunks)source/api_cc/src/DeepPotJAX.cc
(1 hunks)source/api_cc/tests/test_deeppot_jax.cc
(1 hunks)source/cmake/googletest.cmake.in
(1 hunks)source/tests/infer/deeppot_sea.savedmodel/.gitignore
(1 hunks)source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb
(1 hunks)
✅ Files skipped from review due to trivial changes (2)
- source/tests/infer/deeppot_sea.savedmodel/.gitignore
- source/tests/infer/deeppot_sea.savedmodel/fingerprint.pb
🧰 Additional context used
🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc
[performance] 29-29: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 64-64: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 66-66: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 86-86: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 88-88: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 109-109: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 111-111: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 130-130: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 132-132: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
🪛 GitHub Check: CodeQL
source/api_cc/src/DeepPotJAX.cc
[notice] 305-305: Unused local variable
Variable nloc is not used.
[failure] 334-334: Multiplication result converted to larger type
Multiplication result may overflow 'int' before it is converted to 'size_type'.
🔇 Additional comments (6)
source/cmake/googletest.cmake.in (1)
14-14
: Verify test compatibility with Google Test v1.14.0
The version update from 1.12.1 to 1.14.0 is a good improvement. However, since this is a major version update, we should ensure all existing tests are compatible with the new version.
✅ Verification successful
Based on the analysis of the test files, I can now provide a final response:
Google Test v1.14.0 update is safe to proceed
The codebase uses standard Google Test features like TEST, TEST_F, and TYPED_TEST macros for test definitions. No usage of death tests, custom matchers, or printers was found that could be affected by breaking changes in v1.14.0. The test suite appears to use basic assertions and test fixtures which have remained stable across Google Test versions.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for potentially incompatible Google Test features
# Look for death tests that might be affected by the new implementation
rg -l "EXPECT_DEATH|ASSERT_DEATH|EXPECT_DEBUG_DEATH|ASSERT_DEBUG_DEATH"
# Check for custom printers that might need updates
rg -l "MATCHER_P|MATCHER_P2|MATCHER_P3"
Length of output: 115
Script:
#!/bin/bash
# Let's try a different approach to check for test files and Google Test usage
# Find test files
fd -e cpp -e cc -e h test
# Look for Google Test includes and macros
rg -l "gtest|EXPECT_|ASSERT_|TEST|TEST_F"
# Check CMake files for any Google Test related configurations
fd -e cmake . -x cat {} \;
Length of output: 27330
Script:
#!/bin/bash
# Let's examine test files to check for Google Test usage patterns
# Look for specific Google Test features that might be affected by version update
rg -A 2 "TEST\(|TEST_F\(" source/api_c/tests source/api_cc/tests source/lib/tests
# Check for death tests specifically
rg -A 2 "EXPECT_DEATH|ASSERT_DEATH|EXPECT_DEBUG_DEATH|ASSERT_DEBUG_DEATH" source/api_c/tests source/api_cc/tests source/lib/tests
# Look for custom matchers/printers
rg -A 2 "MATCHER_P|MATCHER_P2|MATCHER_P3|PrintTo" source/api_c/tests source/api_cc/tests source/lib/tests
Length of output: 86120
.github/workflows/test_cc.yml (1)
48-48
: LGTM! Good addition of leak sanitizer.
The addition of -fsanitize=leak
flag when check_memleak
is true is a good approach to detect memory leaks, especially given the PR's known limitation regarding potential memory leaks in the TensorFlow C API.
source/api_cc/include/common.h (1)
16-16
: LGTM! The JAX backend addition looks good.
The placement of JAX in the DPBackend enum is correct, maintaining the Unknown value as the last enum option.
source/api_cc/tests/test_deeppot_jax.cc (1)
2-3
: Confirmed: Tests are skipped when memory sanitizer is enabled
The use of #ifndef ENABLE_SANITIZE
ensures that the tests are correctly skipped when memory sanitizer is enabled, as indicated by the comment on line 2.
source/api_cc/src/DeepPot.cc (2)
10-10
: Conditional inclusion of DeepPotJAX.h
is appropriate
Including DeepPotJAX.h
within the #ifdef BUILD_TENSORFLOW
block ensures that the header file is only included when the TensorFlow backend is built, which is necessary because DeepPotJAX
depends on TensorFlow.
45-47
: Properly handle .savedmodel
files for JAX backend selection
The code correctly identifies model files with the .savedmodel
extension and assigns the JAX backend, ensuring that JAX models are appropriately initialized.
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (7)
source/api_cc/tests/test_deeppot_jax.cc (1)
100-110
: Refactor to eliminate redundant variable declarations in test casesIn each test case, the initial declarations of variables that reference class members are repetitive. To improve maintainability and reduce code duplication, consider accessing the member variables directly using
this->
or removingthis->
if not necessary.Apply this diff to remove the redundant local references:
using VALUETYPE = TypeParam; -std::vector<VALUETYPE>& coord = this->coord; -std::vector<int>& atype = this->atype; -std::vector<VALUETYPE>& box = this->box; -std::vector<VALUETYPE>& expected_e = this->expected_e; -std::vector<VALUETYPE>& expected_f = this->expected_f; -std::vector<VALUETYPE>& expected_v = this->expected_v; -int& natoms = this->natoms; -double& expected_tot_e = this->expected_tot_e; -std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v; -deepmd::DeepPot& dp = this->dp;After removing these declarations, you can directly use the member variables in your test cases. For example, replace
coord
withthis->coord
or simplycoord
ifthis->
is not required.Also applies to: 162-172, 245-255, 307-317, 369-379
source/api_cc/src/DeepPotJAX.cc (6)
27-45
: Passfunc_name
by const reference infind_function
Passing
func_name
asconst std::string&
instead ofconst std::string
avoids unnecessary copying, improving performance.🧰 Tools
🪛 cppcheck
[performance] 29-29: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
35-38
: Optimize string manipulation withresize
orpop_back
In the
find_function
method, usingsubstr
may be inefficient when trimming the string. Consider usingresize
orpop_back
for better performance.🧰 Tools
🪛 cppcheck
[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
63-82
: Passfunc_name
anddevice
by const reference inget_func_op
Passing
func_name
anddevice
asconst std::string&
enhances performance by avoiding unnecessary string copies.🧰 Tools
🪛 cppcheck
[performance] 64-64: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 66-66: Function parameter 'device' should be passed by const reference.
(passedByValue)
84-107
: Passfunc_name
anddevice
by const reference inget_scalar
Passing
func_name
anddevice
asconst std::string&
avoids unnecessary copying, improving performance.🧰 Tools
🪛 cppcheck
[performance] 86-86: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 88-88: Function parameter 'device' should be passed by const reference.
(passedByValue)
109-128
: Passfunc_name
anddevice
by const reference inget_vector
Passing
func_name
anddevice
asconst std::string&
improves performance by avoiding unnecessary string copies.🧰 Tools
🪛 cppcheck
[performance] 111-111: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 113-113: Function parameter 'device' should be passed by const reference.
(passedByValue)
130-160
: Passfunc_name
anddevice
by const reference inget_vector_string
Passing
func_name
anddevice
asconst std::string&
enhances performance by avoiding unnecessary string copies.🧰 Tools
🪛 cppcheck
[performance] 132-132: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 134-134: Function parameter 'device' should be passed by const reference.
(passedByValue)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
source/api_cc/include/DeepPotJAX.h
(1 hunks)source/api_cc/src/DeepPotJAX.cc
(1 hunks)source/api_cc/tests/test_deeppot_jax.cc
(1 hunks)
🧰 Additional context used
🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc
[performance] 29-29: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 64-64: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 66-66: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 86-86: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 88-88: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 111-111: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 113-113: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 132-132: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 134-134: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
🔇 Additional comments (3)
source/api_cc/include/DeepPotJAX.h (1)
1-249
: Code Review Completed
The implementation of the DeepPotJAX
class appears comprehensive and well-structured. The class definitions, method declarations, and documentation comments are clear and adhere to the project's coding standards. All public methods are adequately documented, enhancing readability and maintainability.
source/api_cc/src/DeepPotJAX.cc (2)
275-283
: Ensure resources are deallocated regardless of inited
flag
The inited
flag should not guard resource deallocation in the destructor, as resources may have been allocated even if initialization failed. Resources should be cleaned up unconditionally to prevent memory leaks.
369-388
:
Release retvals
to prevent memory leaks in compute
The TFE_TensorHandle*
objects in retvals
obtained from TFE_Execute
are not deleted after use, which can lead to memory leaks. Ensure that you call TFE_DeleteTensorHandle
on each element of retvals
after processing.
Apply this diff to address the issue:
// Process retvals...
+ // Delete TFE_TensorHandle objects to free memory
+ for (int i = 0; i < nretvals; ++i) {
+ TFE_DeleteTensorHandle(retvals[i]);
+ }
Likely invalid or redundant comment.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4307 +/- ##
==========================================
+ Coverage 84.35% 84.47% +0.12%
==========================================
Files 593 596 +3
Lines 55900 56566 +666
Branches 3388 3457 +69
==========================================
+ Hits 47154 47786 +632
- Misses 7636 7653 +17
- Partials 1110 1127 +17 ☔ View full report in Codecov by Sentry. |
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (2)
source/api_cc/src/DeepPotJAX.cc (2)
27-29
: Passfunc_name
anddevice
by const reference to improve performanceIn the functions
find_function
,get_func_op
,get_scalar
,get_vector
, andget_vector_string
, the parametersfunc_name
anddevice
are passed by value. Passing them asconst std::string&
avoids unnecessary copying of strings and can enhance performance.Also applies to: 63-68, 84-90, 109-115, 131-137
🧰 Tools
🪛 cppcheck
[performance] 29-29: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
35-37
: Optimize string truncation usingresize
instead ofsubstr
In the
find_function
function, the call tosubstr
at line 37 may be inefficient since it assigns a prefix of the string to itself. Consider usingname_.resize(pos + 1);
to truncate the string more efficiently.🧰 Tools
🪛 cppcheck
[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
source/api_cc/src/DeepPotJAX.cc
(1 hunks)
🧰 Additional context used
🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc
[performance] 29-29: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 64-64: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 66-66: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 86-86: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 88-88: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 111-111: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 113-113: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 133-133: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 135-135: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 37-37: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (2)
source/lmp/tests/test_lammps_jax.py (2)
251-274
: Refactor repeated 'units' conditional blocks to improve maintainabilityThe
units
conditional blocks in the_lammps
function are repeated multiple times for setting neighbor settings, masses, and timesteps. Refactoring this code can reduce duplication and enhance readability.Consider using dictionaries to map
units
to their corresponding parameters:def _lammps(data_file, units="metal") -> PyLammps: lammps = PyLammps() lammps.units(units) lammps.boundary("p p p") lammps.atom_style("atomic") + units_params = { + "metal": { + "neighbor": "2.0 bin", + "mass": {"1": "16", "2": "2"}, + "timestep": 0.0005, + }, + "real": { + "neighbor": "2.0 bin", + "mass": {"1": "16", "2": "2"}, + "timestep": 0.5, + }, + "si": { + "neighbor": "2.0e-10 bin", + "mass": { + "1": "%.10e" % (16 * constants.mass_metal2si), + "2": "%.10e" % (2 * constants.mass_metal2si), + }, + "timestep": 5e-16, + }, + } + + if units not in units_params: + raise ValueError("units should be metal, real, or si") + + params = units_params[units] + lammps.neighbor(params["neighbor"]) lammps.neigh_modify("every 10 delay 0 check no") lammps.read_data(data_file.resolve()) - if units == "metal" or units == "real": - lammps.mass("1 16") - lammps.mass("2 2") - elif units == "si": - lammps.mass("1 %.10e" % (16 * constants.mass_metal2si)) - lammps.mass("2 %.10e" % (2 * constants.mass_metal2si)) - else: - raise ValueError("units should be metal, real, or si") + for atom_type, mass in params["mass"].items(): + lammps.mass(f"{atom_type} {mass}") - if units == "metal": - lammps.timestep(0.0005) - elif units == "real": - lammps.timestep(0.5) - elif units == "si": - lammps.timestep(5e-16) - else: - raise ValueError("units should be metal, real, or si") + lammps.timestep(params["timestep"]) lammps.fix("1 all nve") return lammps
681-682
: Usepytest.importorskip
for cleaner skipping of testsInstead of manually checking for
mpi4py
, you can usepytest.importorskip
to skip the test if the module is not installed.Replace the manual check with:
-@pytest.mark.skipif( - importlib.util.find_spec("mpi4py") is None, reason="mpi4py is not installed" -) +mpi4py = pytest.importorskip("mpi4py", reason="mpi4py is not installed")
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Outside diff range and nitpick comments (3)
source/api_cc/src/DeepPotJAX.cc (3)
31-49
: Optimize string parameter passing and manipulationThe function parameters could be more efficiently passed by const reference, and the string manipulation could be improved.
Apply this diff to improve efficiency:
-inline void find_function(TF_Function*& found_func, - const std::vector<TF_Function*>& funcs, - const std::string func_name) { +inline void find_function(TF_Function*& found_func, + const std::vector<TF_Function*>& funcs, + const std::string& func_name) { for (size_t i = 0; i < funcs.size(); i++) { TF_Function* func = funcs[i]; const char* name = TF_FunctionName(func); std::string name_(name); // remove trailing integer e.g. _123 - std::string::size_type pos = name_.find_last_not_of("0123456789_"); - if (pos != std::string::npos) { - name_ = name_.substr(0, pos + 1); - } + name_.erase(std::find_if(name_.rbegin(), name_.rend(), + [](char c) { return c != '_' && !std::isdigit(c); }).base() + 1);🧰 Tools
🪛 cppcheck
[performance] 33-33: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 41-41: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
195-209
: Optimize vector data copyingThe manual loop for copying data could be replaced with std::copy for better efficiency and readability.
Apply this diff to improve the implementation:
result.resize(TF_TensorElementCount(tensor)); - for (int i = 0; i < TF_TensorElementCount(tensor); i++) { - result[i] = data[i]; - } + std::copy(data, data + TF_TensorElementCount(tensor), result.begin());
751-778
: Improve exception messages for unimplemented methodsThe exception messages for unimplemented mixed-type computation methods could be more informative.
Apply this diff to improve error messages:
- throw deepmd::deepmd_exception("not implemented"); + throw deepmd::deepmd_exception("Mixed-type computation is not yet implemented in DeepPotJAX");
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
source/api_cc/src/DeepPotJAX.cc
(1 hunks)
🧰 Additional context used
🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc
[error] 253-253: Null pointer dereference
(nullPointer)
[performance] 33-33: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 68-68: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 70-70: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 88-88: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 90-90: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 113-113: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 115-115: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 135-135: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 137-137: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 41-41: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Outside diff range and nitpick comments (2)
source/api_cc/src/DeepPotJAX.cc (2)
331-476
: Consider using RAII for tensor managementThe tensor management code could benefit from RAII to ensure proper resource cleanup, especially in case of exceptions.
Consider creating a RAII wrapper for tensor handles:
class TensorHandleRAII { private: TFE_TensorHandle* handle; TF_Tensor* tensor; public: TensorHandleRAII(TFE_TensorHandle* h, TF_Tensor* t) : handle(h), tensor(t) {} ~TensorHandleRAII() { if (handle) TFE_DeleteTensorHandle(handle); if (tensor) TF_DeleteTensor(tensor); } TFE_TensorHandle* get_handle() { return handle; } }; // Usage in compute: std::vector<TensorHandleRAII> input_handles; for (int i = 0; i < 5; i++) { TF_Tensor* tensor = create_tensor(...); TFE_TensorHandle* handle = TFE_NewTensorHandle(tensor, status); input_handles.emplace_back(handle, tensor); }
752-779
: Improve error messages for unimplemented methodsThe error messages for unimplemented mixed type compute methods could be more descriptive to help users understand why the functionality is not available.
Apply this diff to improve the error messages:
- throw deepmd::deepmd_exception("not implemented"); + throw deepmd::deepmd_exception("Mixed type computation is not yet implemented for JAX backend. Please use single precision types.");
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
source/api_cc/src/DeepPotJAX.cc
(1 hunks)
🧰 Additional context used
🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc
[performance] 33-33: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 68-68: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 70-70: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 88-88: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 90-90: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 113-113: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 115-115: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 135-135: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 137-137: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 41-41: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
🔇 Additional comments (2)
source/api_cc/src/DeepPotJAX.cc (2)
317-329
:
Ensure unconditional resource cleanup in destructor
The destructor currently guards resource cleanup with the inited
flag. This could lead to resource leaks if initialization fails before setting inited = true
. Resources should be cleaned up unconditionally.
Apply this diff to fix the issue:
deepmd::DeepPotJAX::~DeepPotJAX() {
- if (inited) {
- TF_DeleteSession(session, status);
- TF_DeleteGraph(graph);
- TF_DeleteSessionOptions(sessionopts);
- TF_DeleteStatus(status);
- TFE_DeleteContext(ctx);
- TFE_DeleteContextOptions(ctx_opts);
- for (size_t i = 0; i < func_vector.size(); i++) {
- TF_DeleteFunction(func_vector[i]);
- }
- }
+ TF_DeleteSession(session, status);
+ TF_DeleteGraph(graph);
+ TF_DeleteSessionOptions(sessionopts);
+ TF_DeleteStatus(status);
+ TFE_DeleteContext(ctx);
+ TFE_DeleteContextOptions(ctx_opts);
+ for (size_t i = 0; i < func_vector.size(); i++) {
+ TF_DeleteFunction(func_vector[i]);
+ }
}
Likely invalid or redundant comment.
523-529
:
Add safety checks for padding calculation
The padding calculation loop could potentially run indefinitely if there are numerical issues. Add bounds checking to prevent this.
Apply this diff to add safety checks:
+ const size_t MAX_PADDING_ITERATIONS = 100; // Reasonable upper limit
+ size_t iteration_count = 0;
while (padding_to_nall < nall_real) {
padding_to_nall *= PADDING_FACTOR;
+ if (++iteration_count > MAX_PADDING_ITERATIONS) {
+ throw deepmd::deepmd_exception("Exceeded maximum padding iterations. Check for numerical issues.");
+ }
}
Likely invalid or redundant comment.
Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (3)
source/api_cc/src/DeepPotJAX.cc (3)
22-22
: Consider using constexpr instead of macroReplace the macro with a constexpr for better type safety and debugging:
-#define PADDING_FACTOR 1.05 +constexpr double PADDING_FACTOR = 1.05;
331-476
: Consider using RAII for resource managementThe compute method manages multiple resources manually. Consider using RAII patterns with smart pointers or scope guards to ensure proper cleanup even in case of exceptions.
Example approach:
template <typename T> class ScopedTensorHandle { TFE_TensorHandle* handle; public: explicit ScopedTensorHandle(TFE_TensorHandle* h) : handle(h) {} ~ScopedTensorHandle() { if(handle) TFE_DeleteTensorHandle(handle); } TFE_TensorHandle* get() { return handle; } };
764-779
: Add TODO comments for unimplemented methodsThe
computew_mixed_type
methods throw "not implemented" exceptions. Consider adding TODO comments explaining the planned implementation or requirements.Would you like me to help create GitHub issues to track the implementation of these methods?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
source/api_cc/src/DeepPotJAX.cc
(1 hunks)
🧰 Additional context used
🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc
[performance] 33-33: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 68-68: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 70-70: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 88-88: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 90-90: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 113-113: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 115-115: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 135-135: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 137-137: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 41-41: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
🔇 Additional comments (2)
source/api_cc/src/DeepPotJAX.cc (2)
523-529
:
Add safety checks for padding calculations
The padding calculation loop lacks bounds checking which could lead to infinite loops or integer overflow.
+ const size_t MAX_PADDING_ITERATIONS = 100;
+ size_t iteration_count = 0;
while (padding_to_nall < nall_real) {
padding_to_nall *= PADDING_FACTOR;
+ if (++iteration_count > MAX_PADDING_ITERATIONS) {
+ throw deepmd::deepmd_exception("Maximum padding iterations exceeded");
+ }
}
Likely invalid or redundant comment.
317-329
:
Remove inited flag check from destructor
The destructor should unconditionally clean up resources as they are allocated during initialization before the inited flag is set. This prevents resource leaks if initialization fails.
deepmd::DeepPotJAX::~DeepPotJAX() {
- if (inited) {
TF_DeleteSession(session, status);
TF_DeleteGraph(graph);
TF_DeleteSessionOptions(sessionopts);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
TFE_DeleteContextOptions(ctx_opts);
for (size_t i = 0; i < func_vector.size(); i++) {
TF_DeleteFunction(func_vector[i]);
}
- }
}
Likely invalid or redundant comment.
Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (1)
source/api_cc/src/DeepPotJAX.cc (1)
31-49
: Optimize string operations and parameter passingThe function can be optimized for better performance:
- Use
std::string_view
for string parameters that are only read- Use
resize()
instead ofsubstr()
for better performance- Pass string parameters by const reference
-inline void find_function(TF_Function*& found_func, - const std::vector<TF_Function*>& funcs, - const std::string func_name) { +inline void find_function(TF_Function*& found_func, + const std::vector<TF_Function*>& funcs, + std::string_view func_name) { for (size_t i = 0; i < funcs.size(); i++) { TF_Function* func = funcs[i]; const char* name = TF_FunctionName(func); std::string name_(name); // remove trailing integer e.g. _123 std::string::size_type pos = name_.find_last_not_of("0123456789_"); if (pos != std::string::npos) { - name_ = name_.substr(0, pos + 1); + name_.resize(pos + 1); } - if (name_ == "__inference_" + func_name) { + if (name_ == "__inference_" + std::string(func_name)) { found_func = func; return; } } found_func = NULL; }🧰 Tools
🪛 cppcheck
[performance] 33-33: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 41-41: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
source/api_cc/src/DeepPotJAX.cc
(1 hunks)
🧰 Additional context used
🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc
[performance] 33-33: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 68-68: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 70-70: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 88-88: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 90-90: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 113-113: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 115-115: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 135-135: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 137-137: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 41-41: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
🔇 Additional comments (2)
source/api_cc/src/DeepPotJAX.cc (2)
243-255
:
Add bounds checking for GPU memory configuration
The GPU memory configuration manipulates raw bytes without proper bounds checking, which could lead to buffer overflows.
std::array<std::uint8_t, 19> config = {
{0x10, static_cast<std::uint8_t>(num_intra_nthreads), 0x28,
static_cast<std::uint8_t>(num_inter_nthreads), 0x32, 0xb, 0x9, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x20, 0x1, 0x38, 0x1}};
double gpu_memory_fraction = 0.9;
+ if (sizeof(gpu_memory_fraction) > 8) {
+ throw deepmd::deepmd_exception("GPU memory fraction size exceeds buffer capacity");
+ }
auto bytes = reinterpret_cast<std::uint8_t*>(&gpu_memory_fraction);
for (std::size_t i = 0; i < sizeof(gpu_memory_fraction); ++i) {
+ if (i + 3 >= config.size()) {
+ throw deepmd::deepmd_exception("Buffer overflow in GPU memory configuration");
+ }
config[i + 3] = bytes[i];
}
Likely invalid or redundant comment.
526-532
:
Add safety check for padding calculation loop
The while loop for padding calculation could potentially run indefinitely if there are numerical issues.
+ const size_t MAX_PADDING_ITERATIONS = 100; // Reasonable upper limit
+ size_t iteration_count = 0;
while (padding_to_nall < nall_real) {
padding_to_nall *= PADDING_FACTOR;
+ if (++iteration_count > MAX_PADDING_ITERATIONS) {
+ throw deepmd::deepmd_exception("Exceeded maximum padding iterations. Check for numerical issues.");
+ }
}
Likely invalid or redundant comment.
Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Outside diff range and nitpick comments (2)
source/api_cc/src/DeepPotJAX.cc (2)
31-49
: Optimize string operations and parameter passingConsider these performance improvements:
- Pass string parameters by const reference to avoid unnecessary copies
- Use
string_view
for string parameters that are only read- Use
resize()
instead ofsubstr()
for better performance-inline void find_function(TF_Function*& found_func, - const std::vector<TF_Function*>& funcs, - const std::string func_name) { +inline void find_function(TF_Function*& found_func, + const std::vector<TF_Function*>& funcs, + std::string_view func_name) { for (size_t i = 0; i < funcs.size(); i++) { TF_Function* func = funcs[i]; const char* name = TF_FunctionName(func); std::string name_(name); // remove trailing integer e.g. _123 std::string::size_type pos = name_.find_last_not_of("0123456789_"); if (pos != std::string::npos) { - name_ = name_.substr(0, pos + 1); + name_.resize(pos + 1); } - if (name_ == "__inference_" + func_name) { + if (name_ == "__inference_" + std::string(func_name)) { found_func = func; return; } } found_func = NULL; }🧰 Tools
🪛 cppcheck
[performance] 33-33: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 41-41: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
755-782
: Document unimplemented methodsThe
computew_mixed_type
methods throw "not implemented" exceptions but lack documentation explaining why they're not implemented or when they will be.Add documentation comments explaining:
- Why these methods are not implemented
- When they will be implemented
- What alternatives users should use in the meantime
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
source/api_cc/src/DeepPotJAX.cc
(1 hunks)
🧰 Additional context used
🪛 cppcheck
source/api_cc/src/DeepPotJAX.cc
[performance] 33-33: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 68-68: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 70-70: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 88-88: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 90-90: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 113-113: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 115-115: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 135-135: Function parameter 'func_name' should be passed by const reference.
(passedByValue)
[performance] 137-137: Function parameter 'device' should be passed by const reference.
(passedByValue)
[performance] 41-41: Ineffective call of function 'substr' because a prefix of the string is assigned to itself. Use resize() or pop_back() instead.
(uselessCallsSubstr)
🔇 Additional comments (1)
source/api_cc/src/DeepPotJAX.cc (1)
526-532
:
Add safety checks for padding calculations
The while loop for padding calculation could potentially run indefinitely if there are numerical issues.
+ const size_t MAX_PADDING_ITERATIONS = 100; // Reasonable upper limit
+ size_t iteration_count = 0;
while (padding_to_nall < nall_real) {
padding_to_nall *= PADDING_FACTOR;
+ if (++iteration_count > MAX_PADDING_ITERATIONS) {
+ throw deepmd::deepmd_exception("Exceeded maximum padding iterations. Check for numerical issues.");
+ }
}
Likely invalid or redundant comment.
This reverts commit 6d5b45a.
Signed-off-by: Jinzhe Zeng <[email protected]>
As discussed, this PR passes mapping from LAMMPS to the PT C++ interface, which is helpful for the external GNN models. The mapping interface is synced from #4307. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced a new function `DP_NlistSetMapping` for setting mappings in neighbor lists. - Added `set_mapping` method to `InputNlist` for mapping atoms to real atoms. - Enhanced `compute` methods in `DeepPotPT`, `PairDeepMD`, and `FixDPLR` classes to support new mapping functionalities. - **Bug Fixes** - Improved error handling in various classes to ensure robustness during execution. - **Documentation** - Updated and added comments for clarity and consistency in new and existing functions. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]>
Including nlist and no nlist interface.
The limitation: A SavedModel created on a device cannot be run on another. For example, a CUDA model cannot be run on the CPU.
The model is generated using #4336.
Summary by CodeRabbit
Release Notes
New Features
Bug Fixes
Documentation
Tests
Chores