Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Jason committed Jul 5, 2023
2 parents da0f4df + bf5bfee commit 98443a0
Show file tree
Hide file tree
Showing 27 changed files with 705 additions and 600 deletions.
110 changes: 109 additions & 1 deletion mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Materials Project data.
"""

import gzip
import itertools
import json
import platform
Expand All @@ -13,10 +14,13 @@
from json import JSONDecodeError
from math import ceil
from os import environ
from typing import Dict, Generic, List, Optional, Tuple, TypeVar, Union
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
from urllib.parse import quote, urljoin

import boto3
import requests
from botocore import UNSIGNED
from botocore.config import Config
from emmet.core.utils import jsanitize
from monty.json import MontyDecoder
from pydantic import BaseModel, create_model
Expand Down Expand Up @@ -55,6 +59,7 @@ def __init__(
endpoint: str = DEFAULT_ENDPOINT,
include_user_agent: bool = True,
session: Optional[requests.Session] = None,
s3_resource: Optional[Any] = None,
debug: bool = False,
monty_decode: bool = True,
use_document_model: bool = True,
Expand Down Expand Up @@ -108,6 +113,11 @@ def __init__(
else:
self._session = None # type: ignore

if s3_resource:
self._s3_resource = s3_resource
else:
self._s3_resource = None

self.document_model = (
api_sanitize(self.document_model) if self.document_model is not None else None # type: ignore
)
Expand All @@ -120,6 +130,14 @@ def session(self) -> requests.Session:
)
return self._session

@property
def s3_resource(self):
if not self._s3_resource:
self._s3_resource = boto3.resource(
"s3", config=Config(signature_version=UNSIGNED)
)
return self._s3_resource

@staticmethod
def _create_session(api_key, include_user_agent, headers):
session = requests.Session()
Expand Down Expand Up @@ -230,6 +248,96 @@ def _post_resource(
except RequestException as ex:
raise MPRestError(str(ex))

def _patch_resource(
self,
body: Dict = None,
params: Optional[Dict] = None,
suburl: Optional[str] = None,
use_document_model: Optional[bool] = None,
) -> Dict:
"""Patch data to the endpoint for a Resource.
Arguments:
body: body json to send in patch request
params: extra params to send in patch request
suburl: make a request to a specified sub-url
use_document_model: if None, will defer to the self.use_document_model attribute
Returns:
A Resource, a dict with two keys, "data" containing a list of documents, and
"meta" containing meta information, e.g. total number of documents
available.
"""
if use_document_model is None:
use_document_model = self.use_document_model

payload = jsanitize(body)

try:
url = self.endpoint
if suburl:
url = urljoin(self.endpoint, suburl)
if not url.endswith("/"):
url += "/"
response = self.session.patch(url, json=payload, verify=True, params=params)

if response.status_code == 200:
if self.monty_decode:
data = json.loads(response.text, cls=MontyDecoder)
else:
data = json.loads(response.text)

if self.document_model and use_document_model:
if isinstance(data["data"], dict):
data["data"] = self.document_model.parse_obj(data["data"]) # type: ignore
elif isinstance(data["data"], list):
data["data"] = [self.document_model.parse_obj(d) for d in data["data"]] # type: ignore

return data

else:
try:
data = json.loads(response.text)["detail"]
except (JSONDecodeError, KeyError):
data = f"Response {response.text}"
if isinstance(data, str):
message = data
else:
try:
message = ", ".join(
f"{entry['loc'][1]} - {entry['msg']}" for entry in data
)
except (KeyError, IndexError):
message = str(data)

raise MPRestError(
f"REST post query returned with error status code {response.status_code} "
f"on URL {response.url} with message:\n{message}"
)

except RequestException as ex:
raise MPRestError(str(ex))

def _query_open_data(self, bucket: str, prefix: str, key: str) -> dict:
"""Query Materials Project AWS open data s3 buckets
Args:
bucket (str): Materials project bucket name
prefix (str): Full set of file prefixes
key (str): Key for file
Returns:
dict: MontyDecoded data
"""
ref = self.s3_resource.Object(bucket, f"{prefix}/{key}.json.gz") # type: ignore
bytes = ref.get()["Body"] # type: ignore

with gzip.GzipFile(fileobj=bytes) as gzipfile:
content = gzipfile.read()
result = MontyDecoder().decode(content)

return result

def _query_resource(
self,
criteria: Optional[Dict] = None,
Expand Down
8 changes: 6 additions & 2 deletions mp_api/client/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def api_sanitize(
fields_to_leave: list of strings for model fields as "model__name__.field"
"""
models = [
model for model in get_flat_models_from_model(pydantic_model) if issubclass(model, BaseModel)
model
for model in get_flat_models_from_model(pydantic_model)
if issubclass(model, BaseModel)
] # type: List[Type[BaseModel]]

