Skip to content

Commit

Permalink
Showing 1 changed file with 142 additions and 101 deletions.
243 changes: 142 additions & 101 deletions mp_api/client/mprester.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
from functools import lru_cache
from os import environ
from typing import Dict, List, Optional, Union
from json import loads

from emmet.core.charge_density import ChgcarDataDoc
from emmet.core.electronic_structure import BSPathType
@@ -125,9 +126,7 @@ def __init__(

self.api_key = api_key
self.endpoint = endpoint
self.session = BaseRester._create_session(
api_key=api_key, include_user_agent=include_user_agent
)
self.session = BaseRester._create_session(api_key=api_key, include_user_agent=include_user_agent)
self.use_document_model = use_document_model
self.monty_decode = monty_decode

@@ -193,13 +192,10 @@ def __getattr__(self, attr):
)
elif attr == "charge_density":
raise MPRestError(
"boto3 not installed. "
"To query charge density data first install with: 'pip install boto3'"
"boto3 not installed. " "To query charge density data first install with: 'pip install boto3'"
)
else:
raise AttributeError(
f"{self.__class__.__name__!r} object has no attribute {attr!r}"
)
raise AttributeError(f"{self.__class__.__name__!r} object has no attribute {attr!r}")

def get_task_ids_associated_with_material_id(
self, material_id: str, calc_types: Optional[List[CalcType]] = None
@@ -210,13 +206,9 @@ def get_task_ids_associated_with_material_id(
:param calc_types: if specified, will restrict to certain task types, e.g. [CalcType.GGA_STATIC]
:return:
"""
tasks = self.materials.get_data_by_id(
material_id, fields=["calc_types"]
).calc_types
tasks = self.materials.get_data_by_id(material_id, fields=["calc_types"]).calc_types
if calc_types:
return [
task for task, calc_type in tasks.items() if calc_type in calc_types
]
return [task for task, calc_type in tasks.items() if calc_type in calc_types]
else:
return list(tasks.keys())

@@ -238,19 +230,14 @@ def get_structure_by_material_id(
Structure object or list of Structure objects.
"""

structure_data = self.materials.get_structure_by_material_id(
material_id=material_id, final=final
)
structure_data = self.materials.get_structure_by_material_id(material_id=material_id, final=final)

if conventional_unit_cell and structure_data:
if final:
structure_data = SpacegroupAnalyzer(
structure_data
).get_conventional_standard_structure()
structure_data = SpacegroupAnalyzer(structure_data).get_conventional_standard_structure()
else:
structure_data = [
SpacegroupAnalyzer(structure).get_conventional_standard_structure()
for structure in structure_data
SpacegroupAnalyzer(structure).get_conventional_standard_structure() for structure in structure_data
]

return structure_data
@@ -289,9 +276,7 @@ def get_material_id_from_task_id(self, task_id: str) -> Union[str, None]:
if len(docs) == 1: # pragma: no cover
return str(docs[0].material_id) # type: ignore
elif len(docs) > 1: # pragma: no cover
raise ValueError(
f"Multiple documents return for {task_id}, this should not happen, please report it!"
)
raise ValueError(f"Multiple documents return for {task_id}, this should not happen, please report it!")
else: # pragma: no cover
warnings.warn(
f"No material found containing task {task_id}. Please report it if you suspect a task has gone missing."
@@ -330,7 +315,10 @@ def get_materials_id_references(self, material_id: str) -> List[str]:
)
return self.get_material_id_references(material_id)

def get_material_ids(self, chemsys_formula: Union[str, List[str]],) -> List[MPID]:
def get_material_ids(
self,
chemsys_formula: Union[str, List[str]],
) -> List[MPID]:
"""
Get all materials ids for a formula or chemsys.
@@ -342,21 +330,24 @@ def get_material_ids(self, chemsys_formula: Union[str, List[str]],) -> List[MPID
List of all materials ids ([MPID])
"""

if isinstance(chemsys_formula, list) or (
isinstance(chemsys_formula, str) and "-" in chemsys_formula
):
if isinstance(chemsys_formula, list) or (isinstance(chemsys_formula, str) and "-" in chemsys_formula):
input_params = {"chemsys": chemsys_formula}
else:
input_params = {"formula": chemsys_formula}

return sorted(
doc.material_id
for doc in self.materials.search(
**input_params, all_fields=False, fields=["material_id"], # type: ignore
**input_params, # type: ignore
all_fields=False,
fields=["material_id"],
)
)

def get_materials_ids(self, chemsys_formula: Union[str, List[str]],) -> List[MPID]:
def get_materials_ids(
self,
chemsys_formula: Union[str, List[str]],
) -> List[MPID]:
"""
This method is deprecated, please use get_material_ids.
"""
@@ -366,9 +357,7 @@ def get_materials_ids(self, chemsys_formula: Union[str, List[str]],) -> List[MPI
)
return self.get_material_ids(chemsys_formula)

def get_structures(
self, chemsys_formula: Union[str, List[str]], final=True
) -> List[Structure]:
def get_structures(self, chemsys_formula: Union[str, List[str]], final=True) -> List[Structure]:
"""
Get a list of Structures corresponding to a chemical system or formula.
@@ -382,9 +371,7 @@ def get_structures(
List of Structure objects. ([Structure])
"""

if isinstance(chemsys_formula, list) or (
isinstance(chemsys_formula, str) and "-" in chemsys_formula
):
if isinstance(chemsys_formula, list) or (isinstance(chemsys_formula, str) and "-" in chemsys_formula):
input_params = {"chemsys": chemsys_formula}
else:
input_params = {"formula": chemsys_formula}
@@ -393,14 +380,18 @@ def get_structures(
return [
doc.structure
for doc in self.materials.search(
**input_params, all_fields=False, fields=["structure"], # type: ignore
**input_params, # type: ignore
all_fields=False,
fields=["structure"],
)
]
else:
structures = []

for doc in self.materials.search(
**input_params, all_fields=False, fields=["initial_structures"], # type: ignore
**input_params, # type: ignore
all_fields=False,
fields=["initial_structures"],
):
structures.extend(doc.initial_structures)

@@ -525,7 +516,9 @@ def get_entries(
)
else:
docs = self.thermo.search(
**input_params, all_fields=False, fields=fields, # type: ignore
**input_params,
all_fields=False,
fields=fields, # type: ignore
)

for doc in docs:
@@ -538,8 +531,9 @@ def get_entries(

if property_data:
for property in property_data:
entry_dict["data"][property] = doc.dict()[property] if self.use_document_model else doc[
property]
entry_dict["data"][property] = (
doc.dict()[property] if self.use_document_model else doc[property]
)

if conventional_unit_cell:

@@ -626,12 +620,8 @@ def get_pourbaix_entries(
ion_data = self.get_ion_reference_data_for_chemsys(chemsys)

# build the PhaseDiagram for get_ion_entries
ion_ref_comps = [
Ion.from_formula(d["data"]["RefSolid"]).composition for d in ion_data
]
ion_ref_elts = set(
itertools.chain.from_iterable(i.elements for i in ion_ref_comps)
)
ion_ref_comps = [Ion.from_formula(d["data"]["RefSolid"]).composition for d in ion_data]
ion_ref_elts = set(itertools.chain.from_iterable(i.elements for i in ion_ref_comps))
# TODO - would be great if the commented line below would work
# However for some reason you cannot process GibbsComputedStructureEntry with
# MaterialsProjectAqueousCompatibility
@@ -650,42 +640,29 @@ def get_pourbaix_entries(
compat = MaterialsProjectAqueousCompatibility(solid_compat=solid_compat)
# suppress the warning about missing oxidation states
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="Failed to guess oxidation states.*"
)
warnings.filterwarnings("ignore", message="Failed to guess oxidation states.*")
ion_ref_entries = compat.process_entries(ion_ref_entries)
# TODO - if the commented line above would work, this conditional block
# could be removed
if use_gibbs:
# replace the entries with GibbsComputedStructureEntry
from pymatgen.entries.computed_entries import GibbsComputedStructureEntry

ion_ref_entries = GibbsComputedStructureEntry.from_entries(
ion_ref_entries, temp=use_gibbs
)
ion_ref_entries = GibbsComputedStructureEntry.from_entries(ion_ref_entries, temp=use_gibbs)
ion_ref_pd = PhaseDiagram(ion_ref_entries)

ion_entries = self.get_ion_entries(ion_ref_pd, ion_ref_data=ion_data)
pbx_entries = [PourbaixEntry(e, f"ion-{n}") for n, e in enumerate(ion_entries)]

# Construct the solid pourbaix entries from filtered ion_ref entries
extra_elts = (
set(ion_ref_elts)
- {Element(s) for s in chemsys}
- {Element("H"), Element("O")}
)
extra_elts = set(ion_ref_elts) - {Element(s) for s in chemsys} - {Element("H"), Element("O")}
for entry in ion_ref_entries:
entry_elts = set(entry.composition.elements)
# Ensure no OH chemsys or extraneous elements from ion references
if not (
entry_elts <= {Element("H"), Element("O")}
or extra_elts.intersection(entry_elts)
):
if not (entry_elts <= {Element("H"), Element("O")} or extra_elts.intersection(entry_elts)):
# Create new computed entry
form_e = ion_ref_pd.get_form_energy(entry)
new_entry = ComputedEntry(
entry.composition, form_e, entry_id=entry.entry_id
)
new_entry = ComputedEntry(entry.composition, form_e, entry_id=entry.entry_id)
pbx_entry = PourbaixEntry(new_entry)
pbx_entries.append(pbx_entry)

@@ -733,9 +710,7 @@ def get_ion_reference_data(self) -> List[Dict]:

return ion_data

def get_ion_reference_data_for_chemsys(
self, chemsys: Union[str, List]
) -> List[Dict]:
def get_ion_reference_data_for_chemsys(self, chemsys: Union[str, List]) -> List[Dict]:
"""
Download aqueous ion reference data used in the construction of Pourbaix diagrams.
@@ -774,9 +749,7 @@ def get_ion_reference_data_for_chemsys(

return [d for d in ion_data if d["data"]["MajElements"] in chemsys]

def get_ion_entries(
self, pd: PhaseDiagram, ion_ref_data: List[dict] = None
) -> List[IonEntry]:
def get_ion_entries(self, pd: PhaseDiagram, ion_ref_data: List[dict] = None) -> List[IonEntry]:
"""
Retrieve IonEntry objects that can be used in the construction of
Pourbaix Diagrams. The energies of the IonEntry are calculaterd from
@@ -810,8 +783,7 @@ def get_ion_entries(
# raise ValueError if O and H not in chemsys
if "O" not in chemsys or "H" not in chemsys:
raise ValueError(
"The phase diagram chemical system must contain O and H! Your"
f" diagram chemical system is {chemsys}."
"The phase diagram chemical system must contain O and H! Your" f" diagram chemical system is {chemsys}."
)

if not ion_ref_data:
@@ -823,11 +795,7 @@ def get_ion_entries(
ion_entries = []
for n, i_d in enumerate(ion_data):
ion = Ion.from_formula(i_d["formula"])
refs = [
e
for e in pd.all_entries
if e.composition.reduced_formula == i_d["data"]["RefSolid"]
]
refs = [e for e in pd.all_entries if e.composition.reduced_formula == i_d["data"]["RefSolid"]]
if not refs:
raise ValueError("Reference solid not contained in entry list")
stable_ref = sorted(refs, key=lambda x: x.energy_per_atom)[0]
@@ -842,9 +810,7 @@ def get_ion_entries(
# convert to eV/formula unit
ref_solid_energy = i_d["data"]["ΔGᶠRefSolid"]["value"] / 96485
else:
raise ValueError(
f"Ion reference solid energy has incorrect unit {i_d['data']['ΔGᶠRefSolid']['unit']}"
)
raise ValueError(f"Ion reference solid energy has incorrect unit {i_d['data']['ΔGᶠRefSolid']['unit']}")
solid_diff = pd.get_form_energy(stable_ref) - ref_solid_energy * rf
elt = i_d["data"]["MajElements"]
correction_factor = ion.composition[elt] / stable_ref.composition[elt]
@@ -857,9 +823,7 @@ def get_ion_entries(
# convert to eV/formula unit
ion_free_energy = i_d["data"]["ΔGᶠ"]["value"] / 96485
else:
raise ValueError(
f"Ion free energy has incorrect unit {i_d['data']['ΔGᶠ']['unit']}"
)
raise ValueError(f"Ion free energy has incorrect unit {i_d['data']['ΔGᶠ']['unit']}")
energy = ion_free_energy + solid_diff * correction_factor
ion_entries.append(IonEntry(ion, energy))

@@ -1080,12 +1044,8 @@ def get_wulff_shape(self, material_id: str):
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

structure = self.get_structure_by_material_id(material_id)
surfaces = surfaces = self.surface_properties.get_data_by_id(
material_id
).surfaces
lattice = (
SpacegroupAnalyzer(structure).get_conventional_standard_structure().lattice
)
surfaces = surfaces = self.surface_properties.get_data_by_id(material_id).surfaces
lattice = SpacegroupAnalyzer(structure).get_conventional_standard_structure().lattice
miller_energy_map = {}
for surf in surfaces:
miller = tuple(surf.miller_index)
@@ -1095,9 +1055,7 @@ def get_wulff_shape(self, material_id: str):
millers, energies = zip(*miller_energy_map.items())
return WulffShape(lattice, millers, energies)

def get_charge_density_from_material_id(
self, material_id: str, inc_task_doc: bool = False
) -> Optional[Chgcar]:
def get_charge_density_from_material_id(self, material_id: str, inc_task_doc: bool = False) -> Optional[Chgcar]:
"""
Get charge density data for a given Materials Project ID.
@@ -1110,10 +1068,7 @@ def get_charge_density_from_material_id(
"""

if not hasattr(self, "charge_density"):
raise MPRestError(
"boto3 not installed. "
"To query charge density data install the boto3 package."
)
raise MPRestError("boto3 not installed. " "To query charge density data install the boto3 package.")

# TODO: really we want a recommended task_id for charge densities here
# this could potentially introduce an ambiguity
@@ -1138,11 +1093,97 @@ def get_charge_density_from_material_id(

return chgcar

def get_download_info(self, material_ids, calc_types=None, file_patterns=None):
"""
Get a list of URLs to retrieve raw VASP output files from the NoMaD repository
Args:
material_ids (list): list of material identifiers (mp-id's)
task_types (list): list of task types to include in download (see CalcType Enum class)
file_patterns (list): list of wildcard file names to include for each task
Returns:
a tuple of 1) a dictionary mapping material_ids to task_ids and
calc_types, and 2) a list of URLs to download zip archives from
NoMaD repository. Each zip archive will contain a manifest.json with
metadata info, e.g. the task/external_ids that belong to a directory
"""
# task_id's correspond to NoMaD external_id's
calc_types = [t.value for t in calc_types if isinstance(t, CalcType)] if calc_types else []

meta = {}
for doc in self.materials.search(
task_ids=material_ids, fields=["calc_types", "deprecated_tasks", "material_id"]
):

for task_id, calc_type in doc.calc_types.items():
if calc_types and calc_type not in calc_types:
continue
mp_id = doc.material_id
if meta.get(mp_id) is None:
meta[mp_id] = [{"task_id": task_id, "calc_type": calc_type}]
else:
meta[mp_id].append({"task_id": task_id, "calc_type": calc_type})
if not meta:
raise ValueError(f"No tasks found for material id {material_ids}.")

# return a list of URLs for NoMaD Downloads containing the list of files
# for every external_id in `task_ids`
# For reference, please visit https://nomad-lab.eu/prod/rae/api/

# check if these task ids exist on NOMAD
prefix = "https://nomad-lab.eu/prod/rae/api/repo/?"
if file_patterns is not None:
for file_pattern in file_patterns:
prefix += f"file_pattern={file_pattern}&"
prefix += "external_id="

task_ids = [t["task_id"] for tl in meta.values() for t in tl]
nomad_exist_task_ids = self._check_get_download_info_url_by_task_id(prefix=prefix, task_ids=task_ids)
if len(nomad_exist_task_ids) != len(task_ids):
self._print_help_message(nomad_exist_task_ids, task_ids, file_patterns, calc_types)

# generate download links for those that exist
prefix = "https://nomad-lab.eu/prod/rae/api/raw/query?"
if file_patterns is not None:
for file_pattern in file_patterns:
prefix += f"file_pattern={file_pattern}&"
prefix += "external_id="

urls = [prefix + tids for tids in nomad_exist_task_ids]
return meta, urls

def _check_get_download_info_url_by_task_id(self, prefix, task_ids) -> List[str]:
nomad_exist_task_ids: List[str] = []
prefix = prefix.replace("/raw/query", "/repo/")
for task_id in task_ids:
url = prefix + task_id
if self._check_nomad_exist(url):
nomad_exist_task_ids.append(task_id)
return nomad_exist_task_ids

@staticmethod
def _check_nomad_exist(url) -> bool:
response = get(url=url)
if response.status_code != 200:
return False
content = loads(response.text)
if content["pagination"]["total"] == 0:
return False
return True

@staticmethod
def _print_help_message(nomad_exist_task_ids, task_ids, file_patterns, calc_types):
non_exist_ids = set(task_ids) - set(nomad_exist_task_ids)
warnings.warn(
f"For file patterns [{file_patterns}] and calc_types [{calc_types}], \n"
f"the following ids are not found on NOMAD [{list(non_exist_ids)}]. \n"
f"If you need to upload them, please contact Patrick Huck at phuck@lbl.gov"
)

def query(*args, **kwargs):
"""
The MPRester().query method has been replaced with the MPRester().summary.search method.
Note this method also no longer supports direct MongoDB-type queries. For more information,
please see the new documentation.
The MPRester().query method has been replaced with the MPRester().summary.search method.
Note this method also no longer supports direct MongoDB-type queries. For more information,
please see the new documentation.
"""
raise NotImplementedError(
"""

0 comments on commit 17dd2cc

Please sign in to comment.