Skip to content

Commit

Permalink
Added optional PileupDetection step to Mutect and HaplotypeCaller bef…
Browse files Browse the repository at this point in the history
…ore assembly that supplements the assembly variants with variants that show up in the pileups. (#7432)
  • Loading branch information
bhanugandham committed Apr 14, 2022
1 parent 318fa84 commit 9e04333
Show file tree
Hide file tree
Showing 29 changed files with 1,484 additions and 89 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package org.broadinstitute.hellbender.engine;

/**
* Bundles together and AlignmentContext and a ReferenceContext
*/
public class AlignmentAndReferenceContext {

private final AlignmentContext alignmentContext;
private final ReferenceContext referenceContext;

public AlignmentAndReferenceContext(final AlignmentContext alignmentContext,
final ReferenceContext referenceContext) {
this.alignmentContext = alignmentContext;
this.referenceContext = referenceContext;
}

/**
* getter for the AlignmentContect
* @return the alignmentContext
*/
public AlignmentContext getAlignmentContext() {
return alignmentContext;
}

/**
* getter for the ReferenceContect
* @return the referenceContext
*/
public ReferenceContext getReferenceContext() {
return referenceContext;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
import htsjdk.samtools.reference.ReferenceSequenceFile;
import htsjdk.samtools.util.Locatable;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.AssemblyBasedCallerUtils;
import org.broadinstitute.hellbender.utils.IntervalUtils;
import org.broadinstitute.hellbender.utils.SequenceDictionaryUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.clipping.ReadClipper;
Expand Down Expand Up @@ -49,6 +47,11 @@ public final class AssemblyRegion implements Locatable {
*/
private final List<GATKRead> reads;

/**
* The reads are specifically used for haplotype generation to kmerize reads to match with haplotype kmers.
*/
private final List<GATKRead> hardClippedPileupReads;

/**
* The active span in which this AssemblyRegion is responsible for calling variants
*/
Expand All @@ -70,6 +73,8 @@ public final class AssemblyRegion implements Locatable {
*/
private boolean hasBeenFinalized;

private List<AlignmentAndReferenceContext> alignmentData = new ArrayList<>();

/**
* Create a new AssemblyRegion containing no reads
* @param activeSpan the span of this active region
Expand Down Expand Up @@ -104,6 +109,7 @@ public AssemblyRegion(final SimpleInterval activeSpan, final SimpleInterval padd
Utils.validate(paddedSpan.contains(activeSpan), "Padded span must contain active span.");

reads = new ArrayList<>();
hardClippedPileupReads = new ArrayList<>();
this.isActive = isActive;
}

Expand All @@ -114,6 +120,23 @@ public AssemblyRegion(final SimpleInterval activeSpan, final int padding, final
this(activeSpan, true, padding, header);
}

/**
* Method for obtaining the alignment data which is attached to the assembly region.
*
* @return The list of AlignmentData objects associated with ActiveRegion.
*/
public List<AlignmentAndReferenceContext> getAlignmentData() {
return alignmentData;
}

/**
* Method for adding alignment data to the collection of AlignmentData associated with
* the ActiveRegion.
*/
public void addAllAlignmentData(List<AlignmentAndReferenceContext> alignmentData) {
this.alignmentData.addAll(alignmentData);
}

@Override
public String getContig() {
return activeSpan.getContig();
Expand Down Expand Up @@ -176,6 +199,16 @@ public List<GATKRead> getReads(){
return Collections.unmodifiableList(new ArrayList<>(reads));
}

/**
* Get an unmodifiable copy of the list of reads currently in this assembly region.
*
* The reads are sorted by their coordinate position.
* @return an unmodifiable and inmutable copy of the reads in the assembly region.
*/
public List<GATKRead> getHardClippedPileupReads(){
return Collections.unmodifiableList(new ArrayList<>(hardClippedPileupReads));
}

/**
* Returns the header for the reads in this region.
*/
Expand Down Expand Up @@ -252,20 +285,24 @@ public AssemblyRegion trim(final SimpleInterval span, final SimpleInterval padde
* @param read a non-null GATKRead
*/
public void add( final GATKRead read ) {
addToReadCollectionAndValidate(read, reads);
}

private void addToReadCollectionAndValidate(final GATKRead read, final List<GATKRead> collection) {
Utils.nonNull(read, "Read cannot be null");
final SimpleInterval readLoc = new SimpleInterval( read );
Utils.validateArg(paddedSpan.overlaps(read), () ->
"Read location " + readLoc + " doesn't overlap with active region padded span " + paddedSpan);

if ( ! reads.isEmpty() ) {
final GATKRead lastRead = reads.get(size() - 1);
if ( ! collection.isEmpty() ) {
final GATKRead lastRead = collection.get(collection.size() - 1);
Utils.validateArg(Objects.equals(lastRead.getContig(), read.getContig()), () ->
"Attempting to add a read to ActiveRegion not on the same contig as other reads: lastRead " + lastRead + " attempting to add " + read);
Utils.validateArg( read.getStart() >= lastRead.getStart(), () ->
"Attempting to add a read to ActiveRegion out of order w.r.t. other reads: lastRead " + lastRead + " at " + lastRead.getStart() + " attempting to add " + read + " at " + read.getStart());
}

reads.add( read );
collection.add( read );
}

/**
Expand All @@ -279,6 +316,7 @@ public void add( final GATKRead read ) {
*/
public void clearReads() {
reads.clear();
hardClippedPileupReads.clear();
}

/**
Expand All @@ -298,6 +336,10 @@ public void addAll(final Collection<GATKRead> readsToAdd){
Utils.nonNull(readsToAdd).forEach(r -> add(r));
}

public void addHardClippedPileupReads(final Collection<GATKRead> readsToAdd) {
Utils.nonNull(readsToAdd).forEach(r -> addToReadCollectionAndValidate(r, hardClippedPileupReads));
}

/**
* Get the reference bases from referenceReader spanned by the padded span of this region,
* including additional padding bp on either side. If this expanded region would exceed the boundaries
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public class AssemblyRegionIterator implements Iterator<AssemblyRegion> {
private final Iterator<AlignmentContext> locusIterator;
private final LocusIteratorByState libs;
private final ActivityProfile activityProfile;
private Queue<AlignmentAndReferenceContext> pendingAlignmentData;

/**
* Constructs an AssemblyRegionIterator over a provided read shard
Expand All @@ -65,7 +66,8 @@ public AssemblyRegionIterator(final MultiIntervalShard<GATKRead> readShard,
final ReferenceDataSource reference,
final FeatureManager features,
final AssemblyRegionEvaluator evaluator,
final AssemblyRegionArgumentCollection assemblyRegionArgs) {
final AssemblyRegionArgumentCollection assemblyRegionArgs,
final boolean trackPileups ) {

Utils.nonNull(readShard);
Utils.nonNull(readHeader);
Expand All @@ -86,6 +88,7 @@ public AssemblyRegionIterator(final MultiIntervalShard<GATKRead> readShard,
this.readCachingIterator = new ReadCachingIterator(readShard.iterator());
this.readCache = new ArrayDeque<>();
this.activityProfile = new BandPassActivityProfile(assemblyRegionArgs.maxProbPropagationDistance, assemblyRegionArgs.activeProbThreshold, BandPassActivityProfile.MAX_FILTER_SIZE, BandPassActivityProfile.DEFAULT_SIGMA, readHeader);
this.pendingAlignmentData = trackPileups ? new ArrayDeque<>() : null;

// We wrap our LocusIteratorByState inside an IntervalAlignmentContextIterator so that we get empty loci
// for uncovered locations. This is critical for reproducing GATK 3.x behavior!
Expand All @@ -103,7 +106,7 @@ public boolean hasNext() {

@Override
public AssemblyRegion next() {
if ( ! hasNext() ) {
if (!hasNext()) {
throw new NoSuchElementException("next() called when there were no more elements");
}

Expand Down Expand Up @@ -132,6 +135,9 @@ private AssemblyRegion loadNextAssemblyRegion() {
final SimpleInterval pileupInterval = new SimpleInterval(pileup);
final ReferenceContext pileupRefContext = new ReferenceContext(reference, pileupInterval);
final FeatureContext pileupFeatureContext = new FeatureContext(features, pileupInterval);
if (pendingAlignmentData!=null) {
pendingAlignmentData.add(new AlignmentAndReferenceContext(pileup, pileupRefContext));
}

final ActivityProfileState profile = evaluator.isActive(pileup, pileupRefContext, pileupFeatureContext);
activityProfile.add(profile);
Expand Down Expand Up @@ -169,6 +175,8 @@ private AssemblyRegion loadNextAssemblyRegion() {
// If there's a region ready, fill it with reads before returning
if ( nextRegion != null ) {
fillNextAssemblyRegionWithReads(nextRegion);
// fillnextessemblyregion; check you are on correct chr; if alignment data is not in the assembly region then pop it
fillNextAssemblyRegionWithPileupData(nextRegion);
}

return nextRegion;
Expand Down Expand Up @@ -209,6 +217,42 @@ private void fillNextAssemblyRegionWithReads( final AssemblyRegion region ) {
}
}

private void fillNextAssemblyRegionWithPileupData(final AssemblyRegion region){
// Save ourselves the memory footprint and work of saving the pileups in the event they aren't needed for processing.
if (pendingAlignmentData == null){
return;
}
final List<AlignmentAndReferenceContext> overlappingAlignmentData = new ArrayList<>();
final Queue<AlignmentAndReferenceContext> previousAlignmentData = new ArrayDeque<>();

while (!pendingAlignmentData.isEmpty()) {
final AlignmentContext pendingAlignmentContext = pendingAlignmentData.peek().getAlignmentContext();
if (!pendingAlignmentContext.contigsMatch(region) ||
pendingAlignmentContext.getStart() < region.getStart()) {
pendingAlignmentData.poll(); // pop this
} else {
break;
}
}
while (!pendingAlignmentData.isEmpty()) {
final AlignmentContext pendingAlignmentContext = pendingAlignmentData.peek().getAlignmentContext();

if (!pendingAlignmentContext.contigsMatch(region) ||
pendingAlignmentContext.getStart() <= region.getEnd()) {
overlappingAlignmentData.add(pendingAlignmentData.poll()); // pop into overlappingAlignmentData
} else {
break;
}
}

// reconstructing queue to contain items that may be in the next assembly region
previousAlignmentData.addAll(overlappingAlignmentData);
previousAlignmentData.addAll(pendingAlignmentData);
pendingAlignmentData = previousAlignmentData;

region.addAllAlignmentData(overlappingAlignmentData);
}

@Override
public void remove() {
throw new UnsupportedOperationException("remove() not supported by AssemblyRegionIterator");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package org.broadinstitute.hellbender.engine;

import org.broadinstitute.barclay.argparser.Advanced;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.ArgumentCollection;
import org.broadinstitute.barclay.argparser.CommandLineException;
import org.broadinstitute.hellbender.engine.filters.CountingReadFilter;
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
import org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary;
Expand Down Expand Up @@ -185,7 +183,7 @@ public void traverse() {
* @param features FeatureManager
*/
private void processReadShard(MultiIntervalLocalReadShard shard, ReferenceDataSource reference, FeatureManager features ) {
final Iterator<AssemblyRegion> assemblyRegionIter = new AssemblyRegionIterator(shard, getHeaderForReads(), reference, features, assemblyRegionEvaluator(), assemblyRegionArgs);
final Iterator<AssemblyRegion> assemblyRegionIter = new AssemblyRegionIterator(shard, getHeaderForReads(), reference, features, assemblyRegionEvaluator(), assemblyRegionArgs, shouldTrackPileupsForAssemblyRegions());

// Call into the tool implementation to process each assembly region from this shard.
while ( assemblyRegionIter.hasNext() ) {
Expand Down Expand Up @@ -236,6 +234,12 @@ protected final void onShutdown() {
*/
public abstract AssemblyRegionEvaluator assemblyRegionEvaluator();

/**
* Allows implementing tools to decide whether pileups must be tracked and attached to assembly regions for later processing.
* This is configurable for now in order to save on potential increases in memory consumption variant calling machinery.
*/
public abstract boolean shouldTrackPileupsForAssemblyRegions();

/**
* Process an individual AssemblyRegion. Must be implemented by tool authors.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
package org.broadinstitute.hellbender.engine;

import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMReadGroupRecord;
import htsjdk.samtools.util.Locatable;
import htsjdk.samtools.util.OverlapDetector;
import org.apache.commons.collections4.SetUtils;
import org.broadinstitute.hellbender.engine.filters.CountingReadFilter;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.locusiterator.AlignmentContextIteratorBuilder;
import org.broadinstitute.hellbender.utils.read.GATKRead;

import java.util.*;
import java.util.stream.Collectors;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;

/**
* An implementation of {@link LocusWalker} that supports arbitrary interval side inputs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ intervalShards, assemblyRegionEvaluatorSupplierBroadcast(ctx), shardingArgs, ass
} else {
return FindAssemblyRegionsSpark.getAssemblyRegionsFast(ctx, getReads(), getHeaderForReads(), sequenceDictionary, referenceFileName, features,
intervalShards, assemblyRegionEvaluatorSupplierBroadcast(ctx), shardingArgs, assemblyRegionArgs,
shuffle);
shuffle, false);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,21 @@ public static JavaRDD<AssemblyRegionWalkerContext> getAssemblyRegionsFast(
final Broadcast<Supplier<AssemblyRegionEvaluator>> assemblyRegionEvaluatorSupplierBroadcast,
final AssemblyRegionReadShardArgumentCollection shardingArgs,
final AssemblyRegionArgumentCollection assemblyRegionArgs,
final boolean shuffle) {
final boolean shuffle,
final boolean trackPileups) {
JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, reads, GATKRead.class, sequenceDictionary, intervalShards, shardingArgs.readShardSize, shuffle);
Broadcast<FeatureManager> bFeatureManager = features == null ? null : ctx.broadcast(features);
return shardedReads.mapPartitions(getAssemblyRegionsFunctionFast(referenceFileName, bFeatureManager, header,
assemblyRegionEvaluatorSupplierBroadcast, assemblyRegionArgs));
assemblyRegionEvaluatorSupplierBroadcast, assemblyRegionArgs, trackPileups));
}

private static FlatMapFunction<Iterator<Shard<GATKRead>>, AssemblyRegionWalkerContext> getAssemblyRegionsFunctionFast(
final String referenceFileName,
final Broadcast<FeatureManager> bFeatureManager,
final SAMFileHeader header,
final Broadcast<Supplier<AssemblyRegionEvaluator>> supplierBroadcast,
final AssemblyRegionArgumentCollection assemblyRegionArgs) {
final AssemblyRegionArgumentCollection assemblyRegionArgs,
final boolean trackPileups) {
return (FlatMapFunction<Iterator<Shard<GATKRead>>, AssemblyRegionWalkerContext>) shardedReadIterator -> {
final ReferenceDataSource reference = referenceFileName == null ? null : new ReferenceFileSource(IOUtils.getPath(SparkFiles.get(referenceFileName)));
final FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();
Expand All @@ -90,7 +92,7 @@ private static FlatMapFunction<Iterator<Shard<GATKRead>>, AssemblyRegionWalkerCo
.map(downsampledShardedRead -> {
final Iterator<AssemblyRegion> assemblyRegionIter = new AssemblyRegionIterator(
new ShardToMultiIntervalShardAdapter<>(downsampledShardedRead),
header, reference, features, assemblyRegionEvaluator, assemblyRegionArgs);
header, reference, features, assemblyRegionEvaluator, assemblyRegionArgs, trackPileups);
return Utils.stream(assemblyRegionIter).map(assemblyRegion ->
new AssemblyRegionWalkerContext(assemblyRegion,
new ReferenceContext(reference, assemblyRegion.getPaddedSpan()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ public static void callVariantsWithHaplotypeCallerAndWriteOutput(
Broadcast<Supplier<AssemblyRegionEvaluator>> assemblyRegionEvaluatorSupplierBroadcast = assemblyRegionEvaluatorSupplierBroadcast(ctx, hcArgs, assemblyRegionArgs, header, reference, annotations);
JavaRDD<AssemblyRegionWalkerContext> assemblyRegions = strict ?
FindAssemblyRegionsSpark.getAssemblyRegionsStrict(ctx, reads, header, sequenceDictionary, referenceFileName, null, intervalShards, assemblyRegionEvaluatorSupplierBroadcast, shardingArgs, assemblyRegionArgs, false) :
FindAssemblyRegionsSpark.getAssemblyRegionsFast(ctx, reads, header, sequenceDictionary, referenceFileName, null, intervalShards, assemblyRegionEvaluatorSupplierBroadcast, shardingArgs, assemblyRegionArgs, false);
FindAssemblyRegionsSpark.getAssemblyRegionsFast(ctx, reads, header, sequenceDictionary, referenceFileName, null, intervalShards, assemblyRegionEvaluatorSupplierBroadcast, shardingArgs, assemblyRegionArgs, false, hcArgs.pileupDetectionArgs.usePileupDetection);
processAssemblyRegions(assemblyRegions, ctx, header, reference, hcArgs, assemblyRegionArgs, output, annotations, logger, createOutputVariantIndex);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ public AssemblyRegionEvaluator assemblyRegionEvaluator() {
return (locusPileup, referenceContext, featureContext) -> new ActivityProfileState(new SimpleInterval(locusPileup), 1.0);
}

@Override
public boolean shouldTrackPileupsForAssemblyRegions() {
return false;
}

@Override
public void onTraversalStart() {
try {
Expand Down
Loading

0 comments on commit 9e04333

Please sign in to comment.