fields_to_leave = fields_to_leave or []
Expand Down Expand Up @@ -92,7 +94,9 @@ def validate_monty(cls, v):
errors.append("@class")

if len(errors) > 0:
raise ValueError("Missing Monty seriailzation fields in dictionary: {errors}")
raise ValueError(
"Missing Monty seriailzation fields in dictionary: {errors}"
)

return v
else:
Expand Down
30 changes: 23 additions & 7 deletions mp_api/client/mprester.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from os import environ
from typing import Dict, List, Literal, Optional, Union

from emmet.core.charge_density import ChgcarDataDoc
from emmet.core.electronic_structure import BSPathType
from emmet.core.mpid import MPID
from emmet.core.settings import EmmetSettings
from emmet.core.tasks import TaskDoc
from emmet.core.vasp.calc_types import CalcType
from packaging import version
from pymatgen.analysis.phase_diagram import PhaseDiagram
Expand Down Expand Up @@ -207,6 +207,9 @@ def __init__(
self.use_document_model = use_document_model
self.monty_decode = monty_decode

# Check if emmet version of server is compatible
emmet_version = version.parse(self.get_emmet_version())

try:
from mpcontribs.client import Client

Expand All @@ -222,9 +225,6 @@ def __init__(
self.contribs = None
warnings.warn(f"Problem loading MPContribs client: {error}")

# Check if emmet version of server os compatible
emmet_version = version.parse(self.get_emmet_version())

if version.parse(emmet_version.base_version) < version.parse(
_MAPI_SETTINGS.MIN_EMMET_VERSION
):
Expand Down Expand Up @@ -390,7 +390,14 @@ def get_emmet_version(self):
Returns: version as a string
"""
return get(url=self.endpoint + "heartbeat").json()["version"]

response = get(url=self.endpoint + "heartbeat").json()

error = response.get("error", None)
if error:
raise MPRestError(error)

return response["version"]

def get_material_id_from_task_id(self, task_id: str) -> Union[str, None]:
"""Returns the current material_id from a given task_id. The
Expand Down Expand Up @@ -1244,14 +1251,23 @@ def get_charge_density_from_material_id(
task_ids = self.get_task_ids_associated_with_material_id(
material_id, calc_types=[CalcType.GGA_Static, CalcType.GGA_U_Static]
)
results: List[ChgcarDataDoc] = self.charge_density.search(task_ids=task_ids) # type: ignore
results: List[TaskDoc] = self.tasks.search(task_ids=task_ids, fields=["last_updated", "task_id"]) # type: ignore

if len(results) == 0:
return None

latest_doc = max(results, key=lambda x: x.last_updated)

chgcar = self.charge_density.get_charge_density_from_file_id(latest_doc.fs_id)
result = (
self.tasks._query_open_data(
bucket="materialsproject-parsed",
prefix="chgcars",
key=str(latest_doc.task_id),
)
or {}
)

chgcar = result.get("data", None)

if chgcar is None:
raise MPRestError(f"No charge density fetched for {material_id}.")
Expand Down
7 changes: 4 additions & 3 deletions mp_api/client/routes/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ def set_message(
title: str,
body: str,
type: MessageType = MessageType.generic,
authors: List[str] = [],
authors: List[str] = None,
): # pragma: no cover
"""Set user settings
"""
Set user settings
Args:
title: Message title
Expand All @@ -34,7 +35,7 @@ def set_message(
Raises:
MPRestError.
"""
d = {"title": title, "body": body, "type": type.value, "authors": authors}
d = {"title": title, "body": body, "type": type.value, "authors": authors or []}

return self._post_resource(body=d).get("data")

Expand Down
19 changes: 19 additions & 0 deletions mp_api/client/routes/_user_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,25 @@ def set_user_settings(self, consumer_id, settings): # pragma: no cover
body=settings, params={"consumer_id": consumer_id}
).get("data")

def patch_user_time_settings(self, consumer_id, time): # pragma: no cover
"""Set user settings.
Args:
consumer_id: Consumer ID for the user
time: utc datetime object for when the user last see messages
Returns:
Dictionary with consumer_id and write status.
Raises:
MPRestError.
"""

return self._patch_resource(
body={"settings.message_last_read": time.isoformat()},
params={"consumer_id": consumer_id},
).get("data")

def get_user_settings(self, consumer_id): # pragma: no cover
"""Get user settings.
Expand Down
73 changes: 1 addition & 72 deletions mp_api/client/routes/materials/charge_density.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
import zlib
from os import environ
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union

import boto3
import msgpack
from botocore import UNSIGNED
from botocore.client import Config
from botocore.exceptions import ConnectionError
from emmet.core.charge_density import ChgcarDataDoc
from monty.serialization import MontyDecoder, dumpfn
from monty.serialization import dumpfn

from mp_api.client.core import BaseRester
from mp_api.client.core.utils import validate_ids
Expand Down Expand Up @@ -76,67 +69,3 @@ def search( # type: ignore
fields=["last_updated", "task_id", "fs_id"],
**query_params,
)

def get_charge_density_from_file_id(self, fs_id: str):
url_doc = self.get_data_by_id(fs_id)

if url_doc:
# The check below is performed to see if the client is being
# used by our internal AWS deployment. If it is, we pull charge
# density data from a private S3 bucket. Else, we pull data
# from public MinIO buckets.
if environ.get("AWS_EXECUTION_ENV", None) == "AWS_ECS_FARGATE":
if self.boto_resource is None:
self.boto_resource = self._get_s3_resource(
use_minio=False, unsigned=False
)

bucket, obj_prefix = self._extract_s3_url_info(url_doc, use_minio=False)

else:
try:
if self.boto_resource is None:
self.boto_resource = self._get_s3_resource()

bucket, obj_prefix = self._extract_s3_url_info(url_doc)

except ConnectionError:
self.boto_resource = self._get_s3_resource(use_minio=False)

bucket, obj_prefix = self._extract_s3_url_info(
url_doc, use_minio=False
)

r = self.boto_resource.Object(bucket, f"{obj_prefix}/{url_doc.fs_id}").get()["Body"] # type: ignore

packed_bytes = r.read()

packed_bytes = zlib.decompress(packed_bytes)
json_data = msgpack.unpackb(packed_bytes, raw=False)
chgcar = MontyDecoder().process_decoded(json_data["data"])

return chgcar

else:
return None

def _extract_s3_url_info(self, url_doc, use_minio: bool = True):
if use_minio:
url_list = url_doc.url.split("/")
bucket = url_list[3]
obj_prefix = url_list[4]
else:
url_list = url_doc.s3_url_prefix.split("/")
bucket = url_list[2].split(".")[0]
obj_prefix = url_list[3]

return (bucket, obj_prefix)

def _get_s3_resource(self, use_minio: bool = True, unsigned: bool = True):
resource = boto3.resource(
"s3",
endpoint_url="https://minio.materialsproject.org" if use_minio else None,
config=Config(signature_version=UNSIGNED) if unsigned else None,
)

return resource
Loading

0 comments on commit 98443a0

Please sign in to comment.