From 047d832808a715cac9a71be58d7594dcaf02a86f Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Wed, 20 Apr 2022 14:58:58 -0700 Subject: [PATCH] Band structure, DOS, and charge density changes (#577) * Enable receiving and handling of byte data * Pull chgcar data from signed url endpoint * Update prefixed query params * Sort fields added to dos rester * Enable receiving and handling of byte data * Pull chgcar data from signed url endpoint * Linting * Fallback to s3 url * Raise MPRester error for charge densities * mypy fix * Bump emmet and maggma in reqs --- requirements.txt | 4 +- setup.py | 4 +- src/mp_api/client.py | 191 ++++++++-------------- src/mp_api/core/client.py | 18 +- src/mp_api/routes/electronic_structure.py | 109 ++++-------- 5 files changed, 118 insertions(+), 208 deletions(-) diff --git a/requirements.txt b/requirements.txt index 65cf35c4..1fe3422b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,9 @@ pydantic>=1.8.2 pymatgen>=2022.3.7 typing-extensions==4.1.1 -maggma==0.44.1 +maggma==0.46.0 requests==2.27.1 monty==2022.3.12 -emmet-core==0.21.20 +emmet-core==0.24.1 mpcontribs-client>=4.2.9 diff --git a/setup.py b/setup.py index 63d7fd01..bacf206d 100644 --- a/setup.py +++ b/setup.py @@ -28,8 +28,8 @@ "typing-extensions>=3.7.4.1", "requests>=2.23.0", "monty>=2021.3.12", - "emmet-core>=0.21.19", - "maggma>=0.39.1", + "emmet-core>=0.24.1", + "maggma>=0.46.0", "mpcontribs-client", ], classifiers=[ diff --git a/src/mp_api/client.py b/src/mp_api/client.py index 1aca8e92..3ae98527 100644 --- a/src/mp_api/client.py +++ b/src/mp_api/client.py @@ -1,30 +1,34 @@ +import base64 import itertools import warnings +import msgpack +import zlib from functools import lru_cache from os import environ -from requests import get -from typing import List, Optional, Tuple, Union, Dict -from typing_extensions import Literal +from typing import Dict, List, Optional, Tuple, Union +from emmet.core.charge_density import ChgcarDataDoc +from emmet.core.electronic_structure import BSPathType from emmet.core.mpid import MPID from emmet.core.settings import EmmetSettings +from emmet.core.summary import HasProps from emmet.core.symmetry import CrystalSystem from emmet.core.vasp.calc_types import CalcType -from emmet.core.summary import HasProps +from monty.serialization import MontyDecoder from mpcontribs.client import Client from pymatgen.analysis.magnetism import Ordering from pymatgen.analysis.phase_diagram import PhaseDiagram from pymatgen.analysis.pourbaix_diagram import IonEntry from pymatgen.core import Element, Structure from pymatgen.core.ion import Ion -from pymatgen.io.vasp import Chgcar from pymatgen.entries.computed_entries import ComputedEntry +from pymatgen.io.vasp import Chgcar from pymatgen.symmetry.analyzer import SpacegroupAnalyzer +from requests import get +from typing_extensions import Literal from mp_api.core.client import BaseRester, MPRestError from mp_api.routes import * -from emmet.core.charge_density import ChgcarDataDoc -from emmet.core.electronic_structure import BSPathType _DEPRECATION_WARNING = ( "MPRester is being modernized. Please use the new method suggested and " @@ -119,15 +123,15 @@ def __init__( """ if api_key and len(api_key) == 16: - raise ValueError("Please use a new API key from https://next-gen.materialsproject.org/api " - "Keys for the new API are 32 characters, whereas keys for the legacy " - "API are 16 characters.") + raise ValueError( + "Please use a new API key from https://next-gen.materialsproject.org/api " + "Keys for the new API are 32 characters, whereas keys for the legacy " + "API are 16 characters." + ) self.api_key = api_key self.endpoint = endpoint - self.session = BaseRester._create_session( - api_key=api_key, include_user_agent=include_user_agent - ) + self.session = BaseRester._create_session(api_key=api_key, include_user_agent=include_user_agent) try: self.contribs = Client(api_key) @@ -183,13 +187,9 @@ def get_task_ids_associated_with_material_id( :param calc_types: if specified, will restrict to certain task types, e.g. [CalcType.GGA_STATIC] :return: """ - tasks = self.materials.get_data_by_id( - material_id, fields=["calc_types"] - ).calc_types + tasks = self.materials.get_data_by_id(material_id, fields=["calc_types"]).calc_types if calc_types: - return [ - task for task, calc_type in tasks.items() if calc_type in calc_types - ] + return [task for task, calc_type in tasks.items() if calc_type in calc_types] else: return list(tasks.keys()) @@ -211,19 +211,14 @@ def get_structure_by_material_id( Structure object or list of Structure objects. """ - structure_data = self.materials.get_structure_by_material_id( - material_id=material_id, final=final - ) + structure_data = self.materials.get_structure_by_material_id(material_id=material_id, final=final) if conventional_unit_cell and structure_data: if final: - structure_data = SpacegroupAnalyzer( - structure_data - ).get_conventional_standard_structure() + structure_data = SpacegroupAnalyzer(structure_data).get_conventional_standard_structure() else: structure_data = [ - SpacegroupAnalyzer(structure).get_conventional_standard_structure() - for structure in structure_data + SpacegroupAnalyzer(structure).get_conventional_standard_structure() for structure in structure_data ] return structure_data @@ -262,9 +257,7 @@ def get_materials_id_from_task_id(self, task_id: str) -> Union[str, None]: if len(docs) == 1: # pragma: no cover return str(docs[0].material_id) # type: ignore elif len(docs) > 1: # pragma: no cover - raise ValueError( - f"Multiple documents return for {task_id}, this should not happen, please report it!" - ) + raise ValueError(f"Multiple documents return for {task_id}, this should not happen, please report it!") else: # pragma: no cover warnings.warn( f"No material found containing task {task_id}. Please report it if you suspect a task has gone missing." @@ -295,9 +288,7 @@ def get_materials_ids(self, chemsys_formula: Union[str, List[str]],) -> List[MPI List of all materials ids ([MPID]) """ - if isinstance(chemsys_formula, list) or ( - isinstance(chemsys_formula, str) and "-" in chemsys_formula - ): + if isinstance(chemsys_formula, list) or (isinstance(chemsys_formula, str) and "-" in chemsys_formula): input_params = {"chemsys": chemsys_formula} else: input_params = {"formula": chemsys_formula} @@ -309,9 +300,7 @@ def get_materials_ids(self, chemsys_formula: Union[str, List[str]],) -> List[MPI ) ) - def get_structures( - self, chemsys_formula: Union[str, List[str]], final=True - ) -> List[Structure]: + def get_structures(self, chemsys_formula: Union[str, List[str]], final=True) -> List[Structure]: """ Get a list of Structures corresponding to a chemical system or formula. @@ -325,9 +314,7 @@ def get_structures( List of Structure objects. ([Structure]) """ - if isinstance(chemsys_formula, list) or ( - isinstance(chemsys_formula, str) and "-" in chemsys_formula - ): + if isinstance(chemsys_formula, list) or (isinstance(chemsys_formula, str) and "-" in chemsys_formula): input_params = {"chemsys": chemsys_formula} else: input_params = {"formula": chemsys_formula} @@ -404,9 +391,7 @@ def get_entries( List of ComputedEntry or ComputedStructureEntry objects. """ - if isinstance(chemsys_formula, list) or ( - isinstance(chemsys_formula, str) and "-" in chemsys_formula - ): + if isinstance(chemsys_formula, list) or (isinstance(chemsys_formula, str) and "-" in chemsys_formula): input_params = {"chemsys": chemsys_formula} else: input_params = {"formula": chemsys_formula} @@ -496,9 +481,7 @@ def get_pourbaix_entries( # build the PhaseDiagram for get_ion_entries ion_ref_comps = [Ion.from_formula(d["data"]["RefSolid"]).composition for d in ion_data] - ion_ref_elts = list( - itertools.chain.from_iterable(i.elements for i in ion_ref_comps) - ) + ion_ref_elts = list(itertools.chain.from_iterable(i.elements for i in ion_ref_comps)) # TODO - would be great if the commented line below would work # However for some reason you cannot process GibbsComputedStructureEntry with # MaterialsProjectAqueousCompatibility @@ -511,15 +494,12 @@ def get_pourbaix_entries( # entries we get from MPRester with warnings.catch_warnings(): warnings.filterwarnings( - "ignore", - message="You did not provide the required O2 and H2O energies.", + "ignore", message="You did not provide the required O2 and H2O energies.", ) compat = MaterialsProjectAqueousCompatibility(solid_compat=solid_compat) # suppress the warning about missing oxidation states with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="Failed to guess oxidation states.*" - ) + warnings.filterwarnings("ignore", message="Failed to guess oxidation states.*") ion_ref_entries = compat.process_entries(ion_ref_entries) # TODO - if the commented line above would work, this conditional block # could be removed @@ -527,9 +507,7 @@ def get_pourbaix_entries( # replace the entries with GibbsComputedStructureEntry from pymatgen.entries.computed_entries import GibbsComputedStructureEntry - ion_ref_entries = GibbsComputedStructureEntry.from_entries( - ion_ref_entries, temp=use_gibbs - ) + ion_ref_entries = GibbsComputedStructureEntry.from_entries(ion_ref_entries, temp=use_gibbs) ion_ref_pd = PhaseDiagram(ion_ref_entries) ion_entries = self.get_ion_entries(ion_ref_pd, ion_ref_data=ion_data) @@ -537,26 +515,15 @@ def get_pourbaix_entries( # Construct the solid pourbaix entries from filtered ion_ref entries ion_ref_comps = [e.composition for e in ion_entries] - ion_ref_elts = list( - itertools.chain.from_iterable(i.elements for i in ion_ref_comps) - ) - extra_elts = ( - set(ion_ref_elts) - - {Element(s) for s in chemsys} - - {Element("H"), Element("O")} - ) + ion_ref_elts = list(itertools.chain.from_iterable(i.elements for i in ion_ref_comps)) + extra_elts = set(ion_ref_elts) - {Element(s) for s in chemsys} - {Element("H"), Element("O")} for entry in ion_ref_entries: entry_elts = set(entry.composition.elements) # Ensure no OH chemsys or extraneous elements from ion references - if not ( - entry_elts <= {Element("H"), Element("O")} - or extra_elts.intersection(entry_elts) - ): + if not (entry_elts <= {Element("H"), Element("O")} or extra_elts.intersection(entry_elts)): # Create new computed entry form_e = ion_ref_pd.get_form_energy(entry) - new_entry = ComputedEntry( - entry.composition, form_e, entry_id=entry.entry_id - ) + new_entry = ComputedEntry(entry.composition, form_e, entry_id=entry.entry_id) pbx_entry = PourbaixEntry(new_entry) pbx_entries.append(pbx_entry) @@ -596,17 +563,13 @@ def get_ion_reference_data(self) -> List[Dict]: ion_data = [ d for d in self.contribs.contributions.get_entries( - project="ion_ref_data", - fields=["identifier", "formula", "data"], - per_page=500, + project="ion_ref_data", fields=["identifier", "formula", "data"], per_page=500, ).result()["data"] ] return ion_data - def get_ion_reference_data_for_chemsys( - self, chemsys: Union[str, List] - ) -> List[Dict]: + def get_ion_reference_data_for_chemsys(self, chemsys: Union[str, List]) -> List[Dict]: """ Download aqueous ion reference data used in the construction of Pourbaix diagrams. @@ -645,9 +608,7 @@ def get_ion_reference_data_for_chemsys( return [d for d in ion_data if d["data"]["MajElements"] in chemsys] - def get_ion_entries( - self, pd: PhaseDiagram, ion_ref_data: List[dict] = None - ) -> List[IonEntry]: + def get_ion_entries(self, pd: PhaseDiagram, ion_ref_data: List[dict] = None) -> List[IonEntry]: """ Retrieve IonEntry objects that can be used in the construction of Pourbaix Diagrams. The energies of the IonEntry are calculaterd from @@ -681,8 +642,7 @@ def get_ion_entries( # raise ValueError if O and H not in chemsys if "O" not in chemsys or "H" not in chemsys: raise ValueError( - "The phase diagram chemical system must contain O and H! Your" - f" diagram chemical system is {chemsys}." + "The phase diagram chemical system must contain O and H! Your" f" diagram chemical system is {chemsys}." ) if not ion_ref_data: @@ -694,11 +654,7 @@ def get_ion_entries( ion_entries = [] for n, i_d in enumerate(ion_data): ion = Ion.from_formula(i_d["formula"]) - refs = [ - e - for e in pd.all_entries - if e.composition.reduced_formula == i_d["data"]["RefSolid"] - ] + refs = [e for e in pd.all_entries if e.composition.reduced_formula == i_d["data"]["RefSolid"]] if not refs: raise ValueError("Reference solid not contained in entry list") stable_ref = sorted(refs, key=lambda x: x.energy_per_atom)[0] @@ -713,9 +669,7 @@ def get_ion_entries( # convert to eV/formula unit ref_solid_energy = i_d["data"]["ΔGᶠRefSolid"]["value"] / 96485 else: - raise ValueError( - f"Ion reference solid energy has incorrect unit {i_d['data']['ΔGᶠRefSolid']['unit']}" - ) + raise ValueError(f"Ion reference solid energy has incorrect unit {i_d['data']['ΔGᶠRefSolid']['unit']}") solid_diff = pd.get_form_energy(stable_ref) - ref_solid_energy * rf elt = i_d["data"]["MajElements"] correction_factor = ion.composition[elt] / stable_ref.composition[elt] @@ -728,9 +682,7 @@ def get_ion_entries( # convert to eV/formula unit ion_free_energy = i_d["data"]["ΔGᶠ"]["value"] / 96485 else: - raise ValueError( - f"Ion free energy has incorrect unit {i_d['data']['ΔGᶠ']['unit']}" - ) + raise ValueError(f"Ion free energy has incorrect unit {i_d['data']['ΔGᶠ']['unit']}") energy = ion_free_energy + solid_diff * correction_factor ion_entries.append(IonEntry(ion, energy)) @@ -746,11 +698,7 @@ def get_entry_by_material_id(self, material_id: str): Returns: List of ComputedEntry or ComputedStructureEntry object. """ - return list( - self.thermo.get_data_by_id( - document_id=material_id, fields=["entries"] - ).entries.values() - ) + return list(self.thermo.get_data_by_id(document_id=material_id, fields=["entries"]).entries.values()) def get_entries_in_chemsys( self, elements: Union[str, List[str]], use_gibbs: Optional[int] = None, @@ -794,10 +742,7 @@ def get_entries_in_chemsys( return entries def get_bandstructure_by_material_id( - self, - material_id: str, - path_type: BSPathType = BSPathType.setyawan_curtarolo, - line_mode=True, + self, material_id: str, path_type: BSPathType = BSPathType.setyawan_curtarolo, line_mode=True, ): """ Get the band structure pymatgen object associated with a Materials Project ID. @@ -824,9 +769,7 @@ def get_dos_by_material_id(self, material_id: str): Returns: dos (CompleteDos): CompleteDos object """ - return self.electronic_structure_dos.get_dos_from_material_id( # type: ignore - material_id=material_id - ) + return self.electronic_structure_dos.get_dos_from_material_id(material_id=material_id) # type: ignore def get_phonon_dos_by_material_id(self, material_id: str): """ @@ -881,9 +824,7 @@ def query( magnetic_ordering: Optional[Ordering] = None, total_magnetization: Optional[Tuple[float, float]] = None, total_magnetization_normalized_vol: Optional[Tuple[float, float]] = None, - total_magnetization_normalized_formula_units: Optional[ - Tuple[float, float] - ] = None, + total_magnetization_normalized_formula_units: Optional[Tuple[float, float]] = None, num_magnetic_sites: Optional[Tuple[int, int]] = None, num_unique_magnetic_sites: Optional[Tuple[int, int]] = None, k_voigt: Optional[Tuple[float, float]] = None, @@ -1078,12 +1019,8 @@ def get_wulff_shape(self, material_id: str): from pymatgen.symmetry.analyzer import SpacegroupAnalyzer structure = self.get_structure_by_material_id(material_id) - surfaces = surfaces = self.surface_properties.get_data_by_id( - material_id - ).surfaces - lattice = ( - SpacegroupAnalyzer(structure).get_conventional_standard_structure().lattice - ) + surfaces = surfaces = self.surface_properties.get_data_by_id(material_id).surfaces + lattice = SpacegroupAnalyzer(structure).get_conventional_standard_structure().lattice miller_energy_map = {} for surf in surfaces: miller = tuple(surf.miller_index) @@ -1093,9 +1030,7 @@ def get_wulff_shape(self, material_id: str): millers, energies = zip(*miller_energy_map.items()) return WulffShape(lattice, millers, energies) - def get_charge_density_from_material_id( - self, material_id: str, inc_task_doc: bool = False - ) -> Optional[Chgcar]: + def get_charge_density_from_material_id(self, material_id: str, inc_task_doc: bool = False) -> Optional[Chgcar]: """ Get charge density data for a given Materials Project ID. @@ -1119,15 +1054,29 @@ def get_charge_density_from_material_id( latest_doc = max(results, key=lambda x: x.last_updated) - chg_doc = self.charge_density.get_data_by_id(latest_doc.fs_id) + url_doc = self.charge_density.get_data_by_id(latest_doc.fs_id) + + if url_doc: + + r = get(url_doc.url, stream=True) + + if r.status_code != 200: + r = get(url_doc.s3_url_prefix + url_doc.fs_id, stream=True) + + if r.status_code != 200: + raise MPRestError(f"Cannot retrieve charge density for {material_id}.") + + packed_bytes = r.raw.data + + packed_bytes = zlib.decompress(packed_bytes) + json_data = msgpack.unpackb(packed_bytes, raw=False) + chgcar = MontyDecoder().process_decoded(json_data["data"]) - if chg_doc: - chgcar = chg_doc.data - task_doc = self.tasks.get_data_by_id(latest_doc.task_id) if inc_task_doc: + task_doc = self.tasks.get_data_by_id(latest_doc.task_id) return chgcar, task_doc + return chgcar + else: - raise MPRestError( - "Charge density task_id found but no charge density fetched." - ) + raise MPRestError(f"No charge density fetched for {material_id}.") diff --git a/src/mp_api/core/client.py b/src/mp_api/core/client.py index a707e826..0516b9ad 100644 --- a/src/mp_api/core/client.py +++ b/src/mp_api/core/client.py @@ -110,9 +110,7 @@ def __init__( self._session = None # type: ignore self.document_model = ( - api_sanitize(self.document_model) # type: ignore - if self.document_model is not None - else None + api_sanitize(self.document_model) if self.document_model is not None else None # type: ignore ) @property @@ -476,11 +474,13 @@ def _submit_requests( num_docs_needed = min((max_pages * chunk_size), total_num_docs) # Setup progress bar + pbar_message = ( # type: ignore + f"Retrieving {self.document_model.__name__} documents" # type: ignore + if self.document_model is not None + else "Retrieving documents" + ) pbar = ( - tqdm( - desc=f"Retrieving {self.document_model.__name__} documents", # type: ignore - total=num_docs_needed, - ) + tqdm(desc=pbar_message, total=num_docs_needed,) if not MAPIClientSettings().MUTE_PROGRESS_BARS else None ) @@ -764,7 +764,9 @@ def get_data_by_id( document_id = new_document_id results = self._query_resource_data( - criteria=criteria, fields=fields, suburl=document_id, # type: ignore + criteria=criteria, + fields=fields, + suburl=document_id, # type: ignore ) if not results: diff --git a/src/mp_api/routes/electronic_structure.py b/src/mp_api/routes/electronic_structure.py index 2851422f..f175f191 100644 --- a/src/mp_api/routes/electronic_structure.py +++ b/src/mp_api/routes/electronic_structure.py @@ -1,11 +1,11 @@ +import base64 +import zlib from collections import defaultdict from typing import List, Optional, Tuple, Union -from emmet.core.electronic_structure import ( - BSPathType, - DOSProjectionType, - ElectronicStructureDoc, -) +import msgpack +from emmet.core.electronic_structure import BSPathType, DOSProjectionType, ElectronicStructureDoc +from monty.serialization import MontyDecoder from mp_api.core.client import BaseRester, MPRestError from pymatgen.analysis.magnetism.analyzer import Ordering from pymatgen.core.periodic_table import Element @@ -79,9 +79,7 @@ def search_electronic_structure_docs( query_params.update({"exclude_elements": ",".join(exclude_elements)}) if band_gap: - query_params.update( - {"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]} - ) + query_params.update({"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]}) if efermi: query_params.update({"efermi_min": efermi[0], "efermi_max": efermi[1]}) @@ -100,18 +98,10 @@ def search_electronic_structure_docs( {"_sort_fields": ",".join([s.strip() for s in sort_fields])} ) - query_params = { - entry: query_params[entry] - for entry in query_params - if query_params[entry] is not None - } + query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None} return super().search( - num_chunks=num_chunks, - chunk_size=chunk_size, - all_fields=all_fields, - fields=fields, - **query_params + num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, fields=fields, **query_params ) @@ -160,9 +150,7 @@ def search_bandstructure_summary( query_params["path_type"] = path_type.value if band_gap: - query_params.update( - {"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]} - ) + query_params.update({"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]}) if efermi: query_params.update({"efermi_min": efermi[0], "efermi_max": efermi[1]}) @@ -181,18 +169,10 @@ def search_bandstructure_summary( {"_sort_fields": ",".join([s.strip() for s in sort_fields])} ) - query_params = { - entry: query_params[entry] - for entry in query_params - if query_params[entry] is not None - } + query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None} return super().search( - num_chunks=num_chunks, - chunk_size=chunk_size, - all_fields=all_fields, - fields=fields, - **query_params + num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, fields=fields, **query_params ) def get_bandstructure_from_task_id(self, task_id: str): @@ -237,55 +217,42 @@ def get_bandstructure_from_material_id( bandstructure (Union[BandStructure, BandStructureSymmLine]): BandStructure or BandStructureSymmLine object """ - es_rester = ElectronicStructureRester( - endpoint=self.base_endpoint, api_key=self.api_key - ) + es_rester = ElectronicStructureRester(endpoint=self.base_endpoint, api_key=self.api_key) if line_mode: - bs_data = es_rester.get_data_by_id( - document_id=material_id, fields=["bandstructure"] - ).bandstructure + bs_data = es_rester.get_data_by_id(document_id=material_id, fields=["bandstructure"]).bandstructure if bs_data is None: - raise MPRestError( - "No {} band structure data found for {}".format( - path_type.value, material_id - ) - ) + raise MPRestError("No {} band structure data found for {}".format(path_type.value, material_id)) else: bs_data = bs_data.dict() if bs_data.get(path_type.value, None): bs_task_id = bs_data[path_type.value]["task_id"] else: - raise MPRestError( - "No {} band structure data found for {}".format( - path_type.value, material_id - ) - ) + raise MPRestError("No {} band structure data found for {}".format(path_type.value, material_id)) else: - bs_data = es_rester.get_data_by_id( - document_id=material_id, fields=["dos"] - ).dos + bs_data = es_rester.get_data_by_id(document_id=material_id, fields=["dos"]).dos if bs_data is None: - raise MPRestError( - "No uniform band structure data found for {}".format(material_id) - ) + raise MPRestError("No uniform band structure data found for {}".format(material_id)) else: bs_data = bs_data.dict() if bs_data.get("total", None): bs_task_id = bs_data["total"]["1"]["task_id"] else: - raise MPRestError( - "No uniform band structure data found for {}".format(material_id) - ) + raise MPRestError("No uniform band structure data found for {}".format(material_id)) bs_obj = self.get_bandstructure_from_task_id(bs_task_id) if bs_obj: - return bs_obj[0]["data"] + b64_bytes = base64.b64decode(bs_obj[0], validate=True) + packed_bytes = zlib.decompress(b64_bytes) + json_data = msgpack.unpackb(packed_bytes, raw=False) + data = MontyDecoder().process_decoded(json_data["data"]) + + return data else: raise MPRestError("No band structure object found.") @@ -344,9 +311,7 @@ def search_dos_summary( query_params["orbital"] = orbital.value if band_gap: - query_params.update( - {"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]} - ) + query_params.update({"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]}) if efermi: query_params.update({"efermi_min": efermi[0], "efermi_max": efermi[1]}) @@ -366,11 +331,7 @@ def search_dos_summary( } return super().search( - num_chunks=num_chunks, - chunk_size=chunk_size, - all_fields=all_fields, - fields=fields, - **query_params + num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, fields=fields, **query_params ) def get_dos_from_task_id(self, task_id: str): @@ -408,23 +369,21 @@ def get_dos_from_material_id(self, material_id: str): dos (CompleteDos): CompleteDos object """ - es_rester = ElectronicStructureRester( - endpoint=self.base_endpoint, api_key=self.api_key - ) + es_rester = ElectronicStructureRester(endpoint=self.base_endpoint, api_key=self.api_key) - dos_data = es_rester.get_data_by_id( - document_id=material_id, fields=["dos"] - ).dict() + dos_data = es_rester.get_data_by_id(document_id=material_id, fields=["dos"]).dict() if dos_data["dos"]: dos_task_id = dos_data["dos"]["total"]["1"]["task_id"] else: - raise MPRestError( - "No density of states data found for {}".format(material_id) - ) + raise MPRestError("No density of states data found for {}".format(material_id)) dos_obj = self.get_dos_from_task_id(dos_task_id) if dos_obj: - return dos_obj[0]["data"] + b64_bytes = base64.b64decode(dos_obj[0], validate=True) + packed_bytes = zlib.decompress(b64_bytes) + json_data = msgpack.unpackb(packed_bytes, raw=False) + data = MontyDecoder().process_decoded(json_data["data"]) + return data else: raise MPRestError("No density of states object found.")