Skip to content

Commit

Permalink
Fix material sorting post-process (#391)
Browse files Browse the repository at this point in the history
* Add es steps to trajectory and fix struc frames

* Fix material_ids post-process

* Fix material_ids query op test

* Fox query in test
  • Loading branch information
Jason Munro authored Apr 8, 2022
1 parent ec63ace commit cfa03ca
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 16 deletions.
68 changes: 53 additions & 15 deletions emmet-api/emmet/api/routes/summary/query_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,17 @@ class HasPropsQuery(QueryOperator):
def query(
self,
has_props: Optional[str] = Query(
None, description="Comma-delimited list of possible properties given by HasPropsEnum to search for.",
None,
description="Comma-delimited list of possible properties given by HasPropsEnum to search for.",
),
) -> STORE_PARAMS:

crit = {}

if has_props:
crit = {"has_props": {"$all": [prop.strip() for prop in has_props.split(",")]}}
crit = {
"has_props": {"$all": [prop.strip() for prop in has_props.split(",")]}
}

return {"criteria": crit}

Expand All @@ -46,22 +49,36 @@ class MaterialIDsSearchQuery(QueryOperator):
"""

def query(
self, material_ids: Optional[str] = Query(None, description="Comma-separated list of material_ids to query on"),
self,
material_ids: Optional[str] = Query(
None, description="Comma-separated list of material_ids to query on"
),
) -> STORE_PARAMS:

crit = {}

if material_ids:
crit.update({"material_id": {"$in": [material_id.strip() for material_id in material_ids.split(",")]}})
crit.update(
{
"material_id": {
"$in": [
material_id.strip()
for material_id in material_ids.split(",")
]
}
}
)

return {"criteria": crit}

def post_process(self, docs, query):

if not query.get("sort", None):
mpid_list = query.get("criteria", {}).get("material_id", {}).get("$in", None)
mpid_list = (
query.get("criteria", {}).get("material_id", {}).get("$in", None)
)

if mpid_list is not None:
if mpid_list is not None and "material_id" in query.get("properties", []):
mpid_mapping = {mpid: ind for ind, mpid in enumerate(mpid_list)}

docs = sorted(docs, key=lambda d: mpid_mapping[d["material_id"]])
Expand All @@ -75,7 +92,10 @@ class SearchIsStableQuery(QueryOperator):
"""

def query(
self, is_stable: Optional[bool] = Query(None, description="Whether the material is stable."),
self,
is_stable: Optional[bool] = Query(
None, description="Whether the material is stable."
),
):

crit = {}
Expand All @@ -96,7 +116,9 @@ class SearchHasReconstructedQuery(QueryOperator):

def query(
self,
has_reconstructed: Optional[bool] = Query(None, description="Whether the material has reconstructed surfaces."),
has_reconstructed: Optional[bool] = Query(
None, description="Whether the material has reconstructed surfaces."
),
):

crit = {}
Expand All @@ -116,7 +138,10 @@ class SearchMagneticQuery(QueryOperator):
"""

def query(
self, ordering: Optional[Ordering] = Query(None, description="Magnetic ordering of the material."),
self,
ordering: Optional[Ordering] = Query(
None, description="Magnetic ordering of the material."
),
) -> STORE_PARAMS:

crit = defaultdict(dict) # type: dict
Expand All @@ -136,7 +161,10 @@ class SearchIsTheoreticalQuery(QueryOperator):
"""

def query(
self, theoretical: Optional[bool] = Query(None, description="Whether the material is theoretical."),
self,
theoretical: Optional[bool] = Query(
None, description="Whether the material is theoretical."
),
):

crit = {}
Expand All @@ -157,8 +185,12 @@ class SearchESQuery(QueryOperator):

def query(
self,
is_gap_direct: Optional[bool] = Query(None, description="Whether a band gap is direct or not."),
is_metal: Optional[bool] = Query(None, description="Whether the material is considered a metal."),
is_gap_direct: Optional[bool] = Query(
None, description="Whether a band gap is direct or not."
),
is_metal: Optional[bool] = Query(
None, description="Whether the material is considered a metal."
),
) -> STORE_PARAMS:

crit = defaultdict(dict) # type: dict
Expand All @@ -184,15 +216,19 @@ class SearchStatsQuery(QueryOperator):
"""

def __init__(self, search_doc):
valid_numeric_fields = tuple(sorted(k for k, v in search_doc.__fields__.items() if v.type_ == float))
valid_numeric_fields = tuple(
sorted(k for k, v in search_doc.__fields__.items() if v.type_ == float)
)

def query(
field: Literal[valid_numeric_fields] = Query( # type: ignore
valid_numeric_fields[0],
title=f"SearchDoc field to query on, must be a numerical field, "
f"choose from: {', '.join(valid_numeric_fields)}",
),
num_samples: Optional[int] = Query(None, title="If specified, will only sample this number of documents.",),
num_samples: Optional[int] = Query(
None, title="If specified, will only sample this number of documents.",
),
min_val: Optional[float] = Query(
None,
title="If specified, will only consider documents with field values "
Expand All @@ -203,7 +239,9 @@ def query(
title="If specified, will only consider documents with field values "
"less than or equal to this minimum value.",
),
num_points: int = Query(100, title="The number of values in the returned distribution."),
num_points: int = Query(
100, title="The number of values in the returned distribution."
),
) -> STORE_PARAMS:

self.num_points = num_points
Expand Down
2 changes: 1 addition & 1 deletion tests/emmet-api/summary/test_query_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_material_ids_query():

docs = [{"material_id": "mp-13"}, {"material_id": "mp-149"}]

assert op.post_process(docs, query)[0] == docs[1]
assert op.post_process(docs, {**query, "properties": ["material_id"]})[0] == docs[1]

with ScratchDir("."):
dumpfn(op, "temp.json")
Expand Down

0 comments on commit cfa03ca

Please sign in to comment.