Skip to content

Commit

Permalink
Add thermo type consideration to alloy builder (#663)
Browse files Browse the repository at this point in the history
* Add thermo type consideration to alloy builder

* Fix missing index issue in alloys

* Linting
  • Loading branch information
Jason Munro authored Feb 24, 2023
1 parent 615a62b commit 0319ec9
Showing 1 changed file with 38 additions and 57 deletions.
95 changes: 38 additions & 57 deletions emmet-builders/emmet/builders/materials/alloys.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
from itertools import combinations, chain
from typing import Tuple, List, Dict
from typing import Tuple, List, Dict, Union

from tqdm import tqdm
from maggma.builders import Builder
from pymatgen.core.structure import Structure
from matminer.datasets import load_dataset
from emmet.core.thermo import ThermoType

from pymatgen.analysis.alloys.core import (
AlloyPair,
InvalidAlloy,
KNOWN_ANON_FORMULAS,
AlloyMember,
AlloySystem,
)
from pymatgen.analysis.alloys.core import AlloyPair, InvalidAlloy, KNOWN_ANON_FORMULAS, AlloyMember, AlloySystem

# rough sort of ANON_FORMULAS by "complexity"
ANON_FORMULAS = sorted(KNOWN_ANON_FORMULAS, key=lambda af: len(af))
Expand All @@ -31,15 +26,9 @@ class AlloyPairBuilder(Builder):
It does not look for members of an AlloyPair.
"""

def __init__(
self,
materials,
thermo,
electronic_structure,
provenance,
oxi_states,
alloy_pairs,
):
def __init__(self, materials, thermo, electronic_structure,
provenance, oxi_states, alloy_pairs,
thermo_type: Union[ThermoType, str] = ThermoType.GGA_GGA_U_R2SCAN):

self.materials = materials
self.thermo = thermo
Expand All @@ -48,6 +37,13 @@ def __init__(
self.oxi_states = oxi_states
self.alloy_pairs = alloy_pairs

t_type = thermo_type if isinstance(thermo_type, str) else thermo_type.value
valid_types = {*map(str, ThermoType.__members__.values())}
if invalid_types := {t_type} - valid_types:
raise ValueError(f"Invalid thermo type(s) passed: {invalid_types}, valid types are: {valid_types}")

self.thermo_type = t_type

super().__init__(
sources=[materials, thermo, electronic_structure, provenance, oxi_states],
targets=[alloy_pairs],
Expand All @@ -60,9 +56,12 @@ def ensure_indexes(self):
self.alloy_pairs.ensure_index("_search.id")
self.alloy_pairs.ensure_index("_search.formula")
self.alloy_pairs.ensure_index("_search.member_ids")
self.alloy_pairs.ensure_index("alloy_pair.chemsys")

def get_items(self):

self.ensure_indexes()

for idx, af in enumerate(ANON_FORMULAS):

# if af != "AB":
Expand All @@ -80,30 +79,23 @@ def get_items(self):
mpids = list(docs.keys())

thermo_docs = self.thermo.query(
{"material_id": {"$in": mpids}},
properties=[
"material_id",
"energy_above_hull",
"formation_energy_per_atom",
],
{"material_id": {"$in": mpids}, "thermo_type": self.thermo_type},
properties=["material_id", "energy_above_hull", "formation_energy_per_atom"],
)
thermo_docs = {d["material_id"]: d for d in thermo_docs}

electronic_structure_docs = self.electronic_structure.query(
{"material_id": {"$in": mpids}},
properties=["material_id", "band_gap", "is_gap_direct"],
{"material_id": {"$in": mpids}}, properties=["material_id", "band_gap", "is_gap_direct"]
)
electronic_structure_docs = {d["material_id"]: d for d in electronic_structure_docs}

provenance_docs = self.provenance.query(
{"material_id": {"$in": mpids}},
properties=["material_id", "theoretical", "database_IDs"],
{"material_id": {"$in": mpids}}, properties=["material_id", "theoretical", "database_IDs"]
)
provenance_docs = {d["material_id"]: d for d in provenance_docs}

oxi_states_docs = self.oxi_states.query(
{"material_id": {"$in": mpids}, "state": "successful"},
properties=["material_id", "structure"],
{"material_id": {"$in": mpids}, "state": "successful"}, properties=["material_id", "structure"]
)
oxi_states_docs = {d["material_id"]: d for d in oxi_states_docs}

Expand Down Expand Up @@ -152,27 +144,15 @@ def process_item(self, item):
# if (item[mpids[0]]["band_gap"] > 0) or (item[mpids[1]]["band_gap"] > 0):
try:
pair = AlloyPair.from_structures(
structures=[
item[mpids[0]]["structure"],
item[mpids[1]]["structure"],
],
structures=[item[mpids[0]]["structure"], item[mpids[1]]["structure"]],
structures_with_oxidation_states=[
item[mpids[0]]["structure_oxi"],
item[mpids[1]]["structure_oxi"],
],
ids=[mpids[0], mpids[1]],
properties=[
item[mpids[0]]["properties"],
item[mpids[1]]["properties"],
],
)
pairs.append(
{
"alloy_pair": pair.as_dict(),
"_search": pair.search_dict(),
"pair_id": pair.pair_id,
}
properties=[item[mpids[0]]["properties"], item[mpids[1]]["properties"]],
)
pairs.append({"alloy_pair": pair.as_dict(), "_search": pair.search_dict(), "pair_id": pair.pair_id})
except InvalidAlloy:
pass
except Exception as exc:
Expand Down Expand Up @@ -203,14 +183,20 @@ def __init__(self, alloy_pairs, materials, snls, alloy_pair_members):
self.snls = snls
self.alloy_pair_members = alloy_pair_members

super().__init__(
sources=[alloy_pairs, materials, snls],
targets=[alloy_pair_members],
)
super().__init__(sources=[alloy_pairs, materials, snls], targets=[alloy_pair_members])

def ensure_indexes(self):

self.alloy_pairs.ensure_index("pair_id")
self.alloy_pairs.ensure_index("_search.id")
self.alloy_pairs.ensure_index("_search.formula")
self.alloy_pairs.ensure_index("_search.member_ids")
self.alloy_pairs.ensure_index("alloy_pair.chemsys")
self.alloy_pairs.ensure_index("alloy_pair.anonymous_formula")

def get_items(self):

all_alloy_chemsys = set(alloy_pairs.distinct("alloy_pair.chemsys"))
all_alloy_chemsys = set(self.alloy_pairs.distinct("alloy_pair.chemsys"))
all_known_chemsys = set(self.materials.distinct("chemsys")) | set(self.snls.distinct("chemsys"))
possible_chemsys = all_known_chemsys.intersection(all_alloy_chemsys)

Expand All @@ -225,14 +211,11 @@ def get_items(self):
pairs = [AlloyPair.from_dict(d["alloy_pair"]) for d in pairs]

mp_docs = self.materials.query(
criteria={"chemsys": chemsys, "deprecated": False},
properties=["structure", "material_id"],
criteria={"chemsys": chemsys, "deprecated": False}, properties=["structure", "material_id"]
)
mp_structures = {d["material_id"]: Structure.from_dict(d["structure"]) for d in mp_docs}

snl_docs = self.snls.query(
{"chemsys": chemsys},
)
snl_docs = self.snls.query({"chemsys": chemsys})
snl_structures = {d["snl_id"]: Structure.from_dict(d) for d in snl_docs}

structures = mp_structures
Expand Down Expand Up @@ -290,9 +273,7 @@ def __init__(self, alloy_pairs, alloy_pair_members, alloy_pairs_merged, alloy_sy
self.alloy_systems = alloy_systems

super().__init__(
sources=[alloy_pairs, alloy_pair_members],
targets=[alloy_pairs_merged, alloy_systems],
chunk_size=8,
sources=[alloy_pairs, alloy_pair_members], targets=[alloy_pairs_merged, alloy_systems], chunk_size=8
)

def get_items(self):
Expand Down

0 comments on commit 0319ec9

Please sign in to comment.