Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cut up root node (not just ultimate ancestor) #850

Open
hyanwong opened this issue Jul 11, 2023 · 6 comments
Open

Cut up root node (not just ultimate ancestor) #850

hyanwong opened this issue Jul 11, 2023 · 6 comments

Comments

@hyanwong
Copy link
Member

hyanwong commented Jul 11, 2023

On the basis that the ultimate ancestor is not biologically very plausible, in recent version of tsinfer we now cut up edges that led direct to the ultimate ancestor, by running the new post_process routine.

However, I suspect (and tests show) that we still make root ancestors that are too long. Therefore we could think about cutting up not just the ultimate ancestor, but also any root in which the edges-in or the edges-out change.

Here's some example code, with a histogram of actual edge spans of the root node. Note that this code may result in nodes that are not ordered strictly by time.

import collections
import itertools

import numpy as np
from matplotlib import pyplot as plt
import msprime
import tsinfer
import tskit

ts = msprime.sim_ancestry(100, population_size=1e4, recombination_rate=1e-8, sequence_length=1e7, random_seed=1)
print("Simulation has", ts.num_trees, "trees")
mts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1)
print("Simulation has", mts.num_mutations, "mutations")

sd = tsinfer.SampleData.from_tree_sequence(mts)
its = tsinfer.infer(sd, progress_monitor=True)
ists = its.simplify()  # remove unary nodes

def unsquash(edge_table, positions, edges=None):
    """
    For a set of positions and a given set of edges (or all edges if ``edges`` is None),
    create a set of new edges which cover the same span but which are chopped up into
    separate edges at every specified position. This is essentially the opposite of
    EdgeTable.squash()
    """
    new_edges = tskit.EdgeTable()
    positions = np.unique(positions) # sort and uniquify
    skip = []
    if edges is not None:
        skip = np.ones(edge_table.num_rows, dtype=bool)
        skip[edges] = False
    for edge, do_skip in itertools.zip_longest(edge_table, skip, fillvalue=False):
        if do_skip:
            new_edges.append(edge)
            continue
        for l, r in itertools.pairwise(itertools.chain(
            [edge.left],
            positions[np.logical_and(positions > edge.left, positions < edge.right)],
            [edge.right]
        )):
            new_edges.append(edge.replace(left=l, right=r))
    edge_table.replace_with(new_edges)

def break_root_nodes(ts):
    tables = ts.dump_tables()
    edges_to_break = set()
    # break up the edges to the root
    for tree in ts.trees():
        if tree.num_edges == 0:
            continue
        for u in tree.children(tree.root):
            edges_to_break.add(tree.edge(u))
    unsquash(
        tables.edges,
        ts.breakpoints(as_array=True),
        edges=np.array([e for e in edges_to_break]),
    )

    ts_split = tables.tree_sequence()
    tables.edges.clear()
    tables.mutations.clear()
    prev_root = None
    nd_map = {u: u for u in range(ts.num_nodes)}
    for ed, ed_split, tree in zip(
        ts.edge_diffs(), ts_split.edge_diffs(), ts_split.trees()
    ):
        if tree.num_edges == 0:
            continue
        if tree.root == prev_root:
            parents = {e.parent for e in ed.edges_out} | {e.parent for e in ed.edges_in}
            if tree.root in parents:
                nd_map[tree.root] = tables.nodes.append(ts.node(tree.root))
        for m in tree.mutations():
            tables.mutations.append(m.replace(node=nd_map[m.node]))
        prev_root = tree.root
        for e in ed_split.edges_in:
            tables.edges.add_row(
                left=e.left, right=e.right, parent=nd_map[e.parent], child=nd_map[e.child])
    tables.sort()
    tables.edges.squash()
    tables.sort()
    return tables.tree_sequence()

iists = break_root_nodes(ists)
print("Created",iists.num_nodes - ists.num_nodes, "new roots")

## Do some histograms

prev_root = None
root_breaks = [0]
for tree in mts.trees():
    if prev_root != tree.root:
        if prev_root is not None:
            root_breaks.append(tree.interval.left)
    prev_root = tree.root
root_breaks.append(mts.sequence_length)
plt.hist(np.log(np.diff(root_breaks)), bins=40, density=True, label="True")

r2 = [0]
prev_root = None
for tree in ists.trees():
    if tree.num_edges == 0:
        continue
    if prev_root != tree.root:
        r2.append(tree.interval.left)
    prev_root = tree.root
r2.append(ists.sequence_length)

plt.hist(np.log(np.diff(r2)), alpha=0.5, bins=40, density=True, label="split ultimate")

r3 = [0]
prev_root = None
for tree in iists.trees():
    if tree.num_edges == 0:
        continue
    if prev_root != tree.root:
        r3.append(tree.interval.left)
    prev_root = tree.root
r3.append(iists.sequence_length)

plt.hist(np.log(np.diff(r3)), alpha=0.5, bins=40, density=True, label="additionally split root")

plt.legend();
Simulation has 21109 trees
Simulation has 23596 mutations
Created 1651 new roots

image

@hyanwong
Copy link
Member Author

And here are the correlations between the known lengths of root nodes and what we infer (it's a pretty poor correlation, though!)

rb = np.array(root_breaks)
mid_root_pos = rb[:-1] + np.diff(rb)/2
ss = np.searchsorted(rb, mid_root_pos)
plt.scatter(np.diff(root_breaks), rb[ss] - rb[ss-1])

rb = np.array(r2)
ss = np.searchsorted(rb, mid_root_pos)
plt.scatter(np.diff(root_breaks), rb[ss] - rb[ss-1], alpha=0.1)
print(
    "corr coeff: known root lengths vs lengths with split ultimate:\n ",
    np.corrcoef(np.diff(root_breaks), rb[ss] - rb[ss-1])[0, 1])

rb = np.array(r3)
ss = np.searchsorted(rb, mid_root_pos)
plt.scatter(np.diff(root_breaks), rb[ss] - rb[ss-1], alpha=0.1)
print(
    "corr coeff: known root lengths vs lengths with extra split root:\n ",
    np.corrcoef(np.diff(root_breaks), rb[ss] - rb[ss-1])[0, 1])

plt.xscale('log')
plt.yscale('log')
corr coeff: known root lengths vs lengths with split ultimate:
  0.06516027384592456
corr coeff: known root lengths vs lengths with extra split root:
  0.13137918764309806

image

@hyanwong
Copy link
Member Author

Extra splitting of the root certainly improves the n=10 plot from @a-ignatieva's ppreprint, especially when combined with @nspope's variational gamma method:

image

@hyanwong
Copy link
Member Author

hyanwong commented Jul 14, 2023

And here for 100 samples. Since these use exactly the same topology, the improvement can't be anything to do with e.g. better polytomy breaking.

image

@hyanwong
Copy link
Member Author

@jeromekelleher and I decided this should be implemented at a minimum for post_process, and then probably rolled out as the default. However, it would be good to think of a more efficient method that the one coded above, and also a method that keeps the nodes in time-order (this might have to be done with a sort at the end, though)

@hyanwong
Copy link
Member Author

A more justified model-based method to cutting up the root nodes is to implement the PSMC-on-the-tree idea for the root. If this is implemented, then it's possible that we should use that to cut up the root nodes instead. So there's an argument for making the version above only available as a non-default post-process option.

@hyanwong
Copy link
Member Author

hyanwong commented Jan 20, 2025

We are (I would say) fully justified in cutting up the ultimate root, as we know that an ancestor with all-zeros is a simplification. Cutting up root nodes in general is more heuristic, so I suggest this should be part of the tsdate preprocess_ts function instead, like the split_disjoint option.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant