Skip to content

Commit

Permalink
initial stab at merge, untested
Browse files Browse the repository at this point in the history
  • Loading branch information
anjohan committed Oct 18, 2024
1 parent 20538c9 commit 10fc99b
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 132 deletions.
6 changes: 3 additions & 3 deletions compute_allegro.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ ComputeAllegro<peratom>::ComputeAllegro(LAMMPS *lmp, int narg, char **arg) : Com
error->all(FLERR, "no pair style; compute allegro must be defined after pair style");
}

((PairAllegro<lowhigh> *) force->pair)->add_custom_output(quantity);
((PairAllegro<0> *) force->pair)->add_custom_output(quantity);
}

template<int peratom>
Expand Down Expand Up @@ -109,7 +109,7 @@ void ComputeAllegro<peratom>::compute_vector()
}
} else {
const torch::Tensor &quantity_tensor =
((PairAllegro<lowhigh> *) force->pair)->custom_output.at(quantity).cpu().ravel();
((PairAllegro<0> *) force->pair)->custom_output.at(quantity).cpu().ravel();

auto quantity = quantity_tensor.data_ptr<double>();

Expand Down Expand Up @@ -140,7 +140,7 @@ void ComputeAllegro<peratom>::compute_peratom()
// guard against empty domain (pair style won't store tensor)
if (atom->nlocal > 0) {
const torch::Tensor &quantity_tensor =
((PairAllegro<lowhigh> *) force->pair)->custom_output.at(quantity).cpu().contiguous().reshape({-1,nperatom});
((PairAllegro<0> *) force->pair)->custom_output.at(quantity).cpu().contiguous().reshape({-1,nperatom});

auto quantity = quantity_tensor.accessor<double,2>();
quantityptr = quantity_tensor.data_ptr<double>();
Expand Down
227 changes: 137 additions & 90 deletions pair_allegro.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@

using namespace LAMMPS_NS;

template <Precision precision> PairAllegro<precision>::PairAllegro(LAMMPS *lmp) : Pair(lmp)
template <int nequip_mode> PairAllegro<nequip_mode>::PairAllegro(LAMMPS *lmp) : Pair(lmp)
{
restartinfo = 0;
manybody_flag = 1;
Expand Down Expand Up @@ -103,7 +103,7 @@ template <Precision precision> PairAllegro<precision>::PairAllegro(LAMMPS *lmp)
if (debug_mode) std::cout << "Allegro is using device " << device << "\n";
}

template <Precision precision> PairAllegro<precision>::~PairAllegro()
template <int nequip_mode> PairAllegro<nequip_mode>::~PairAllegro()
{
if (copymode) return;
if (allocated) {
Expand All @@ -113,28 +113,28 @@ template <Precision precision> PairAllegro<precision>::~PairAllegro()
}
}

template <Precision precision> void PairAllegro<precision>::init_style()
template <int nequip_mode> void PairAllegro<nequip_mode>::init_style()
{
if (atom->tag_enable == 0) error->all(FLERR, "Pair style Allegro requires atom IDs");

// Request a full neighbor list.
if (lmp->kokkos) {
// Only request full to avoid a Kokkos bug; pair_allegro_kokkos.cpp doesn't need GHOST anyway
neighbor->add_request(this, NeighConst::REQ_FULL);
} else {
// Non-kokkos needs ghost to avoid segfaults
neighbor->add_request(this, NeighConst::REQ_FULL | NeighConst::REQ_GHOST);
}

if (force->newton_pair == 0) error->all(FLERR, "Pair style Allegro requires newton pair on");
if (!nequip_mode && force->newton_pair == 0) error->all(FLERR, "Pair style allegro requires newton pair on");
if (nequip_mode && force->newton_pair) error->all(FLERR, "Pair style nequip requires newton pair off");
}

template <Precision precision> double PairAllegro<precision>::init_one(int i, int j)
template <int nequip_mode> double PairAllegro<nequip_mode>::init_one(int i, int j)
{
return cutoff;
}

template <Precision precision> void PairAllegro<precision>::allocate()
template <int nequip_mode> void PairAllegro<nequip_mode>::allocate()
{
allocated = 1;
int n = atom->ntypes;
Expand All @@ -144,13 +144,13 @@ template <Precision precision> void PairAllegro<precision>::allocate()
memory->create(cutoff_matrix, n, n, "pair:cutoff_matrix");
}

template <Precision precision> void PairAllegro<precision>::settings(int narg, char ** /*arg*/)
template <int nequip_mode> void PairAllegro<nequip_mode>::settings(int narg, char ** /*arg*/)
{
// "allegro" should be the only word after "pair_style" in the input file.
if (narg > 0) error->all(FLERR, "Illegal pair_style command, too many arguments");
}

template <Precision precision> void PairAllegro<precision>::coeff(int narg, char **arg)
template <int nequip_mode> void PairAllegro<nequip_mode>::coeff(int narg, char **arg)
{
if (!allocated) allocate();

Expand Down Expand Up @@ -241,7 +241,6 @@ template <Precision precision> void PairAllegro<precision>::coeff(int narg, char

cutoff = std::stod(metadata["r_max"]);

//TODO: This
type_mapper.resize(ntypes, -1);
std::stringstream ss;
int n_species = std::stod(metadata["n_species"]);
Expand Down Expand Up @@ -288,7 +287,7 @@ template <Precision precision> void PairAllegro<precision>::coeff(int narg, char
arg[reverse_type_mapper[j] + 3], i, j, reverse_type_mapper[i],
reverse_type_mapper[j], cutij);
}
cutoff_matrix[reverse_type_mapper[i]][reverse_type_mapper[j]] = cutij; //TODO
cutoff_matrix[reverse_type_mapper[i]][reverse_type_mapper[j]] = cutij;
}
}
}
Expand All @@ -300,26 +299,87 @@ template <Precision precision> void PairAllegro<precision>::coeff(int narg, char
}

