diff --git a/emmet-builders/emmet/builders/materials/provenance.py b/emmet-builders/emmet/builders/materials/provenance.py index 6480f9f3fc..314c588941 100644 --- a/emmet-builders/emmet/builders/materials/provenance.py +++ b/emmet-builders/emmet/builders/materials/provenance.py @@ -1,14 +1,17 @@ from collections import defaultdict from itertools import chain from typing import Dict, Iterable, List, Optional, Tuple +from math import ceil +from datetime import datetime from maggma.core import Builder, Store from maggma.utils import grouper +from pymatgen.analysis.structure_matcher import ElementComparator, StructureMatcher from pymatgen.core.structure import Structure from emmet.builders.settings import EmmetBuildSettings from emmet.core.provenance import ProvenanceDoc, SNLDict -from emmet.core.utils import group_structures +from emmet.core.utils import group_structures, get_sg class ProvenanceBuilder(Builder): @@ -62,7 +65,7 @@ def prechunk(self, number_splits: int) -> Iterable[Dict]: # Find all formulas for materials that have been updated since this # builder was last ran q = self.query - updated_materials = self.provenance.newer_in(self.materials, criteria=q, exhaustive=True,) + updated_materials = self.provenance.newer_in(self.materials, criteria=q, exhaustive=True) forms_to_update = set(self.materials.distinct("formula_pretty", {"material_id": {"$in": updated_materials}})) # Find all new SNL formulas since the builder was last run @@ -74,13 +77,16 @@ def prechunk(self, number_splits: int) -> Iterable[Dict]: forms_avail = set(self.materials.distinct("formula_pretty", self.query)) forms_to_update = forms_to_update & forms_avail - self.logger.info( - f"Found {len(forms_to_update)} new/updated systems to distribute to workers " - f"in chunks of {len(forms_to_update)/number_splits}" + mat_ids = set(self.materials.distinct("material_id", {"formula_pretty": {"$in": list(forms_to_update)}})) & set( + updated_materials ) - for chunk in grouper(forms_to_update, number_splits): - yield {"formula_pretty": {"$in": chunk}} + N = ceil(len(mat_ids) / number_splits) + + self.logger.info(f"Found {len(mat_ids)} new/updated systems to distribute to workers " f"in {N} chunks.") + + for chunk in grouper(mat_ids, N): + yield {"query": {"material_id": {"$in": chunk}}} def get_items(self) -> Tuple[List[Dict], List[Dict]]: # type: ignore """ @@ -96,7 +102,7 @@ def get_items(self) -> Tuple[List[Dict], List[Dict]]: # type: ignore # Find all formulas for materials that have been updated since this # builder was last ran q = self.query - updated_materials = self.provenance.newer_in(self.materials, criteria=q, exhaustive=True,) + updated_materials = self.provenance.newer_in(self.materials, criteria=q, exhaustive=True) forms_to_update = set(self.materials.distinct("formula_pretty", {"material_id": {"$in": updated_materials}})) # Find all new SNL formulas since the builder was last run @@ -108,45 +114,47 @@ def get_items(self) -> Tuple[List[Dict], List[Dict]]: # type: ignore forms_avail = set(self.materials.distinct("formula_pretty", self.query)) forms_to_update = forms_to_update & forms_avail - self.logger.info(f"Found {len(forms_to_update)} new/updated systems to process") + mat_ids = set(self.materials.distinct("material_id", {"formula_pretty": {"$in": list(forms_to_update)}})) & set( + updated_materials + ) + + self.total = len(mat_ids) - self.total = len(forms_to_update) + self.logger.info(f"Found {self.total} new/updated systems to process") + + for mat_id in mat_ids: + + mat = self.materials.query_one( + properties=[ + "material_id", + "last_updated", + "structure", + "initial_structures", + "formula_pretty", + "deprecated", + ], + criteria={"material_id": mat_id}, + ) - for formulas in grouper(forms_to_update, self.chunk_size): snls = [] # type: list for source in self.source_snls: - snls.extend(source.query(criteria={"formula_pretty": {"$in": formulas}})) - - mats = list( - self.materials.query( - properties=[ - "material_id", - "last_updated", - "structure", - "initial_structures", - "formula_pretty", - "deprecated", - ], - criteria={"formula_pretty": {"$in": formulas}}, - ) - ) + snls.extend(source.query(criteria={"formula_pretty": mat["formula_pretty"]})) - form_groups = defaultdict(list) + snl_groups = defaultdict(list) for snl in snls: - form_groups[snl["formula_pretty"]].append(snl) + struc = Structure.from_dict(snl) + snl_sg = get_sg(struc) + struc.snl = SNLDict(**snl) + snl_groups[snl_sg].append(struc) - mat_groups = defaultdict(list) - for mat in mats: - mat_groups[mat["formula_pretty"]].append(mat) + mat_sg = get_sg(Structure.from_dict(mat["structure"])) - for formula, snl_group in form_groups.items(): + snl_structs = snl_groups[mat_sg] - mat_group = mat_groups[formula] + self.logger.debug(f"Found {len(snl_structs)} potential snls for {mat_id}") + yield mat, snl_structs - self.logger.debug(f"Found {len(snl_group)} snls and {len(mat_group)} mats") - yield mat_group, snl_group - - def process_item(self, item) -> List[Dict]: + def process_item(self, item) -> Dict: """ Matches SNLS and Materials Args: @@ -154,36 +162,43 @@ def process_item(self, item) -> List[Dict]: Returns: list(dict): a list of collected snls with material ids """ - mats, source_snls = item - formula_pretty = mats[0]["formula_pretty"] - snl_docs = list() + mat, snl_structs = item + formula_pretty = mat["formula_pretty"] + snl_doc = None self.logger.debug(f"Finding Provenance {formula_pretty}") # Match up SNLS with materials - for mat in mats: - matched_snls = list(self.match(source_snls, mat)) - if len(matched_snls) > 0: - doc = ProvenanceDoc.from_SNLs( - material_id=mat["material_id"], - structure=Structure.from_dict(mat["structure"]), - snls=matched_snls, - deprecated=mat["deprecated"], - ) + matched_snls = self.match(snl_structs, mat) + + if len(matched_snls) > 0: + doc = ProvenanceDoc.from_SNLs( + material_id=mat["material_id"], + structure=Structure.from_dict(mat["structure"]), + snls=matched_snls, + deprecated=mat["deprecated"], + ) + else: + doc = ProvenanceDoc( + material_id=mat["material_id"], + structure=Structure.from_dict(mat["structure"]), + deprecated=mat["deprecated"], + created_at=datetime.utcnow(), + ) - doc.authors.append(self.settings.DEFAULT_AUTHOR) - doc.history.append(self.settings.DEFAULT_HISTORY) - doc.references.append(self.settings.DEFAULT_REFERENCE) + doc.authors.append(self.settings.DEFAULT_AUTHOR) + doc.history.append(self.settings.DEFAULT_HISTORY) + doc.references.append(self.settings.DEFAULT_REFERENCE) - snl_docs.append(doc.dict(exclude_none=True)) + snl_doc = doc.dict(exclude_none=True) - return snl_docs + return snl_doc - def match(self, snls, mat): + def match(self, snl_structs, mat): """ Finds a material doc that matches with the given snl Args: - snl ([dict]): the snls list + snl_structs ([dict]): the snls struct list mat (dict): a materials doc Returns: generator of materials doc keys @@ -192,22 +207,25 @@ def match(self, snls, mat): m_strucs = [Structure.from_dict(mat["structure"])] + [ Structure.from_dict(init_struc) for init_struc in mat["initial_structures"] ] - snl_strucs = [] - for snl in snls: - struc = Structure.from_dict(snl) - struc.snl = SNLDict(**snl) - snl_strucs.append(struc) - - groups = group_structures( - m_strucs + snl_strucs, + + sm = StructureMatcher( ltol=self.settings.LTOL, stol=self.settings.STOL, angle_tol=self.settings.ANGLE_TOL, - # comparator=OrderDisorderElementComparator(), + primitive_cell=True, + scale=True, + attempt_supercell=False, + allow_subset=False, + comparator=ElementComparator(), ) - matched_groups = [group for group in groups if any(not hasattr(struc, "snl") for struc in group)] - snls = [struc.snl for group in matched_groups for struc in group if hasattr(struc, "snl")] + snls = [] + + for s in m_strucs: + for snl_struc in snl_structs: + if sm.fit(s, snl_struc): + if snl_struc.snl not in snls: + snls.append(snl_struc.snl) self.logger.debug(f"Found {len(snls)} SNLs for {mat['material_id']}") return snls @@ -216,8 +234,7 @@ def update_targets(self, items): """ Inserts the new SNL docs into the SNL collection """ - - snls = list(filter(None, chain.from_iterable(items))) + snls = list(filter(None, items)) if len(snls) > 0: self.logger.info(f"Found {len(snls)} SNLs to update") diff --git a/emmet-core/emmet/core/provenance.py b/emmet-core/emmet/core/provenance.py index c95ac19f3a..aacdc42568 100644 --- a/emmet-core/emmet/core/provenance.py +++ b/emmet-core/emmet/core/provenance.py @@ -40,9 +40,6 @@ class History(BaseModel): name: str url: str description: Optional[Dict] = Field(None, description="Dictionary of exra data for this history node") - experimental: Optional[bool] = Field( - False, description="Whether this node dictates this is an experimental history not", - ) @root_validator(pre=True) def str_to_dict(cls, values): @@ -133,7 +130,7 @@ def from_SNLs(cls, material_id: MPID, structure: Structure, snls: List[SNLDict], # Choose earliest created_at created_at = min([snl.about.created_at for snl in snls]) - last_updated = max([snl.about.created_at for snl in snls]) + # last_updated = max([snl.about.created_at for snl in snls]) # Choose earliest history history = sorted(snls, key=lambda snl: snl.about.created_at)[0].about.history @@ -160,7 +157,13 @@ def from_SNLs(cls, material_id: MPID, structure: Structure, snls: List[SNLDict], authors = [entry for snl in snls for entry in snl.about.authors] # Check if this entry is experimental - experimental = any(history.experimental for snl in snls for history in snl.about.history) + exp_vals = [] + for snl in snls: + for entry in snl.about.history: + if entry.description is not None: + exp_vals.append(entry.description.get("experimental", False)) + + experimental = any(exp_vals) # Aggregate all the database IDs snl_ids = {snl.snl_id for snl in snls} @@ -181,6 +184,4 @@ def from_SNLs(cls, material_id: MPID, structure: Structure, snls: List[SNLDict], "history": history, } - return super().from_structure( - material_id=material_id, meta_structure=structure, last_updated=last_updated, **fields, **kwargs, - ) + return super().from_structure(material_id=material_id, meta_structure=structure, **fields, **kwargs,) diff --git a/tests/emmet-core/test_provenance.py b/tests/emmet-core/test_provenance.py index 8704a1efa9..f012f134f9 100644 --- a/tests/emmet-core/test_provenance.py +++ b/tests/emmet-core/test_provenance.py @@ -49,7 +49,7 @@ def test_from_snls(snls, structure): } # Test experimental detection - snls[0].about.history[0].experimental = True + snls[0].about.history[0].description["experimental"] = True assert ( ProvenanceDoc.from_SNLs(material_id="mp-3", snls=snls, structure=structure, deprecated=False).theoretical is False