Skip to content

Commit

Permalink
refactor: Streamline join internals, encode indices via two separate …
Browse files Browse the repository at this point in the history
…arrays
  • Loading branch information
nvictus committed May 25, 2024
1 parent e5f3eb5 commit 755abd5
Show file tree
Hide file tree
Showing 2 changed files with 293 additions and 384 deletions.
248 changes: 117 additions & 131 deletions bioframe/core/arrops.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,10 @@ def overlap_intervals(starts1, ends1, starts2, ends2, closed=False, sort=False):
If True, then treat intervals as closed and report single-point overlaps.
Returns
-------
overlap_ids : numpy.ndarray
An Nx2 array containing the indices of pairs of overlapping intervals.
The 1st column contains ids from the 1st set, the 2nd column has ids
from the 2nd set.
ovids1, ovids2 : numpy.ndarray
Two 1D arrays containing the indices of pairs of overlapping intervals.
The 1st contains ids from the 1st set, the 2nd has ids from the 2nd set.
"""

for vec in [starts1, ends1, starts2, ends2]:
if isinstance(vec, pd.Series):
warnings.warn(
Expand Down Expand Up @@ -353,28 +350,21 @@ def overlap_intervals(starts1, ends1, starts2, ends2, closed=False, sort=False):
)

# Generate IDs of pairs of overlapping intervals
overlap_ids = np.block(
[
[
np.repeat(ids1[match_2in1_mask], match_2in1_ends - match_2in1_starts)[
:, None
],
ids2[arange_multi(match_2in1_starts, match_2in1_ends)][:, None],
],
[
ids1[arange_multi(match_1in2_starts, match_1in2_ends)][:, None],
np.repeat(ids2[match_1in2_mask], match_1in2_ends - match_1in2_starts)[
:, None
],
],
]
)
ovids1 = np.concatenate([
np.repeat(ids1[match_2in1_mask], match_2in1_ends - match_2in1_starts),
ids1[arange_multi(match_1in2_starts, match_1in2_ends)],
])
ovids2 = np.concatenate([
ids2[arange_multi(match_2in1_starts, match_2in1_ends)],
np.repeat(ids2[match_1in2_mask], match_1in2_ends - match_1in2_starts),
])

if sort:
# Sort overlaps according to the 1st
overlap_ids = overlap_ids[np.lexsort([overlap_ids[:, 1], overlap_ids[:, 0]])]
idx = np.lexsort([ovids2, ovids1])
ovids1 = ovids1[idx]
ovids2 = ovids2[idx]

return overlap_ids
return ovids1, ovids2


def overlap_intervals_outer(starts1, ends1, starts2, ends2, closed=False):
Expand All @@ -393,25 +383,25 @@ def overlap_intervals_outer(starts1, ends1, starts2, ends2, closed=False):
Returns
-------
overlap_ids : numpy.ndarray
An Nx2 array containing the indices of pairs of overlapping intervals.
The 1st column contains ids from the 1st set, the 2nd column has ids
from the 2nd set.
ovids1, ovids2 : numpy.ndarray
Two 1D arrays containing the indices of pairs of overlapping intervals.
The 1st contains ids from the 1st set, the 2nd has ids from the 2nd set.
no_overlap_ids1, no_overlap_ids2 : numpy.ndarray
Two 1D arrays containing the indices of intervals in sets 1 and 2
respectively that do not overlap with any interval in the other set.
"""

ovids = overlap_intervals(starts1, ends1, starts2, ends2, closed=closed)
no_overlap_ids1 = np.where(
np.bincount(ovids[:, 0], minlength=starts1.shape[0]) == 0
)[0]
no_overlap_ids2 = np.where(
np.bincount(ovids[:, 1], minlength=starts2.shape[0]) == 0
)[0]
return ovids, no_overlap_ids1, no_overlap_ids2
n1, n2 = len(starts1), len(starts2)
ovids1, ovids2 = overlap_intervals(starts1, ends1, starts2, ends2, closed=closed)
if n1 > 0:
no_overlap_ids1 = np.setdiff1d(np.arange(len(starts1)), ovids1)
else:
no_overlap_ids1 = np.array([], dtype=int)
if n2 > 0:
no_overlap_ids2 = np.setdiff1d(np.arange(len(starts2)), ovids2)
else:
no_overlap_ids2 = np.array([], dtype=int)
return ovids1, ovids2, no_overlap_ids1, no_overlap_ids2


