Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strider MCQ #458

Merged
merged 16 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
402 changes: 288 additions & 114 deletions strider/fetcher.py

Large diffs are not rendered by default.

46 changes: 17 additions & 29 deletions strider/graph.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,22 @@
"""Graph - a dict with extra methods."""

import json

def connected_edges(message, node_id):
"""Find edges connected to node."""
outgoing = []
incoming = []
for edge_id, edge in message.query_graph.edges.items():
if node_id == edge.subject:
outgoing.append(edge_id)
if node_id == edge.object:
incoming.append(edge_id)
return outgoing, incoming

class Graph(dict):
"""Graph."""

def __init__(self, *args, **kwargs):
"""Initialize."""
super().__init__(*args, **kwargs)

def __hash__(self):
"""Compute hash."""
return hash(json.dumps(self, sort_keys=True))

def connected_edges(self, node_id):
"""Find edges connected to node."""
outgoing = []
incoming = []
for edge_id, edge in self["edges"].items():
if node_id == edge["subject"]:
outgoing.append(edge_id)
if node_id == edge["object"]:
incoming.append(edge_id)
return outgoing, incoming

def remove_orphaned(self):
"""Remove nodes with degree 0."""
self["nodes"] = {
node_id: node
for node_id, node in self["nodes"].items()
if any(self.connected_edges(node_id))
}
def remove_orphaned(message):
"""Remove nodes with degree 0."""
message.query_graph.nodes = {
node_id: node
for node_id, node in message.query_graph.nodes.items()
if any(connected_edges(message, node_id))
}
11 changes: 8 additions & 3 deletions strider/knowledge_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,20 @@ async def map_prefixes(
curie_map = self.normalizer.map(curies, prefixes)
apply_curie_map(message, curie_map, self.id, self.logger)

async def get_mcq_uuid(self, curies: list[str]) -> str:
"""Given a list of curies, get the MCQ uuid from NN."""
uuid = await self.normalizer.get_mcq_uuid(curies)
return uuid

async def solve_onehop(
self, request, bypass_cache: bool, call_stack: list, last_hop: bool
):
"""Solve one-hop query."""
request = remove_null_values(request)
request = remove_null_values(request.dict())
response = None
try:
response = await self.throttle.query(
{"message": {"query_graph": request}},
{"message": request},
bypass_cache,
call_stack,
last_hop,
Expand Down Expand Up @@ -182,7 +187,7 @@ async def solve_onehop(
message = response.message
if message.query_graph is None:
message = Message(
query_graph=QueryGraph.parse_obj(request),
query_graph=QueryGraph.parse_obj(request["query_graph"]),
knowledge_graph=KnowledgeGraph.parse_obj({"nodes": {}, "edges": {}}),
results=Results.parse_obj([]),
auxiliary_graphs=AuxiliaryGraphs.parse_obj({}),
Expand Down
31 changes: 31 additions & 0 deletions strider/mcq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Utility functions for MCQ queries."""

from typing import List
from reasoner_pydantic import (
QNode,
Result,
Edge,
KnowledgeGraph,
AuxiliaryGraphs,
)


def is_mcq_node(qnode: QNode) -> bool:
"""Determin if query graph node is a set for MCQ (MultiCurieQuery)."""
return qnode.set_interpretation == "MANY"


def get_mcq_edge_ids(
result: Result, kgraph: KnowledgeGraph, auxgraph: AuxiliaryGraphs
) -> List[Edge]:
mcq_edge_ids = []
for analysis in result.analyses:
for edge_bindings in analysis.edge_bindings.values():
for edge_binding in edge_bindings:
kgraph_edge = edge_binding.id
for attribute in kgraph.edges[kgraph_edge].attributes:
if attribute.attribute_type_id == "biolink:support_graphs":
for auxgraph_id in attribute.value:
for edge_id in auxgraph[auxgraph_id].edges:
mcq_edge_ids.append(edge_id)
return mcq_edge_ids
14 changes: 7 additions & 7 deletions strider/node_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,27 @@

from collections import defaultdict
from datetime import datetime
from reasoner_pydantic import Message
from reasoner_pydantic import Message, Query


def collapse_sets(query: dict, logger) -> None:
def collapse_sets(query: Query, logger) -> None:
"""Collase results according to set_interpretation qnode notations."""
# just deserializing the query_graph is very fast
qgraph = query.message.query_graph.dict()
unique_qnodes = {
set_qnodes = {
qnode_id
for qnode_id, qnode in qgraph["nodes"].items()
if (qnode.get("set_interpretation", None) or "BATCH") == "BATCH"
if (qnode.get("set_interpretation", None) or "ALL") == "ALL"
}
if len(unique_qnodes) == len(query.message.query_graph.nodes):
if len(set_qnodes) == 0:
# no set nodes
return
logger.info("Collapsing sets. This might take a while...")
message = query.message.dict()
unique_qedges = {
qedge_id
for qedge_id, qedge in message["query_graph"]["edges"].items()
if (qedge["subject"] in unique_qnodes and qedge["object"] in unique_qnodes)
if (qedge["subject"] in set_qnodes and qedge["object"] in set_qnodes)
}
result_buckets = defaultdict(
lambda: {
Expand All @@ -34,7 +34,7 @@ def collapse_sets(query: dict, logger) -> None:
bucket_key = tuple(
[
binding["id"]
for qnode_id in unique_qnodes
for qnode_id in set_qnodes
for binding in result["node_bindings"][qnode_id]
]
+ [
Expand Down
25 changes: 25 additions & 0 deletions strider/normalizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Node Normalizer Utilities."""

from collections import namedtuple
import httpx
import logging
import uuid

from reasoner_pydantic import Message

Expand Down Expand Up @@ -155,3 +157,26 @@ def map_curie(
),
)
return [curie]

async def get_mcq_uuid(self, curies: list[str]) -> str:
"""Get the MCQ uuid from NN."""
response = {}
try:
async with httpx.AsyncClient(verify=False, timeout=10.0) as client:
self.logger.debug("Sending request to NN for MCQ setid.")
res = await client.get(
f"{settings.normalizer_url}/get_setid",
params={
"curie": curies,
"conflation": [
"GeneProtein",
"DrugChemical",
],
},
)
res.raise_for_status()
response = res.json()
except Exception as e:
self.logger.error(f"Normalizer MCQ setid failed with: {e}")

return response.get("setid", f"uuid:unknown-{str(uuid.uuid4())}")
109 changes: 78 additions & 31 deletions strider/query_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from itertools import chain
import logging
import math
from typing import Generator, Union
from reasoner_pydantic import QueryGraph
from typing import Generator, Union, Dict

from strider.caching import get_kp_registry
from strider.config import settings
from strider.traversal import get_traversals, NoAnswersError
from strider.utils import WBMT
from strider.mcq import is_mcq_node

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -213,13 +215,13 @@ def get_kp_operations_queries(


async def generate_plan(
qgraph: dict,
qgraph: QueryGraph,
backup_kps: dict,
logger: logging.Logger = None,
) -> tuple[dict[str, list[str]], dict[str, dict]]:
"""Generate traversal plan."""
# check that qgraph is traversable
get_traversals(qgraph)
get_traversals(qgraph.dict())

if logger is None:
logger = logging.getLogger(__name__)
Expand All @@ -231,38 +233,77 @@ async def generate_plan(
"Unable to get kp registry from cache. Falling back to in-memory registry..."
)
registry = backup_kps
for qedge_id in qgraph["edges"]:
qedge = qgraph["edges"][qedge_id]
provided_by = {"allowlist": None, "denylist": None} | qedge.pop(
for qedge_id in qgraph.edges:
inverse_kps = dict()
qedge = qgraph.edges[qedge_id]
provided_by = {"allowlist": None, "denylist": None} | qedge.dict().pop(
"provided_by", {}
)
(
subject_categories,
object_categories,
predicates,
inverse_predicates,
) = get_kp_operations_queries(
qgraph["nodes"][qedge["subject"]]["categories"],
qedge["predicates"],
qgraph["nodes"][qedge["object"]]["categories"],
)
direct_kps = search(
registry,
subject_categories,
predicates,
object_categories,
settings.openapi_server_maturity,
)
if inverse_predicates:
inverse_kps = search(
registry,
if is_mcq_node(qgraph.nodes[qedge.subject]) or is_mcq_node(
qgraph.nodes[qedge.object]
):
# TODO: update from hard-coded MCQ KPs
direct_kps = {
"infores:answer-coalesce": {
"url": "https://answercoalesce.renci.org/query",
"title": "Answer Coalescer",
"infores": "infores:answer-coalesce",
"maturity": "development",
"operations": [],
"details": {"preferred_prefixes": {}},
},
"infores:genetics-data-provider": {
"url": "https://translator.broadinstitute.org/genetics_provider/trapi/v1.5/query",
"title": "Genetics KP",
"infores": "infores:genetics-data-provider",
"maturity": "development",
"operations": [],
"details": {"preferred_prefixes": {}},
},
"infores:cohd": {
"url": "https://cohd.io/api/query",
"title": "COHD KP",
"infores": "infores:cohd",
"maturity": "development",
"operations": [],
"details": {"preferred_prefixes": {}},
},
"infores:semsemian": {
"url": "http://mcq-trapi.monarchinitiative.org/1.5/query",
"title": "Semsemian Monarch KP",
"infores": "infores:semsemian",
"maturity": "development",
"operations": [],
"details": {"preferred_prefixes": {}},
},
}
else:
# normal lookup edge
(
subject_categories,
object_categories,
predicates,
inverse_predicates,
) = get_kp_operations_queries(
qgraph.nodes[qedge.subject].categories,
qedge.predicates,
qgraph.nodes[qedge.object].categories,
)
direct_kps = search(
registry,
subject_categories,
predicates,
object_categories,
settings.openapi_server_maturity,
)
else:
inverse_kps = dict()
if inverse_predicates:
inverse_kps = search(
registry,
object_categories,
inverse_predicates,
subject_categories,
settings.openapi_server_maturity,
)
kp_results = {
kpid: details
for kpid, details in chain(*(direct_kps.items(), inverse_kps.items()))
Expand All @@ -280,8 +321,8 @@ async def generate_plan(
raise NoAnswersError(msg)
for kp in kp_results.values():
for op in kp["operations"]:
op["subject"] = (qedge["subject"], op.pop("subject_category"))
op["object"] = (qedge["object"], op.pop("object_category"))
op["subject"] = (qedge.subject, op.pop("subject_category"))
op["object"] = (qedge.object, op.pop("object_category"))
plan[qedge_id] = list(kp_results.keys())
kps.update(kp_results)
return plan, kps
Expand Down Expand Up @@ -330,7 +371,13 @@ def get_next_qedge(qgraph):
"""Get next qedge to solve."""
qgraph = copy.deepcopy(qgraph)
for qnode in qgraph["nodes"].values():
if qnode.get("ids") is not None:
if (
qnode.get("set_interpretation") == "MANY"
and len(qnode.get("member_ids") or []) > 0
):
# MCQ
qnode["ids"] = len(qnode["member_ids"])
elif qnode.get("ids") is not None:
qnode["ids"] = len(qnode["ids"])
else:
qnode["ids"] = N
Expand Down
Loading
Loading