Skip to content

Commit

Permalink
add microbial mode (#6694)
Browse files Browse the repository at this point in the history
* this changes the behavior of the adaptive pruner at low/patchy reference coverage sites.
* fixed an issue where the logic for suffix ends was incorrect for dangling tails
* change misspellings of dangling to the correct spelling
* doc as beta feature
Co-authored-by: James <[email protected]>
  • Loading branch information
ahaessly committed Aug 2, 2021
1 parent 9951f77 commit f548ccd
Show file tree
Hide file tree
Showing 15 changed files with 396 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
package org.broadinstitute.hellbender.tools.walkers.fasta;

import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.samtools.SAMSequenceRecord;
import htsjdk.samtools.reference.FastaReferenceWriter;
import htsjdk.samtools.reference.FastaReferenceWriterBuilder;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.BetaFeature;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
import org.broadinstitute.hellbender.engine.GATKTool;
import org.broadinstitute.hellbender.engine.ReferenceDataSource;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.io.IOUtils;
import picard.cmdline.programgroups.ReferenceProgramGroup;

import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.List;
import java.util.ListIterator;

/**
* Create a fasta with the bases shifted by offset
*
* delta1 = offset - 1
* delta2 = total - delta1
*
* To shift forward:
* if you are given a position in the regular fasta (pos_r) and want the position in the shifted fasta (pos_s):
* if pos_r > delta1 => pos_s = pos_r - delta1 == pos_r - offset +1
* otherwise pos_s = pos_r + delta2 == pos_r + total - offset + 1
*
* To shift back:
* if you are given a position in the shifted fasta (pos_s) and want the position in the regular fasta (pos_r):
* if pos_s > delta2 => pos_r = pos_s - delta2 == pos_s - total + offset - 1
* otherwise pos_r = pos_s + delta1 == pos_s + offset - 1
*
* Example command line:
* ShiftFasta
* -R "<CIRCURLAR_REFERENCE.fasta>" // the reference to shift
* -O "<SHIFTED_REFERENCE.fasta>" // output; the shifted fasta
* --shift-back-output "<SHIFT_BACK.chain>" // output; the shiftback chain file to use when lifting over
* --shift-offset-list "<SHIFT_OFFSETS>" // optional; Specifies the offset to shift for each contig in the reference. If not specified, the offset will be half the length of the contig.
* --interval-file-name "<SHIFT_INTERVALS>" // output; base name for output interval files (.intervals and .shifted.intervals) that should be used when calling variants against the unshifted and shifted reference.
* --line-width 100
*/
@DocumentedFeature
@BetaFeature
@CommandLineProgramProperties(
summary = "Create a new fasta starting at the shift-offset +1 position and a shift_back chain file that can be used with the Liftover tool. It will shift all contigs by default.",
oneLineSummary = "Creates a shifted fasta file and shift_back file",
programGroup = ReferenceProgramGroup.class
)
public class ShiftFasta extends GATKTool {

@Argument(fullName = StandardArgumentDefinitions.OUTPUT_LONG_NAME,
shortName = StandardArgumentDefinitions.OUTPUT_SHORT_NAME,
doc = "Path to write the output fasta to")
protected String output;

public static final String SHIFT_BACK_OUTPUT = "shift-back-output";
@Argument(fullName = SHIFT_BACK_OUTPUT,
doc = "Path to write the shift_back file to")
protected String shiftBackOutput;

public static final String SHIFT_OFFSET_LIST = "shift-offset-list";
@Argument(fullName = SHIFT_OFFSET_LIST,
doc="Number of bases to skip in the reference before starting the shifted reference. If the reference contains multiple contigs, a value should be specified for each contig. " +
"For example, if 300 is specified, the new fasta will start at the 301th base (count starting at 1)." +
"If not specified, each contig will be shifted by half the number of bases. To skip the shifting of a contig, specify 0 in the list.", optional = true)
private List<Integer> shiftOffsets = null;

public static final String INTERAL_FILE_NAME = "interval-file-name";
@Argument(fullName = INTERAL_FILE_NAME,
doc="Base name for interval files. Intervals will be midway between beginning and computed offset. If not specified, no interval files will be written.", optional = true)
private String intervalFilename;

public static final String LINE_WIDTH_LONG_NAME = "line-width";
@Argument(fullName= LINE_WIDTH_LONG_NAME, doc="Maximum length of sequence to write per line", optional=true)
public int basesPerLine = FastaReferenceWriter.DEFAULT_BASES_PER_LINE;

private ReferenceDataSource refSource;
private FastaReferenceWriter refWriter;
private FileWriter chainFileWriter;
private FileWriter intervalRegularWriter;
private FileWriter intervalShiftedWriter;

private int chainId = 0;

@Override
public boolean requiresReference() {
return true;
}

@Override
public void onTraversalStart() {
refSource = referenceArguments.getReferencePath() != null ? ReferenceDataSource.of(referenceArguments.getReferencePath()) : null;
final Path path = IOUtils.getPath(output);
chainId = 1;
try {
refWriter = new FastaReferenceWriterBuilder()
.setFastaFile(path)
.setBasesPerLine(basesPerLine)
.build();
chainFileWriter = new FileWriter(shiftBackOutput);
if (intervalFilename != null) {
intervalRegularWriter = new FileWriter(intervalFilename+ ".intervals");
intervalShiftedWriter = new FileWriter(intervalFilename + ".shifted.intervals");
}
} catch (IOException e) {
throw new UserException.CouldNotCreateOutputFile("Couldn't create " + output + ", encountered exception: " + e.getMessage(), e);
}
}

public void traverse() {
SAMSequenceDictionary refDict = refSource.getSequenceDictionary();
long refLengthLong = refDict.getReferenceLength();
if (refLengthLong > Integer.MAX_VALUE) {
// TODO fix this??
throw new UserException.BadInput("Reference length is too long");
}
List<SAMSequenceRecord> contigs = refSource.getSequenceDictionary().getSequences();
if (shiftOffsets != null && !shiftOffsets.isEmpty() && shiftOffsets.size() != contigs.size()) {
throw new UserException.BadInput("Shift offset list size " + shiftOffsets.size() + " must equal number of contigs in the reference " + contigs.size());
}
final ListIterator<Integer> shiftOffsetsIt = shiftOffsets != null && !shiftOffsets.isEmpty() ? shiftOffsets.listIterator() : null;
refSource.getSequenceDictionary().getSequences().forEach(seq -> shiftContig(seq, shiftOffsetsIt));
}

/**
* This method adds to a new fasta ref file that has been shifted by the amount indicated in the shiftOffsetsIt.
* This also adds to the supporting files: chainfile, interval list for both the shifted and unshifted fasta files
* The shift is all done in memory. This is a scaling limitation.
* @param seq The contig or sequence within the fasta file
* @param shiftOffsetsIt the iterator at the correct position to get the next offset or null if dividing contig by 2
*/
protected final void shiftContig(SAMSequenceRecord seq, ListIterator<Integer> shiftOffsetsIt) {
final int contigLength = seq.getSequenceLength();
final String seqName = seq.getSequenceName();
int shiftOffset = shiftOffsetsIt == null ? contigLength/2 : shiftOffsetsIt.next();
if (shiftOffset > 0 && shiftOffset < contigLength) {
byte[] bases = refSource.queryAndPrefetch(new SimpleInterval(seqName, 1, contigLength)).getBases();
byte[] basesAtEnd = Arrays.copyOfRange(bases, shiftOffset, bases.length);
byte[] basesAtStart = Arrays.copyOf(bases, shiftOffset);
int shiftBackOffset = bases.length - shiftOffset;

addToShiftedReference(refWriter, seqName, basesPerLine, basesAtStart, basesAtEnd);
addToChainFile(seqName, contigLength, shiftOffset, bases, shiftBackOffset);
if (intervalFilename != null && shiftOffsetsIt == null) {
addToIntervalFiles(intervalRegularWriter, intervalShiftedWriter, seqName, shiftOffset, contigLength);
}
} else {
logger.info("not shifting config " + seq.getContig() + " because shift offset " + shiftOffset + " is not between 1-" + contigLength );
}
}

private void addToIntervalFiles(FileWriter intervalRegularWriter, FileWriter intervalShiftedWriter, String seqName, int shiftOffset, int contigLength) {
try {
int intervalStart = shiftOffset/2;
int intervalEnd = intervalStart + contigLength/2 - 1;
int shiftedIntervalStart = intervalStart;
int shiftedIntervalEnd = intervalEnd + contigLength % 2;
intervalRegularWriter.append(seqName + ":" + intervalStart + "-" + intervalEnd + "\n");
intervalShiftedWriter.append(seqName + ":" + shiftedIntervalStart + "-" + shiftedIntervalEnd + "\n");
} catch (IOException e) {
throw new UserException("Failed to write interval files due to " + e.getMessage(), e);
}

}

private void addToShiftedReference(FastaReferenceWriter refWriter, String seqName, int basesPerLine, byte[] basesAtStart, byte[] basesAtEnd) {
try {
refWriter.startSequence(seqName, basesPerLine);
// swap the bases
refWriter.appendBases(basesAtEnd).appendBases(basesAtStart);
} catch (IOException e) {
throw new UserException("Failed to write shifted reference due to " + e.getMessage(), e);
}
}

private void addToChainFile(String seqName, int contigLength, int shiftOffset, byte[] bases, int shiftBackOffset) {
try {
chainFileWriter.append(createChainString(seqName, shiftBackOffset, contigLength, shiftOffset, bases.length, 0, shiftBackOffset, chainId++));
chainFileWriter.append("\n" + shiftBackOffset + "\n\n");
chainFileWriter.append(createChainString(seqName, shiftOffset - 1, contigLength, 0, shiftOffset, shiftBackOffset, bases.length, chainId++));
chainFileWriter.append("\n" + shiftOffset + "\n\n");
} catch (IOException e) {
throw new UserException("Failed to write chainFile due to " + e.getMessage(), e);
}
}

private String createChainString(String name, int score, int length, int start, int end, int shiftBackStart, int shiftBackEnd, int id) {
String[] items = new String[] { "chain",
Integer.toString(score),
name,
Integer.toString(length),
"+",
Integer.toString(shiftBackStart),
Integer.toString(shiftBackEnd),
name,
Integer.toString(length),
"+",
Integer.toString(start),
Integer.toString(end),
Integer.toString(id)
};
return String.join("\t", items);
}

@Override
public Object onTraversalSuccess(){
return null;
}

@Override
public void closeTool() {
super.closeTool();
try{
if( refWriter != null ) {
refWriter.close();
}
} catch (IOException e) {
throw new UserException("Failed to write fasta due to " + e.getMessage(), e);
}
try{
if (chainFileWriter != null) {
chainFileWriter.close();
}
} catch (IOException e) {
throw new UserException("Failed to write chain file due to " + e.getMessage(), e);
}
try{
if (intervalRegularWriter != null) {
intervalRegularWriter.close();
}
if (intervalShiftedWriter != null) {
intervalShiftedWriter.close();
}
} catch (IOException e) {
throw new UserException("Failed to write intervals due to " + e.getMessage(), e);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ private Collection<Path<V,E>> likelyErrorChains(final List<Path<V, E>> chains, f
final Multimap<V, Path<V,E>> vertexToGoodOutgoingChains = ArrayListMultimap.create();

for (final Path<V,E> chain : chains) {
if (chainLogOdds.get(chain).getRight() >= logOddsThreshold) {
if (chainLogOdds.get(chain).getRight() >= logOddsThreshold || chain.getEdges().get(0).isRef()) {
vertexToGoodIncomingChains.put(chain.getLastVertex(), chain);
}

if (chainLogOdds.get(chain).getLeft() >= logOddsThreshold) {
if (chainLogOdds.get(chain).getLeft() >= logOddsThreshold || chain.getEdges().get(0).isRef()) {
vertexToGoodOutgoingChains.put(chain.getFirstVertex(), chain);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public abstract class AbstractReadThreadingGraph extends BaseGraph<MultiDeBruijn
private static final boolean DEBUG_NON_UNIQUE_CALC = false;
private static final int MAX_CIGAR_COMPLEXITY = 3;
private static final boolean INCREASE_COUNTS_BACKWARDS = true;
private int minMatchingBasesToDangingEndRecovery = -1;
private int minMatchingBasesToDanglingEndRecovery = -1;

/**
* for debugging info printing
Expand Down Expand Up @@ -87,7 +87,7 @@ public AbstractReadThreadingGraph(final int kmerSize, final boolean debugGraphTr

this.debugGraphTransformations = debugGraphTransformations;
this.minBaseQualityToUseInAssembly = minBaseQualityToUseInAssembly;
this.minMatchingBasesToDangingEndRecovery = numDanglingMatchingPrefixBases;
this.minMatchingBasesToDanglingEndRecovery = numDanglingMatchingPrefixBases;
}

/**
Expand Down Expand Up @@ -181,8 +181,8 @@ protected void setAlreadyBuilt() {
}

@VisibleForTesting
void setMinMatchingBasesToDangingEndRecovery(final int minMatchingBasesToDangingEndRecovery) {
this.minMatchingBasesToDangingEndRecovery = minMatchingBasesToDangingEndRecovery;
void setMinMatchingBasesToDanglingEndRecovery(final int minMatchingBasesToDanglingEndRecovery) {
this.minMatchingBasesToDanglingEndRecovery = minMatchingBasesToDanglingEndRecovery;
}

@VisibleForTesting
Expand Down Expand Up @@ -512,7 +512,7 @@ private Pair<Integer, Integer> bestPrefixMatch(final List<CigarElement> cigarEle
* The minimum number of matches to be considered allowable for recovering dangling ends
*/
private int getMinMatchingBases() {
return minMatchingBasesToDangingEndRecovery;
return minMatchingBasesToDanglingEndRecovery;
}

/**
Expand All @@ -539,7 +539,7 @@ private int recoverDanglingHead(final MultiDeBruijnVertex vertex, final int prun
}

// merge
return minMatchingBasesToDangingEndRecovery >= 0 ? mergeDanglingHead(danglingHeadMergeResult) : mergeDanglingHeadLegacy(danglingHeadMergeResult);
return minMatchingBasesToDanglingEndRecovery >= 0 ? mergeDanglingHead(danglingHeadMergeResult) : mergeDanglingHeadLegacy(danglingHeadMergeResult);
}

/**
Expand All @@ -557,7 +557,7 @@ final int mergeDanglingTail(final DanglingChainMergeHelper danglingTailMergeResu

final int lastRefIndex = danglingTailMergeResult.cigar.getReferenceLength() - 1;
final int matchingSuffix = Math.min(longestSuffixMatch(danglingTailMergeResult.referencePathString, danglingTailMergeResult.danglingPathString, lastRefIndex), lastElement.getLength());
if (minMatchingBasesToDangingEndRecovery >= 0 ? matchingSuffix < minMatchingBasesToDangingEndRecovery : matchingSuffix == 0 ) {
if (minMatchingBasesToDanglingEndRecovery >= 0 ? matchingSuffix < minMatchingBasesToDanglingEndRecovery : matchingSuffix == 0 ) {
return 0;
}

Expand Down Expand Up @@ -887,7 +887,7 @@ private int bestPrefixMatchLegacy(final byte[] path1, final byte[] path2, final

/**
* NOTE: this method is only used for dangling heads and not tails.
*
*
* Determine the maximum number of mismatches permitted on the branch.
* Unless it's preset (e.g. by unit tests) it should be the length of the branch divided by the kmer size.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class JunctionTreeLinkedDeBruijnGraph extends AbstractReadThreadingGraph
private static final long serialVersionUID = 1l;
private static final MultiDeBruijnVertex SYMBOLIC_END_VETEX = new MultiDeBruijnVertex(new byte[]{'_'});
private MultiSampleEdge SYMBOLIC_END_EDGE;

private Map<MultiDeBruijnVertex, ThreadingTree> readThreadingJunctionTrees = new HashMap<>();

// TODO should this be constructed here or elsewhere
Expand Down Expand Up @@ -232,7 +232,7 @@ private List<MultiDeBruijnVertex> getReferencePathForwardFromKmer(final MultiDeB
return extraSequence;
}

// TODO this behavior is frankly silly and needs to be fixed, there is no way upwards paths should be dangingling head recovered differently
// TODO this behavior is frankly silly and needs to be fixed, there is no way upwards paths should be dangling head recovered differently
private List<MultiDeBruijnVertex> getReferencePathBackwardsForKmer(final MultiDeBruijnVertex targetKmer) {
int firstIndex = referencePath.indexOf(targetKmer);
if (firstIndex == -1) return Collections.singletonList(targetKmer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public final class ReadThreadingAssembler {
private boolean recoverDanglingBranches = true;
private boolean recoverAllDanglingBranches = false;
private int minDanglingBranchLength = 0;

protected byte minBaseQualityToUseInAssembly = DEFAULT_MIN_BASE_QUALITY_TO_USE;
private int pruneFactor;
private final ChainPruner<MultiDeBruijnVertex, MultiSampleEdge> chainPruner;
Expand All @@ -79,7 +79,7 @@ public ReadThreadingAssembler(final int maxAllowedPathsForReadThreadingAssembler
final int numPruningSamples, final int pruneFactor, final boolean useAdaptivePruning,
final double initialErrorRateForPruning, final double pruningLogOddsThreshold,
final double pruningSeedingLogOddsThreshold, final int maxUnprunedVariants, final boolean useLinkedDebruijnGraphs,
final boolean enableLegacyGraphCycleDetection, final int minMachingBasesToDanglngEndRecovery) {
final boolean enableLegacyGraphCycleDetection, final int minMachingBasesToDanglingEndRecovery) {
Utils.validateArg( maxAllowedPathsForReadThreadingAssembler >= 1, "numBestHaplotypesPerGraph should be >= 1 but got " + maxAllowedPathsForReadThreadingAssembler);
this.kmerSizes = kmerSizes.stream().sorted(Integer::compareTo).collect(Collectors.toList());
this.dontIncreaseKmerSizesForCycles = dontIncreaseKmerSizesForCycles;
Expand All @@ -95,7 +95,7 @@ public ReadThreadingAssembler(final int maxAllowedPathsForReadThreadingAssembler
chainPruner = useAdaptivePruning ? new AdaptiveChainPruner<>(initialErrorRateForPruning, pruningLogOddsThreshold, pruningSeedingLogOddsThreshold, maxUnprunedVariants) :
new LowWeightChainPruner<>(pruneFactor);
numBestHaplotypesPerGraph = maxAllowedPathsForReadThreadingAssembler;
this.minMatchingBasesToDanglingEndRecovery = minMachingBasesToDanglngEndRecovery;
this.minMatchingBasesToDanglingEndRecovery = minMachingBasesToDanglingEndRecovery;
}

@VisibleForTesting
Expand Down Expand Up @@ -866,4 +866,4 @@ public void setArtificialHaplotypeRecoveryMode(boolean disableUncoveredJunctionT
recoverHaplotypesFromEdgesNotCoveredInJunctionTrees = false;
}
}
}
}
Loading

0 comments on commit f548ccd

Please sign in to comment.