Skip to content

Commit

Permalink
Provenance overhaul (#255)
Browse files Browse the repository at this point in the history
* Provenance last_updated bugs

* Last updated changed to build time

* Speed up snl structure matching

* Fix snl doc model

* Append only unque snls

* Provenance builder overhaul

* Linting

* More linting

* Fix provenance tests
  • Loading branch information
Jason Munro authored Sep 2, 2021
1 parent 661c0b3 commit ce8a71c
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 78 deletions.
155 changes: 86 additions & 69 deletions emmet-builders/emmet/builders/materials/provenance.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from collections import defaultdict
from itertools import chain
from typing import Dict, Iterable, List, Optional, Tuple
from math import ceil
from datetime import datetime

from maggma.core import Builder, Store
from maggma.utils import grouper
from pymatgen.analysis.structure_matcher import ElementComparator, StructureMatcher
from pymatgen.core.structure import Structure

from emmet.builders.settings import EmmetBuildSettings
from emmet.core.provenance import ProvenanceDoc, SNLDict
from emmet.core.utils import group_structures
from emmet.core.utils import group_structures, get_sg


class ProvenanceBuilder(Builder):
Expand Down Expand Up @@ -62,7 +65,7 @@ def prechunk(self, number_splits: int) -> Iterable[Dict]:
# Find all formulas for materials that have been updated since this
# builder was last ran
q = self.query
updated_materials = self.provenance.newer_in(self.materials, criteria=q, exhaustive=True,)
updated_materials = self.provenance.newer_in(self.materials, criteria=q, exhaustive=True)
forms_to_update = set(self.materials.distinct("formula_pretty", {"material_id": {"$in": updated_materials}}))

# Find all new SNL formulas since the builder was last run
Expand All @@ -74,13 +77,16 @@ def prechunk(self, number_splits: int) -> Iterable[Dict]:
forms_avail = set(self.materials.distinct("formula_pretty", self.query))
forms_to_update = forms_to_update & forms_avail

self.logger.info(
f"Found {len(forms_to_update)} new/updated systems to distribute to workers "
f"in chunks of {len(forms_to_update)/number_splits}"
mat_ids = set(self.materials.distinct("material_id", {"formula_pretty": {"$in": list(forms_to_update)}})) & set(
updated_materials
)

for chunk in grouper(forms_to_update, number_splits):
yield {"formula_pretty": {"$in": chunk}}
N = ceil(len(mat_ids) / number_splits)

self.logger.info(f"Found {len(mat_ids)} new/updated systems to distribute to workers " f"in {N} chunks.")

for chunk in grouper(mat_ids, N):
yield {"query": {"material_id": {"$in": chunk}}}

def get_items(self) -> Tuple[List[Dict], List[Dict]]: # type: ignore
"""
Expand All @@ -96,7 +102,7 @@ def get_items(self) -> Tuple[List[Dict], List[Dict]]: # type: ignore
# Find all formulas for materials that have been updated since this
# builder was last ran
q = self.query
updated_materials = self.provenance.newer_in(self.materials, criteria=q, exhaustive=True,)
updated_materials = self.provenance.newer_in(self.materials, criteria=q, exhaustive=True)
forms_to_update = set(self.materials.distinct("formula_pretty", {"material_id": {"$in": updated_materials}}))

# Find all new SNL formulas since the builder was last run
Expand All @@ -108,82 +114,91 @@ def get_items(self) -> Tuple[List[Dict], List[Dict]]: # type: ignore
forms_avail = set(self.materials.distinct("formula_pretty", self.query))
forms_to_update = forms_to_update & forms_avail

self.logger.info(f"Found {len(forms_to_update)} new/updated systems to process")
mat_ids = set(self.materials.distinct("material_id", {"formula_pretty": {"$in": list(forms_to_update)}})) & set(
updated_materials
)

self.total = len(mat_ids)

self.total = len(forms_to_update)
self.logger.info(f"Found {self.total} new/updated systems to process")

for mat_id in mat_ids:

mat = self.materials.query_one(
properties=[
"material_id",
"last_updated",
"structure",
"initial_structures",
"formula_pretty",
"deprecated",
],
criteria={"material_id": mat_id},
)

for formulas in grouper(forms_to_update, self.chunk_size):
snls = [] # type: list
for source in self.source_snls:
snls.extend(source.query(criteria={"formula_pretty": {"$in": formulas}}))

mats = list(
self.materials.query(
properties=[
"material_id",
"last_updated",
"structure",
"initial_structures",
"formula_pretty",
"deprecated",
],
criteria={"formula_pretty": {"$in": formulas}},
)
)
snls.extend(source.query(criteria={"formula_pretty": mat["formula_pretty"]}))

form_groups = defaultdict(list)
snl_groups = defaultdict(list)
for snl in snls:
form_groups[snl["formula_pretty"]].append(snl)
struc = Structure.from_dict(snl)
snl_sg = get_sg(struc)
struc.snl = SNLDict(**snl)
snl_groups[snl_sg].append(struc)

mat_groups = defaultdict(list)
for mat in mats:
mat_groups[mat["formula_pretty"]].append(mat)
mat_sg = get_sg(Structure.from_dict(mat["structure"]))

for formula, snl_group in form_groups.items():
snl_structs = snl_groups[mat_sg]

mat_group = mat_groups[formula]
self.logger.debug(f"Found {len(snl_structs)} potential snls for {mat_id}")
yield mat, snl_structs

self.logger.debug(f"Found {len(snl_group)} snls and {len(mat_group)} mats")
yield mat_group, snl_group

def process_item(self, item) -> List[Dict]:
def process_item(self, item) -> Dict:
"""
Matches SNLS and Materials
Args:
item (tuple): a tuple of materials and snls
Returns:
list(dict): a list of collected snls with material ids
"""
mats, source_snls = item
formula_pretty = mats[0]["formula_pretty"]
snl_docs = list()
mat, snl_structs = item
formula_pretty = mat["formula_pretty"]
snl_doc = None
self.logger.debug(f"Finding Provenance {formula_pretty}")

# Match up SNLS with materials
for mat in mats:
matched_snls = list(self.match(source_snls, mat))

if len(matched_snls) > 0:
doc = ProvenanceDoc.from_SNLs(
material_id=mat["material_id"],
structure=Structure.from_dict(mat["structure"]),
snls=matched_snls,
deprecated=mat["deprecated"],
)
matched_snls = self.match(snl_structs, mat)

if len(matched_snls) > 0:
doc = ProvenanceDoc.from_SNLs(
material_id=mat["material_id"],
structure=Structure.from_dict(mat["structure"]),
snls=matched_snls,
deprecated=mat["deprecated"],
)
else:
doc = ProvenanceDoc(
material_id=mat["material_id"],
structure=Structure.from_dict(mat["structure"]),
deprecated=mat["deprecated"],
created_at=datetime.utcnow(),
)

doc.authors.append(self.settings.DEFAULT_AUTHOR)
doc.history.append(self.settings.DEFAULT_HISTORY)
doc.references.append(self.settings.DEFAULT_REFERENCE)
doc.authors.append(self.settings.DEFAULT_AUTHOR)
doc.history.append(self.settings.DEFAULT_HISTORY)
doc.references.append(self.settings.DEFAULT_REFERENCE)

snl_docs.append(doc.dict(exclude_none=True))
snl_doc = doc.dict(exclude_none=True)

return snl_docs
return snl_doc

def match(self, snls, mat):
def match(self, snl_structs, mat):
"""
Finds a material doc that matches with the given snl
Args:
snl ([dict]): the snls list
snl_structs ([dict]): the snls struct list
mat (dict): a materials doc
Returns:
generator of materials doc keys
Expand All @@ -192,22 +207,25 @@ def match(self, snls, mat):
m_strucs = [Structure.from_dict(mat["structure"])] + [
Structure.from_dict(init_struc) for init_struc in mat["initial_structures"]
]
snl_strucs = []
for snl in snls:
struc = Structure.from_dict(snl)
struc.snl = SNLDict(**snl)
snl_strucs.append(struc)

groups = group_structures(
m_strucs + snl_strucs,

sm = StructureMatcher(
ltol=self.settings.LTOL,
stol=self.settings.STOL,
angle_tol=self.settings.ANGLE_TOL,
# comparator=OrderDisorderElementComparator(),
primitive_cell=True,
scale=True,
attempt_supercell=False,
allow_subset=False,
comparator=ElementComparator(),
)

matched_groups = [group for group in groups if any(not hasattr(struc, "snl") for struc in group)]
snls = [struc.snl for group in matched_groups for struc in group if hasattr(struc, "snl")]
snls = []

for s in m_strucs:
for snl_struc in snl_structs:
if sm.fit(s, snl_struc):
if snl_struc.snl not in snls:
snls.append(snl_struc.snl)

self.logger.debug(f"Found {len(snls)} SNLs for {mat['material_id']}")
return snls
Expand All @@ -216,8 +234,7 @@ def update_targets(self, items):
"""
Inserts the new SNL docs into the SNL collection
"""

snls = list(filter(None, chain.from_iterable(items)))
snls = list(filter(None, items))

if len(snls) > 0:
self.logger.info(f"Found {len(snls)} SNLs to update")
Expand Down
17 changes: 9 additions & 8 deletions emmet-core/emmet/core/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ class History(BaseModel):
name: str
url: str
description: Optional[Dict] = Field(None, description="Dictionary of exra data for this history node")
experimental: Optional[bool] = Field(
False, description="Whether this node dictates this is an experimental history not",
)

@root_validator(pre=True)
def str_to_dict(cls, values):
Expand Down Expand Up @@ -133,7 +130,7 @@ def from_SNLs(cls, material_id: MPID, structure: Structure, snls: List[SNLDict],

# Choose earliest created_at
created_at = min([snl.about.created_at for snl in snls])
last_updated = max([snl.about.created_at for snl in snls])
# last_updated = max([snl.about.created_at for snl in snls])

# Choose earliest history
history = sorted(snls, key=lambda snl: snl.about.created_at)[0].about.history
Expand All @@ -160,7 +157,13 @@ def from_SNLs(cls, material_id: MPID, structure: Structure, snls: List[SNLDict],
authors = [entry for snl in snls for entry in snl.about.authors]

# Check if this entry is experimental
experimental = any(history.experimental for snl in snls for history in snl.about.history)
exp_vals = []
for snl in snls:
for entry in snl.about.history:
if entry.description is not None:
exp_vals.append(entry.description.get("experimental", False))

experimental = any(exp_vals)

# Aggregate all the database IDs
snl_ids = {snl.snl_id for snl in snls}
Expand All @@ -181,6 +184,4 @@ def from_SNLs(cls, material_id: MPID, structure: Structure, snls: List[SNLDict],
"history": history,
}

return super().from_structure(
material_id=material_id, meta_structure=structure, last_updated=last_updated, **fields, **kwargs,
)
return super().from_structure(material_id=material_id, meta_structure=structure, **fields, **kwargs,)
2 changes: 1 addition & 1 deletion tests/emmet-core/test_provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_from_snls(snls, structure):
}

# Test experimental detection
snls[0].about.history[0].experimental = True
snls[0].about.history[0].description["experimental"] = True
assert (
ProvenanceDoc.from_SNLs(material_id="mp-3", snls=snls, structure=structure, deprecated=False).theoretical
is False
Expand Down

0 comments on commit ce8a71c

Please sign in to comment.