From cfa03ca1c9770b3bd3a0098bd45d9dc3de9e38c8 Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Thu, 7 Apr 2022 19:56:57 -0700 Subject: [PATCH] Fix material sorting post-process (#391) * Add es steps to trajectory and fix struc frames * Fix material_ids post-process * Fix material_ids query op test * Fox query in test --- .../api/routes/summary/query_operators.py | 68 +++++++++++++++---- .../emmet-api/summary/test_query_operators.py | 2 +- 2 files changed, 54 insertions(+), 16 deletions(-) diff --git a/emmet-api/emmet/api/routes/summary/query_operators.py b/emmet-api/emmet/api/routes/summary/query_operators.py index 856531aa51..064bda678c 100644 --- a/emmet-api/emmet/api/routes/summary/query_operators.py +++ b/emmet-api/emmet/api/routes/summary/query_operators.py @@ -28,14 +28,17 @@ class HasPropsQuery(QueryOperator): def query( self, has_props: Optional[str] = Query( - None, description="Comma-delimited list of possible properties given by HasPropsEnum to search for.", + None, + description="Comma-delimited list of possible properties given by HasPropsEnum to search for.", ), ) -> STORE_PARAMS: crit = {} if has_props: - crit = {"has_props": {"$all": [prop.strip() for prop in has_props.split(",")]}} + crit = { + "has_props": {"$all": [prop.strip() for prop in has_props.split(",")]} + } return {"criteria": crit} @@ -46,22 +49,36 @@ class MaterialIDsSearchQuery(QueryOperator): """ def query( - self, material_ids: Optional[str] = Query(None, description="Comma-separated list of material_ids to query on"), + self, + material_ids: Optional[str] = Query( + None, description="Comma-separated list of material_ids to query on" + ), ) -> STORE_PARAMS: crit = {} if material_ids: - crit.update({"material_id": {"$in": [material_id.strip() for material_id in material_ids.split(",")]}}) + crit.update( + { + "material_id": { + "$in": [ + material_id.strip() + for material_id in material_ids.split(",") + ] + } + } + ) return {"criteria": crit} def post_process(self, docs, query): if not query.get("sort", None): - mpid_list = query.get("criteria", {}).get("material_id", {}).get("$in", None) + mpid_list = ( + query.get("criteria", {}).get("material_id", {}).get("$in", None) + ) - if mpid_list is not None: + if mpid_list is not None and "material_id" in query.get("properties", []): mpid_mapping = {mpid: ind for ind, mpid in enumerate(mpid_list)} docs = sorted(docs, key=lambda d: mpid_mapping[d["material_id"]]) @@ -75,7 +92,10 @@ class SearchIsStableQuery(QueryOperator): """ def query( - self, is_stable: Optional[bool] = Query(None, description="Whether the material is stable."), + self, + is_stable: Optional[bool] = Query( + None, description="Whether the material is stable." + ), ): crit = {} @@ -96,7 +116,9 @@ class SearchHasReconstructedQuery(QueryOperator): def query( self, - has_reconstructed: Optional[bool] = Query(None, description="Whether the material has reconstructed surfaces."), + has_reconstructed: Optional[bool] = Query( + None, description="Whether the material has reconstructed surfaces." + ), ): crit = {} @@ -116,7 +138,10 @@ class SearchMagneticQuery(QueryOperator): """ def query( - self, ordering: Optional[Ordering] = Query(None, description="Magnetic ordering of the material."), + self, + ordering: Optional[Ordering] = Query( + None, description="Magnetic ordering of the material." + ), ) -> STORE_PARAMS: crit = defaultdict(dict) # type: dict @@ -136,7 +161,10 @@ class SearchIsTheoreticalQuery(QueryOperator): """ def query( - self, theoretical: Optional[bool] = Query(None, description="Whether the material is theoretical."), + self, + theoretical: Optional[bool] = Query( + None, description="Whether the material is theoretical." + ), ): crit = {} @@ -157,8 +185,12 @@ class SearchESQuery(QueryOperator): def query( self, - is_gap_direct: Optional[bool] = Query(None, description="Whether a band gap is direct or not."), - is_metal: Optional[bool] = Query(None, description="Whether the material is considered a metal."), + is_gap_direct: Optional[bool] = Query( + None, description="Whether a band gap is direct or not." + ), + is_metal: Optional[bool] = Query( + None, description="Whether the material is considered a metal." + ), ) -> STORE_PARAMS: crit = defaultdict(dict) # type: dict @@ -184,7 +216,9 @@ class SearchStatsQuery(QueryOperator): """ def __init__(self, search_doc): - valid_numeric_fields = tuple(sorted(k for k, v in search_doc.__fields__.items() if v.type_ == float)) + valid_numeric_fields = tuple( + sorted(k for k, v in search_doc.__fields__.items() if v.type_ == float) + ) def query( field: Literal[valid_numeric_fields] = Query( # type: ignore @@ -192,7 +226,9 @@ def query( title=f"SearchDoc field to query on, must be a numerical field, " f"choose from: {', '.join(valid_numeric_fields)}", ), - num_samples: Optional[int] = Query(None, title="If specified, will only sample this number of documents.",), + num_samples: Optional[int] = Query( + None, title="If specified, will only sample this number of documents.", + ), min_val: Optional[float] = Query( None, title="If specified, will only consider documents with field values " @@ -203,7 +239,9 @@ def query( title="If specified, will only consider documents with field values " "less than or equal to this minimum value.", ), - num_points: int = Query(100, title="The number of values in the returned distribution."), + num_points: int = Query( + 100, title="The number of values in the returned distribution." + ), ) -> STORE_PARAMS: self.num_points = num_points diff --git a/tests/emmet-api/summary/test_query_operators.py b/tests/emmet-api/summary/test_query_operators.py index e84644bd92..abe3ad20b2 100644 --- a/tests/emmet-api/summary/test_query_operators.py +++ b/tests/emmet-api/summary/test_query_operators.py @@ -42,7 +42,7 @@ def test_material_ids_query(): docs = [{"material_id": "mp-13"}, {"material_id": "mp-149"}] - assert op.post_process(docs, query)[0] == docs[1] + assert op.post_process(docs, {**query, "properties": ["material_id"]})[0] == docs[1] with ScratchDir("."): dumpfn(op, "temp.json")