def merge_intervals(starts, ends, min_dist=0):
Expand Down Expand Up @@ -532,14 +522,12 @@ def _closest_intervals_nooverlap(
Returns
-------
ids: numpy.ndarray
One Nx2 array containing the indices of pairs of closest intervals,
ids1, ids2: numpy.ndarray
Two arrays containing the indices of pairs of closest intervals,
reported for the neighbors in specified direction (by genomic
coordinate). The two columns are the inteval ids from set 1, ids of
coordinate). The two arrays are the inteval ids from set 1, ids of
the closest intevals from set 2.
"""

for vec in [starts1, ends1, starts2, ends2]:
if isinstance(vec, pd.Series):
warnings.warn(
Expand All @@ -556,7 +544,8 @@ def _closest_intervals_nooverlap(
n1 = starts1.shape[0]
n2 = starts2.shape[0]

ids = np.zeros((0, 2), dtype=int)
ids1 = np.array([], dtype=int)
ids2 = np.array([], dtype=int)

if k > 0 and direction == "left":
if tie_arr is None:
Expand All @@ -573,14 +562,8 @@ def _closest_intervals_nooverlap(
int1_ids = np.repeat(np.arange(n1), left_closest_endidx - left_closest_startidx)
int2_sorted_ids = arange_multi(left_closest_startidx, left_closest_endidx)

ids = np.vstack(
[
int1_ids,
ids2_endsorted[int2_sorted_ids],
# ends2_sorted[int2_sorted_ids] - starts1[int1_ids],
# arange_multi(left_closest_startidx - left_closest_endidx, 0)
]
).T
ids1 = int1_ids
ids2 = ids2_endsorted[int2_sorted_ids]

elif k > 0 and direction == "right":
if tie_arr is None:
Expand All @@ -598,17 +581,11 @@ def _closest_intervals_nooverlap(
np.arange(n1), right_closest_endidx - right_closest_startidx
)
int2_sorted_ids = arange_multi(right_closest_startidx, right_closest_endidx)
ids = np.vstack(
[
int1_ids,
ids2_startsorted[int2_sorted_ids],
# starts2_sorted[int2_sorted_ids] - ends1[int1_ids],
# arange_multi(1, right_closest_endidx -
# right_closest_startidx + 1)
]
).T

return ids
ids1 = int1_ids
ids2 = ids2_startsorted[int2_sorted_ids]

return ids1, ids2


def closest_intervals(
Expand All @@ -621,10 +598,11 @@ def closest_intervals(
ignore_overlaps=False,
ignore_upstream=False,
ignore_downstream=False,
direction=None,
along=None,
):
"""
For every interval in set 1, return the indices of k closest intervals from set 2.
For every interval in set 1, return the indices of k closest intervals
from set 2.
Parameters
----------
Expand All @@ -637,127 +615,135 @@ def closest_intervals(
The number of neighbors to report.
tie_arr : numpy.ndarray or None
Extra data describing intervals in set 2 to break ties when multiple intervals
are located at the same distance. Intervals with *lower* tie_arr values will
be given priority.
Extra data describing intervals in set 2 to break ties when multiple
intervals are located at the same distance. Intervals with *lower*
tie_arr values will be given priority.
ignore_overlaps : bool
If True, ignore set 2 intervals that overlap with set 1 intervals.
ignore_upstream, ignore_downstream : bool
If True, ignore set 2 intervals upstream/downstream of set 1 intervals.
direction : numpy.ndarray with dtype bool or None
Strand vector to define the upstream/downstream orientation of the intervals.
along : numpy.ndarray with dtype bool or None
Strand vector to define the upstream/downstream orientation of the
intervals.
Returns
-------
closest_ids : numpy.ndarray
An Nx2 array containing the indices of pairs of closest intervals.
The 1st column contains ids from the 1st set, the 2nd column has ids
closest_ids1, closest_ids2 : numpy.ndarray
Two arrays containing the indices of pairs of closest intervals.
The 1st array contains ids from the 1st set, the 2nd array has ids
from the 2nd set.
"""

# Get overlapping intervals:
# Get overlaps
if ignore_overlaps:
overlap_ids = np.zeros((0, 2), dtype=int)
ovids1, ovids2 = np.array([], dtype=int), np.array([], dtype=int)
elif (starts2 is None) and (ends2 is None):
starts2, ends2 = starts1, ends1
overlap_ids = overlap_intervals(starts1, ends1, starts2, ends2)
overlap_ids = overlap_ids[overlap_ids[:, 0] != overlap_ids[:, 1]]
ovids1, ovids2 = overlap_intervals(starts1, ends1, starts2, ends2)
mask = ovids1 != ovids2
ovids1 = ovids1[mask]
ovids2 = ovids2[mask]
else:
overlap_ids = overlap_intervals(starts1, ends1, starts2, ends2)
ovids1, ovids2 = overlap_intervals(starts1, ends1, starts2, ends2)

# Get non-overlapping intervals:
# Get non-overlapping nearest neighbors
n = len(starts1)
all_ids = np.arange(n)

# + directed intervals
ids_left_upstream = _closest_intervals_nooverlap(
starts1[direction],
ends1[direction],
if along is None:
along = np.ones(n, dtype=bool)

# + stranded intervals
pos_starts1, pos_ends1 = starts1[along], ends1[along]
pos_up1, pos_up2 = _closest_intervals_nooverlap(
pos_starts1,
pos_ends1,
starts2,
ends2,
direction="left",
tie_arr=tie_arr,
k=0 if ignore_upstream else k,
)
ids_right_downstream = _closest_intervals_nooverlap(
starts1[direction],
ends1[direction],
pos_dn1, pos_dn2 = _closest_intervals_nooverlap(
pos_starts1,
pos_ends1,
starts2,
ends2,
direction="right",
tie_arr=tie_arr,
k=0 if ignore_downstream else k,
)
# - directed intervals
ids_right_upstream = _closest_intervals_nooverlap(
starts1[~direction],
ends1[~direction],

# - stranded intervals
neg_starts1, neg_ends1 = starts1[~along], ends1[~along]
neg_up1, neg_up2 = _closest_intervals_nooverlap(
neg_starts1,
neg_ends1,
starts2,
ends2,
direction="right",
tie_arr=tie_arr,
k=0 if ignore_upstream else k,
)
ids_left_downstream = _closest_intervals_nooverlap(
starts1[~direction],
ends1[~direction],
neg_dn1, neg_dn2 = _closest_intervals_nooverlap(
neg_starts1,
neg_ends1,
starts2,
ends2,
direction="left",
tie_arr=tie_arr,
k=0 if ignore_downstream else k,
)

# Reconstruct original indexes (b/c we split regions by direction above)
ids_left_upstream[:, 0] = all_ids[direction][ids_left_upstream[:, 0]]
ids_right_downstream[:, 0] = all_ids[direction][ids_right_downstream[:, 0]]
ids_left_downstream[:, 0] = all_ids[~direction][ids_left_downstream[:, 0]]
ids_right_upstream[:, 0] = all_ids[~direction][ids_right_upstream[:, 0]]
# Reconstruct original indices (b/c we split ranges by strand above)
pos_ids = np.where(along)[0]
neg_ids = np.where(~along)[0]
pos_up1 = pos_ids[pos_up1]
pos_dn1 = pos_ids[pos_dn1]
neg_dn1 = neg_ids[neg_dn1]
neg_up1 = neg_ids[neg_up1]

left_ids = np.concatenate([ids_left_upstream, ids_left_downstream])
right_ids = np.concatenate([ids_right_upstream, ids_right_downstream])
# Combine by absolute search direction
left_ids1 = np.concatenate([pos_up1, neg_dn1])
left_ids2 = np.concatenate([pos_up2, neg_dn2])
right_ids1 = np.concatenate([neg_up1, pos_dn1])
right_ids2 = np.concatenate([neg_up2, pos_dn2])

# Increase the distance by 1 to distinguish between overlapping
# and non-overlapping set 2 intervals.
left_dists = starts1[left_ids[:, 0]] - ends2[left_ids[:, 1]] + 1
right_dists = starts2[right_ids[:, 1]] - ends1[right_ids[:, 0]] + 1

closest_ids = np.vstack([left_ids, right_ids, overlap_ids])
closest_dists = np.concatenate(
[left_dists, right_dists, np.zeros(overlap_ids.shape[0])]
left_dists = starts1[left_ids1] - ends2[left_ids2] + 1
right_dists = starts2[right_ids2] - ends1[right_ids1] + 1

# Combine the results
events1 = np.concatenate([left_ids1, right_ids1, ovids1])
events2 = np.concatenate([left_ids2, right_ids2, ovids2])
dists = np.concatenate(
[left_dists, right_dists, np.zeros(ovids1.shape[0])]
)

if len(closest_ids) == 0:
return np.empty((0, 2), dtype=int)
if len(events1) == 0:
return np.array([], dtype=int), np.array([], dtype=int)

# Sort by distance to set 1 intervals and, if present, by the tie-breaking
# data array.
if tie_arr is None:
order = np.lexsort([closest_ids[:, 1], closest_dists, closest_ids[:, 0]])
order = np.lexsort([events2, dists, events1])
else:
order = np.lexsort(
[closest_ids[:, 1], tie_arr, closest_dists, closest_ids[:, 0]]
)

closest_ids = closest_ids[order, :2]

# For each set 1 interval, select up to k closest neighbours.
interval1_run_border_mask = closest_ids[:-1, 0] != closest_ids[1:, 0]
interval1_run_borders = np.where(np.r_[True, interval1_run_border_mask, True])[0]
interval1_run_starts = interval1_run_borders[:-1]
interval1_run_ends = interval1_run_borders[1:]
closest_ids = closest_ids[
arange_multi(
interval1_run_starts,
lengths=np.minimum(k, interval1_run_ends - interval1_run_starts),
)
]
order = np.lexsort([events2, tie_arr, dists, events1])
events1 = events1[order]
events2 = events2[order]

# Prune the results to the k nearest neighbors
# For each sorted run of set 1 intervals, select up to k entries
run_borders = np.where(np.r_[True, events1[:-1] != events1[1:], True])[0]
run_starts = run_borders[:-1]
run_ends = run_borders[1:]
idx = arange_multi(
run_starts,
lengths=np.minimum(k, run_ends - run_starts),
)

return closest_ids
return events1[idx], events2[idx]


def coverage_intervals_rle(starts, ends, weights=None):
Expand Down
Loading

0 comments on commit 755abd5

Please sign in to comment.