Skip to content

Commit

Permalink
Resource class overhaul (#236)
Browse files Browse the repository at this point in the history
* Base resource and get resource added

* Task endpoint updated to match new resource

* Base resource and get resource added

* Task endpoint updated to match new resource

* Consumer post resource added

* User settings post endpoint added

* Generaic post method added to core client

* Find structure endpoint updates

* Remove consumer endpoints from docs

* Linting

* pycodestyle linting

* Mypy cleanup

* flake8 fixes

* Consumer settings endpoint data model change

* mypy fixes
  • Loading branch information
Jason Munro authored Mar 30, 2021
1 parent 89e6fa0 commit f63376a
Show file tree
Hide file tree
Showing 37 changed files with 590 additions and 283 deletions.
21 changes: 21 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@
s3_chgcar_json = os.environ.get("S3_CHGCAR_STORE", "s3_chgcar.json")


consumer_settings_store_json = os.environ.get(
"CONSUMER_SETTINGS_STORE", "consumer_settings_store.json"
)


if db_uri:
from maggma.stores import MongoURIStore, S3Store

Expand Down Expand Up @@ -268,6 +273,13 @@
searchable_fields=["task_id", "fs_id"],
)

consumer_settings_store = MongoURIStore(
uri=f"mongodb+srv://{db_uri}",
database="mp_consumers",
key="consumer_id",
collection_name="settings",
)


else:
materials_store = loadfn(materials_store_json)
Expand Down Expand Up @@ -304,6 +316,8 @@
s3_chgcar_index = loadfn(s3_chgcar_index_json)
s3_chgcar = loadfn(s3_chgcar_json)

consumer_settings_store = loadfn(consumer_settings_store_json)

# Materials
from mp_api.materials.resources import materials_resource

Expand Down Expand Up @@ -444,5 +458,12 @@

resources.update({"dos": dos_resource(dos_store, s3_dos)})


# Consumers
from mp_api._consumer.resources import set_settings_resource

resources.update({"user_settings": set_settings_resource(consumer_settings_store)})


api = MAPI(resources=resources)
app = api.app
23 changes: 23 additions & 0 deletions src/mp_api/_consumer/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from mp_api._consumer.models import UserSettingsDoc
from mp_api.core.client import BaseRester


class UserSettingsRester(BaseRester):

suffix = "user_settings"
document_model = UserSettingsDoc

def set_user_settings(self, consumer_id, settings):
"""
Set user settings.
Args:
consumer_id: Consumer ID for the user
settings: Dictionary with user settings
Returns:
Dictionary with consumer_id and write status.
Raises:
MPRestError
"""
return self._post_resource(
body=settings, params={"consumer_id": consumer_id}
).get("data")
17 changes: 17 additions & 0 deletions src/mp_api/_consumer/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from pydantic import BaseModel, Field


class UserSettingsDoc(BaseModel):
"""
Defines data for user settings
"""

consumer_id: str = Field(
None, title="Consumer ID", description="Consumer ID for a specific user."
)

settings: dict = Field(
None,
title="Consumer ID settings",
description="Settings defined for a specific user.",
)
27 changes: 27 additions & 0 deletions src/mp_api/_consumer/query_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Dict
from fastapi import Query, Body
from mp_api.core.utils import STORE_PARAMS
from mp_api.core.query_operator import QueryOperator


class UserSettingsQuery(QueryOperator):
"""Query operators to provide user settings information"""

def query(
self,
consumer_id: str = Query(..., title="Consumer ID",),
settings: Dict = Body(..., title="User settings",),
) -> STORE_PARAMS:

self.cid = consumer_id
self.settings = settings

crit = {"consumer_id": consumer_id, "settings": settings}

return {"criteria": crit}

def post_process(self, written):

d = [{"consumer_id": self.cid, "settings": self.settings}]

return d
14 changes: 14 additions & 0 deletions src/mp_api/_consumer/resources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from mp_api.core.resource import ConsumerPostResource
from mp_api._consumer.models import UserSettingsDoc
from mp_api._consumer.query_operator import UserSettingsQuery


def set_settings_resource(consumer_settings_store):
resource = ConsumerPostResource(
consumer_settings_store,
UserSettingsDoc,
query_operators=[UserSettingsQuery()],
tags=["Consumer"],
)

return resource
4 changes: 2 additions & 2 deletions src/mp_api/bandstructure/resources.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from fastapi.param_functions import Query
from mp_api.core.resource import Resource
from mp_api.core.resource import GetResource
from mp_api.bandstructure.models.doc import BSDoc, BSObjectReturn
from mp_api.bandstructure.models.core import BSPathType

Expand Down Expand Up @@ -74,7 +74,7 @@ async def get_object(
tags=self.tags,
)(get_object)

resource = Resource(
resource = GetResource(
bs_store,
BSDoc,
query_operators=[
Expand Down
4 changes: 2 additions & 2 deletions src/mp_api/charge_density/resources.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from fastapi.param_functions import Path
from mp_api.core.resource import Resource
from mp_api.core.resource import GetResource
from mp_api.charge_density.models import ChgcarDataDoc
from mp_api.core.query_operator import PaginationQuery, SparseFieldsQuery, SortQuery
from mp_api.core.utils import STORE_PARAMS
Expand Down Expand Up @@ -63,7 +63,7 @@ async def get_chgcar_data(
tags=self.tags,
)(get_chgcar_data)

resource = Resource(
resource = GetResource(
s3_store,
ChgcarDataDoc,
query_operators=[
Expand Down
4 changes: 2 additions & 2 deletions src/mp_api/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Dict
from datetime import datetime
from monty.json import MSONable
from mp_api.core.resource import Resource
from mp_api.core.resource import GetResource
from pymatgen.core import __version__ as pmg_version # type: ignore
from fastapi.openapi.utils import get_openapi

Expand All @@ -21,7 +21,7 @@ class MAPI(MSONable):

def __init__(
self,
resources: Dict[str, Resource],
resources: Dict[str, GetResource],
title="Materials Project API",
version="3.0.0-dev",
):
Expand Down
92 changes: 86 additions & 6 deletions src/mp_api/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import warnings

import requests
from monty.json import MontyDecoder
from monty.json import MontyEncoder, MontyDecoder
from requests.exceptions import RequestException
from pydantic import BaseModel

Expand Down Expand Up @@ -154,12 +154,84 @@ def _make_request(self, sub_url, monty_decode: bool = True):

raise MPRestError(str(ex))

def _post_resource(
self,
body: Dict = None,
params: Optional[Dict] = None,
monty_decode: bool = True,
suburl: Optional[str] = None,
use_document_model: Optional[bool] = True,
):
"""
Post data to the endpoint for a Resource.
Arguments:
body: body json to send in post request
params: extra params to send in post request
monty_decode: Decode the data using monty into python objects
suburl: make a request to a specified sub-url
use_document_model: whether to use the core document model for data reconstruction
Returns:
A Resource, a dict with two keys, "data" containing a list of documents, and
"meta" containing meta information, e.g. total number of documents
available.
"""

payload = json.dumps(body, cls=MontyEncoder)

try:
url = self.endpoint
if suburl:
url = urljoin(self.endpoint, suburl)
if not url.endswith("/"):
url += "/"
response = self.session.post(url, data=payload, verify=True, params=params)

if response.status_code == 200:

if monty_decode:
data = json.loads(response.text, cls=MontyDecoder)
else:
data = json.loads(response.text)

if self.document_model and use_document_model:
data["data"] = [self.document_model.parse_obj(d) for d in data["data"]] # type: ignore

return data

else:
try:
data = json.loads(response.text)["detail"]
except (JSONDecodeError, KeyError):
data = "Response {}".format(response.text)
if isinstance(data, str):
message = data
else:
try:
message = ", ".join(
"{} - {}".format(entry["loc"][1], entry["msg"])
for entry in data
)
except (KeyError, IndexError):
message = str(data)

raise MPRestError(
f"REST post query returned with error status code {response.status_code} "
f"on URL {response.url} with message:\n{message}"
)

except RequestException as ex:

raise MPRestError(str(ex))

def _query_resource(
self,
criteria: Optional[Dict] = None,
fields: Optional[List[str]] = None,
monty_decode: bool = True,
suburl: Optional[str] = None,
use_document_model: Optional[bool] = True,
):
"""
Query the endpoint for a Resource containing a list of documents
Expand All @@ -173,6 +245,7 @@ def _query_resource(
fields: list of fields to return
monty_decode: Decode the data using monty into python objects
suburl: make a request to a specified sub-url
use_document_model: whether to use the core document model for data reconstruction
Returns:
A Resource, a dict with two keys, "data" containing a list of documents, and
Expand Down Expand Up @@ -203,7 +276,7 @@ def _query_resource(
else:
data = json.loads(response.text)

if self.document_model:
if self.document_model and use_document_model:
data["data"] = [self.document_model.parse_obj(d) for d in data["data"]] # type: ignore

return data
Expand Down Expand Up @@ -278,8 +351,10 @@ def get_document_by_id(
"""

if document_id is None:
raise ValueError("Please supply a specific id. You can use the query method to find "
"ids of interest.")
raise ValueError(
"Please supply a specific id. You can use the query method to find "
"ids of interest."
)

if fields is None:
criteria = {"all_fields": True, "limit": 1} # type: dict
Expand All @@ -293,7 +368,10 @@ def get_document_by_id(
fields = (fields,)

results = self.query(
criteria=criteria, fields=fields, monty_decode=monty_decode, suburl=document_id
criteria=criteria,
fields=fields,
monty_decode=monty_decode,
suburl=document_id,
)

if not results:
Expand All @@ -307,7 +385,9 @@ def get_document_by_id(
return results[0]

def query_by_task_id(self, *args, **kwargs):
print("query_by_task_id has been renamed to get_document_by_id to be more general")
print(
"query_by_task_id has been renamed to get_document_by_id to be more general"
)
return self.get_document_by_id(*args, **kwargs)

def count(self, criteria: Optional[Dict] = None) -> Union[int, str]:
Expand Down
10 changes: 4 additions & 6 deletions src/mp_api/core/query_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ def meta(self, store: Store, query: Dict) -> Dict:
"""
return {}

def post_process(self, doc: Dict) -> Dict:
def post_process(self, docs: List[Dict]) -> List[Dict]:
"""
An optional post-processing function for the data
"""
return doc
return docs


class PaginationQuery(QueryOperator):
Expand Down Expand Up @@ -169,8 +169,7 @@ class VersionQuery(QueryOperator):
def query(
self,
version: Optional[str] = Query(
None,
description="Database version to query on formatted as YYYY.MM.DD",
None, description="Database version to query on formatted as YYYY.MM.DD",
),
) -> STORE_PARAMS:

Expand All @@ -191,8 +190,7 @@ def query(
self,
field: Optional[str] = Query(None, description="Field to sort with"),
ascending: Optional[bool] = Query(
None,
description="Whether the sorting should be ascending",
None, description="Whether the sorting should be ascending",
),
) -> STORE_PARAMS:

Expand Down
Loading

0 comments on commit f63376a

Please sign in to comment.