// Force and energy computation
template <Precision precision> void PairAllegro<precision>::compute(int eflag, int vflag)
template <int nequip_mode> void PairAllegro<nequip_mode>::compute(int eflag, int vflag)
{
ev_init(eflag, vflag);

// Get info from lammps:
// Atom forces
double **f = atom->f;

int inum = list->inum;
if (inum==0) return;

// Number of ghost atoms
int nghost = list->gnum;
// Total number of atoms
int ntotal = inum + nghost;
// Mapping from neigh list ordering to x/f ordering
int *ilist = list->ilist;


auto input = preprocess();
std::vector<torch::IValue> input_vector(1, input);
auto output = model.forward(input_vector).toGenericDict();

torch::Tensor forces_tensor = output.at("forces").toTensor().cpu();
auto forces = forces_tensor.accessor<outputtype, 2>();

torch::Tensor atomic_energy_tensor = output.at("atomic_energy").toTensor().cpu();
auto atomic_energies = atomic_energy_tensor.accessor<outputtype, 2>();
outputtype atomic_energy_sum = atomic_energy_tensor.sum().data_ptr<outputtype>()[0];


eng_vdwl = 0.0;
int nforces = nequip_mode ? inum : ntotal;
#pragma omp parallel for reduction(+ : eng_vdwl)
for (int ii = 0; ii < nforces; ii++) {
int i = ilist[ii];

f[i][0] += forces[i][0];
f[i][1] += forces[i][1];
f[i][2] += forces[i][2];
if (eflag_atom && ii < inum) eatom[i] = atomic_energies[i][0];
if (ii < inum) eng_vdwl += atomic_energies[i][0];
}

if (vflag) {
torch::Tensor v_tensor = output.at("virial").toTensor().cpu();
auto v = v_tensor.accessor<outputtype, 3>();
// Convert from 3x3 symmetric tensor format, which NequIP outputs, to the flattened form LAMMPS expects
// First [0] index on v is batch
virial[0] = v[0][0][0];
virial[1] = v[0][1][1];
virial[2] = v[0][2][2];
virial[3] = v[0][0][1];
virial[4] = v[0][0][2];
virial[5] = v[0][1][2];
}
if (vflag_atom) { error->all(FLERR, "Pair style Allegro does not support per-atom virial"); }

if (debug_mode) {
std::cout << "ALLEGRO CUSTOM OUTPUT" << std::endl;
for (const auto &elem : output) {
std::cout << elem.key() << "\n" << elem.value() << std::endl;
}
}

for (const std::string &output_name : custom_output_names) {
if (!output.contains(output_name)) error->all(FLERR, "missing {}", output_name);
// printf("pair_allegro inserting %s\n", output_name.data()); fflush(stdout);
custom_output.insert_or_assign(output_name, output.at(output_name).toTensor().detach());
}
}

