Skip to content

Commit

Permalink
Fix asynchronous Python exception propagation in StreamingPythonExecu…
Browse files Browse the repository at this point in the history
…tor/CNNScoreVariants. (#7402)

* Catch python exceptions during execution of asynchronous batch write statements, update tests and example tool.
* Add a CNNSCoreVariants tests to force an exception during batch write.
* fixes #7401
  • Loading branch information
cmnbroad committed Dec 23, 2022
1 parent 7f40444 commit aa05918
Show file tree
Hide file tree
Showing 6 changed files with 687 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public class ExampleStreamingPythonExecutor extends ReadWalker {
final StreamingPythonScriptExecutor<String> pythonExecutor = new StreamingPythonScriptExecutor<>(true);

private List<String> batchList = new ArrayList<>(batchSize);
private boolean batchIsOutstanding = false;
private int batchCount = 0;

@Override
Expand All @@ -77,8 +78,11 @@ public void apply(GATKRead read, ReferenceContext referenceContext, FeatureConte
// Extract data from the read and accumulate, unless we've reached a batch size, in which case we
// kick off an asynchronous batch write.
if (batchCount == batchSize) {
pythonExecutor.waitForPreviousBatchCompletion();
if (batchIsOutstanding) {
pythonExecutor.waitForPreviousBatchCompletion();
}
startAsynchronousBatchWrite(); // start a new batch
batchIsOutstanding = true;
}
batchList.add(String.format(
"Read at %s:%d-%d:\n%s\n",
Expand All @@ -91,12 +95,15 @@ public void apply(GATKRead read, ReferenceContext referenceContext, FeatureConte
* @return Success indicator.
*/
public Object onTraversalSuccess() {
pythonExecutor.waitForPreviousBatchCompletion(); // wait for the previous batch to complete, if there is one
if (batchCount != 0) {
// If we have any accumulated reads that haven't been dispatched, start one last
// async batch write, and then wait for it to complete
if (batchIsOutstanding) {
pythonExecutor.waitForPreviousBatchCompletion();
}
startAsynchronousBatchWrite();
pythonExecutor.waitForPreviousBatchCompletion();
batchIsOutstanding = false;
}

return true;
Expand All @@ -108,6 +115,7 @@ private void startAsynchronousBatchWrite() {
pythonExecutor.startBatchWrite(
String.format("for i in range(%s):\n tempFile.write(tool.readDataFIFO())" + NL + NL, batchCount),
batchList);
batchIsOutstanding = true;
batchList = new ArrayList<>(batchSize);
batchCount = 0;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,15 @@ public void startBatchWrite(final String pythonCommand, final List<T> batchList)
* @return returns null if no previous work to complete, otherwise a completed Future
*/
public Future<Integer> waitForPreviousBatchCompletion() {
// wait for the batch queue to be completely written
// Rather than waiting for the asyncWriter Future to complete first, and THEN waiting for
// the ack, call waitForAck() first instead, because it will will detect and propagate any
// exception that occurs on the python side that causes it to stop pulling data from the
// FIFO (which in turn can result in the background thread blocking, thereby preventing the
// asyncWriter Future from ever completing). This is safer than waiting for the Future first,
// since the Future might never complete if the async writer thread is blocked.
waitForAck();
// now that we have the ack, verify that the async batch write completed
final Future<Integer> numberOfItemsWritten = asyncWriter.waitForPreviousBatchCompletion();
if (numberOfItemsWritten != null) {
// wait for the written items to be completely consumed
waitForAck();
}
return numberOfItemsWritten;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public Future<Integer> waitForPreviousBatchCompletion() {
*/
public boolean terminate() {
boolean isCancelled = true;
if (previousBatch != null) {
if (previousBatch != null && !previousBatch.isDone()) {
logger.warn("Cancelling outstanding asynchronous writing");
isCancelled = previousBatch.cancel(true);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.broadinstitute.hellbender.testutils.VariantContextTestUtils;
import org.broadinstitute.hellbender.utils.Utils;

import org.broadinstitute.hellbender.utils.python.PythonScriptExecutorException;
import org.testng.Assert;
import org.testng.SkipException;
import org.broadinstitute.hellbender.utils.variant.GATKVCFConstants;
Expand Down Expand Up @@ -65,6 +66,22 @@ public void testAllDefaultArgs() {
assertInfoFieldsAreClose(tempVcf, expectedVcf, GATKVCFConstants.CNN_1D_KEY);
}

@Test(groups = {"python"}, expectedExceptions = PythonScriptExecutorException.class)
public void testExceptionDuringAsyncBatch() {
final ArgumentsBuilder argsBuilder = new ArgumentsBuilder();
final File tempVcf = createTempFile("tester", ".vcf");
// the last variant in this vcf has a value of "." for the float attributes in the default CNN
// annotation set MQ, MQRankSum, ReadPosRankSum, SOR, VQSLOD, and QD
//TODO: move this into the large resources dir
final File malformedVCF = new File("src/test/resources/cnn_1d_chr20_subset_expected.badAnnotations.vcf");
argsBuilder.add(StandardArgumentDefinitions.VARIANT_LONG_NAME, malformedVCF)
.add(StandardArgumentDefinitions.OUTPUT_LONG_NAME, tempVcf.getPath())
.add(StandardArgumentDefinitions.REFERENCE_LONG_NAME, b37_reference_20_21)
.add(StandardArgumentDefinitions.ADD_OUTPUT_VCF_COMMANDLINE, "false");

runCommandLine(argsBuilder);
}

@Test(groups = {"python"})
public void testInferenceArchitecture() {
final boolean newExpectations = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import htsjdk.samtools.util.BufferedLineReader;
import org.broadinstitute.hellbender.GATKBaseTest;
import org.broadinstitute.hellbender.utils.runtime.AsynchronousStreamWriter;
import org.broadinstitute.hellbender.utils.runtime.ProcessOutput;
import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
Expand Down Expand Up @@ -180,8 +179,8 @@ public void testAsyncWriteService(final PythonScriptExecutor.PythonExecutableNam
for (int i = 0; i < ROUND_TRIP_COUNT; i++) {
if (i != 0 && (i % syncFrequency) == 0) {
// wait for the last batch to complete before we start a new one
streamingPythonExecutor.waitForPreviousBatchCompletion();
streamingPythonExecutor.startBatchWrite(String.format(PYTHON_TRANSFER_FIFO_TO_TEMP_FILE, count), fifoData);
streamingPythonExecutor.waitForPreviousBatchCompletion();
count = 0;
fifoData = new ArrayList<>(syncFrequency);
}
Expand All @@ -194,9 +193,6 @@ public void testAsyncWriteService(final PythonScriptExecutor.PythonExecutableNam
count++;
}

// wait for the writing to complete
streamingPythonExecutor.waitForPreviousBatchCompletion();

if (fifoData.size() != 0) {
streamingPythonExecutor.startBatchWrite(String.format(PYTHON_TRANSFER_FIFO_TO_TEMP_FILE, count), fifoData);
// wait for the writing to complete
Expand Down Expand Up @@ -254,6 +250,55 @@ public void testRaisePythonException(final PythonScriptExecutor.PythonExecutable
executeBadPythonCode(executableName,"raise Exception");
}

@Test(groups = "python", dataProvider = "supportedPythonVersions", dependsOnMethods = "testPythonExists",
expectedExceptions = PythonScriptExecutorException.class)
public void testRaiseAsynchronousPythonException(final PythonScriptExecutor.PythonExecutableName executableName) {
final StreamingPythonScriptExecutor<String> streamingPythonExecutor =
new StreamingPythonScriptExecutor<>(executableName, true);
Assert.assertNotNull(streamingPythonExecutor);
Assert.assertTrue(streamingPythonExecutor.start(Collections.emptyList(), true, null));

try {
streamingPythonExecutor.sendAsynchronousCommand("raise Exception" + NL);
streamingPythonExecutor.waitForAck();
} finally {
streamingPythonExecutor.terminate();
Assert.assertFalse(streamingPythonExecutor.getProcess().isAlive());
}
}

@Test(groups = "python", dataProvider = "supportedPythonVersions", dependsOnMethods = "testPythonExists",
expectedExceptions = PythonScriptExecutorException.class)
public void testRaiseAsynchronousBatchWritePythonException(final PythonScriptExecutor.PythonExecutableName executableName) {
final StreamingPythonScriptExecutor<String> streamingPythonExecutor =
new StreamingPythonScriptExecutor<>(executableName, true);
Assert.assertNotNull(streamingPythonExecutor);
Assert.assertTrue(streamingPythonExecutor.start(Collections.emptyList(), true, null));

try {
final int BATCH_SIZE = 1000;
final List<String> batchList = createLargeBatch(BATCH_SIZE);
streamingPythonExecutor.initStreamWriter(AsynchronousStreamWriter.stringSerializer);

final String batchCommand = String.format(
"for i in range(0, %d):"+ NL + "\t tool.readDataFIFO()" + NL + NL + "raise Exception" + NL,
BATCH_SIZE);
streamingPythonExecutor.startBatchWrite(batchCommand, batchList);
streamingPythonExecutor.waitForPreviousBatchCompletion();
} finally {
streamingPythonExecutor.terminate();
Assert.assertFalse(streamingPythonExecutor.getProcess().isAlive());
}
}

private List<String> createLargeBatch(final int batchSize) {
final List<String> batchList = new ArrayList<>(1000);
for (int i = 0; i < batchSize; i++) {
batchList.add(String.format("%d\n", i));
}
return batchList;
}

@Test(groups = "python", dataProvider="supportedPythonVersions", dependsOnMethods = "testPythonExists",
expectedExceptions = PythonScriptExecutorException.class)
public void testRaisePythonAssert(final PythonScriptExecutor.PythonExecutableName executableName) {
Expand Down
Loading

0 comments on commit aa05918

Please sign in to comment.