From fc94e7969ef05766b62caa1c78016cfe132d8821 Mon Sep 17 00:00:00 2001 From: kenko911 Date: Sat, 22 Jun 2024 09:24:55 -0700 Subject: [PATCH 01/11] improve TensorNet model coverage --- pyproject.toml | 2 +- tests/models/test_tensornet.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b25dd4d7..7d7cefa2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ classifiers = [ ] dependencies = [ "ase", - "dgl>=2.0.0", + "dgl", "pymatgen", "lightning", "torch<=2.2.1", diff --git a/tests/models/test_tensornet.py b/tests/models/test_tensornet.py index 4edb1a86..2a8dfee7 100644 --- a/tests/models/test_tensornet.py +++ b/tests/models/test_tensornet.py @@ -21,6 +21,7 @@ def test_model(self, graph_MoS): os.remove("model.pt") os.remove("model.json") os.remove("state.pt") + model = TensorNet(is_intensive=False, equivariance_invariance_group="SO(3)") def test_exceptions(self): with pytest.raises(ValueError, match="Invalid activation type"): From 53357d77fea999a36afa9fb85d00a10c0cd54161 Mon Sep 17 00:00:00 2001 From: Tsz Wai Ko <47970742+kenko911@users.noreply.github.com> Date: Sat, 22 Jun 2024 09:30:39 -0700 Subject: [PATCH 02/11] Update pyproject.toml Signed-off-by: Tsz Wai Ko <47970742+kenko911@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7d7cefa2..b25dd4d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ classifiers = [ ] dependencies = [ "ase", - "dgl", + "dgl>=2.0.0", "pymatgen", "lightning", "torch<=2.2.1", From f65cdba5d1e34e656cff74789965a8d069169bd2 Mon Sep 17 00:00:00 2001 From: kenko911 Date: Sat, 22 Jun 2024 09:34:43 -0700 Subject: [PATCH 03/11] Improve the unit test for SO(3) equivarance in TensorNet class --- tests/models/test_tensornet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/test_tensornet.py b/tests/models/test_tensornet.py index 2a8dfee7..f7767046 100644 --- a/tests/models/test_tensornet.py +++ b/tests/models/test_tensornet.py @@ -22,6 +22,7 @@ def test_model(self, graph_MoS): os.remove("model.json") os.remove("state.pt") model = TensorNet(is_intensive=False, equivariance_invariance_group="SO(3)") + assert torch.numel(output) == 1 def test_exceptions(self): with pytest.raises(ValueError, match="Invalid activation type"): From 16abc3840201bd9c30f645669d862dc694ff844c Mon Sep 17 00:00:00 2001 From: kenko911 Date: Sat, 22 Jun 2024 18:12:32 -0700 Subject: [PATCH 04/11] improve SO3Net model class coverage and simplify TensorNet implementations --- src/matgl/models/_tensornet.py | 5 +---- tests/models/test_so3net.py | 10 ++++++---- tests/models/test_tensornet.py | 2 -- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/matgl/models/_tensornet.py b/src/matgl/models/_tensornet.py index bfed4610..e1129b09 100644 --- a/src/matgl/models/_tensornet.py +++ b/src/matgl/models/_tensornet.py @@ -123,10 +123,7 @@ def __init__( f"Invalid activation type, please try using one of {[af.name for af in ActivationFunction]}" ) from None - if element_types is None: - self.element_types = DEFAULT_ELEMENTS - else: - self.element_types = element_types # type: ignore + self.element_types = element_types # type: ignore self.bond_expansion = BondExpansion( cutoff=cutoff, diff --git a/tests/models/test_so3net.py b/tests/models/test_so3net.py index 9f8c6b0d..4f2c9c0e 100644 --- a/tests/models/test_so3net.py +++ b/tests/models/test_so3net.py @@ -37,16 +37,18 @@ def test_model_intensive(self, graph_MoS): output = model(g=graph) assert torch.numel(output) == 2 - def test_model_intensive_with_weighted_atom(self, graph_MoS): + def test_model_intensive_reduce_atom_classification(self, graph_MoS): structure, graph, state = graph_MoS lat = torch.tensor(np.array([structure.lattice.matrix]), dtype=matgl.float_th) graph.edata["pbc_offshift"] = torch.matmul(graph.edata["pbc_offset"], lat[0]) graph.ndata["pos"] = graph.ndata["frac_coords"] @ lat[0] - model = SO3Net(element_types=["Mo", "S"], is_intensive=True, readout_type="weighted_atom") + model = SO3Net( + element_types=["Mo", "S"], is_intensive=True, readout_type="reduce_atom", target_property="graph" + ) output = model(g=graph) - assert torch.numel(output) == 2 + assert torch.numel(output) == 1 - def test_model_intensive_with_classification(self, graph_MoS): + def test_model_intensive_weighted_atom_classification(self, graph_MoS): structure, graph, state = graph_MoS lat = torch.tensor(np.array([structure.lattice.matrix]), dtype=matgl.float_th) graph.edata["pbc_offshift"] = torch.matmul(graph.edata["pbc_offset"], lat[0]) diff --git a/tests/models/test_tensornet.py b/tests/models/test_tensornet.py index f7767046..4edb1a86 100644 --- a/tests/models/test_tensornet.py +++ b/tests/models/test_tensornet.py @@ -21,8 +21,6 @@ def test_model(self, graph_MoS): os.remove("model.pt") os.remove("model.json") os.remove("state.pt") - model = TensorNet(is_intensive=False, equivariance_invariance_group="SO(3)") - assert torch.numel(output) == 1 def test_exceptions(self): with pytest.raises(ValueError, match="Invalid activation type"): From 2798176026ec99b8deb0f85c8b1e0e02f79e0cfe Mon Sep 17 00:00:00 2001 From: kenko911 Date: Sun, 23 Jun 2024 21:53:02 -0700 Subject: [PATCH 05/11] improve the coverage in MLP_norm class --- tests/layers/test_core_and_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/layers/test_core_and_embedding.py b/tests/layers/test_core_and_embedding.py index 1d74d938..8d94689f 100644 --- a/tests/layers/test_core_and_embedding.py +++ b/tests/layers/test_core_and_embedding.py @@ -48,7 +48,7 @@ def test_gated_mlp(self, x): @pytest.mark.parametrize("normalization", ["layer", "graph"]) def test_mlp_norm(self, x, graph, normalization): - layer = MLP_norm(dims=[10, 3], normalization=normalization) + layer = MLP_norm(dims=[10, 3], normalization=normalization, normalize_hidden=True) out = layer(x, g=graph).double() assert [out.size()[0], out.size()[1]] == [4, 3] assert out.mean().item() == pytest.approx(0, abs=1e-6) From 9b179635b755a112125d08c0e509b9e2fe9bd554 Mon Sep 17 00:00:00 2001 From: kenko911 Date: Wed, 3 Jul 2024 11:03:03 -0700 Subject: [PATCH 06/11] Improve the implementation of three-body interactions --- src/matgl/layers/_three_body.py | 63 +++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 22 deletions(-) diff --git a/src/matgl/layers/_three_body.py b/src/matgl/layers/_three_body.py index 7a8151d1..6d196093 100644 --- a/src/matgl/layers/_three_body.py +++ b/src/matgl/layers/_three_body.py @@ -19,7 +19,8 @@ class ThreeBodyInteractions(nn.Module): """Include 3D interactions to the bond update.""" def __init__(self, update_network_atom: nn.Module, update_network_bond: nn.Module, **kwargs): - """Init ThreeBodyInteractions. + """ + Initialize ThreeBodyInteractions. Args: update_network_atom: MLP for node features in Eq.2 @@ -31,45 +32,63 @@ def __init__(self, update_network_atom: nn.Module, update_network_bond: nn.Modul self.update_network_bond = update_network_bond def forward( - self, - graph: dgl.DGLGraph, - line_graph: dgl.DGLGraph, - three_basis: torch.Tensor, - three_cutoff: float, - node_feat: torch.Tensor, - edge_feat: torch.Tensor, + self, + graph: dgl.DGLGraph, + line_graph: dgl.DGLGraph, + three_basis: torch.Tensor, + three_cutoff: torch.Tensor, + node_feat: torch.Tensor, + edge_feat: torch.Tensor, ): """ Forward function for ThreeBodyInteractions. Args: graph: dgl graph - line_graph: line graph. + line_graph: line graph three_basis: three body basis expansion three_cutoff: cutoff radius node_feat: node features - edge_feat: edge features. + edge_feat: edge features """ - end_atom_index = torch.gather(graph.edges()[1], 0, line_graph.edges()[1].to(torch.int64)) - atoms = self.update_network_atom(node_feat) - end_atom_index = torch.unsqueeze(end_atom_index, 1) - atoms = torch.squeeze(atoms[end_atom_index]) - basis = three_basis * atoms - three_cutoff = torch.unsqueeze(three_cutoff, dim=1) # type: ignore - weights = three_cutoff[torch.stack(list(line_graph.edges()), dim=1)].view(-1, 2) # type: ignore - weights = torch.prod(weights, dim=-1) # type: ignore + # Get the indices of the end atoms for each bond in the line graph + end_atom_indices = graph.edges()[1][line_graph.edges()[1]].to(matgl.int_th) + + # Update node features using the atom update network + updated_atoms = self.update_network_atom(node_feat) + + # Gather updated atom features for the end atoms + end_atom_features = updated_atoms[end_atom_indices] + + # Compute the basis term + basis = three_basis * end_atom_features + + # Reshape and compute weights based on the three-cutoff tensor + three_cutoff = three_cutoff.unsqueeze(1) + edge_indices = torch.stack(list(line_graph.edges()), dim=1) + weights = three_cutoff[edge_indices].view(-1, 2) + weights = weights.prod(dim=-1) + + # Compute the weighted basis basis = basis * weights[:, None] + + # Aggregate the new bonds using scatter_sum + segment_ids = get_segment_indices_from_n(line_graph.ndata["n_triple_ij"]) new_bonds = scatter_sum( basis.to(matgl.float_th), - segment_ids=get_segment_indices_from_n(line_graph.ndata["n_triple_ij"]), + segment_ids=segment_ids, num_segments=graph.num_edges(), dim=0, ) - if not new_bonds.data.shape[0]: + + # If no new bonds are generated, return the original edge features + if new_bonds.shape[0] == 0: return edge_feat - edge_feat_updated = edge_feat + self.update_network_bond(new_bonds) - return edge_feat_updated + # Update edge features using the bond update network + updated_edge_feat = edge_feat + self.update_network_bond(new_bonds) + + return updated_edge_feat def combine_sbf_shf(sbf, shf, max_n: int, max_l: int, use_phi: bool): """Combine the spherical Bessel function and the spherical Harmonics function. From dc6ed594ca275bb0f69e7634b543770aa25e9364 Mon Sep 17 00:00:00 2001 From: kenko911 Date: Wed, 3 Jul 2024 11:09:50 -0700 Subject: [PATCH 07/11] fixed black --- src/matgl/layers/_three_body.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/matgl/layers/_three_body.py b/src/matgl/layers/_three_body.py index 6d196093..38235ae2 100644 --- a/src/matgl/layers/_three_body.py +++ b/src/matgl/layers/_three_body.py @@ -32,13 +32,13 @@ def __init__(self, update_network_atom: nn.Module, update_network_bond: nn.Modul self.update_network_bond = update_network_bond def forward( - self, - graph: dgl.DGLGraph, - line_graph: dgl.DGLGraph, - three_basis: torch.Tensor, - three_cutoff: torch.Tensor, - node_feat: torch.Tensor, - edge_feat: torch.Tensor, + self, + graph: dgl.DGLGraph, + line_graph: dgl.DGLGraph, + three_basis: torch.Tensor, + three_cutoff: torch.Tensor, + node_feat: torch.Tensor, + edge_feat: torch.Tensor, ): """ Forward function for ThreeBodyInteractions. @@ -90,6 +90,7 @@ def forward( return updated_edge_feat + def combine_sbf_shf(sbf, shf, max_n: int, max_l: int, use_phi: bool): """Combine the spherical Bessel function and the spherical Harmonics function. From 522809cdd591cf83c5a63c233080844db0804563 Mon Sep 17 00:00:00 2001 From: kenko911 Date: Fri, 5 Jul 2024 10:32:36 -0700 Subject: [PATCH 08/11] Optimize the speed of _compute_3body class --- src/matgl/graph/compute.py | 45 +++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/src/matgl/graph/compute.py b/src/matgl/graph/compute.py index 9a5645cb..8af2a857 100644 --- a/src/matgl/graph/compute.py +++ b/src/matgl/graph/compute.py @@ -88,7 +88,7 @@ def create_line_graph(g: dgl.DGLGraph, threebody_cutoff: float, directed: bool = if directed: lg = _create_directed_line_graph(graph_with_three_body, threebody_cutoff) else: - lg, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s = _compute_3body(graph_with_three_body) + lg = _compute_3body(graph_with_three_body) return lg @@ -174,49 +174,44 @@ def _compute_3body(g: dgl.DGLGraph): n_triple_s (np.ndarray): number of three-body angles for each structure """ n_atoms = g.num_nodes() - first_col = g.edges()[0].cpu().numpy().reshape(-1, 1) - all_indices = np.arange(n_atoms).reshape(1, -1) - n_bond_per_atom = np.count_nonzero(first_col == all_indices, axis=0) + first_col = g.edges()[0].cpu().numpy() + + # Count bonds per atom efficiently + n_bond_per_atom = np.bincount(first_col, minlength=n_atoms) + n_triple_i = n_bond_per_atom * (n_bond_per_atom - 1) - n_triple = np.sum(n_triple_i) + n_triple = n_triple_i.sum() n_triple_ij = np.repeat(n_bond_per_atom - 1, n_bond_per_atom) - triple_bond_indices = np.empty((n_triple, 2), dtype=matgl.int_np) # type: ignore + + triple_bond_indices = np.empty((n_triple, 2), dtype=matgl.int_np) start = 0 cs = 0 for n in n_bond_per_atom: if n > 0: - """ - triple_bond_indices is generated from all pair permutations of atom indices. The - numpy version below does this with much greater efficiency. The equivalent slow - code is: - - ``` - for j, k in itertools.permutations(range(n), 2): - triple_bond_indices[index] = [start + j, start + k] - ``` - """ r = np.arange(n) x, y = np.meshgrid(r, r, indexing="xy") - c = np.stack([y.ravel(), x.ravel()], axis=1) - final = c[c[:, 0] != c[:, 1]] - triple_bond_indices[start : start + (n * (n - 1)), :] = final + cs + final = np.stack([y.ravel(), x.ravel()], axis=1) + mask = final[:, 0] != final[:, 1] + final = final[mask] + triple_bond_indices[start : start + n * (n - 1)] = final + cs start += n * (n - 1) cs += n - n_triple_s = [np.sum(n_triple_i[0:n_atoms])] src_id = torch.tensor(triple_bond_indices[:, 0], dtype=matgl.int_th) dst_id = torch.tensor(triple_bond_indices[:, 1], dtype=matgl.int_th) l_g = dgl.graph((src_id, dst_id)).to(g.device) - three_body_id = torch.concatenate(l_g.edges()) - n_triple_ij = torch.tensor(n_triple_ij, dtype=matgl.int_th).to(g.device) - max_three_body_id = torch.max(three_body_id) + 1 if three_body_id.numel() > 0 else 0 + three_body_id = torch.cat(l_g.edges()) + n_triple_ij = torch.tensor(n_triple_ij, dtype=matgl.int_th, device=g.device) + + max_three_body_id = three_body_id.max().item() + 1 if three_body_id.numel() > 0 else 0 + l_g.ndata["bond_dist"] = g.edata["bond_dist"][:max_three_body_id] l_g.ndata["bond_vec"] = g.edata["bond_vec"][:max_three_body_id] l_g.ndata["pbc_offset"] = g.edata["pbc_offset"][:max_three_body_id] l_g.ndata["n_triple_ij"] = n_triple_ij[:max_three_body_id] - n_triple_s = torch.tensor(n_triple_s, dtype=matgl.int_th) # type: ignore - return l_g, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s + + return l_g def _create_directed_line_graph(graph: dgl.DGLGraph, threebody_cutoff: float) -> dgl.DGLGraph: From 27c728f8e7a2c6f05c36c9d3f5734704e4fe632b Mon Sep 17 00:00:00 2001 From: kenko911 Date: Mon, 8 Jul 2024 09:03:44 -0700 Subject: [PATCH 09/11] type checking is added for scheduler --- src/matgl/utils/training.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/matgl/utils/training.py b/src/matgl/utils/training.py index 17953e12..27e15283 100644 --- a/src/matgl/utils/training.py +++ b/src/matgl/utils/training.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: import dgl import numpy as np - from torch.optim import Optimizer + from torch.optim import LRScheduler, Optimizer class MatglLightningModuleMixin: @@ -147,7 +147,7 @@ def __init__( data_std: float = 1.0, loss: str = "mse_loss", optimizer: Optimizer | None = None, - scheduler=None, + scheduler: LRScheduler | None = None, lr: float = 0.001, decay_steps: int = 1000, decay_alpha: float = 0.01, @@ -270,7 +270,7 @@ def __init__( loss: str = "mse_loss", loss_params: dict | None = None, optimizer: Optimizer | None = None, - scheduler=None, + scheduler: LRScheduler | None = None, lr: float = 0.001, decay_steps: int = 1000, decay_alpha: float = 0.01, From 3ac4f396cf98b5f6dbda6ca8d1a878b9d283ddfa Mon Sep 17 00:00:00 2001 From: kenko911 Date: Sun, 14 Jul 2024 12:26:38 -0700 Subject: [PATCH 10/11] update M3GNet Potential training notebook for the demonstration of obtaining and using element offsets --- .../Training a M3GNet Potential with PyTorch Lightning.ipynb | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/Training a M3GNet Potential with PyTorch Lightning.ipynb b/examples/Training a M3GNet Potential with PyTorch Lightning.ipynb index 28325197..9f4832cc 100644 --- a/examples/Training a M3GNet Potential with PyTorch Lightning.ipynb +++ b/examples/Training a M3GNet Potential with PyTorch Lightning.ipynb @@ -269,7 +269,10 @@ "# download a pre-trained M3GNet\n", "m3gnet_nnp = matgl.load_model(\"M3GNet-MP-2021.2.8-PES\")\n", "model_pretrained = m3gnet_nnp.model\n", - "lit_module_finetune = PotentialLightningModule(model=model_pretrained, lr=1e-4, include_line_graph=True)" + "# obtain element energy offset\n", + "property_offset = m3gnet_nnp.element_refs.property_offset\n", + "# you should test whether including the original property_offset helps improve training and validation accuracy\n", + "lit_module_finetune = PotentialLightningModule(model=model_pretrained, element_refs=property_offset, lr=1e-4, include_line_graph=True)" ] }, { From 992c17e9a4bf09475f82f7309b9e192e4a57f802 Mon Sep 17 00:00:00 2001 From: kenko911 Date: Sun, 14 Jul 2024 12:47:35 -0700 Subject: [PATCH 11/11] Downgrade sympy to avoid crash of SO3 operations --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index b3f54ece..0208189b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ ase==3.23.0 pydantic==2.7.1 boto3==1.34.101 numpy==1.26.4 +sympy==1.12.1