template <int nequip_mode> c10::Dict<std::string, torch::Tensor> PairAllegro<nequip_mode>::preprocess() {
// Atom positions, including ghost atoms
double **x = atom->x;
// Atom forces
double **f = atom->f;
// Atom IDs, unique, reproducible, the "real" indices
// Probably 1-based
tagint *tag = atom->tag;
// Atom types, 1-based
int *type = atom->type;
// Number of local/real atoms
int nlocal = atom->nlocal;
// Whether Newton is on (i.e. reverse "communication" of forces on ghost atoms).
// Should be on.
int newton_pair = force->newton_pair;

// Number of local/real atoms
int inum = list->inum;
Expand All @@ -335,15 +395,11 @@ template <Precision precision> void PairAllegro<precision>::compute(int eflag, i
// Neighbor list per atom
int **firstneigh = list->firstneigh;

// Skip calculation if empty domain
if (inum==0) return;

// Total number of bonds (sum of number of neighbors)
int nedges = 0;

// Number of bonds per atom
std::vector<int> neigh_per_atom(nlocal, 0);
int ntypes = atom->ntypes;

#pragma omp parallel for reduction(+ : nedges)
for (int ii = 0; ii < nlocal; ii++) {
Expand Down Expand Up @@ -390,6 +446,20 @@ template <Precision precision> void PairAllegro<precision>::compute(int eflag, i
auto edges = edges_tensor.accessor<long, 2>();
auto ij2type = ij2type_tensor.accessor<long, 1>();

std::vector<int> tag2i(nequip_mode ? inum : 0);
torch::Tensor cell_tensor, cell_inv_tensor;
torch::Tensor edge_cell_shifts_tensor;
inputtype* edge_cell_shifts, *cell_inv;
inputtype periodic_shift[3];
if (nequip_mode) {
cell_tensor = get_cell();
cell_inv_tensor = cell_tensor.inverse().transpose(0,1);
cell_inv = cell_inv_tensor.data_ptr<inputtype>();
edge_cell_shifts_tensor = torch::zeros({nedges,3}, torch::TensorOptions().dtype(inputtorchtype));
edge_cell_shifts = edge_cell_shifts_tensor.data_ptr<inputtype>();
get_tag2i(tag2i);
}

// Loop over atoms and neighbors,
// store edges and _cell_shifts
// ii follows the order of the neighbor lists,
Expand Down Expand Up @@ -426,12 +496,25 @@ template <Precision precision> void PairAllegro<precision>::compute(int eflag, i
double rsq = dx * dx + dy * dy + dz * dz;

double cutij =
cutoff_matrix[itype - 1][jtype - 1]; //cutoff_matrix[(type[i]-1)*ntypes + type[j]-1];
cutoff_matrix[itype - 1][jtype - 1];
if (rsq > cutij * cutij) { continue; }

// TODO: double check order
edges[0][edge_counter] = i;
edges[1][edge_counter] = j;
edges[1][edge_counter] = nequip_mode ? tag2i[jtag] : j;

if constexpr (nequip_mode) {
for (int d = 0; d < 3; d++)
periodic_shift[d] = x[j][d] - x[tag2i[jtag]][d];

// edge_cell_shift[e] = round(cell_inv.matmul(periodic_shift))
for (int d = 0; d < 3; d++) {
inputtype tmp = 0;
for (int k = 0; k < 3; k++)
tmp += cell_inv[3*d+k] * periodic_shift[k];

edge_cell_shifts[3*edge_counter+d] = std::round(tmp);
}
}

edge_counter++;

Expand All @@ -440,89 +523,53 @@ template <Precision precision> void PairAllegro<precision>::compute(int eflag, i
}
if (debug_mode) printf("end Allegro edges\n");

torch::Tensor compute_custom_tensor =
torch::full({1}, false, torch::TensorOptions().dtype(torch::kBool));
if (update->ntimestep == output->next && custom_output_names.size() > 0) {
// error->message(FLERR, "computing custom output");
auto tmp = compute_custom_tensor.accessor<bool, 1>();
tmp[0] = true;
}

c10::Dict<std::string, torch::Tensor> input;
input.insert("pos", pos_tensor.to(device));
input.insert("edge_index", edges_tensor.to(device));
input.insert("atom_types", ij2type_tensor.to(device));
input.insert("compute_custom_output", compute_custom_tensor);
std::vector<torch::IValue> input_vector(1, input);
if (nequip_mode) {
input.insert("edge_cell_shift", edge_cell_shifts_tensor.to(device));
input.insert("cell", cell_tensor.to(device));
}

auto output = model.forward(input_vector).toGenericDict();
return input;
}

torch::Tensor forces_tensor = output.at("forces").toTensor().cpu();
auto forces = forces_tensor.accessor<outputtype, 2>();
template <int nequip_mode> torch::Tensor PairAllegro<nequip_mode>::get_cell(){
torch::Tensor cell_tensor = torch::zeros({3,3}, torch::TensorOptions().dtype(inputtorchtype));
auto cell = cell_tensor.accessor<inputtype,2>();

//torch::Tensor total_energy_tensor = output.at("total_energy").toTensor().cpu(); WRONG WITH MPI
cell[0][0] = domain->boxhi[0] - domain->boxlo[0];

torch::Tensor atomic_energy_tensor = output.at("atomic_energy").toTensor().cpu();
auto atomic_energies = atomic_energy_tensor.accessor<outputtype, 2>();
outputtype atomic_energy_sum = atomic_energy_tensor.sum().data_ptr<outputtype>()[0];
cell[1][0] = domain->xy;
cell[1][1] = domain->boxhi[1] - domain->boxlo[1];

//std::cout << "atomic energy sum: " << atomic_energy_sum << std::endl;
//std::cout << "Total energy: " << total_energy_tensor << "\n";
//std::cout << "atomic energy shape: " << atomic_energy_tensor.sizes()[0] << "," << atomic_energy_tensor.sizes()[1] << std::endl;
//std::cout << "atomic energies: " << atomic_energy_tensor << std::endl;
cell[2][0] = domain->xz;
cell[2][1] = domain->yz;
cell[2][2] = domain->boxhi[2] - domain->boxlo[2];

// Write forces and per-atom energies (0-based tags here)
eng_vdwl = 0.0;
#pragma omp parallel for reduction(+ : eng_vdwl)
for (int ii = 0; ii < ntotal; ii++) {
int i = ilist[ii];
return cell_tensor;
}

f[i][0] += forces[i][0];
f[i][1] += forces[i][1];
f[i][2] += forces[i][2];
if (eflag_atom && ii < inum) eatom[i] = atomic_energies[i][0];
if (ii < inum) eng_vdwl += atomic_energies[i][0];
}
template <int nequip_mode> void PairAllegro<nequip_mode>::get_tag2i(std::vector<int> &tag2i){
int inum = list->inum;
int *ilist = list->ilist;
tagint *tag = atom->tag;
for(int ii = 0; ii < inum; ii++){
int i = ilist[ii];
int itag = tag[i];

if (vflag) {
torch::Tensor v_tensor = output.at("virial").toTensor().cpu();
auto v = v_tensor.accessor<outputtype, 3>();
// Convert from 3x3 symmetric tensor format, which NequIP outputs, to the flattened form LAMMPS expects
// First [0] index on v is batch
virial[0] = v[0][0][0];
virial[1] = v[0][1][1];
virial[2] = v[0][2][2];
virial[3] = v[0][0][1];
virial[4] = v[0][0][2];
virial[5] = v[0][1][2];
// Inverse mapping from tag to x/f atom index
tag2i[itag-1] = i; // tag is probably 1-based
}
if (vflag_atom) { error->all(FLERR, "Pair style Allegro does not support per-atom virial"); }

// TODO: Figure out reliable solution
// if (update->ntimestep == this->output->next || update->ntimestep==0) {
if (debug_mode) {
std::cout << "ALLEGRO CUSTOM OUTPUT" << std::endl;
for (const auto &elem : output) {
std::cout << elem.key() << "\n" << elem.value() << std::endl;
}
}

for (const std::string &output_name : custom_output_names) {
if (!output.contains(output_name)) error->all(FLERR, "missing {}", output_name);
// printf("pair_allegro inserting %s\n", output_name.data()); fflush(stdout);
custom_output.insert_or_assign(output_name, output.at(output_name).toTensor().detach());
}
// }
}

template <Precision precision> void PairAllegro<precision>::add_custom_output(std::string name)
template <int nequip_mode> void PairAllegro<nequip_mode>::add_custom_output(std::string name)
{
custom_output_names.push_back(name);
}

namespace LAMMPS_NS {
template class PairAllegro<lowlow>;
template class PairAllegro<highhigh>;
template class PairAllegro<lowhigh>;
template class PairAllegro<highlow>;
template class PairAllegro<0>;
template class PairAllegro<1>;
} // namespace LAMMPS_NS
Loading

0 comments on commit 10fc99b

Please sign in to comment.