Skip to content

Commit

Permalink
WeightedSplitInterval fixes [VS-384] [VS-332] (#7795)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcovarr committed Apr 25, 2022
1 parent f09b162 commit 614a0f7
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ public WeightedInterval[] split(int basesInFirstInterval) {
this.getStart(),
this.getStart() + basesInFirstInterval - 1,
this.isNegativeStrand(),
this.getName() + "-1", // ensure names are unique
this.getName() == null ? null : this.getName() + "-1", // ensure non-null names are unique
this.getWeightPerBase() * basesInFirstInterval);

WeightedInterval right = new WeightedInterval(
this.getContig(),
this.getStart() + basesInFirstInterval,
this.getEnd(),
this.isNegativeStrand(),
this.getName() + "-2", // ensure names are unique
this.getName() == null ? null : this.getName() + "-2", // ensure non-null names are unique
this.getWeight() - left.getWeight()); // give remainder to right

return new WeightedInterval[]{left, right};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public void onTraversalStart() {
WeightedInterval wi = (WeightedInterval) iter.next();

// if we're not mixing contigs, but we switched contigs, emit the list
if (dontMixContigs && lastContig != null && !lastContig.equals(wi.getContig()) ) {
if (dontMixContigs && lastContig != null && !lastContig.equals(wi.getContig())) {
// write out the current list (uniqued and sorted) and start a new one
writeIntervalList(formatString, scatterPiece++, currentList);
currentList = new IntervalList(sequenceDictionary);
Expand All @@ -122,24 +122,29 @@ public void onTraversalStart() {
lastContig = wi.getContig();

// if the interval fits completely, just add it
if (cumulativeWeight + wi.getWeight() <= targetWeightPerScatter ) {
if (cumulativeWeight + wi.getWeight() <= targetWeightPerScatter) {
cumulativeWeight += wi.getWeight();
currentList.add(wi);

// if it would push us over the edge
} else {
// add a piece of it
// try to add a piece of it
float remainingSpace = targetWeightPerScatter - cumulativeWeight;

// how many bases can we take?
int basesToTake = (int) Math.floor(remainingSpace / wi.getWeightPerBase());

// split and add the first part into this list
WeightedInterval[] pair = wi.split(basesToTake);
currentList.add(pair[0]);

// push the remainder back onto the iterator
iter.pushback(pair[1]);
int basesToTake = (int) Math.floor(remainingSpace / wi.getWeightPerBase());

if (basesToTake == 0) {
// We can't add any more bases to the current interval list so put this interval back.
iter.pushback(wi);
} else {
// split and add the first part into this list
WeightedInterval[] pair = wi.split(basesToTake);
currentList.add(pair[0]);

// push the remainder back onto the iterator
iter.pushback(pair[1]);
}

// add uniqued, sorted output list and reset
writeIntervalList(formatString, scatterPiece++, currentList);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,21 @@
package org.broadinstitute.hellbender.tools.gvs.common;

import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.samtools.SAMSequenceRecord;
import htsjdk.samtools.util.Interval;
import htsjdk.samtools.util.IntervalList;
import htsjdk.samtools.util.OverlapDetector;
import org.broadinstitute.hellbender.CommandLineProgramTest;
import org.broadinstitute.hellbender.testutils.ArgumentsBuilder;
import org.broadinstitute.hellbender.tools.walkers.SplitIntervals;
import org.broadinstitute.hellbender.tools.walkers.SplitIntervalsIntegrationTest;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.testng.Assert;
import org.testng.annotations.Test;

import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.TreeSet;
import java.util.stream.Collectors;

import static java.util.Arrays.asList;


// Suppresses the "expression might be null" warnings which would be genuine concerns in production code but which
// are fine in tests (unexpected and dereferenced nulls should crash the test which will appropriately signal failure).
@SuppressWarnings("ConstantConditions")
public class WeightedSplitIntervalsIntegrationTest extends CommandLineProgramTest {

@Test
Expand Down Expand Up @@ -62,7 +54,6 @@ public void testNoLossRealisticWgs() {

final int scatterCount = 100;
final File outputDir = createTempDir("output");
final Interval interval = new Interval("chr20", 1000000, 2000000);

final ArgumentsBuilder args = new ArgumentsBuilder()
.addInterval(wgsIntervalList.getAbsolutePath())
Expand All @@ -88,6 +79,44 @@ public void testNoLossRealisticWgs() {

}

@Test
public void testHandleNotOneMoreBase() {
// Sometimes, the remaining space in a scatter interval is non-zero, yet can't take
// even a single base of the next interval.
// e.g.
// 1 bases, 200 weight each
// 10 bases 1000 weight each
// target: sum 10200 / 10 -> 1020 per shard
// chr20 60000 60001 . 200
// chr20 60001 60011 . 10000
final File weights = new File(publicTestDir + "example_weights_chr20_test_zero.bed");
final int scatterCount = 10;
final File outputDir = createTempDir("output");
final Interval inputInterval = new Interval("chr20",60000,60109);

final ArgumentsBuilder args = new ArgumentsBuilder()
.addInterval(inputInterval.getContig() + ":" + inputInterval.getStart() + "-" + inputInterval.getEnd())
.addReference(hg38Reference)
.add(SplitIntervals.SCATTER_COUNT_SHORT_NAME, scatterCount)
.add(WeightedSplitIntervals.WEIGHTS_BED_FILE_FULL_NAME, weights)
.addOutput(outputDir);
runCommandLine(args);

// even though we asked for 10 scatter, we should get 11 because the target per shard is 1020 weight
// and due to the size of our weights each of the 11 bases must be in their own shard
Assert.assertEquals(outputDir.listFiles().length, 11);

// verify we have exactly the input intervals
IntervalList outList = IntervalList.fromFiles(Arrays.stream(outputDir.listFiles()).collect(Collectors.toList())).uniqued().sorted();

// assert it's a single interval
Assert.assertEquals(outList.getIntervals().size(), 1);

// and the interval itself is the same
Assert.assertEquals(outList.getIntervals().get(0).compareTo(inputInterval), 0);

}

@Test
public void testDontMixContigs() {
final File weights = new File(publicTestDir + "example_weights_chr20_chr21.bed.gz");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,22 @@ public void testWeightGtZero() {
WeightedInterval w = new WeightedInterval("chr1", 1, 1000, -1);
}

@Test
public void testNonNullNameSplit() {
WeightedInterval w = new WeightedInterval("chr1", 1, 1000, false, "definitely not null", 1.0f);
WeightedInterval[] split = w.split(100);
Assert.assertEquals(split[0].getName(), "definitely not null-1");
Assert.assertEquals(split[1].getName(), "definitely not null-2");
}

@Test
public void testNullNameSplit() {
WeightedInterval w = new WeightedInterval("chr1", 1, 1000, false, null, 1.0f);
WeightedInterval[] split = w.split(100);
Assert.assertEquals(split[0].getName(), null);
Assert.assertEquals(split[1].getName(), null);
}

private void assertWeight(WeightedInterval w, String contig, int start, int end, long weight) {
Assert.assertEquals(w.getContig(), contig);
Assert.assertEquals(w.getStart(), start);
Expand Down
2 changes: 2 additions & 0 deletions src/test/resources/example_weights_chr20_test_zero.bed
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
chr20 60000 60001 . 200
chr20 60001 60011 . 10000

0 comments on commit 614a0f7

Please sign in to comment.