diff --git a/strider/fetcher.py b/strider/fetcher.py index b8e9fd51..ec157dfe 100644 --- a/strider/fetcher.py +++ b/strider/fetcher.py @@ -13,16 +13,16 @@ import aiostream import asyncio -from kp_registry import Registry from reasoner_pydantic import ( Message, - QueryGraph, KnowledgeGraph, AuxiliaryGraphs, Result, + Node, + EdgeBinding, ) -from .graph import Graph +from .graph import remove_orphaned from .normalizer import Normalizer from .knowledge_provider import KnowledgeProvider from .trapi import ( @@ -30,6 +30,7 @@ fill_categories_predicates, ) from .query_planner import generate_plan, get_next_qedge +from .mcq import is_mcq_node, get_mcq_edge_ids from .config import settings from .utils import ( WBMT, @@ -67,7 +68,7 @@ def __init__(self, logger, bypass_cache, parameters): async def lookup( self, - qgraph: Graph = None, + message: Message = None, call_stack: List = [], qid: str = "", ): @@ -77,9 +78,9 @@ async def lookup( else: qid = str(uuid.uuid4())[:8] # if this is a leaf node, we're done - if qgraph is None: - qgraph = Graph(self.qgraph) - if not qgraph["edges"]: + if message is None: + message = self.message + if not message.query_graph.edges: self.logger.info(f"[{qid}] Finished call stack: {(', ').join(call_stack)}") # gets sent to generate_from_result for final result merge and then yield to server.py yield KnowledgeGraph.parse_obj( @@ -96,28 +97,34 @@ async def lookup( return try: - qedge_id, qedge = get_next_qedge(qgraph) + qedge_id, qedge = get_next_qedge(message.query_graph.dict()) except StopIteration: self.logger.error("Cannot find qedge with pinned endpoint") raise RuntimeError("Cannot find qedge with pinned endpoint") except Exception as e: - self.logger.error("Unable to get next qedge") + self.logger.error(f"Unable to get next qedge: {e}") self.logger.info(f"[{qid}] Getting results for {qedge_id}") - qedge = qgraph["edges"][qedge_id] - onehop = { - "nodes": { - key: value - for key, value in qgraph["nodes"].items() - if key in (qedge["subject"], qedge["object"]) - }, - "edges": {qedge_id: qedge}, - } + qedge = message.query_graph.edges[qedge_id] + onehop = Message.parse_obj( + { + "query_graph": { + "nodes": { + key: value + for key, value in message.query_graph.nodes.items() + if key in (qedge.subject, qedge.object) + }, + "edges": {qedge_id: qedge}, + } + } + ) + onehop.knowledge_graph = message.knowledge_graph + onehop.auxiliary_graphs = message.auxiliary_graphs generators = [ self.generate_from_kp( - qgraph, + message, onehop, self.kps[kp_id], copy.deepcopy(call_stack), @@ -131,8 +138,8 @@ async def lookup( async def generate_from_kp( self, - qgraph: Graph, - onehop_qgraph: Graph, + message: Message, + onehop_message: Message, kp: KnowledgeProvider, call_stack: List, qid: str, @@ -141,13 +148,17 @@ async def generate_from_kp( # keep track of call stack for each kp plan branch call_stack.append(kp.id) self.logger.info(f"[{qid}] Current call stack: {(', ').join(call_stack)}") + is_mcq = False + for node in onehop_message.query_graph.nodes.values(): + node_is_mcq = is_mcq_node(node) + is_mcq = is_mcq or node_is_mcq onehop_response = None # check if message wants to override the cache overwrite_cache = self.parameters.get("overwrite_cache") overwrite_cache = overwrite_cache if type(overwrite_cache) is bool else False if not self.bypass_cache and not overwrite_cache: # get onehop response from cache - onehop_response = await get_kp_onehop(kp.id, onehop_qgraph) + onehop_response = await get_kp_onehop(kp.id, onehop_message.dict()) if onehop_response is not None: self.logger.info( f"[{qid}] [{kp.id}]: Got onehop with {len(onehop_response['results'])} results from cache" @@ -156,16 +167,18 @@ async def generate_from_kp( if onehop_response is None and not settings.offline_mode: # onehop not in cache, have to go get response self.logger.info( - f"[{kp.id}] Need to get results for: {json.dumps(elide_curies(onehop_qgraph))}" + f"[{kp.id}] Need to get results for: {json.dumps(elide_curies(onehop_message.dict()))}" ) onehop_response = await kp.solve_onehop( - onehop_qgraph, + onehop_message, self.bypass_cache, call_stack, - last_hop=len(qgraph["edges"]) == 1, + last_hop=len(message.query_graph.edges) == 1, ) if not self.bypass_cache: - await save_kp_onehop(kp.id, onehop_qgraph, onehop_response.dict()) + await save_kp_onehop( + kp.id, onehop_message.dict(), onehop_response.dict() + ) if onehop_response is None and settings.offline_mode: self.logger.info( f"[{kp.id}] Didn't get anything back from cache in offline mode." @@ -177,30 +190,36 @@ async def generate_from_kp( onehop_kgraph = onehop_response.knowledge_graph onehop_results = onehop_response.results onehop_auxgraphs = onehop_response.auxiliary_graphs - qedge_id = next(iter(onehop_qgraph["edges"].keys())) + qedge_id = next(iter(onehop_message.query_graph.edges.keys())) generators = [] if onehop_results: - subqgraph = copy.deepcopy(qgraph) + subqgraph = copy.deepcopy(message) # remove edge - subqgraph["edges"].pop(qedge_id) + subqgraph.query_graph.edges.pop(qedge_id) # remove orphaned nodes - subqgraph.remove_orphaned() + remove_orphaned(subqgraph) else: self.logger.info( f"[{qid}] Ending call stack with no results: {(', ').join(call_stack)}" ) + return + + result_map = defaultdict(list) for batch_results in batch(onehop_results, self.parameters["batch_size"]): - result_map = defaultdict(list) + if is_mcq: + # only take the top 100 results + batch_results = batch_results[:100] + # copy subqgraph between each batch # before we fill it with result curies # this keeps the sub query graph from being modified and passing # extra curies into subsequent batches populated_subqgraph = copy.deepcopy(subqgraph) # clear out any existing bindings to only use the new ones we get back - for qnode_id in onehop_qgraph["nodes"].keys(): - if qnode_id in populated_subqgraph["nodes"]: - populated_subqgraph["nodes"][qnode_id]["ids"] = [] + for qnode_id in onehop_message.query_graph.nodes.keys(): + if qnode_id in populated_subqgraph.query_graph.nodes: + populated_subqgraph.query_graph.nodes[qnode_id].ids = [] for result in batch_results: # add edge to results and kgraph @@ -248,23 +267,26 @@ async def generate_from_kp( ] ) - kgraph_node_ids = set( - binding.id - for _, bindings in result.node_bindings.items() - for binding in bindings - ) - - for aux_graph_id in aux_graphs: - for edge_id in result_auxgraph[aux_graph_id].edges or []: - kgraph_node_ids.add(onehop_kgraph.edges[edge_id].subject) - kgraph_node_ids.add(onehop_kgraph.edges[edge_id].object) - try: + # do some knowledge graph collection + node_ids = [ + onehop_kgraph.edges[edge_id].subject + for edge_id in kgraph_edge_ids + if onehop_kgraph.edges[edge_id].subject in onehop_kgraph.nodes + ] + node_ids.extend( + [ + onehop_kgraph.edges[edge_id].object + for edge_id in kgraph_edge_ids + if onehop_kgraph.edges[edge_id].object + in onehop_kgraph.nodes + ] + ) result_kgraph = KnowledgeGraph.parse_obj( { "nodes": { node_id: onehop_kgraph.nodes[node_id] - for node_id in kgraph_node_ids + for node_id in node_ids }, "edges": { edge_id: onehop_kgraph.edges[edge_id] @@ -276,52 +298,139 @@ async def generate_from_kp( self.logger.error( f"Something went wrong making the sub-result kgraph: {traceback.format_exc()}" ) - # with open("bad_kp_response.json", "w") as f: - # json.dump(onehop_response.dict(), f) raise Exception(e) # pin nodes for qnode_id, bindings in result.node_bindings.items(): - if qnode_id not in populated_subqgraph["nodes"]: + if qnode_id not in populated_subqgraph.query_graph.nodes: continue # add curies from result into the qgraph - populated_subqgraph["nodes"][qnode_id]["ids"] = list( - # need to call set() to remove any duplicates - set( - (populated_subqgraph["nodes"][qnode_id].get("ids") or []) - # use query_id (original curie) for any subclass results - + [binding.query_id or binding.id for binding in bindings] + if is_mcq: + # TODO: this doesn't support cyclic graphs + populated_subqgraph.query_graph.nodes[qnode_id].member_ids = ( + list( + # need to call set() to remove any duplicates + set( + ( + populated_subqgraph.query_graph.nodes[ + qnode_id + ].member_ids + or [] + ) + # use query_id (original curie) for any subclass results + + [ + binding.query_id or binding.id + for binding in bindings + ] + ) + ) + ) + else: + populated_subqgraph.query_graph.nodes[qnode_id].ids = list( + # need to call set() to remove any duplicates + set( + ( + populated_subqgraph.query_graph.nodes[qnode_id].ids + or [] + ) + # use query_id (original curie) for any subclass results + + [ + binding.query_id or binding.id + for binding in bindings + ] + ) ) - ) # get intersection of result node ids and new sub qgraph # should be empty on last hop because the qgraph is empty - qnode_ids = set(populated_subqgraph["nodes"].keys()) & set( + qnode_ids = set(populated_subqgraph.query_graph.nodes.keys()) & set( result.node_bindings.keys() ) + # result key becomes ex. ((n0, (MONDO:0005737,)), (n1, (RXCUI:340169,))) - key_fcn = lambda res: tuple( - ( - qnode_id, - tuple( - binding.query_id if binding.query_id else binding.id - for binding in bindings - ), # probably only one - ) + def result_key_fcn(res, kgraph, auxgraph): + result_keys = [] # for cyclic queries, the qnode ids can get out of order, so we need to sort the keys - for qnode_id, bindings in sorted(res.node_bindings.items()) - if qnode_id in qnode_ids - ) - result_map[key_fcn(result)].append( - (result, result_kgraph, result_auxgraph) - ) + for qnode_id, bindings in sorted(res.node_bindings.items()): + if qnode_id not in qnode_ids: + continue + if is_mcq_node(onehop_message.query_graph.nodes[qnode_id]): + # is mcq node, the binding is going to point to the standard uuid, so we need to look + # into the kgraph and auxgraphs to find its origin + try: + curie_list = ( + populated_subqgraph.query_graph.nodes[qnode_id].ids + or [] + ) + ( + populated_subqgraph.query_graph.nodes[ + qnode_id + ].member_ids + or [] + ) + mcq_edge_ids = get_mcq_edge_ids(res, kgraph, auxgraph) + for mcq_edge_id in mcq_edge_ids: + mcq_edge = kgraph.edges[mcq_edge_id] + if mcq_edge.predicate != "biolink:member_of": + # we assume that the node not referenced in the qgraph is what we want + if mcq_edge.subject in curie_list: + curie_key = mcq_edge.subject + else: + curie_key = mcq_edge.object + result_keys.append((qnode_id, (curie_key,))) + + except Exception as e: + self.logger.error( + f"Failed to create result map key: {e}" + ) + else: + result_keys.append( + ( + qnode_id, + tuple( + ( + binding.query_id + if binding.query_id + else binding.id + ) + for binding in bindings + ), # probably only one + ) + ) + return tuple(result_keys) + + result_keys = result_key_fcn(result, result_kgraph, result_auxgraph) + if len(result_keys) == 0: + result_map[()].append((result, result_kgraph, result_auxgraph)) + else: + for result_key in result_keys: + result_map[result_key].append( + (result, result_kgraph, result_auxgraph) + ) + + for node in populated_subqgraph.query_graph.nodes.values(): + node_is_mcq = is_mcq_node(node) + if node_is_mcq: + # get MCQ uuid from NN + mcq_node_id = await kp.get_mcq_uuid(node.member_ids) + node.ids = [mcq_node_id] + node_dict = node.dict() + populated_subqgraph.knowledge_graph.nodes[mcq_node_id] = ( + Node.parse_obj( + { + "categories": node_dict["categories"], + "is_set": True, + "name": "MCQ_Set", + "attributes": [], + } + ) + ) generators.append( self.generate_from_result( populated_subqgraph, - key_fcn, - lambda result: result_map[key_fcn(result)], + result_key_fcn, result_map, + is_mcq, call_stack, qid, ) @@ -333,50 +442,117 @@ async def generate_from_kp( async def generate_from_result( self, - qgraph, + submessage, key_fcn, - get_results: Callable[[dict], Iterable[tuple[dict, dict]]], result_map, + is_mcq: bool, call_stack: List, sub_qid: str, ): async for subkgraph, subresult, subauxgraph, qid in self.lookup( - qgraph, + submessage, call_stack, sub_qid, ): + # subresult above is next hop self.logger.debug( - f"[{qid}] looking for key {key_fcn(subresult)}: {subresult.json()}" + f"[{qid}] looking for key {key_fcn(subresult, subkgraph, subauxgraph)}: {subresult.json()}" ) - if not key_fcn(subresult) in result_map: - self.logger.error( - f"[{qid}] Couldn't find subresult in result map: {key_fcn(subresult)}" - ) - self.logger.error(f"[{sub_qid}] Result map: {result_map.keys()}") - self.logger.error(f"[{qid}] subresult from lookup: {subresult.json()}") - raise KeyError("Subresult not found in result map") - for result, kgraph, auxgraph in get_results(subresult): - # combine one-hop with subquery results - # Need to create a new result with all node bindings combined - new_subresult = Result.parse_obj( - { - "node_bindings": { - **subresult.node_bindings, - **result.node_bindings, - }, - "analyses": [ - *subresult.analyses, - *result.analyses, - # reconsider - ], - } - ) - new_subkgraph = copy.deepcopy(subkgraph) - new_subkgraph.nodes.update(kgraph.nodes) - new_subkgraph.edges.update(kgraph.edges) - new_auxgraph = copy.deepcopy(subauxgraph) - new_auxgraph.update(auxgraph) - yield new_subkgraph, new_subresult, new_auxgraph, qid + result_keys = key_fcn(subresult, subkgraph, subauxgraph) + if len(result_keys) == 0: + result_keys = [()] + for result_key in result_keys: + if result_key not in result_map: + self.logger.error( + f"[{qid}] Couldn't find subresult in result map: {key_fcn(subresult, subkgraph, subauxgraph)}" + ) + self.logger.error(f"[{sub_qid}] Result map: {result_map.keys()}") + self.logger.error( + f"[{qid}] subresult from lookup: {subresult.json()}" + ) + raise KeyError("Subresult not found in result map") + for result, kgraph, auxgraph in result_map[result_key]: + # result above is previous/current hop + # for result, kgraph, auxgraph in result_map[key_fcn(subresult, subkgraph, subauxgraph)]: + # combine one-hop with subquery results + # Need to create a new result with all node bindings combined + if not is_mcq: + new_subresult = Result.parse_obj( + { + "node_bindings": { + **subresult.node_bindings, + **result.node_bindings, + }, + "analyses": [ + *subresult.analyses, + *result.analyses, + # reconsider + ], + } + ) + + else: + mcq_edge_ids = get_mcq_edge_ids( + subresult, subkgraph, subauxgraph + ) + member_of_edge_id = None + mcq_node_id = None + mcq_node_curie = None + for node_id, node_binding in subresult.node_bindings.items(): + node_binding_curie = next(iter(node_binding)).id + if node_binding_curie.startswith("uuid"): + mcq_node_id = node_id + mcq_node_curie = node_binding_curie + if mcq_node_id is not None: + previous_hop_node_curie = next( + iter(result.node_bindings[mcq_node_id]) + ).id + for mcq_edge_id in mcq_edge_ids: + mcq_edge = subkgraph.edges[mcq_edge_id] + if mcq_edge.predicate == "biolink:member_of": + if ( + mcq_edge.subject == mcq_node_curie + and mcq_edge.object == previous_hop_node_curie + ) or ( + mcq_edge.subject == previous_hop_node_curie + and mcq_edge.object == mcq_node_curie + ): + if member_of_edge_id is not None: + raise ValueError("Got two member of edges!") + member_of_edge_id = mcq_edge_id + + for result_analysis in result.analyses: + edge_binding = next( + iter(result_analysis.edge_bindings.values()) + ) + edge_binding.add( + EdgeBinding.parse_obj( + { + "id": member_of_edge_id, + "attributes": [], + } + ) + ) + # handle mcq merging + new_subresult = Result.parse_obj( + { + "node_bindings": { + **subresult.node_bindings, + **result.node_bindings, + }, + "analyses": [ + *result.analyses, + *subresult.analyses, + ], + } + ) + + new_subkgraph = copy.deepcopy(subkgraph) + new_subkgraph.nodes.update(kgraph.nodes) + new_subkgraph.edges.update(kgraph.edges) + new_auxgraph = copy.deepcopy(subauxgraph) + new_auxgraph.update(auxgraph) + yield new_subkgraph, new_subresult, new_auxgraph, qid async def __aenter__(self): """Enter context.""" @@ -392,28 +568,26 @@ async def __aexit__(self, *args): # pylint: disable=arguments-differ async def setup( self, - qgraph: dict, + message: dict, backup_kps: dict, information_content_threshold: int, ): """Set up.""" # Update qgraph identifiers - message = Message.parse_obj({"query_graph": qgraph}) - curies = get_curies(message) + self.message = Message.parse_obj(message) + curies = get_curies(self.message) if len(curies): await self.normalizer.load_curies(*curies) curie_map = self.normalizer.map(curies, self.preferred_prefixes) - map_qgraph_curies(message.query_graph, curie_map, primary=True) - - self.qgraph = message.query_graph.dict() + map_qgraph_curies(self.message.query_graph, curie_map, primary=True) # Fill in missing categories and predicates using normalizer - await fill_categories_predicates(self.qgraph, self.logger) + await fill_categories_predicates(self.message.query_graph, self.logger) # Generate traversal plan self.plan, kps = await generate_plan( - self.qgraph, + self.message.query_graph, backup_kps=backup_kps, logger=self.logger, ) diff --git a/strider/graph.py b/strider/graph.py index ba4f2288..3f93ada4 100644 --- a/strider/graph.py +++ b/strider/graph.py @@ -1,34 +1,22 @@ """Graph - a dict with extra methods.""" -import json +def connected_edges(message, node_id): + """Find edges connected to node.""" + outgoing = [] + incoming = [] + for edge_id, edge in message.query_graph.edges.items(): + if node_id == edge.subject: + outgoing.append(edge_id) + if node_id == edge.object: + incoming.append(edge_id) + return outgoing, incoming -class Graph(dict): - """Graph.""" - def __init__(self, *args, **kwargs): - """Initialize.""" - super().__init__(*args, **kwargs) - - def __hash__(self): - """Compute hash.""" - return hash(json.dumps(self, sort_keys=True)) - - def connected_edges(self, node_id): - """Find edges connected to node.""" - outgoing = [] - incoming = [] - for edge_id, edge in self["edges"].items(): - if node_id == edge["subject"]: - outgoing.append(edge_id) - if node_id == edge["object"]: - incoming.append(edge_id) - return outgoing, incoming - - def remove_orphaned(self): - """Remove nodes with degree 0.""" - self["nodes"] = { - node_id: node - for node_id, node in self["nodes"].items() - if any(self.connected_edges(node_id)) - } +def remove_orphaned(message): + """Remove nodes with degree 0.""" + message.query_graph.nodes = { + node_id: node + for node_id, node in message.query_graph.nodes.items() + if any(connected_edges(message, node_id)) + } diff --git a/strider/knowledge_provider.py b/strider/knowledge_provider.py index ac06bada..e420ac80 100644 --- a/strider/knowledge_provider.py +++ b/strider/knowledge_provider.py @@ -102,15 +102,20 @@ async def map_prefixes( curie_map = self.normalizer.map(curies, prefixes) apply_curie_map(message, curie_map, self.id, self.logger) + async def get_mcq_uuid(self, curies: list[str]) -> str: + """Given a list of curies, get the MCQ uuid from NN.""" + uuid = await self.normalizer.get_mcq_uuid(curies) + return uuid + async def solve_onehop( self, request, bypass_cache: bool, call_stack: list, last_hop: bool ): """Solve one-hop query.""" - request = remove_null_values(request) + request = remove_null_values(request.dict()) response = None try: response = await self.throttle.query( - {"message": {"query_graph": request}}, + {"message": request}, bypass_cache, call_stack, last_hop, @@ -182,7 +187,7 @@ async def solve_onehop( message = response.message if message.query_graph is None: message = Message( - query_graph=QueryGraph.parse_obj(request), + query_graph=QueryGraph.parse_obj(request["query_graph"]), knowledge_graph=KnowledgeGraph.parse_obj({"nodes": {}, "edges": {}}), results=Results.parse_obj([]), auxiliary_graphs=AuxiliaryGraphs.parse_obj({}), diff --git a/strider/mcq.py b/strider/mcq.py new file mode 100644 index 00000000..e3ea280a --- /dev/null +++ b/strider/mcq.py @@ -0,0 +1,31 @@ +"""Utility functions for MCQ queries.""" + +from typing import List +from reasoner_pydantic import ( + QNode, + Result, + Edge, + KnowledgeGraph, + AuxiliaryGraphs, +) + + +def is_mcq_node(qnode: QNode) -> bool: + """Determin if query graph node is a set for MCQ (MultiCurieQuery).""" + return qnode.set_interpretation == "MANY" + + +def get_mcq_edge_ids( + result: Result, kgraph: KnowledgeGraph, auxgraph: AuxiliaryGraphs +) -> List[Edge]: + mcq_edge_ids = [] + for analysis in result.analyses: + for edge_bindings in analysis.edge_bindings.values(): + for edge_binding in edge_bindings: + kgraph_edge = edge_binding.id + for attribute in kgraph.edges[kgraph_edge].attributes: + if attribute.attribute_type_id == "biolink:support_graphs": + for auxgraph_id in attribute.value: + for edge_id in auxgraph[auxgraph_id].edges: + mcq_edge_ids.append(edge_id) + return mcq_edge_ids diff --git a/strider/node_sets.py b/strider/node_sets.py index 5b0a8712..bba36374 100644 --- a/strider/node_sets.py +++ b/strider/node_sets.py @@ -2,19 +2,19 @@ from collections import defaultdict from datetime import datetime -from reasoner_pydantic import Message +from reasoner_pydantic import Message, Query -def collapse_sets(query: dict, logger) -> None: +def collapse_sets(query: Query, logger) -> None: """Collase results according to set_interpretation qnode notations.""" # just deserializing the query_graph is very fast qgraph = query.message.query_graph.dict() - unique_qnodes = { + set_qnodes = { qnode_id for qnode_id, qnode in qgraph["nodes"].items() - if (qnode.get("set_interpretation", None) or "BATCH") == "BATCH" + if (qnode.get("set_interpretation", None) or "ALL") == "ALL" } - if len(unique_qnodes) == len(query.message.query_graph.nodes): + if len(set_qnodes) == 0: # no set nodes return logger.info("Collapsing sets. This might take a while...") @@ -22,7 +22,7 @@ def collapse_sets(query: dict, logger) -> None: unique_qedges = { qedge_id for qedge_id, qedge in message["query_graph"]["edges"].items() - if (qedge["subject"] in unique_qnodes and qedge["object"] in unique_qnodes) + if (qedge["subject"] in set_qnodes and qedge["object"] in set_qnodes) } result_buckets = defaultdict( lambda: { @@ -34,7 +34,7 @@ def collapse_sets(query: dict, logger) -> None: bucket_key = tuple( [ binding["id"] - for qnode_id in unique_qnodes + for qnode_id in set_qnodes for binding in result["node_bindings"][qnode_id] ] + [ diff --git a/strider/normalizer.py b/strider/normalizer.py index 7cfa38a1..853317b0 100644 --- a/strider/normalizer.py +++ b/strider/normalizer.py @@ -1,7 +1,9 @@ """Node Normalizer Utilities.""" from collections import namedtuple +import httpx import logging +import uuid from reasoner_pydantic import Message @@ -155,3 +157,26 @@ def map_curie( ), ) return [curie] + + async def get_mcq_uuid(self, curies: list[str]) -> str: + """Get the MCQ uuid from NN.""" + response = {} + try: + async with httpx.AsyncClient(verify=False, timeout=10.0) as client: + self.logger.debug("Sending request to NN for MCQ setid.") + res = await client.get( + f"{settings.normalizer_url}/get_setid", + params={ + "curie": curies, + "conflation": [ + "GeneProtein", + "DrugChemical", + ], + }, + ) + res.raise_for_status() + response = res.json() + except Exception as e: + self.logger.error(f"Normalizer MCQ setid failed with: {e}") + + return response.get("setid", f"uuid:unknown-{str(uuid.uuid4())}") diff --git a/strider/query_planner.py b/strider/query_planner.py index e132831d..f52f2a32 100644 --- a/strider/query_planner.py +++ b/strider/query_planner.py @@ -5,12 +5,14 @@ from itertools import chain import logging import math -from typing import Generator, Union +from reasoner_pydantic import QueryGraph +from typing import Generator, Union, Dict from strider.caching import get_kp_registry from strider.config import settings from strider.traversal import get_traversals, NoAnswersError from strider.utils import WBMT +from strider.mcq import is_mcq_node LOGGER = logging.getLogger(__name__) @@ -213,13 +215,13 @@ def get_kp_operations_queries( async def generate_plan( - qgraph: dict, + qgraph: QueryGraph, backup_kps: dict, logger: logging.Logger = None, ) -> tuple[dict[str, list[str]], dict[str, dict]]: """Generate traversal plan.""" # check that qgraph is traversable - get_traversals(qgraph) + get_traversals(qgraph.dict()) if logger is None: logger = logging.getLogger(__name__) @@ -231,38 +233,77 @@ async def generate_plan( "Unable to get kp registry from cache. Falling back to in-memory registry..." ) registry = backup_kps - for qedge_id in qgraph["edges"]: - qedge = qgraph["edges"][qedge_id] - provided_by = {"allowlist": None, "denylist": None} | qedge.pop( + for qedge_id in qgraph.edges: + inverse_kps = dict() + qedge = qgraph.edges[qedge_id] + provided_by = {"allowlist": None, "denylist": None} | qedge.dict().pop( "provided_by", {} ) - ( - subject_categories, - object_categories, - predicates, - inverse_predicates, - ) = get_kp_operations_queries( - qgraph["nodes"][qedge["subject"]]["categories"], - qedge["predicates"], - qgraph["nodes"][qedge["object"]]["categories"], - ) - direct_kps = search( - registry, - subject_categories, - predicates, - object_categories, - settings.openapi_server_maturity, - ) - if inverse_predicates: - inverse_kps = search( - registry, + if is_mcq_node(qgraph.nodes[qedge.subject]) or is_mcq_node( + qgraph.nodes[qedge.object] + ): + # TODO: update from hard-coded MCQ KPs + direct_kps = { + "infores:answer-coalesce": { + "url": "https://answercoalesce.renci.org/query", + "title": "Answer Coalescer", + "infores": "infores:answer-coalesce", + "maturity": "development", + "operations": [], + "details": {"preferred_prefixes": {}}, + }, + "infores:genetics-data-provider": { + "url": "https://translator.broadinstitute.org/genetics_provider/trapi/v1.5/query", + "title": "Genetics KP", + "infores": "infores:genetics-data-provider", + "maturity": "development", + "operations": [], + "details": {"preferred_prefixes": {}}, + }, + "infores:cohd": { + "url": "https://cohd.io/api/query", + "title": "COHD KP", + "infores": "infores:cohd", + "maturity": "development", + "operations": [], + "details": {"preferred_prefixes": {}}, + }, + "infores:semsemian": { + "url": "http://mcq-trapi.monarchinitiative.org/1.5/query", + "title": "Semsemian Monarch KP", + "infores": "infores:semsemian", + "maturity": "development", + "operations": [], + "details": {"preferred_prefixes": {}}, + }, + } + else: + # normal lookup edge + ( + subject_categories, object_categories, + predicates, inverse_predicates, + ) = get_kp_operations_queries( + qgraph.nodes[qedge.subject].categories, + qedge.predicates, + qgraph.nodes[qedge.object].categories, + ) + direct_kps = search( + registry, subject_categories, + predicates, + object_categories, settings.openapi_server_maturity, ) - else: - inverse_kps = dict() + if inverse_predicates: + inverse_kps = search( + registry, + object_categories, + inverse_predicates, + subject_categories, + settings.openapi_server_maturity, + ) kp_results = { kpid: details for kpid, details in chain(*(direct_kps.items(), inverse_kps.items())) @@ -280,8 +321,8 @@ async def generate_plan( raise NoAnswersError(msg) for kp in kp_results.values(): for op in kp["operations"]: - op["subject"] = (qedge["subject"], op.pop("subject_category")) - op["object"] = (qedge["object"], op.pop("object_category")) + op["subject"] = (qedge.subject, op.pop("subject_category")) + op["object"] = (qedge.object, op.pop("object_category")) plan[qedge_id] = list(kp_results.keys()) kps.update(kp_results) return plan, kps @@ -330,7 +371,13 @@ def get_next_qedge(qgraph): """Get next qedge to solve.""" qgraph = copy.deepcopy(qgraph) for qnode in qgraph["nodes"].values(): - if qnode.get("ids") is not None: + if ( + qnode.get("set_interpretation") == "MANY" + and len(qnode.get("member_ids") or []) > 0 + ): + # MCQ + qnode["ids"] = len(qnode["member_ids"]) + elif qnode.get("ids") is not None: qnode["ids"] = len(qnode["ids"]) else: qnode["ids"] = N diff --git a/strider/server.py b/strider/server.py index 9d615946..266fe6a1 100644 --- a/strider/server.py +++ b/strider/server.py @@ -74,7 +74,7 @@ title="Strider", description=DESCRIPTION, docs_url=None, - version="4.7.3", + version="4.8.0", terms_of_service=( "http://robokop.renci.org:7055/tos" "?service_long=Strider" @@ -450,7 +450,7 @@ async def lookup( """Perform lookup operation.""" global backup_kps lookup_start_time = time.time() - qgraph = query_dict["message"]["query_graph"] + message = query_dict.get("message", {}) log_level = query_dict.get("log_level") or "INFO" # grab information content threshold from message if exists, otherwise grab from environment @@ -471,17 +471,13 @@ async def lookup( fetcher = Fetcher(logger, bypass_cache, parameters) - logger.info(f"Doing lookup for qgraph: {json.dumps(qgraph)}") + logger.info(f"Doing lookup for message: {json.dumps(message)}") try: - await fetcher.setup(qgraph, backup_kps, information_content_threshold) - except NoAnswersError: - logger.warning("Returning no results.") + await fetcher.setup(message, backup_kps, information_content_threshold) + except NoAnswersError as e: + logger.warning(f"Returning no results. {e}") return { - "message": { - "query_graph": qgraph, - "knowledge_graph": {"nodes": {}, "edges": {}}, - "results": [], - }, + "message": message, "logs": list(log_handler.contents()), } except Exception as e: @@ -490,9 +486,11 @@ async def lookup( # Result container to make result merging much faster output_results = HashableMapping[str, Result]() - output_kgraph = KnowledgeGraph.parse_obj({"nodes": {}, "edges": {}}) + output_kgraph = KnowledgeGraph.parse_obj( + message.get("knowledge_graph") or {"nodes": {}, "edges": {}} + ) - output_auxgraphs = AuxiliaryGraphs.parse_obj({}) + output_auxgraphs = AuxiliaryGraphs.parse_obj(message.get("auxiliary_graphs") or {}) message_merging_time = 0 @@ -531,11 +529,14 @@ async def lookup( output_query = Query( message=Message( - query_graph=QueryGraph.parse_obj(qgraph), + query_graph=fetcher.message.query_graph, knowledge_graph=output_kgraph, results=results, auxiliary_graphs=output_auxgraphs, - ) + ), + log_level=log_level, + workflow=query_dict.get("workflow"), + bypass_cache=bypass_cache, ) # Collapse sets @@ -634,10 +635,13 @@ async def single_lookup(query_key): timeout=httpx.Timeout(timeout=600.0) ) as client: LOGGER.info(f"[{qid}]: Calling back to {callback}...") + query_result["parameters"] = queries[query_key].get("parameters") or {} + query_result["parameters"]["multiquery_uid"] = query_key callback_response = await client.post(callback, json=query_result) LOGGER.info( f"[{qid}]: Called back to {callback}. Status={callback_response.status_code}" ) + callback_response.raise_for_status() except Exception as e: LOGGER.error(f"[{qid}]: Callback to {callback} failed with: {e}") return query_result diff --git a/strider/throttle.py b/strider/throttle.py index e03156fa..351b7f04 100644 --- a/strider/throttle.py +++ b/strider/throttle.py @@ -32,7 +32,7 @@ remove_curies, filter_by_curie_mapping, ) -from .trapi import get_canonical_qgraphs +from .trapi import get_canonical_qgraphs, validate_message from .utils import elide_curies, log_request, remove_null_values from .caching import async_locking_cache from .config import settings @@ -266,7 +266,9 @@ async def process_batch( # Parse with reasoner_pydantic to validate response_body = ReasonerResponse.parse_obj(response_dict) + # validate_message(response_body.message.dict(), self.logger) await self.postproc(response_body, last_hop) + # validate_message(response_body.message.dict(), self.logger) new_num_results = len(response_body.message.results or []) if num_results != new_num_results: self.logger.info( @@ -433,7 +435,13 @@ async def _query( qgraphs = get_canonical_qgraphs(query.message.query_graph) for qgraph in qgraphs: - subquery = Query(message=Message(query_graph=qgraph)) + subquery = Query( + message=Message( + query_graph=qgraph, + knowledge_graph=query.message.knowledge_graph, + auxiliary_graphs=query.message.auxiliary_graphs, + ) + ) # Queue query for processing request_id = str(uuid.uuid1()) diff --git a/strider/throttle_utils.py b/strider/throttle_utils.py index 775626b8..a299f114 100644 --- a/strider/throttle_utils.py +++ b/strider/throttle_utils.py @@ -27,6 +27,18 @@ def get_curies(qgraph: QueryGraph) -> dict[str, list[str]]: } +def get_member_ids(qgraph: QueryGraph) -> dict[str, list[str]]: + """ + Pull curies from query graph and + return them as a mapping of node_id -> curie_list + """ + return { + node_id: copy.deepcopy(member_ids) + for node_id, node in qgraph.nodes.items() + if (member_ids := node.member_ids or None) is not None + } + + def get_max_num_curies(requests: list) -> int: """ Given a collection of requests, find the maximum curie length of all query graph nodes @@ -34,8 +46,11 @@ def get_max_num_curies(requests: list) -> int: total_curies = defaultdict(int) for request_payload in requests: num_curies = get_curies(request_payload.message.query_graph) + num_member_ids = get_member_ids(request_payload.message.query_graph) for qnode_id, curies in num_curies.items(): total_curies[qnode_id] += len(curies) + for qnode_id, member_ids in num_member_ids.items(): + total_curies[qnode_id] += len(member_ids) # get the max value in dict of curies return max(total_curies.values()) diff --git a/strider/trapi.py b/strider/trapi.py index 43babc74..7e57a4d7 100644 --- a/strider/trapi.py +++ b/strider/trapi.py @@ -373,45 +373,45 @@ async def fill_categories_predicates( normalizer = Normalizer(logger) # Fill in missing predicates with most general term - for edge in qg["edges"].values(): - if ("predicates" not in edge) or (edge["predicates"] is None): - edge["predicates"] = ["biolink:related_to"] + for edge in qg.edges.values(): + if edge.predicates is None: + edge.predicates = ["biolink:related_to"] # Fill in missing categories with most general term - for node in qg["nodes"].values(): - if ("categories" not in node) or (node["categories"] is None): - node["categories"] = ["biolink:NamedThing"] + for node in qg.nodes.values(): + if node.categories is None: + node.categories = ["biolink:NamedThing"] # Use node normalizer to add # a category to nodes with a curie - for node in qg["nodes"].values(): - node_id = node.get("ids", None) + for node in qg.nodes.values(): + node_id = node.ids if not node_id: if ( - "biolink:Gene" in node["categories"] - and "biolink:Protein" not in node["categories"] + "biolink:Gene" in node.categories + and "biolink:Protein" not in node.categories ): - node["categories"].append("biolink:Protein") + node.categories.append("biolink:Protein") if ( - "biolink:Protein" in node["categories"] - and "biolink:Gene" not in node["categories"] + "biolink:Protein" in node.categories + and "biolink:Gene" not in node.categories ): - node["categories"].append("biolink:Gene") + node.categories.append("biolink:Gene") else: if not isinstance(node_id, list): node_id = [node_id] # Get full list of categorys - categories = await normalizer.get_types(node_id) + categories = await normalizer.get_types(list(node_id)) # Remove duplicates categories = list(set(categories)) if categories: # Filter categorys that are ancestors of other categorys we were given - node["categories"] = filter_ancestor_types(categories) - elif "categories" not in node: - node["categories"] = [] + node.categories = filter_ancestor_types(categories) + elif node.categories is None: + node.categories = [] def convert_subclasses_to_aux_graphs( diff --git a/strider/utils.py b/strider/utils.py index 10a8ef9d..aa80666c 100644 --- a/strider/utils.py +++ b/strider/utils.py @@ -126,6 +126,10 @@ def elide_curies(payload): for qnode in payload["message"]["query_graph"]["nodes"].values(): if (num_curies := len(qnode.get("ids", None) or [])) > 10: qnode["ids"] = f"**{num_curies} CURIEs not shown for brevity**" + if "query_graph" in payload: # TRAPI message + for qnode in payload["query_graph"]["nodes"].values(): + if (num_curies := len(qnode.get("ids", None) or [])) > 10: + qnode["ids"] = f"**{num_curies} CURIEs not shown for brevity**" return payload @@ -257,9 +261,7 @@ def remove_null_values(obj): return { key: remove_null_values(value) for key, value in obj.items() - # TODO: also handle empty lists - # need to take this out when automats get fixed - if value is not None and value != [] + if value is not None } elif isinstance(obj, list): return [remove_null_values(el) for el in obj] diff --git a/tests/test_fetcher.py b/tests/test_fetcher.py index acc5e1ef..120a3f80 100644 --- a/tests/test_fetcher.py +++ b/tests/test_fetcher.py @@ -60,9 +60,10 @@ async def test_fetcher_bad_response(monkeypatch, httpx_mock: HTTPXMock): n1(( category biolink:NamedThing )) """ ) + message = {"query_graph": QGRAPH} fetcher = Fetcher(logger, False, {}) - await fetcher.setup(QGRAPH, {}, 75) + await fetcher.setup(message, {}, 75) num_responses = 0 diff --git a/tests/test_lookup.py b/tests/test_lookup.py index c8bf74a1..a765219f 100644 --- a/tests/test_lookup.py +++ b/tests/test_lookup.py @@ -58,14 +58,14 @@ async def test_mixed_canonical(monkeypatch, mocker): "ids": ["CHEBI:6801"], "categories": ["biolink:ChemicalSubstance"], "set_interpretation": "BATCH", - # "member_ids": [], - # "constraints": [], + "member_ids": [], + "constraints": [], }, "n1": { "categories": ["biolink:Disease"], "set_interpretation": "BATCH", - # "member_ids": [], - # "constraints": [], + "member_ids": [], + "constraints": [], }, }, "edges": { @@ -73,8 +73,8 @@ async def test_mixed_canonical(monkeypatch, mocker): "subject": "n0", "object": "n1", "predicates": ["biolink:treats", "biolink:phenotype_of"], - # "attribute_constraints": [], - # "qualifier_constraints": [], + "attribute_constraints": [], + "qualifier_constraints": [], }, }, }, @@ -121,14 +121,14 @@ async def test_symmetric_noncanonical(monkeypatch, mocker): "ids": ["CHEBI:6801"], "categories": ["biolink:ChemicalSubstance"], "set_interpretation": "BATCH", - # "member_ids": [], - # "constraints": [], + "member_ids": [], + "constraints": [], }, "n1": { "categories": ["biolink:Disease"], "set_interpretation": "BATCH", - # "member_ids": [], - # "constraints": [], + "member_ids": [], + "constraints": [], }, }, "edges": { @@ -136,8 +136,8 @@ async def test_symmetric_noncanonical(monkeypatch, mocker): "subject": "n0", "object": "n1", "predicates": ["biolink:genetically_interacts_with"], - # "attribute_constraints": [], - # "qualifier_constraints": [], + "attribute_constraints": [], + "qualifier_constraints": [], }, }, }, @@ -236,14 +236,14 @@ async def test_protein_gene_conflation(monkeypatch, mocker): "ids": ["MONDO:0008114"], "categories": ["biolink:Disease"], "set_interpretation": "BATCH", - # "member_ids": [], - # "constraints": [], + "member_ids": [], + "constraints": [], }, "n1": { "categories": ["biolink:Protein", "biolink:Gene"], "set_interpretation": "BATCH", - # "member_ids": [], - # "constraints": [], + "member_ids": [], + "constraints": [], }, }, "edges": { @@ -251,8 +251,8 @@ async def test_protein_gene_conflation(monkeypatch, mocker): "subject": "n0", "object": "n1", "predicates": ["biolink:related_to"], - # "attribute_constraints": [], - # "qualifier_constraints": [], + "attribute_constraints": [], + "qualifier_constraints": [], }, }, }, @@ -296,15 +296,15 @@ async def test_gene_protein_conflation(monkeypatch, mocker): "n0": { "categories": ["biolink:Gene", "biolink:Protein"], "set_interpretation": "BATCH", - # "member_ids": [], - # "constraints": [], + "member_ids": [], + "constraints": [], }, "n1": { "ids": ["MONDO:0008114"], "categories": ["biolink:Disease"], "set_interpretation": "BATCH", - # "member_ids": [], - # "constraints": [], + "member_ids": [], + "constraints": [], }, }, "edges": { @@ -312,8 +312,8 @@ async def test_gene_protein_conflation(monkeypatch, mocker): "subject": "n0", "object": "n1", "predicates": ["biolink:related_to"], - # "attribute_constraints": [], - # "qualifier_constraints": [], + "attribute_constraints": [], + "qualifier_constraints": [], }, }, }, @@ -361,14 +361,14 @@ async def test_node_set(monkeypatch, mocker): "ids": ["CHEBI:6801"], "categories": ["biolink:ChemicalSubstance"], "set_interpretation": "BATCH", - # "member_ids": [], - # "constraints": [], + "member_ids": [], + "constraints": [], }, "n1": { "categories": ["biolink:Disease"], "set_interpretation": "ALL", - # "member_ids": [], - # "constraints": [], + "member_ids": [], + "constraints": [], }, }, "edges": { @@ -376,8 +376,8 @@ async def test_node_set(monkeypatch, mocker): "subject": "n0", "object": "n1", "predicates": ["biolink:treats"], - # "attribute_constraints": [], - # "qualifier_constraints": [], + "attribute_constraints": [], + "qualifier_constraints": [], }, }, }, @@ -425,14 +425,14 @@ async def test_bypass_cache_is_sent_along_to_kps(monkeypatch, mocker): "ids": ["CHEBI:6801"], "categories": ["biolink:ChemicalSubstance"], "set_interpretation": "BATCH", - # "member_ids": [], - # "constraints": [], + "member_ids": [], + "constraints": [], }, "n1": { "categories": ["biolink:Disease"], "set_interpretation": "BATCH", - # "member_ids": [], - # "constraints": [], + "member_ids": [], + "constraints": [], }, }, "edges": { @@ -440,8 +440,8 @@ async def test_bypass_cache_is_sent_along_to_kps(monkeypatch, mocker): "subject": "n0", "object": "n1", "predicates": ["biolink:treats"], - # "attribute_constraints": [], - # "qualifier_constraints": [], + "attribute_constraints": [], + "qualifier_constraints": [], }, }, }, diff --git a/tests/test_query_planner.py b/tests/test_query_planner.py index bcedde9f..69ec1d77 100644 --- a/tests/test_query_planner.py +++ b/tests/test_query_planner.py @@ -2,6 +2,7 @@ import pytest from pytest_httpx import HTTPXMock import redis.asyncio +from reasoner_pydantic import QueryGraph from tests.helpers.utils import ( query_graph_from_string, @@ -39,6 +40,7 @@ async def test_not_enough_kps(monkeypatch): n0-- biolink:related_to -->n1 """ ) + qg = QueryGraph.parse_obj(qg) with pytest.raises(NoAnswersError, match=r"cannot reach"): plan, kps = await generate_plan(qg, {}, logger=logging.getLogger()) @@ -60,6 +62,7 @@ async def test_plan_reverse_edge(monkeypatch): n1-- biolink:treats -->n0 """ ) + qg = QueryGraph.parse_obj(qg) plan, kps = await generate_plan(qg, {}) assert plan == {"n1n0": ["infores:kp1"]} @@ -84,6 +87,7 @@ async def test_plan_loop(monkeypatch): n2-- biolink:treats -->n1 """ ) + qg = QueryGraph.parse_obj(qg) plan, _ = await generate_plan(qg, {}) @@ -109,6 +113,7 @@ async def test_plan_reuse_pinned(monkeypatch): n0-- biolink:related_to -->n3 """ ) + qg = QueryGraph.parse_obj(qg) plan, kps = await generate_plan(qg, {}) @@ -135,6 +140,7 @@ async def test_plan_double_loop(monkeypatch): n4-- biolink:related_to -->n2 """ ) + qg = QueryGraph.parse_obj(qg) plan, kps = await generate_plan(qg, {}) @@ -156,6 +162,7 @@ async def test_valid_two_pinned_nodes(monkeypatch): n2(( categories[] biolink:Disease )) """ ) + qg = QueryGraph.parse_obj(qg) await prepare_query_graph(qg) plan, kps = await generate_plan(qg, {}) @@ -180,26 +187,19 @@ async def test_fork(monkeypatch): n0-- biolink:has_phenotype -->n2 """ ) + qg = QueryGraph.parse_obj(qg) plan, kps = await generate_plan(qg, {}) @pytest.mark.asyncio -async def test_unbound_unconnected_node(monkeypatch, httpx_mock: HTTPXMock): +async def test_unbound_unconnected_node(monkeypatch): """ Test Pinned -> Unbound + Unbound This should be invalid because there is no path to the unbound node """ monkeypatch.setattr(redis.asyncio, "Redis", redisMock) - httpx_mock.add_response( - url="http://normalizer/get_normalized_nodes", - json=get_normalizer_response( - """ - MONDO:0005148 categories biolink:Disease - """ - ), - ) qg = query_graph_from_string( """ n0(( ids[] MONDO:0005148 )) @@ -208,7 +208,9 @@ async def test_unbound_unconnected_node(monkeypatch, httpx_mock: HTTPXMock): n2(( categories[] biolink:PhenotypicFeature )) """ ) + qg = QueryGraph.parse_obj(qg) await prepare_query_graph(qg) + print(qg) with pytest.raises(NoAnswersError, match=r"cannot reach"): plan, kps = await generate_plan(qg, {}) @@ -234,6 +236,7 @@ async def test_valid_two_disconnected_components(monkeypatch): n2-- biolink:treated_by -->n3 """ ) + qg = QueryGraph.parse_obj(qg) plan, kps = await generate_plan(qg, {}) assert plan == {"n0n1": ["infores:kp1"], "n2n3": ["infores:kp1"]} @@ -263,6 +266,7 @@ async def test_bad_norm(monkeypatch): } }, } + qg = QueryGraph.parse_obj(qg) plan, kps = await generate_plan(qg, {}) assert plan == {"e01": ["infores:kp1"]} @@ -282,6 +286,7 @@ async def test_double_sided(monkeypatch): n0-- biolink:treated_by -->n1 """ ) + qg = QueryGraph.parse_obj(qg) plan, kps = await generate_plan(qg, {}, logger=logging.getLogger()) assert plan == {"n0n1": ["infores:kp1"]} assert "infores:kp1" in kps @@ -376,6 +381,7 @@ async def test_predicate_fanout(monkeypatch): } }, } + qg = QueryGraph.parse_obj(qg) plan, kps = await generate_plan(qg, {}, logger=logging.getLogger()) assert plan == {"ab": ["infores:kp2"]} @@ -397,6 +403,7 @@ async def test_inverse_predicate(monkeypatch): n0-- biolink:treated_by -->n1 """ ) + qg = QueryGraph.parse_obj(qg) plan, kps = await generate_plan(qg, {}, logger=logging.getLogger()) assert plan == {"n0n1": ["infores:kp1"]} @@ -417,6 +424,7 @@ async def test_symmetric_predicate(monkeypatch): n1-- biolink:correlated_with -->n0 """ ) + qg = QueryGraph.parse_obj(qg) plan, kps = await generate_plan(qg, {}, logger=logging.getLogger()) assert plan == {"n1n0": ["infores:kp1"]} assert "infores:kp1" in kps @@ -439,6 +447,7 @@ async def test_subpredicate(monkeypatch): } }, } + qg = QueryGraph.parse_obj(qg) plan, kps = await generate_plan(qg, {}, logger=logging.getLogger()) assert plan == {"ab": ["infores:kp1"]} @@ -460,6 +469,7 @@ async def test_solve_double_subclass(monkeypatch): n0-- biolink:ameliorates -->n1 """ ) + qg = QueryGraph.parse_obj(qg) plan, kps = await generate_plan(qg, {}, logger=logging.getLogger()) assert plan == {"n0n1": ["infores:kp1"]} @@ -482,6 +492,7 @@ async def test_pinned_to_pinned(monkeypatch): n0-- biolink:related_to -->n1 """ ) + qg = QueryGraph.parse_obj(qg) plan, kps = await generate_plan(qg, {}, logger=logging.getLogger()) assert plan == {"n0n1": ["infores:kp3"]} @@ -501,6 +512,7 @@ async def test_self_edge(monkeypatch): n0-- biolink:related_to -->n0 """ ) + qg = QueryGraph.parse_obj(qg) # await prepare_query_graph(qg) plan, kps = await generate_plan(qg, {}, logger=logging.getLogger()) diff --git a/tests/test_server.py b/tests/test_server.py index 7afa80d9..5353021d 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -149,14 +149,14 @@ async def test_solve_missing_predicate(monkeypatch, mocker): "ids": ["HP:001"], "categories": ["biolink:Gene"], "set_interpretation": "BATCH", - # "member_ids": [], - # "constraints": [], + "member_ids": [], + "constraints": [], }, "n1": { "categories": ["biolink:Gene", "biolink:Protein"], "set_interpretation": "BATCH", - # "member_ids": [], - # "constraints": [], + "member_ids": [], + "constraints": [], }, }, "edges": { @@ -164,8 +164,8 @@ async def test_solve_missing_predicate(monkeypatch, mocker): "subject": "n0", "object": "n1", "predicates": ["biolink:related_to"], - # "attribute_constraints": [], - # "qualifier_constraints": [], + "attribute_constraints": [], + "qualifier_constraints": [], } }, } @@ -211,14 +211,14 @@ async def test_solve_missing_category(monkeypatch, mocker): "ids": ["CHEBI:6801"], "categories": ["biolink:NamedThing"], "set_interpretation": "BATCH", - # "member_ids": [], - # "constraints": [], + "member_ids": [], + "constraints": [], }, "n1": { "categories": ["biolink:Disease"], "set_interpretation": "BATCH", - # "member_ids": [], - # "constraints": [], + "member_ids": [], + "constraints": [], }, }, "edges": { @@ -226,8 +226,8 @@ async def test_solve_missing_category(monkeypatch, mocker): "subject": "n0", "object": "n1", "predicates": ["biolink:treats"], - # "attribute_constraints": [], - # "qualifier_constraints": [], + "attribute_constraints": [], + "qualifier_constraints": [], } }, } @@ -284,14 +284,14 @@ async def test_normalizer_different_category( "ids": ["CHEBI:6801"], "categories": ["biolink:Vitamin"], "set_interpretation": "BATCH", - # "member_ids": [], - # "constraints": [], + "member_ids": [], + "constraints": [], }, "n1": { "categories": ["biolink:Disease"], "set_interpretation": "BATCH", - # "member_ids": [], - # "constraints": [], + "member_ids": [], + "constraints": [], }, }, "edges": { @@ -299,8 +299,8 @@ async def test_normalizer_different_category( "subject": "n0", "object": "n1", "predicates": ["biolink:treats"], - # "attribute_constraints": [], - # "qualifier_constraints": [], + "attribute_constraints": [], + "qualifier_constraints": [], } }, }