diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/PartiallyDeterminedHaplotypeComputationEngine.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/PartiallyDeterminedHaplotypeComputationEngine.java index 76925f943bb..0771da7380d 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/PartiallyDeterminedHaplotypeComputationEngine.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/PartiallyDeterminedHaplotypeComputationEngine.java @@ -20,6 +20,10 @@ import org.broadinstitute.hellbender.utils.read.CigarBuilder; import org.broadinstitute.hellbender.utils.read.CigarUtils; import org.broadinstitute.hellbender.utils.smithwaterman.SmithWatermanAligner; +import org.jgrapht.Graph; +import org.jgrapht.alg.ConnectivityInspector; +import org.jgrapht.graph.DefaultEdge; +import org.jgrapht.graph.SimpleGraph; import java.util.*; import java.util.stream.Collectors; @@ -87,51 +91,18 @@ public static AssemblyResultSet generatePDHaplotypes(final AssemblyResultSet sou final boolean debug = pileupArgs.debugPileupStdout; final List eventsInOrder = makeFinalListOfEventsInOrder(sourceSet, badPileupEvents, goodPileupEvents, referenceHaplotype, pileupArgs, debug); - // TODO this is where we filter out if indels > 32 (a heuristic known from DRAGEN that is not implemented here) Map> eventsByDRAGENCoordinates = eventsInOrder.stream() .collect(Collectors.groupingBy(e -> dragenStart(e), LinkedHashMap::new, Collectors.toList())); + eventByDragenCoordinateMessage(referenceHaplotype, debug, eventsByDRAGENCoordinates); SortedMap> variantsByStartPos = eventsInOrder.stream() .collect(Collectors.groupingBy(Event::getStart, TreeMap::new, Collectors.toList())); - List eventGroups = new ArrayList<>(); - int lastEventEnd = -1; - for (Event vc : eventsInOrder) { - // Break everything into independent groups (don't worry about transitivitiy right now) - Double eventKey = dragenStart(vc) - referenceHaplotype.getStart(); - if (eventKey <= lastEventEnd + 0.5) { - eventGroups.get(eventGroups.size()-1).addEvent(vc); - } else { - eventGroups.add(new EventGroup(vc)); - } - int newEnd = (vc.getEnd() - referenceHaplotype.getStart()); - lastEventEnd = Math.max(newEnd, lastEventEnd); - } - eventGroupsMessage(referenceHaplotype, debug, eventsByDRAGENCoordinates); - - // Iterate over all events starting with all indels List> disallowedPairs = smithWatermanRealignPairsOfVariantsForEquivalentEvents(referenceHaplotype, aligner, args.getHaplotypeToReferenceSWParameters(), debug, eventsInOrder); dragenDisallowedGroupsMessage(referenceHaplotype.getStart(), debug, disallowedPairs); - Utils.printIf(debug, () -> "Event groups before merging:\n"+eventGroups.stream().map(eg -> eg.toDisplayString(referenceHaplotype.getStart())).collect(Collectors.joining("\n"))); - - //Now that we have the disallowed groups, lets merge any of them from separate groups: - //TODO this is not an efficient way of doing this - for (List pair : disallowedPairs) { - EventGroup eventGrpLeft = null; - for (Event event : pair) { - EventGroup grpForEvent = eventGroups.stream().filter(grp -> grp.contains(event)).findFirst().get(); - // If the event isn't in the same event group as its predecessor, merge this group with that one and - if (eventGrpLeft != grpForEvent) { - if (eventGrpLeft == null) { - eventGrpLeft = grpForEvent; - } else { - eventGrpLeft.mergeEvent(grpForEvent); - eventGroups.remove(grpForEvent); - } - } - } - } + + final List eventGroups = getEventGroupClusters(eventsInOrder, disallowedPairs); Utils.printIf(debug,() -> "Event groups after merging:\n"+eventGroups.stream().map(eg -> eg.toDisplayString(referenceHaplotype.getStart())).collect(Collectors.joining("\n"))); // if any of our merged event groups is too large, abort. @@ -334,6 +305,42 @@ private static List makeFinalListOfEventsInOrder(final AssemblyResultSet return eventsInOrder; } + /** + * Partition events into clusters that must be considered together, either because they overlap or because they belong to the + * same mutually exclusive pair or trio. To find this clustering we calculate the connected components of an undirected graph + * with an edge connecting events that overlap or are mutually excluded. + */ + private static List getEventGroupClusters(List eventsInOrder, List> disallowedPairsAndTrios) { + final Graph graph = new SimpleGraph<>(DefaultEdge.class); + eventsInOrder.forEach(graph::addVertex); + + // edges due to overlapping position + for (int e1 = 0; e1 < eventsInOrder.size(); e1++) { + final Event event1 = eventsInOrder.get(e1); + for (int e2 = e1 + 1; e2 < eventsInOrder.size(); e2++) { + final Event event2 = eventsInOrder.get(e2); + if (dragenOverlap(event1, event2)) { + graph.addEdge(event1, event2); + } else if (event2.getStart() > event1.getEnd() + 1){ + break; + } + } + } + + // edges due to mutual exclusion + for (final List excludedGroup : disallowedPairsAndTrios) { + graph.addEdge(excludedGroup.get(0), excludedGroup.get(1)); + if (excludedGroup.size() == 3) { + graph.addEdge(excludedGroup.get(1), excludedGroup.get(2)); + } + } + + return new ConnectivityInspector<>(graph).connectedSets().stream() + .map(set -> set.stream().sorted(Comparator.comparingInt(Event::getStart)).collect(Collectors.toList())) + .map(EventGroup::new) + .collect(Collectors.toList()); + } + /** * Helper method that handles one of the Heuristics baked by DRAGEN into this artificial haplotype generation code. * @@ -719,7 +726,7 @@ private static class EventGroup { // Optimization to save ourselves recomputing the subsets at every point its necessary to do so. List>> cachedEventLists = null; - public EventGroup(final Event ... events) { + public EventGroup(final Collection events) { eventsInBitmapOrder = new ArrayList<>(); eventSet = new HashSet<>(); @@ -903,7 +910,12 @@ private static double dragenStart(final Event event) { return event.getStart() + (event.isIndel() ? (event.isSimpleDeletion() ? 1 : 0.5) : 0); } - private static void eventGroupsMessage(final Haplotype referenceHaplotype, final boolean debug, final Map> eventsByDRAGENCoordinates) { + private static final boolean dragenOverlap(final Event event1, final Event event2) { + return event1.getStart() <= event2.getStart() ? + (dragenStart(event2) < event1.getEnd() + 1) : (dragenStart(event1) < event2.getEnd() + 1); + } + + private static void eventByDragenCoordinateMessage(final Haplotype referenceHaplotype, final boolean debug, final Map> eventsByDRAGENCoordinates) { Utils.printIf(debug, () -> eventsByDRAGENCoordinates.entrySet().stream() .map(e -> String.format("%.1f", e.getKey()) + " -> " + dragenString(e.getValue(), referenceHaplotype.getStart(),",")) .collect(Collectors.joining("\n")));