Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add client side checks on material and task IDs #589

Merged
merged 7 commits into from
Apr 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ jobs:
- name: Test with pytest
env:
MP_API_KEY: ${{ secrets.MP_API_KEY }}
MP_API_ENDPOINT: https://api-preview.materialsproject.org/
#MP_API_ENDPOINT: https://api-preview.materialsproject.org/
run: |
pip install -e .
pytest --cov=mp_api --cov-report=xml
Expand Down
29 changes: 22 additions & 7 deletions src/mp_api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,10 @@ def get_materials_id_references(self, material_id: str) -> List[str]:
"""
return self.provenance.get_data_by_id(material_id).references

def get_materials_ids(self, chemsys_formula: Union[str, List[str]],) -> List[MPID]:
def get_materials_ids(
self,
chemsys_formula: Union[str, List[str]],
) -> List[MPID]:
"""
Get all materials ids for a formula or chemsys.

Expand All @@ -307,7 +310,9 @@ def get_materials_ids(self, chemsys_formula: Union[str, List[str]],) -> List[MPI
return sorted(
doc.material_id
for doc in self.materials.search_material_docs(
**input_params, all_fields=False, fields=["material_id"], # type: ignore
**input_params, # type: ignore
all_fields=False,
fields=["material_id"],
)
)

Expand Down Expand Up @@ -338,14 +343,18 @@ def get_structures(
return [
doc.structure
for doc in self.materials.search_material_docs(
**input_params, all_fields=False, fields=["structure"], # type: ignore
**input_params, # type: ignore
all_fields=False,
fields=["structure"],
)
]
else:
structures = []

for doc in self.materials.search_material_docs(
**input_params, all_fields=False, fields=["initial_structures"], # type: ignore
**input_params, # type: ignore
all_fields=False,
fields=["initial_structures"],
):
structures.extend(doc.initial_structures)

Expand Down Expand Up @@ -390,7 +399,9 @@ def find_structure(
)

def get_entries(
self, chemsys_formula: Union[str, List[str]], sort_by_e_above_hull=False,
self,
chemsys_formula: Union[str, List[str]],
sort_by_e_above_hull=False,
):
"""
Get a list of ComputedEntries or ComputedStructureEntries corresponding
Expand Down Expand Up @@ -429,7 +440,9 @@ def get_entries(

else:
for doc in self.thermo.search_thermo_docs(
**input_params, all_fields=False, fields=["entries"], # type: ignore
**input_params, # type: ignore
all_fields=False,
fields=["entries"],
):
entries.extend(list(doc.entries.values()))

Expand Down Expand Up @@ -757,7 +770,9 @@ def get_entry_by_material_id(self, material_id: str):
)

def get_entries_in_chemsys(
self, elements: Union[str, List[str]], use_gibbs: Optional[int] = None,
self,
elements: Union[str, List[str]],
use_gibbs: Optional[int] = None,
):
"""
Helper method to get a list of ComputedEntries in a chemical system.
Expand Down
100 changes: 36 additions & 64 deletions src/mp_api/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,27 @@
Materials Project data.
"""

from hashlib import new
import itertools
import json
import platform
import sys
import warnings
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from copy import copy
from hashlib import new
from json import JSONDecodeError
from math import ceil
from os import environ
from typing import Dict, Generic, List, Optional, TypeVar, Union, Tuple
from typing import Dict, Generic, List, Optional, Tuple, TypeVar, Union
from urllib.parse import urljoin
from copy import copy
from math import ceil
from matplotlib import use

import requests
from emmet.core.utils import jsanitize
from maggma.api.utils import api_sanitize
from matplotlib import use
from monty.json import MontyDecoder
from mp_api.core.settings import MAPIClientSettings
from mp_api.core.utils import validate_ids
from pydantic import BaseModel
from requests.adapters import HTTPAdapter
from requests.exceptions import RequestException
Expand Down Expand Up @@ -130,9 +131,7 @@ def _create_session(api_key, include_user_agent):
sys.version_info.major, sys.version_info.minor, sys.version_info.micro
)
platform_info = "{}/{}".format(platform.system(), platform.release())
session.headers["user-agent"] = "{} ({} {})".format(
pymatgen_info, python_info, platform_info
)
session.headers["user-agent"] = "{} ({} {})".format(pymatgen_info, python_info, platform_info)

max_retry_num = MAPIClientSettings().MAX_RETRIES
retry = Retry(
Expand Down Expand Up @@ -221,10 +220,7 @@ def _post_resource(
message = data
else:
try:
message = ", ".join(
"{} - {}".format(entry["loc"][1], entry["msg"])
for entry in data
)
message = ", ".join("{} - {}".format(entry["loc"][1], entry["msg"]) for entry in data)
except (KeyError, IndexError):
message = str(data)

Expand Down Expand Up @@ -338,9 +334,7 @@ def _submit_requests(
# criteria dicts.
if parallel_param is not None:
param_length = len(criteria[parallel_param].split(","))
slice_size = (
int(param_length / MAPIClientSettings().NUM_PARALLEL_REQUESTS) or 1
)
slice_size = int(param_length / MAPIClientSettings().NUM_PARALLEL_REQUESTS) or 1

new_param_values = [
entry
Expand All @@ -365,11 +359,7 @@ def _submit_requests(
# Split list and generate multiple criteria
new_criteria = [
{
**{
key: criteria[key]
for key in criteria
if key not in [parallel_param, "_limit"]
},
**{key: criteria[key] for key in criteria if key not in [parallel_param, "_limit"]},
parallel_param: ",".join(list_chunk),
"_limit": new_limits[list_num],
}
Expand All @@ -392,13 +382,9 @@ def _submit_requests(
subtotals = []
remaining_docs_avail = {}

initial_params_list = [
{"url": url, "verify": True, "params": copy(crit)} for crit in new_criteria
]
initial_params_list = [{"url": url, "verify": True, "params": copy(crit)} for crit in new_criteria]

initial_data_tuples = self._multi_thread(
use_document_model, initial_params_list
)
initial_data_tuples = self._multi_thread(use_document_model, initial_params_list)

for data, subtotal, crit_ind in initial_data_tuples:

Expand All @@ -411,9 +397,7 @@ def _submit_requests(

# Rebalance if some parallel queries produced too few results
if len(remaining_docs_avail) > 1 and len(total_data["data"]) < chunk_size:
remaining_docs_avail = dict(
sorted(remaining_docs_avail.items(), key=lambda item: item[1])
)
remaining_docs_avail = dict(sorted(remaining_docs_avail.items(), key=lambda item: item[1]))

# Redistribute missing docs from initial chunk among queries
# which have head room with respect to remaining document number.
Expand All @@ -440,19 +424,15 @@ def _submit_requests(
new_limits[crit_ind] += fill_docs
fill_docs = 0

rebalance_params.append(
{"url": url, "verify": True, "params": copy(crit)}
)
rebalance_params.append({"url": url, "verify": True, "params": copy(crit)})

new_criteria[crit_ind]["_skip"] += crit["_limit"]
new_criteria[crit_ind]["_limit"] = chunk_size

# Obtain missing initial data after rebalancing
if len(rebalance_params) > 0:

rebalance_data_tuples = self._multi_thread(
use_document_model, rebalance_params
)
rebalance_data_tuples = self._multi_thread(use_document_model, rebalance_params)

for data, _, _ in rebalance_data_tuples:
total_data["data"].extend(data["data"])
Expand All @@ -466,9 +446,7 @@ def _submit_requests(
total_data["meta"] = last_data_entry["meta"]

# Get max number of reponse pages
max_pages = (
num_chunks if num_chunks is not None else ceil(total_num_docs / chunk_size)
)
max_pages = num_chunks if num_chunks is not None else ceil(total_num_docs / chunk_size)

# Get total number of docs needed
num_docs_needed = min((max_pages * chunk_size), total_num_docs)
Expand All @@ -480,7 +458,10 @@ def _submit_requests(
else "Retrieving documents"
)
pbar = (
tqdm(desc=pbar_message, total=num_docs_needed,)
tqdm(
desc=pbar_message,
total=num_docs_needed,
)
if not MAPIClientSettings().MUTE_PROGRESS_BARS
else None
)
Expand Down Expand Up @@ -579,21 +560,15 @@ def _multi_thread(

return_data = []

params_gen = iter(
params_list
) # Iter necessary for islice to keep track of what has been accessed
params_gen = iter(params_list) # Iter necessary for islice to keep track of what has been accessed

params_ind = 0

with ThreadPoolExecutor(
max_workers=MAPIClientSettings().NUM_PARALLEL_REQUESTS
) as executor:
with ThreadPoolExecutor(max_workers=MAPIClientSettings().NUM_PARALLEL_REQUESTS) as executor:

# Get list of initial futures defined by max number of parallel requests
futures = set({})
for params in itertools.islice(
params_gen, MAPIClientSettings().NUM_PARALLEL_REQUESTS
):
for params in itertools.islice(params_gen, MAPIClientSettings().NUM_PARALLEL_REQUESTS):

future = executor.submit(
self._submit_request_and_process,
Expand Down Expand Up @@ -670,10 +645,7 @@ def _submit_request_and_process(
message = data
else:
try:
message = ", ".join(
"{} - {}".format(entry["loc"][1], entry["msg"])
for entry in data
)
message = ", ".join("{} - {}".format(entry["loc"][1], entry["msg"]) for entry in data)
except (KeyError, IndexError):
message = str(data)

Expand Down Expand Up @@ -713,7 +685,9 @@ def _query_resource_data(
).get("data")

def get_data_by_id(
self, document_id: str, fields: Optional[List[str]] = None,
self,
document_id: str,
fields: Optional[List[str]] = None,
) -> T:
"""
Query the endpoint for a single document.
Expand All @@ -727,10 +701,10 @@ def get_data_by_id(
"""

if document_id is None:
raise ValueError(
"Please supply a specific id. You can use the query method to find "
"ids of interest."
)
raise ValueError("Please supply a specific ID. You can use the query method to find " "ids of interest.")

if self.primary_key in ["material_id", "task_id"]:
validate_ids([document_id])

if fields is None:
criteria = {"_all_fields": True, "_limit": 1} # type: dict
Expand All @@ -744,7 +718,9 @@ def get_data_by_id(

try:
results = self._query_resource_data(
criteria=criteria, fields=fields, suburl=document_id, # type: ignore
criteria=criteria,
fields=fields,
suburl=document_id, # type: ignore
)
except MPRestError:

Expand Down Expand Up @@ -772,9 +748,7 @@ def get_data_by_id(
if not results:
raise MPRestError(f"No result for record {document_id}.")
elif len(results) > 1: # pragma: no cover
raise ValueError(
f"Multiple records for {document_id}, this shouldn't happen. Please report as a bug."
)
raise ValueError(f"Multiple records for {document_id}, this shouldn't happen. Please report as a bug.")
else:
return results[0]

Expand Down Expand Up @@ -881,9 +855,7 @@ def count(self, criteria: Optional[Dict] = None) -> Union[int, str]:
False,
False,
) # do not waste cycles decoding
results = self._query_resource(
criteria=criteria, num_chunks=1, chunk_size=1
)
results = self._query_resource(criteria=criteria, num_chunks=1, chunk_size=1)
self.monty_decode, self.use_document_model = user_preferences
return results["meta"]["total_doc"]
except Exception: # pragma: no cover
Expand Down
23 changes: 23 additions & 0 deletions src/mp_api/core/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import re
from typing import List


def validate_ids(id_list: List[str]):
"""Function to validate material and task IDs

Args:
id_list (List[str]): List of material or task IDs.

Raises:
ValueError: If at least one ID is not formatted correctly.

Returns:
id_list: Returns original ID list if everything is formatted correctly.
"""
pattern = "(mp|mvc|mol)-.*"

for entry in id_list:
if re.match(pattern, entry) is None:
raise ValueError(f"{entry} is not formatted correctly!")

return id_list
10 changes: 5 additions & 5 deletions src/mp_api/routes/charge_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
class ChargeDensityRester(BaseRester[ChgcarDataDoc]):

suffix = "charge_density"
primary_key = "task_id"
primary_key = "fs_id"
document_model = ChgcarDataDoc # type: ignore
boto_resource = None

Expand Down Expand Up @@ -87,11 +87,11 @@ def get_charge_density_from_file_id(self, fs_id: str):
if environ.get("AWS_EXECUTION_ENV", None) == "AWS_ECS_FARGATE":

if self.boto_resource is None:
self.boto_resource = self._get_s3_resource(use_minio=False, unsigned=False)
self.boto_resource = self._get_s3_resource(
use_minio=False, unsigned=False
)

bucket, obj_prefix = self._extract_s3_url_info(
url_doc, use_minio=False
)
bucket, obj_prefix = self._extract_s3_url_info(url_doc, use_minio=False)

else:
try:
Expand Down
3 changes: 2 additions & 1 deletion src/mp_api/routes/grain_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict

from mp_api.core.client import BaseRester
from mp_api.core.utils import validate_ids

from emmet.core.grain_boundary import GBTypeEnum, GrainBoundaryDoc

Expand Down Expand Up @@ -58,7 +59,7 @@ def search_grain_boundary_docs(
query_params = defaultdict(dict) # type: dict

if material_ids:
query_params.update({"task_ids": ",".join(material_ids)})
query_params.update({"task_ids": ",".join(validate_ids(material_ids))})

if gb_plane:
query_params.update({"gb_plane": ",".join([str(n) for n in gb_plane])})
Expand Down
Loading