Skip to content

Commit

Permalink
sync mapping interface with deepmodeling#4307
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 13, 2024
1 parent 1e6f069 commit 76cd66a
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 3 deletions.
16 changes: 13 additions & 3 deletions source/api_c/include/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ extern "C" {
/** C API version. Bumped whenever the API is changed.
* @since API version 22
*/
#define DP_C_API_VERSION 24
#define DP_C_API_VERSION 25

/**
* @brief Neighbor list.
Expand All @@ -31,7 +31,7 @@ extern DP_Nlist* DP_NewNlist(int inum_,
int* ilist_,
int* numneigh_,
int** firstneigh_);
/*
/**
* @brief Create a new neighbor list with communication capabilities.
* @details This function extends DP_NewNlist by adding support for parallel
* communication, allowing the neighbor list to be used in distributed
Expand Down Expand Up @@ -68,7 +68,7 @@ extern DP_Nlist* DP_NewNlist_comm(int inum_,
int* recvproc,
void* world);

/*
/**
* @brief Set mask for a neighbor list.
*
* @param nl Neighbor list.
Expand All @@ -78,6 +78,16 @@ extern DP_Nlist* DP_NewNlist_comm(int inum_,
**/
extern void DP_NlistSetMask(DP_Nlist* nl, int mask);

/**
* @brief Set mapping for a neighbor list.
*
* @param nl Neighbor list.
* @param mapping mapping from all atoms to real atoms, in size nall.
* @since API version 25
*
**/
extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping);

/**
* @brief Delete a neighbor list.
*
Expand Down
5 changes: 5 additions & 0 deletions source/api_c/include/deepmd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,11 @@ struct InputNlist {
* @brief Set mask for this neighbor list.
*/
void set_mask(int mask) { DP_NlistSetMask(nl, mask); };
/**
* @brief Set mapping for this neighbor list.
* @param mapping mapping from all atoms to real atoms, in size nall.
*/
void set_mapping(int *mapping) { DP_NlistSetMapping(nl, mapping); };
};

/**
Expand Down
3 changes: 3 additions & 0 deletions source/api_c/src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ DP_Nlist* DP_NewNlist_comm(int inum_,
return new_nl;
}
void DP_NlistSetMask(DP_Nlist* nl, int mask) { nl->nl.set_mask(mask); }
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
nl->nl.set_mapping(mapping);
}
void DP_DeleteNlist(DP_Nlist* nl) { delete nl; }

// DP Base Model
Expand Down
6 changes: 6 additions & 0 deletions source/lib/include/neighbor_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ struct InputNlist {
void* world;
/// mask to the neighbor index
int mask = 0xFFFFFFFF;
/// mapping from all atoms to real atoms, in the size of nall
int* mapping = nullptr;
InputNlist()
: inum(0),
ilist(NULL),
Expand Down Expand Up @@ -99,6 +101,10 @@ struct InputNlist {
* @brief Set mask for this neighbor list.
*/
void set_mask(int mask_) { mask = mask_; };
/**
* @brief Set mapping for this neighbor list.
*/
void set_mapping(int* mapping_) { mapping = mapping_; };
};

/**
Expand Down
11 changes: 11 additions & 0 deletions source/lmp/fix_dplr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,14 @@ void FixDPLR::pre_force(int vflag) {
int nghost = atom->nghost;
int nall = nlocal + nghost;

// mapping (for DPA-2 JAX)
std::vector<int> mapping_vec(nall, -1);
if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) {
for (size_t ii = 0; ii < nall; ++ii) {
mapping_vec[ii] = atom->map(atom->tag[ii]);
}
}

// if (eflag_atom) {
// error->all(FLERR,"atomic energy calculation is not supported by this
// fix\n");
Expand Down Expand Up @@ -499,6 +507,9 @@ void FixDPLR::pre_force(int vflag) {
deepmd_compat::InputNlist lmp_list(list->inum, list->ilist, list->numneigh,
list->firstneigh);
lmp_list.set_mask(NEIGHMASK);
if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) {
lmp_list.set_mapping(mapping_vec.data());
}
// declear output
vector<FLOAT_PREC> tensor;
// compute
Expand Down
11 changes: 11 additions & 0 deletions source/lmp/pair_deepmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,14 @@ void PairDeepMD::compute(int eflag, int vflag) {
}
}

// mapping (for DPA-2 JAX)
std::vector<int> mapping_vec(nall, -1);
if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) {
for (size_t ii = 0; ii < nall; ++ii) {
mapping_vec[ii] = atom->map(atom->tag[ii]);
}
}

if (do_compute_aparam) {
make_aparam_from_compute(daparam);
} else if (aparam.size() > 0) {
Expand Down Expand Up @@ -198,6 +206,9 @@ void PairDeepMD::compute(int eflag, int vflag) {
commdata_->firstrecv, commdata_->sendlist, commdata_->sendproc,
commdata_->recvproc, &world);
lmp_list.set_mask(NEIGHMASK);
if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) {
lmp_list.set_mapping(mapping_vec.data());
}
deepmd_compat::InputNlist extend_lmp_list;
if (single_model || multi_models_no_mod_devi) {
// cvflag_atom is the right flag for the cvatom matrix
Expand Down

0 comments on commit 76cd66a

Please sign in to comment.