diff --git a/matgl/graph/compute.py b/matgl/graph/compute.py index 2f848639..1d724cfb 100644 --- a/matgl/graph/compute.py +++ b/matgl/graph/compute.py @@ -1,6 +1,8 @@ """Computing various graph based operations.""" from __future__ import annotations +from typing import Callable + import dgl import numpy as np import torch @@ -8,7 +10,155 @@ import matgl -def compute_3body(g: dgl.DGLGraph): +def compute_pair_vector_and_distance(g: dgl.DGLGraph): + """Calculate bond vectors and distances using dgl graphs. + + Args: + g: DGL graph + + Returns: + bond_vec (torch.tensor): bond distance between two atoms + bond_dist (torch.tensor): vector from src node to dst node + """ + dst_pos = g.ndata["pos"][g.edges()[1]] + g.edata["pbc_offshift"] + src_pos = g.ndata["pos"][g.edges()[0]] + bond_vec = (dst_pos - src_pos).float() + bond_dist = torch.norm(bond_vec, dim=1) + + return bond_vec, bond_dist + + +def compute_theta_and_phi(edges: dgl.udf.EdgeBatch): + """Calculate bond angle Theta and Phi using dgl graphs. + + Args: + edges: DGL graph edges + + Returns: + cos_theta: torch.Tensor + phi: torch.Tensor + triple_bond_lengths (torch.tensor): + """ + angles = compute_theta(edges, cosine=True, directed=False) + angles["phi"] = torch.zeros_like(angles["cos_theta"]) + return angles + + +def compute_theta( + edges: dgl.udf.EdgeBatch, cosine: bool = False, directed: bool = True, eps=1e-7 +) -> dict[str, torch.Tensor]: + """User defined dgl function to calculate bond angles from edges in a graph. + + Args: + edges: DGL graph edges + cosine: Whether to return the cosine of the angle or the angle itself + directed: Whether to the line graph was created with create directed line graph. + In which case bonds (only those that are not self bonds) need to + have their bond vectors flipped. + eps: eps value used to clamp cosine values to avoid acos of values > 1.0 + + Returns: + dict[str, torch.Tensor]: Dictionary containing bond angles and distances + """ + vec1 = edges.src["bond_vec"] * edges.src["src_bond_sign"] if directed else edges.src["bond_vec"] + vec2 = edges.dst["bond_vec"] + key = "cos_theta" if cosine else "theta" + val = torch.sum(vec1 * vec2, dim=1) / (torch.norm(vec1, dim=1) * torch.norm(vec2, dim=1)) + val = val.clamp_(min=-1 + eps, max=1 - eps) # stability for floating point numbers > 1.0 + if not cosine: + val = torch.acos(val) + return {key: val, "triple_bond_lengths": edges.dst["bond_dist"]} + + +def create_line_graph(g: dgl.DGLGraph, threebody_cutoff: float, directed: bool = False) -> dgl.DGLGraph: + """ + Calculate the three body indices from pair atom indices. + + Args: + g: DGL graph + threebody_cutoff (float): cutoff for three-body interactions + directed (bool): Whether to create a directed line graph, or an m3gnet 3body line graph (default: False, m3gnet) + + Returns: + l_g: DGL graph containing three body information from graph + """ + graph_with_three_body = prune_edges_by_features(g, feat_name="bond_dist", condition=lambda x: x > threebody_cutoff) + if directed: + lg = _create_directed_line_graph(graph_with_three_body, threebody_cutoff) + else: + lg, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s = _compute_3body(graph_with_three_body) + + return lg + + +def ensure_line_graph_compatibility( + graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float, directed: bool = False, tol: float = 5e-7 +) -> dgl.DGLGraph: + """Ensure that line graph is compatible with graph. + + Sets edge data in line graph to be consistent with graph. The line graph is updated in place. + + Args: + graph: atomistic graph + line_graph: line graph of atomistic graph + threebody_cutoff: cutoff for three-body interactions + directed (bool): Whether to create a directed line graph, or an m3gnet 3body line graph (default: False, m3gnet) + tol: numerical tolerance for cutoff + """ + if directed: + line_graph = _ensure_directed_line_graph_compatibility(graph, line_graph, threebody_cutoff, tol) + else: + line_graph = _ensure_3body_line_graph_compatibility(graph, line_graph, threebody_cutoff) + + return line_graph + + +def prune_edges_by_features( + graph: dgl.DGLGraph, + feat_name: str, + condition: Callable[[torch.Tensor], torch.Tensor], + keep_ndata: bool = False, + keep_edata: bool = True, + *args, + **kwargs, +) -> dgl.DGLGraph: + """Removes edges graph that do satisfy given condition based on a specified feature value. + + Returns a new graph with edges removed. + + Args: + graph: DGL graph + feat_name: edge field name + condition: condition function. Must be a function where the first is the value + of the edge field data and returns a Tensor of boolean values. + keep_ndata: whether to keep node features + keep_edata: whether to keep edge features + *args: additional arguments to pass to condition function + **kwargs: additional keyword arguments to pass to condition function + + Returns: dgl.Graph with removed edges. + """ + if feat_name not in graph.edata: + raise ValueError(f"Edge field {feat_name} not an edge feature in given graph.") + + valid_edges = torch.logical_not(condition(graph.edata[feat_name], *args, **kwargs)) + src, dst = graph.edges() + src, dst = src[valid_edges], dst[valid_edges] + e_ids = valid_edges.nonzero().squeeze() + new_g = dgl.graph((src, dst), device=graph.device) + new_g.edata["edge_ids"] = e_ids # keep track of original edge ids + + if keep_ndata: + for key, value in graph.ndata.items(): + new_g.ndata[key] = value + if keep_edata: + for key, value in graph.edata.items(): + new_g.edata[key] = value[valid_edges] + + return new_g + + +def _compute_3body(g: dgl.DGLGraph): """Calculate the three body indices from pair atom indices. Args: @@ -67,76 +217,132 @@ def compute_3body(g: dgl.DGLGraph): return l_g, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s -def compute_pair_vector_and_distance(g: dgl.DGLGraph): - """Calculate bond vectors and distances using dgl graphs. +def _create_directed_line_graph(graph: dgl.DGLGraph, threebody_cutoff: float) -> dgl.DGLGraph: + """Creates a line graph from a graph, considers periodic boundary conditions. Args: - g: DGL graph + graph: DGL graph representing atom graph + threebody_cutoff: cutoff for three-body interactions Returns: - bond_vec (torch.tensor): bond distance between two atoms - bond_dist (torch.tensor): vector from src node to dst node + line_graph: DGL graph line graph of pruned graph to three body cutoff """ - dst_pos = g.ndata["pos"][g.edges()[1]] + g.edata["pbc_offshift"] - src_pos = g.ndata["pos"][g.edges()[0]] - bond_vec = (dst_pos - src_pos).float() - bond_dist = torch.norm(bond_vec, dim=1) + with torch.no_grad(): + pg = prune_edges_by_features(graph, feat_name="bond_dist", condition=lambda x: x > threebody_cutoff) + src_indices, dst_indices = pg.edges() + images = pg.edata["pbc_offset"] + all_indices = torch.arange(pg.number_of_nodes(), device=graph.device).unsqueeze(dim=0) + num_bonds_per_atom = torch.count_nonzero(src_indices.unsqueeze(dim=1) == all_indices, dim=0) + num_edges_per_bond = (num_bonds_per_atom - 1).repeat_interleave(num_bonds_per_atom) + lg_src = torch.empty(num_edges_per_bond.sum(), dtype=matgl.int_th, device=graph.device) + lg_dst = torch.empty(num_edges_per_bond.sum(), dtype=matgl.int_th, device=graph.device) - return bond_vec, bond_dist + incoming_edges = src_indices.unsqueeze(1) == dst_indices + is_self_edge = src_indices == dst_indices + not_self_edge = ~is_self_edge + n = 0 + # create line graph edges for bonds that are self edges in atom graph + if is_self_edge.any(): + edge_inds_s = is_self_edge.nonzero() + lg_dst_s = edge_inds_s.repeat_interleave(num_edges_per_bond[is_self_edge] + 1) + lg_src_s = incoming_edges[is_self_edge].nonzero()[:, 1].squeeze() + lg_src_s = lg_src_s[lg_src_s != lg_dst_s] + lg_dst_s = edge_inds_s.repeat_interleave(num_edges_per_bond[is_self_edge]) + n = len(lg_dst_s) + lg_src[:n], lg_dst[:n] = lg_src_s, lg_dst_s -def compute_theta_and_phi(edges: dgl.udf.EdgeBatch): - """Calculate bond angle Theta and Phi using dgl graphs. + # create line graph edges for bonds that are not self edges in atom graph + shared_src = src_indices.unsqueeze(1) == src_indices + back_tracking = (dst_indices.unsqueeze(1) == src_indices) & torch.all(-images.unsqueeze(1) == images, axis=2) + incoming = incoming_edges & (shared_src | ~back_tracking) - Args: - edges: DGL graph edges + edge_inds_ns = not_self_edge.nonzero().squeeze() + lg_src_ns = incoming[not_self_edge].nonzero()[:, 1].squeeze() + lg_dst_ns = edge_inds_ns.repeat_interleave(num_edges_per_bond[not_self_edge]) + lg_src[n:], lg_dst[n:] = lg_src_ns, lg_dst_ns + lg = dgl.graph((lg_src, lg_dst)) - Returns: - cos_theta: torch.Tensor - phi: torch.Tensor - triple_bond_lengths (torch.tensor): - """ - angles = compute_theta(edges, cosine=True) - angles["phi"] = torch.zeros_like(angles["cos_theta"]) - return angles + for key in pg.edata: + lg.ndata[key] = pg.edata[key][: lg.number_of_nodes()] + # we need to store the sign of bond vector when a bond is a src node in the line + # graph in order to appropriately calculate angles when self edges are involved + lg.ndata["src_bond_sign"] = torch.ones( + (lg.number_of_nodes(), 1), dtype=lg.ndata["bond_vec"].dtype, device=lg.device + ) + # if we flip self edges then we need to correct computed angles by pi - angle + # lg.ndata["src_bond_sign"][edge_inds_s] = -lg.ndata["src_bond_sign"][edge_ind_s] + # find the intersection for the rare cases where not all edges end up as nodes in the line graph + all_ns, counts = torch.cat([torch.arange(lg.number_of_nodes(), device=graph.device), edge_inds_ns]).unique( + return_counts=True + ) + lg_inds_ns = all_ns[torch.where(counts > 1)] + lg.ndata["src_bond_sign"][lg_inds_ns] = -lg.ndata["src_bond_sign"][lg_inds_ns] -def compute_theta(edges: dgl.udf.EdgeBatch, cosine: bool = False) -> dict[str, torch.Tensor]: - """User defined dgl function to calculate bond angles from edges in a graph. + return lg - Args: - edges: DGL graph edges - cosine: Whether to return the cosine of the angle or the angle itself - Returns: - dict[str, torch.Tensor]: Dictionary containing bond angles and distances - """ - vec1 = edges.src["bond_vec"] - vec2 = edges.dst["bond_vec"] - key = "cos_theta" if cosine else "theta" - val = torch.sum(vec1 * vec2, dim=1) / (torch.norm(vec1, dim=1) * torch.norm(vec2, dim=1)) - if not cosine: - val = torch.acos(val) - return {key: val, "triple_bond_lengths": edges.dst["bond_dist"]} +def _ensure_3body_line_graph_compatibility(graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float): + """Ensure that 3body line graph is compatible with a given graph. + Sets edge data in line graph to be consistent with graph. The line graph is updated in place. -def create_line_graph(g: dgl.DGLGraph, threebody_cutoff: float): + Args: + graph: atomistic graph + line_graph: line graph of atomistic graph + threebody_cutoff: cutoff for three-body interactions """ - Calculate the three body indices from pair atom indices. + valid_three_body = graph.edata["bond_dist"] <= threebody_cutoff + if line_graph.num_nodes() == graph.edata["bond_vec"][valid_three_body].shape[0]: + line_graph.ndata["bond_vec"] = graph.edata["bond_vec"][valid_three_body] + line_graph.ndata["bond_dist"] = graph.edata["bond_dist"][valid_three_body] + line_graph.ndata["pbc_offset"] = graph.edata["pbc_offset"][valid_three_body] + else: + three_body_id = torch.concatenate(line_graph.edges()) + max_three_body_id = torch.max(three_body_id) + 1 if three_body_id.numel() > 0 else 0 + line_graph.ndata["bond_vec"] = graph.edata["bond_vec"][:max_three_body_id] + line_graph.ndata["bond_dist"] = graph.edata["bond_dist"][:max_three_body_id] + line_graph.ndata["pbc_offset"] = graph.edata["pbc_offset"][:max_three_body_id] - Args: - g: DGL graph - threebody_cutoff (float): cutoff for three-body interactions + return line_graph - Returns: - l_g: DGL graph containing three body information from graph + +def _ensure_directed_line_graph_compatibility( + graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float, tol: float = 5e-7 +) -> dgl.DGLGraph: + """Ensure that line graph is compatible with graph. + + Sets edge data in line graph to be consistent with graph. The line graph is updated in place. + + Args: + graph: atomistic graph + line_graph: line graph of atomistic graph + threebody_cutoff: cutoff for three-body interactions + tol: numerical tolerance for cutoff """ - valid_three_body = g.edata["bond_dist"] <= threebody_cutoff - src_id_with_three_body = g.edges()[0][valid_three_body] - dst_id_with_three_body = g.edges()[1][valid_three_body] - graph_with_three_body = dgl.graph((src_id_with_three_body, dst_id_with_three_body)) - graph_with_three_body.edata["bond_dist"] = g.edata["bond_dist"][valid_three_body] - graph_with_three_body.edata["bond_vec"] = g.edata["bond_vec"][valid_three_body] - graph_with_three_body.edata["pbc_offset"] = g.edata["pbc_offset"][valid_three_body] - l_g, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s = compute_3body(graph_with_three_body) - return l_g + valid_edges = graph.edata["bond_dist"] <= threebody_cutoff + + # this means there probably is a bond that is just at the cutoff + # this should only really occur when batching graphs + if line_graph.number_of_nodes() > sum(valid_edges): + valid_edges = graph.edata["bond_dist"] <= threebody_cutoff + tol + + # check again and raise if invalid + if line_graph.number_of_nodes() > sum(valid_edges): + raise RuntimeError("Line graph is not compatible with graph.") + + edge_ids = valid_edges.nonzero().squeeze()[: line_graph.number_of_nodes()] + line_graph.ndata["edge_ids"] = edge_ids + + for key in graph.edata: + line_graph.ndata[key] = graph.edata[key][edge_ids] + + src_indices, dst_indices = graph.edges() + ns_edge_ids = (src_indices[edge_ids] != dst_indices[edge_ids]).nonzero().squeeze() + line_graph.ndata["src_bond_sign"] = torch.ones( + (line_graph.number_of_nodes(), 1), dtype=graph.edata["bond_vec"].dtype, device=line_graph.device + ) + line_graph.ndata["src_bond_sign"][ns_edge_ids] = -line_graph.ndata["src_bond_sign"][ns_edge_ids] + + return line_graph diff --git a/matgl/models/_m3gnet.py b/matgl/models/_m3gnet.py index 65e392a5..0ea981e0 100644 --- a/matgl/models/_m3gnet.py +++ b/matgl/models/_m3gnet.py @@ -22,6 +22,7 @@ compute_pair_vector_and_distance, compute_theta_and_phi, create_line_graph, + ensure_line_graph_compatibility, ) from matgl.layers import ( MLP, @@ -232,17 +233,7 @@ def forward( if l_g is None: l_g = create_line_graph(g, self.threebody_cutoff) else: - valid_three_body = g.edata["bond_dist"] <= self.threebody_cutoff - if l_g.num_nodes() == g.edata["bond_vec"][valid_three_body].shape[0]: - l_g.ndata["bond_vec"] = g.edata["bond_vec"][valid_three_body] - l_g.ndata["bond_dist"] = g.edata["bond_dist"][valid_three_body] - l_g.ndata["pbc_offset"] = g.edata["pbc_offset"][valid_three_body] - else: - three_body_id = torch.concatenate(l_g.edges()) - max_three_body_id = torch.max(three_body_id) + 1 if three_body_id.numel() > 0 else 0 - l_g.ndata["bond_vec"] = g.edata["bond_vec"][:max_three_body_id] - l_g.ndata["bond_dist"] = g.edata["bond_dist"][:max_three_body_id] - l_g.ndata["pbc_offset"] = g.edata["pbc_offset"][:max_three_body_id] + l_g = ensure_line_graph_compatibility(g, l_g, self.threebody_cutoff) l_g.apply_edges(compute_theta_and_phi) g.edata["rbf"] = expanded_dists three_body_basis = self.basis_expansion(l_g) diff --git a/tests/graph/test_compute.py b/tests/graph/test_compute.py index be0d118a..70344d22 100644 --- a/tests/graph/test_compute.py +++ b/tests/graph/test_compute.py @@ -3,7 +3,9 @@ from functools import partial import numpy as np +import pytest import torch +import torch.testing as tt from pymatgen.core import Lattice, Structure from matgl.ext.pymatgen import Structure2Graph, get_element_list @@ -12,6 +14,8 @@ compute_theta, compute_theta_and_phi, create_line_graph, + ensure_line_graph_compatibility, + prune_edges_by_features, ) @@ -46,8 +50,8 @@ def _calculate_cos_loop(graph, threebody_cutoff=4.0): for j in range(n_site): if i == j: continue - vi = graph.edata["bond_vec"][i + start_index].numpy() - vj = graph.edata["bond_vec"][j + start_index].numpy() + vi = graph.edata["bond_vec"][i + start_index].detach().numpy() + vj = graph.edata["bond_vec"][j + start_index].detach().numpy() di = np.linalg.norm(vi) dj = np.linalg.norm(vj) if (di <= threebody_cutoff) and (dj <= threebody_cutoff): @@ -114,14 +118,13 @@ def test_compute_angle(self, graph_Mo, graph_CH4): ) # test only compute theta - line_graph.apply_edges(compute_theta) - np.testing.assert_array_almost_equal( - np.sort(np.arccos(np.array(cos_loop))), np.sort(np.array(line_graph.edata["theta"])), decimal=4 - ) + line_graph.apply_edges(partial(compute_theta, directed=False)) + theta = np.arccos(np.clip(cos_loop, -1.0 + 1e-7, 1.0 - 1e-7)) + np.testing.assert_array_almost_equal(np.sort(theta), np.sort(np.array(line_graph.edata["theta"])), decimal=4) # test only compute theta with cosine _ = line_graph.edata.pop("cos_theta") - line_graph.apply_edges(partial(compute_theta, cosine=True)) + line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False)) np.testing.assert_array_almost_equal( np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata["cos_theta"])) ) @@ -140,14 +143,14 @@ def test_compute_angle(self, graph_Mo, graph_CH4): ) # test only compute theta - line_graph.apply_edges(compute_theta) + line_graph.apply_edges(partial(compute_theta, directed=False)) np.testing.assert_array_almost_equal( np.sort(np.arccos(np.array(cos_loop))), np.sort(np.array(line_graph.edata["theta"])) ) # test only compute theta with cosine _ = line_graph.edata.pop("cos_theta") - line_graph.apply_edges(partial(compute_theta, cosine=True)) + line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False)) np.testing.assert_array_almost_equal( np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata["cos_theta"])) ) @@ -186,3 +189,75 @@ def test_line_graph_extensive(): assert 2 * g1.number_of_edges() == g2.number_of_edges() assert 2 * lg1.number_of_nodes() == lg2.number_of_nodes() assert 2 * lg1.number_of_edges() == lg2.number_of_edges() + + +@pytest.mark.parametrize("keep_ndata", [True, False]) +@pytest.mark.parametrize("keep_edata", [True, False]) +def test_remove_edges_by_features(graph_Mo, keep_ndata, keep_edata): + s1, g1, state1 = graph_Mo + bv, bd = compute_pair_vector_and_distance(g1) + g1.edata["bond_vec"] = bv + g1.edata["bond_dist"] = bd + + new_cutoff = 3.0 + converter = Structure2Graph(element_types=get_element_list([s1]), cutoff=new_cutoff) + g2, state2 = converter.get_graph(s1) + + # remove edges by features + new_g = prune_edges_by_features( + g1, "bond_dist", condition=lambda x: x > new_cutoff, keep_ndata=keep_ndata, keep_edata=keep_edata + ) + valid_edges = g1.edata["bond_dist"] <= new_cutoff + + assert new_g.num_edges() == g2.num_edges() + assert new_g.num_nodes() == g2.num_nodes() + assert torch.allclose(new_g.edata["edge_ids"], valid_edges.nonzero().squeeze()) + + if keep_ndata: + assert new_g.ndata.keys() == g1.ndata.keys() + + if keep_edata: + for key in g1.edata: + if key != "edge_ids": + assert torch.allclose(new_g.edata[key], g1.edata[key][valid_edges]) + + +@pytest.mark.parametrize("cutoff", [2.0, 3.0, 4.0]) +@pytest.mark.parametrize("graph_data", ["graph_Mo", "graph_CH4", "graph_MoS", "graph_LiFePO4", "graph_MoSH"]) +def test_directed_line_graph(graph_data, cutoff, request): + s1, g1, state1 = request.getfixturevalue(graph_data) + bv, bd = compute_pair_vector_and_distance(g1) + g1.edata["bond_vec"] = bv + g1.edata["bond_dist"] = bd + cos_loop = _calculate_cos_loop(g1, cutoff) + theta_loop = np.arccos(np.clip(cos_loop, -1.0 + 1e-7, 1.0 - 1e-7)) + + line_graph = create_line_graph(g1, cutoff, directed=True) + line_graph.apply_edges(compute_theta) + + # this test might be lax with just 4 decimal places + np.testing.assert_array_almost_equal(np.sort(theta_loop), np.sort(np.array(line_graph.edata["theta"])), decimal=4) + + +@pytest.mark.parametrize("graph_data", ["graph_Mo", "graph_CH4", "graph_LiFePO4", "graph_MoSH"]) +def test_ensure_directed_line_graph_compat(graph_data, request): + s, g, state = request.getfixturevalue(graph_data) + bv, bd = compute_pair_vector_and_distance(g) + g.edata["bond_vec"] = bv + g.edata["bond_dist"] = bd + line_graph = create_line_graph(g, 3.0, directed=True) + edge_ids = line_graph.ndata["edge_ids"].clone() + src_bond_sign = line_graph.ndata["src_bond_sign"].clone() + line_graph.ndata["edge_ids"] = torch.zeros(line_graph.num_nodes(), dtype=torch.long) + line_graph.ndata["src_bond_sign"] = torch.zeros(line_graph.num_nodes()) + + assert not torch.allclose(line_graph.ndata["edge_ids"], edge_ids) + assert not torch.allclose(line_graph.ndata["src_bond_sign"], src_bond_sign) + + # test that the line graph is not compatible + line_graph = ensure_line_graph_compatibility(g, line_graph, 3.0, directed=True) + tt.assert_allclose(line_graph.ndata["edge_ids"], edge_ids) + tt.assert_allclose(line_graph.ndata["src_bond_sign"], src_bond_sign) + + with pytest.raises(RuntimeError): + ensure_line_graph_compatibility(g, line_graph, 1.0, directed=True)