From c83f738e82418aa28fbe6d737b09d4affa1ca2e5 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 26 Apr 2022 13:52:24 -0700 Subject: [PATCH 1/7] Add client side checks on material and task IDs --- src/mp_api/client.py | 29 ++++++-- src/mp_api/core/client.py | 100 ++++++++++------------------ src/mp_api/core/utils.py | 23 +++++++ src/mp_api/routes/grain_boundary.py | 3 +- src/mp_api/routes/materials.py | 3 +- src/mp_api/routes/summary.py | 23 ++----- src/mp_api/routes/thermo.py | 3 +- src/mp_api/routes/xas.py | 3 +- 8 files changed, 95 insertions(+), 92 deletions(-) create mode 100644 src/mp_api/core/utils.py diff --git a/src/mp_api/client.py b/src/mp_api/client.py index 797c4079..1598ac18 100644 --- a/src/mp_api/client.py +++ b/src/mp_api/client.py @@ -285,7 +285,10 @@ def get_materials_id_references(self, material_id: str) -> List[str]: """ return self.provenance.get_data_by_id(material_id).references - def get_materials_ids(self, chemsys_formula: Union[str, List[str]],) -> List[MPID]: + def get_materials_ids( + self, + chemsys_formula: Union[str, List[str]], + ) -> List[MPID]: """ Get all materials ids for a formula or chemsys. @@ -307,7 +310,9 @@ def get_materials_ids(self, chemsys_formula: Union[str, List[str]],) -> List[MPI return sorted( doc.material_id for doc in self.materials.search_material_docs( - **input_params, all_fields=False, fields=["material_id"], # type: ignore + **input_params, + all_fields=False, + fields=["material_id"], # type: ignore ) ) @@ -338,14 +343,18 @@ def get_structures( return [ doc.structure for doc in self.materials.search_material_docs( - **input_params, all_fields=False, fields=["structure"], # type: ignore + **input_params, + all_fields=False, + fields=["structure"], # type: ignore ) ] else: structures = [] for doc in self.materials.search_material_docs( - **input_params, all_fields=False, fields=["initial_structures"], # type: ignore + **input_params, + all_fields=False, + fields=["initial_structures"], # type: ignore ): structures.extend(doc.initial_structures) @@ -390,7 +399,9 @@ def find_structure( ) def get_entries( - self, chemsys_formula: Union[str, List[str]], sort_by_e_above_hull=False, + self, + chemsys_formula: Union[str, List[str]], + sort_by_e_above_hull=False, ): """ Get a list of ComputedEntries or ComputedStructureEntries corresponding @@ -429,7 +440,9 @@ def get_entries( else: for doc in self.thermo.search_thermo_docs( - **input_params, all_fields=False, fields=["entries"], # type: ignore + **input_params, + all_fields=False, + fields=["entries"], # type: ignore ): entries.extend(list(doc.entries.values())) @@ -757,7 +770,9 @@ def get_entry_by_material_id(self, material_id: str): ) def get_entries_in_chemsys( - self, elements: Union[str, List[str]], use_gibbs: Optional[int] = None, + self, + elements: Union[str, List[str]], + use_gibbs: Optional[int] = None, ): """ Helper method to get a list of ComputedEntries in a chemical system. diff --git a/src/mp_api/core/client.py b/src/mp_api/core/client.py index 0516b9ad..ab5c1542 100644 --- a/src/mp_api/core/client.py +++ b/src/mp_api/core/client.py @@ -5,26 +5,27 @@ Materials Project data. """ -from hashlib import new import itertools import json import platform import sys import warnings from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait +from copy import copy +from hashlib import new from json import JSONDecodeError +from math import ceil from os import environ -from typing import Dict, Generic, List, Optional, TypeVar, Union, Tuple +from typing import Dict, Generic, List, Optional, Tuple, TypeVar, Union from urllib.parse import urljoin -from copy import copy -from math import ceil -from matplotlib import use import requests from emmet.core.utils import jsanitize from maggma.api.utils import api_sanitize +from matplotlib import use from monty.json import MontyDecoder from mp_api.core.settings import MAPIClientSettings +from mp_api.core.utils import validate_ids from pydantic import BaseModel from requests.adapters import HTTPAdapter from requests.exceptions import RequestException @@ -130,9 +131,7 @@ def _create_session(api_key, include_user_agent): sys.version_info.major, sys.version_info.minor, sys.version_info.micro ) platform_info = "{}/{}".format(platform.system(), platform.release()) - session.headers["user-agent"] = "{} ({} {})".format( - pymatgen_info, python_info, platform_info - ) + session.headers["user-agent"] = "{} ({} {})".format(pymatgen_info, python_info, platform_info) max_retry_num = MAPIClientSettings().MAX_RETRIES retry = Retry( @@ -221,10 +220,7 @@ def _post_resource( message = data else: try: - message = ", ".join( - "{} - {}".format(entry["loc"][1], entry["msg"]) - for entry in data - ) + message = ", ".join("{} - {}".format(entry["loc"][1], entry["msg"]) for entry in data) except (KeyError, IndexError): message = str(data) @@ -338,9 +334,7 @@ def _submit_requests( # criteria dicts. if parallel_param is not None: param_length = len(criteria[parallel_param].split(",")) - slice_size = ( - int(param_length / MAPIClientSettings().NUM_PARALLEL_REQUESTS) or 1 - ) + slice_size = int(param_length / MAPIClientSettings().NUM_PARALLEL_REQUESTS) or 1 new_param_values = [ entry @@ -365,11 +359,7 @@ def _submit_requests( # Split list and generate multiple criteria new_criteria = [ { - **{ - key: criteria[key] - for key in criteria - if key not in [parallel_param, "_limit"] - }, + **{key: criteria[key] for key in criteria if key not in [parallel_param, "_limit"]}, parallel_param: ",".join(list_chunk), "_limit": new_limits[list_num], } @@ -392,13 +382,9 @@ def _submit_requests( subtotals = [] remaining_docs_avail = {} - initial_params_list = [ - {"url": url, "verify": True, "params": copy(crit)} for crit in new_criteria - ] + initial_params_list = [{"url": url, "verify": True, "params": copy(crit)} for crit in new_criteria] - initial_data_tuples = self._multi_thread( - use_document_model, initial_params_list - ) + initial_data_tuples = self._multi_thread(use_document_model, initial_params_list) for data, subtotal, crit_ind in initial_data_tuples: @@ -411,9 +397,7 @@ def _submit_requests( # Rebalance if some parallel queries produced too few results if len(remaining_docs_avail) > 1 and len(total_data["data"]) < chunk_size: - remaining_docs_avail = dict( - sorted(remaining_docs_avail.items(), key=lambda item: item[1]) - ) + remaining_docs_avail = dict(sorted(remaining_docs_avail.items(), key=lambda item: item[1])) # Redistribute missing docs from initial chunk among queries # which have head room with respect to remaining document number. @@ -440,9 +424,7 @@ def _submit_requests( new_limits[crit_ind] += fill_docs fill_docs = 0 - rebalance_params.append( - {"url": url, "verify": True, "params": copy(crit)} - ) + rebalance_params.append({"url": url, "verify": True, "params": copy(crit)}) new_criteria[crit_ind]["_skip"] += crit["_limit"] new_criteria[crit_ind]["_limit"] = chunk_size @@ -450,9 +432,7 @@ def _submit_requests( # Obtain missing initial data after rebalancing if len(rebalance_params) > 0: - rebalance_data_tuples = self._multi_thread( - use_document_model, rebalance_params - ) + rebalance_data_tuples = self._multi_thread(use_document_model, rebalance_params) for data, _, _ in rebalance_data_tuples: total_data["data"].extend(data["data"]) @@ -466,9 +446,7 @@ def _submit_requests( total_data["meta"] = last_data_entry["meta"] # Get max number of reponse pages - max_pages = ( - num_chunks if num_chunks is not None else ceil(total_num_docs / chunk_size) - ) + max_pages = num_chunks if num_chunks is not None else ceil(total_num_docs / chunk_size) # Get total number of docs needed num_docs_needed = min((max_pages * chunk_size), total_num_docs) @@ -480,7 +458,10 @@ def _submit_requests( else "Retrieving documents" ) pbar = ( - tqdm(desc=pbar_message, total=num_docs_needed,) + tqdm( + desc=pbar_message, + total=num_docs_needed, + ) if not MAPIClientSettings().MUTE_PROGRESS_BARS else None ) @@ -579,21 +560,15 @@ def _multi_thread( return_data = [] - params_gen = iter( - params_list - ) # Iter necessary for islice to keep track of what has been accessed + params_gen = iter(params_list) # Iter necessary for islice to keep track of what has been accessed params_ind = 0 - with ThreadPoolExecutor( - max_workers=MAPIClientSettings().NUM_PARALLEL_REQUESTS - ) as executor: + with ThreadPoolExecutor(max_workers=MAPIClientSettings().NUM_PARALLEL_REQUESTS) as executor: # Get list of initial futures defined by max number of parallel requests futures = set({}) - for params in itertools.islice( - params_gen, MAPIClientSettings().NUM_PARALLEL_REQUESTS - ): + for params in itertools.islice(params_gen, MAPIClientSettings().NUM_PARALLEL_REQUESTS): future = executor.submit( self._submit_request_and_process, @@ -670,10 +645,7 @@ def _submit_request_and_process( message = data else: try: - message = ", ".join( - "{} - {}".format(entry["loc"][1], entry["msg"]) - for entry in data - ) + message = ", ".join("{} - {}".format(entry["loc"][1], entry["msg"]) for entry in data) except (KeyError, IndexError): message = str(data) @@ -713,7 +685,9 @@ def _query_resource_data( ).get("data") def get_data_by_id( - self, document_id: str, fields: Optional[List[str]] = None, + self, + document_id: str, + fields: Optional[List[str]] = None, ) -> T: """ Query the endpoint for a single document. @@ -727,10 +701,10 @@ def get_data_by_id( """ if document_id is None: - raise ValueError( - "Please supply a specific id. You can use the query method to find " - "ids of interest." - ) + raise ValueError("Please supply a specific ID. You can use the query method to find " "ids of interest.") + + if self.primary_key in ["material_id", "task_id"]: + validate_ids([document_id]) if fields is None: criteria = {"_all_fields": True, "_limit": 1} # type: dict @@ -744,7 +718,9 @@ def get_data_by_id( try: results = self._query_resource_data( - criteria=criteria, fields=fields, suburl=document_id, # type: ignore + criteria=criteria, + fields=fields, + suburl=document_id, # type: ignore ) except MPRestError: @@ -772,9 +748,7 @@ def get_data_by_id( if not results: raise MPRestError(f"No result for record {document_id}.") elif len(results) > 1: # pragma: no cover - raise ValueError( - f"Multiple records for {document_id}, this shouldn't happen. Please report as a bug." - ) + raise ValueError(f"Multiple records for {document_id}, this shouldn't happen. Please report as a bug.") else: return results[0] @@ -881,9 +855,7 @@ def count(self, criteria: Optional[Dict] = None) -> Union[int, str]: False, False, ) # do not waste cycles decoding - results = self._query_resource( - criteria=criteria, num_chunks=1, chunk_size=1 - ) + results = self._query_resource(criteria=criteria, num_chunks=1, chunk_size=1) self.monty_decode, self.use_document_model = user_preferences return results["meta"]["total_doc"] except Exception: # pragma: no cover diff --git a/src/mp_api/core/utils.py b/src/mp_api/core/utils.py new file mode 100644 index 00000000..3fb37e63 --- /dev/null +++ b/src/mp_api/core/utils.py @@ -0,0 +1,23 @@ +import re +from typing import List + + +def validate_ids(id_list: List[str]): + """Function to validate material and task IDs + + Args: + id_list (List[str]): List of material or task IDs. + + Raises: + ValueError: If at least one ID is not formatted correctly. + + Returns: + id_list: Returns original ID list if everything is formatted correctly. + """ + pattern = "(mp|mvc)-.*" + + for entry in id_list: + if re.match(pattern, entry) is None: + raise ValueError(f"{entry} is not formatted correctly!") + + return id_list diff --git a/src/mp_api/routes/grain_boundary.py b/src/mp_api/routes/grain_boundary.py index 319cb312..e95dba25 100644 --- a/src/mp_api/routes/grain_boundary.py +++ b/src/mp_api/routes/grain_boundary.py @@ -2,6 +2,7 @@ from collections import defaultdict from mp_api.core.client import BaseRester +from mp_api.core.utils import validate_ids from emmet.core.grain_boundary import GBTypeEnum, GrainBoundaryDoc @@ -58,7 +59,7 @@ def search_grain_boundary_docs( query_params = defaultdict(dict) # type: dict if material_ids: - query_params.update({"task_ids": ",".join(material_ids)}) + query_params.update({"task_ids": ",".join(validate_ids(material_ids))}) if gb_plane: query_params.update({"gb_plane": ",".join([str(n) for n in gb_plane])}) diff --git a/src/mp_api/routes/materials.py b/src/mp_api/routes/materials.py index 22b0de4c..7125a61d 100644 --- a/src/mp_api/routes/materials.py +++ b/src/mp_api/routes/materials.py @@ -7,6 +7,7 @@ from emmet.core.settings import EmmetSettings from mp_api.core.client import BaseRester, MPRestError +from mp_api.core.utils import validate_ids _EMMET_SETTINGS = EmmetSettings() @@ -107,7 +108,7 @@ def search_material_docs( query_params.update({"exclude_elements": ",".join(exclude_elements)}) if task_ids: - query_params.update({"task_ids": ",".join(task_ids)}) + query_params.update({"task_ids": ",".join(validate_ids(task_ids))}) query_params.update( { diff --git a/src/mp_api/routes/summary.py b/src/mp_api/routes/summary.py index c0ee0292..6fb11e02 100644 --- a/src/mp_api/routes/summary.py +++ b/src/mp_api/routes/summary.py @@ -5,6 +5,7 @@ from emmet.core.summary import HasProps, SummaryDoc from emmet.core.symmetry import CrystalSystem from mp_api.core.client import BaseRester +from mp_api.core.utils import validate_ids from pymatgen.analysis.magnetism import Ordering @@ -42,9 +43,7 @@ def search_summary_docs( magnetic_ordering: Optional[Ordering] = None, total_magnetization: Optional[Tuple[float, float]] = None, total_magnetization_normalized_vol: Optional[Tuple[float, float]] = None, - total_magnetization_normalized_formula_units: Optional[ - Tuple[float, float] - ] = None, + total_magnetization_normalized_formula_units: Optional[Tuple[float, float]] = None, num_magnetic_sites: Optional[Tuple[int, int]] = None, num_unique_magnetic_sites: Optional[Tuple[int, int]] = None, k_voigt: Optional[Tuple[float, float]] = None, @@ -200,7 +199,7 @@ def search_summary_docs( ) if material_ids: - query_params.update({"material_ids": ",".join(material_ids)}) + query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) if deprecated is not None: query_params.update({"deprecated": deprecated}) @@ -253,20 +252,10 @@ def search_summary_docs( query_params.update({"theoretical": theoretical}) if sort_fields: - query_params.update( - {"_sort_fields": ",".join([s.strip() for s in sort_fields])} - ) + query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])}) - query_params = { - entry: query_params[entry] - for entry in query_params - if query_params[entry] is not None - } + query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None} return super().search( - num_chunks=num_chunks, - chunk_size=chunk_size, - all_fields=all_fields, - fields=fields, - **query_params + num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, fields=fields, **query_params ) diff --git a/src/mp_api/routes/thermo.py b/src/mp_api/routes/thermo.py index 09cb2089..81077cef 100644 --- a/src/mp_api/routes/thermo.py +++ b/src/mp_api/routes/thermo.py @@ -1,6 +1,7 @@ from collections import defaultdict from typing import Optional, List, Tuple, Union from mp_api.core.client import BaseRester +from mp_api.core.utils import validate_ids from emmet.core.thermo import ThermoDoc from pymatgen.analysis.phase_diagram import PhaseDiagram @@ -71,7 +72,7 @@ def search_thermo_docs( query_params.update({"chemsys": ",".join(chemsys)}) if material_ids: - query_params.update({"material_ids": ",".join(material_ids)}) + query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) if nelements: query_params.update( diff --git a/src/mp_api/routes/xas.py b/src/mp_api/routes/xas.py index 5a0b3774..19c78668 100644 --- a/src/mp_api/routes/xas.py +++ b/src/mp_api/routes/xas.py @@ -2,6 +2,7 @@ from emmet.core.xas import Edge, XASDoc from mp_api.core.client import BaseRester +from mp_api.core.utils import validate_ids from pymatgen.core.periodic_table import Element @@ -67,7 +68,7 @@ def search_xas_docs( query_params.update({"elements": ",".join(elements)}) if material_ids is not None: - query_params["material_ids"] = ",".join(material_ids) + query_params["material_ids"] = ",".join(validate_ids(material_ids)) if sort_fields: query_params.update( From 187b4b94f1bd94eeb3dc4239df7e0e7de1e2dc1d Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 26 Apr 2022 14:23:29 -0700 Subject: [PATCH 2/7] Mypy --- src/mp_api/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mp_api/client.py b/src/mp_api/client.py index 1598ac18..a80f0377 100644 --- a/src/mp_api/client.py +++ b/src/mp_api/client.py @@ -343,9 +343,9 @@ def get_structures( return [ doc.structure for doc in self.materials.search_material_docs( - **input_params, + **input_params, # type: ignore all_fields=False, - fields=["structure"], # type: ignore + fields=["structure"], ) ] else: From 384ed3afd785c6e6de84b2daf5988fe3d61651cd Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 26 Apr 2022 14:26:45 -0700 Subject: [PATCH 3/7] Revert to testing on deployed API endpoint --- .github/workflows/testing.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 0573e05d..d1cd55d4 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -78,7 +78,7 @@ jobs: - name: Test with pytest env: MP_API_KEY: ${{ secrets.MP_API_KEY }} - MP_API_ENDPOINT: https://api-preview.materialsproject.org/ + #MP_API_ENDPOINT: https://api-preview.materialsproject.org/ run: | pip install -e . pytest --cov=mp_api --cov-report=xml From cfd2a1fab0844568109dec6d99be0f3056c86830 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 26 Apr 2022 14:37:40 -0700 Subject: [PATCH 4/7] More linting --- src/mp_api/client.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/mp_api/client.py b/src/mp_api/client.py index a80f0377..731bb709 100644 --- a/src/mp_api/client.py +++ b/src/mp_api/client.py @@ -310,9 +310,9 @@ def get_materials_ids( return sorted( doc.material_id for doc in self.materials.search_material_docs( - **input_params, + **input_params, # type: ignore all_fields=False, - fields=["material_id"], # type: ignore + fields=["material_id"], ) ) @@ -352,9 +352,9 @@ def get_structures( structures = [] for doc in self.materials.search_material_docs( - **input_params, + **input_params, # type: ignore all_fields=False, - fields=["initial_structures"], # type: ignore + fields=["initial_structures"], ): structures.extend(doc.initial_structures) @@ -440,9 +440,9 @@ def get_entries( else: for doc in self.thermo.search_thermo_docs( - **input_params, + **input_params, # type: ignore all_fields=False, - fields=["entries"], # type: ignore + fields=["entries"], ): entries.extend(list(doc.entries.values())) From b1945d33bf8b1a3b56c08309031291479bb14297 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 26 Apr 2022 14:52:29 -0700 Subject: [PATCH 5/7] Add molecule IDs to validation function --- src/mp_api/core/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mp_api/core/utils.py b/src/mp_api/core/utils.py index 3fb37e63..6531c98d 100644 --- a/src/mp_api/core/utils.py +++ b/src/mp_api/core/utils.py @@ -14,7 +14,7 @@ def validate_ids(id_list: List[str]): Returns: id_list: Returns original ID list if everything is formatted correctly. """ - pattern = "(mp|mvc)-.*" + pattern = "(mp|mvc|mol)-.*" for entry in id_list: if re.match(pattern, entry) is None: From 65cefa7de0c443b9342882dc8e8e281899a95f61 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 26 Apr 2022 14:52:41 -0700 Subject: [PATCH 6/7] Change primary key in charge density rester --- src/mp_api/routes/charge_density.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/mp_api/routes/charge_density.py b/src/mp_api/routes/charge_density.py index 7ae7ceff..177a0dd6 100644 --- a/src/mp_api/routes/charge_density.py +++ b/src/mp_api/routes/charge_density.py @@ -22,7 +22,7 @@ class ChargeDensityRester(BaseRester[ChgcarDataDoc]): suffix = "charge_density" - primary_key = "task_id" + primary_key = "fs_id" document_model = ChgcarDataDoc # type: ignore boto_resource = None @@ -87,11 +87,11 @@ def get_charge_density_from_file_id(self, fs_id: str): if environ.get("AWS_EXECUTION_ENV", None) == "AWS_ECS_FARGATE": if self.boto_resource is None: - self.boto_resource = self._get_s3_resource(use_minio=False, unsigned=False) + self.boto_resource = self._get_s3_resource( + use_minio=False, unsigned=False + ) - bucket, obj_prefix = self._extract_s3_url_info( - url_doc, use_minio=False - ) + bucket, obj_prefix = self._extract_s3_url_info(url_doc, use_minio=False) else: try: From 50aaedc6560884cdc601d8c07215cba67c85f4b2 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 26 Apr 2022 15:03:41 -0700 Subject: [PATCH 7/7] Switch to full materials doc --- src/mp_api/routes/materials.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/mp_api/routes/materials.py b/src/mp_api/routes/materials.py index 7125a61d..8c33f2fc 100644 --- a/src/mp_api/routes/materials.py +++ b/src/mp_api/routes/materials.py @@ -1,9 +1,8 @@ from typing import List, Optional, Tuple, Union from pymatgen.core.structure import Structure -from emmet.core.material import MaterialsDoc +from emmet.core.vasp.material import MaterialsDoc from emmet.core.symmetry import CrystalSystem -from emmet.core.utils import jsanitize from emmet.core.settings import EmmetSettings from mp_api.core.client import BaseRester, MPRestError