Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Narrow deep paths mutation #4

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
8 changes: 8 additions & 0 deletions config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@
MUTPB_FV_SAMPLE_MAXN = 32 # max n of instantiations to sample from top k
MUTPB_FV_QUERY_LIMIT = 256 # SPARQL query limit for the top k instantiations
MUTPB_SP = 0.05 # prob to simplify pattern (warning: can restrict exploration)
MUTPB_DN = 0.05 # prob to try a deep and narrow paths mutation
MUTPB_DN_MIN_LEN = 2 # minimum length of the deep and narrow paths
MUTPB_DN_MAX_LEN = 10 # max of path length if not stopped by term_pb
MUTPB_DN_TERM_PB = 0.7 # prob to terminate expansion each step > min_len
MUTPB_DN_MAX_NODE_COUNT = 10 # edge fixations may have <= nodes
MUTPB_DN_MIN_EDGE_COUNT = 2 # edges need to be valid for >= GTPs
MUTPB_DN_QUERY_LIMIT = 32 # SPARQL query limit for top edge fixations
MUTPB_DN_REC_RETRIES = 3 # retrial attempts in each recursion, WARNING: EXP!

# for import in helpers and __init__
__all__ = [_v for _v in globals().keys() if _v.isupper()]
139 changes: 136 additions & 3 deletions gp_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from gp_query import predict_query
from gp_query import query_time_hard_exceeded
from gp_query import query_time_soft_exceeded
from gp_query import dnp_query
from gp_query import variable_substitution_query
from graph_pattern import canonicalize
from graph_pattern import gen_random_var
Expand Down Expand Up @@ -653,6 +654,137 @@ def mutate_fix_var(
return res


def _mutate_deep_narrow_path_helper(
sparql, timeout, gtp_scores, child, edge_var, node_var,
gtp_sample_n=config.MUTPB_FV_RGTP_SAMPLE_N,
max_node_count=config.MUTPB_DN_MAX_NODE_COUNT,
min_edge_count=config.MUTPB_DN_MIN_EDGE_COUNT,
limit=config.MUTPB_DN_QUERY_LIMIT,
sample_n=config.MUTPB_FV_SAMPLE_MAXN,
):
assert isinstance(child, GraphPattern)
assert isinstance(gtp_scores, GTPScores)

# The further we get, the less gtps are remaining. Sampling too many (all)
# of them might hurt as common substitutions (> limit ones) which are dead
# ends could cover less common ones that could actually help
gtp_sample_n = min(gtp_sample_n, int(gtp_scores.remaining_gain))
gtp_sample_n = random.randint(1, gtp_sample_n)

ground_truth_pairs = gtp_scores.remaining_gain_sample_gtps(
max_n=gtp_sample_n)
t, substitution_counts = dnp_query(
sparql, timeout, child, ground_truth_pairs,
edge_var=edge_var,
node_var=node_var,
max_node_count=max_node_count,
min_edge_count=min_edge_count,
limit=limit,
)
edge_count, node_sum_count = substitution_counts
if not node_sum_count:
# the current pattern is unfit, as we can't find anything fulfilling it
logger.debug("tried to fix a var %s without result:\n%s"
"seems as if the pattern can't be fulfilled!",
edge_var, child.to_sparql_select_query())
fixed = False
return child, fixed
mutate_fix_var_filter(node_sum_count)
mutate_fix_var_filter(edge_count)
if not node_sum_count:
# could have happened that we removed the only possible substitution
fixed = False
return child, fixed

prio = Counter()
for edge, node_sum in node_sum_count.items():
ec = edge_count[edge]
prio[edge] = ec / (node_sum / ec) # ec / AVG degree
# randomly pick n of the substitutions with a prob ~ to their prios
edges, prios = zip(*prio.most_common())

substs = sample_from_list(edges, prios, sample_n)
logger.info(
'fixed variable %s to %s in %s\n %s\n<%d out of:\n%s\n',
edge_var.n3(),
substs[0] if substs else '',
child,
'\n '.join([subst.n3() for subst in substs]),
sample_n,
'\n'.join([
' %.3f: %s' % (c, v.n3()) for v, c in prio.most_common()]),
)
fixed = True
children = [
GraphPattern(child, mapping={edge_var: subst})
for subst in substs
]
children = [
c for c in children if fit_to_live(c)
]
if children:
child = children[0]
return child, fixed


def mutate_deep_narrow_path(
child, sparql, timeout, gtp_scores,
_rec_depth=0,
start_node=None,
target_nodes = None,
min_len=config.MUTPB_DN_MIN_LEN,
max_len=config.MUTPB_DN_MAX_LEN,
term_pb=config.MUTPB_DN_TERM_PB,
pb_en_out_link=config.MUTPB_EN_OUT_LINK,
retries=config.MUTPB_DN_REC_RETRIES,
):
assert isinstance(child, GraphPattern)
assert min_len > 0
if _rec_depth > max_len:
return None
if _rec_depth >= min_len and random.random() < term_pb:
return None
if _rec_depth == 0:
nodes = child.nodes
if not start_node:
start_node = random.choice(list(nodes))
target_nodes = list(nodes - {start_node})
if _rec_depth >=min_len:
closed_gp = child
for node in target_nodes:
var_edge_to_target = gen_random_var()
if random.random() < pb_en_out_link:
new_triple = (start_node, var_edge_to_target, node)
else:
new_triple = (node, var_edge_to_target, start_node)
closed_gp += [new_triple]
closed_gp, fixed_edge_to_target = _mutate_deep_narrow_path_helper(
sparql, timeout, gtp_scores, closed_gp,var_edge_to_target, node)
if fixed_edge_to_target:
return closed_gp

gp = child
new_triple, var_node, var_edge = _mutate_expand_node_helper(start_node)
gp += [new_triple]
for r in range(retries):
fixed_gp, fixed = _mutate_deep_narrow_path_helper(
sparql, timeout, gtp_scores, gp, var_edge, var_node)
rec_gp = mutate_deep_narrow_path(
fixed_gp, sparql, timeout, gtp_scores,
_rec_depth+1,
start_node=var_node,
target_nodes = target_nodes,
)
if rec_gp:
return rec_gp
if fixed:
if _rec_depth > min_len:
return fixed_gp
if _rec_depth == 0:
return child
return None


def mutate_simplify_pattern(gp):
if len(gp) < 2:
return gp
Expand Down Expand Up @@ -757,6 +889,7 @@ def mutate(
pb_dt=config.MUTPB_DT,
pb_en=config.MUTPB_EN,
pb_fv=config.MUTPB_FV,
pb_dn=config.MUTPB_DN,
pb_id=config.MUTPB_ID,
pb_iv=config.MUTPB_IV,
pb_mv=config.MUTPB_MV,
Expand Down Expand Up @@ -796,15 +929,15 @@ def mutate(
if random.random() < pb_sp:
child = mutate_simplify_pattern(child)

if random.random() < pb_dn:
child = mutate_deep_narrow_path(child, sparql, timeout, gtp_scores)

if random.random() < pb_fv:
child = canonicalize(child)
children = mutate_fix_var(sparql, timeout, gtp_scores, child)
else:
children = [child]


# TODO: deep & narrow paths mutation

children = {
c if fit_to_live(c) else orig_child
for c in children
Expand Down
75 changes: 74 additions & 1 deletion gp_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from graph_pattern import TARGET_VAR
from graph_pattern import ASK_VAR
from graph_pattern import COUNT_VAR
from graph_pattern import NODE_VAR_SUM
from graph_pattern import EDGE_VAR_COUNT
from utils import exception_stack_catcher
from utils import sparql_json_result_bindings_to_rdflib
from utils import timer
Expand Down Expand Up @@ -279,7 +281,6 @@ def _combined_chunk_res(q_res, _vars, _ret_val_mapping):
return chunk_res



def count_query(sparql, timeout, graph_pattern, source=None,
**kwds):
assert isinstance(graph_pattern, GraphPattern)
Expand Down Expand Up @@ -457,6 +458,78 @@ def _var_subst_res_update(res, update, **_):
res += update


def dnp_query(
sparql, timeout, graph_pattern, source_target_pairs,
edge_var, node_var, max_node_count, min_edge_count, limit,
batch_size=config.BATCH_SIZE
):
_vars, _values, _ret_val_mapping = _get_vars_values_mapping(
graph_pattern, source_target_pairs)
return _multi_query(
sparql, timeout, graph_pattern, source_target_pairs, batch_size,
_vars, _values, _ret_val_mapping,
_dnp_res_init, _dnp_chunk_q,
_dnp_chunk_result_ext,
_res_update=_dnp_res_update,
edge_var=edge_var,
node_var=node_var,
max_node_count=max_node_count,
min_edge_count=min_edge_count,
limit=limit,
# non standard, passed via **kwds, see handling below
)


# noinspection PyUnusedLocal
def _dnp_res_init(_, **kwds):
return Counter(), Counter()


def _dnp_chunk_q(
gp, _vars, values_chunk,
edge_var, node_var, max_node_count, min_edge_count, limit,
**_
):
return gp.to_deep_narrow_path_query(
edge_var=edge_var,
node_var=node_var,
vars_=_vars,
values={_vars: values_chunk},
max_node_count=max_node_count,
min_edge_count=min_edge_count,
limit=limit,
)


# noinspection PyUnusedLocal
def _dnp_chunk_result_ext(
q_res, _vars, _,
edge_var,
**kwds
):
chunk_edge_count, chunk_node_sum = Counter(), Counter()
res_rows_path = ['results', 'bindings']
bindings = sparql_json_result_bindings_to_rdflib(
get_path(q_res, res_rows_path, default=[])
)

for row in bindings:
row_res = get_path(row, [edge_var])
edge_count = int(get_path(row, [EDGE_VAR_COUNT], '0'))
chunk_edge_count[row_res] += edge_count
node_sum_count = int(get_path(row, [NODE_VAR_SUM], '0'))
chunk_node_sum[row_res] += node_sum_count
return chunk_edge_count, chunk_node_sum,


def _dnp_res_update(res, up, **_):
edge_count, node_sum_count = res
if up:
chunk_edge_count, chunk_node_sum = up
edge_count.update(chunk_edge_count)
node_sum_count.update(chunk_node_sum)


def generate_stps_from_gp(sparql, gp):
"""Generates a list of source target pairs from a given graph pattern.

Expand Down
Loading