Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagated all hidden arguments in gCNV CASE mode from the given model #7464

Merged
merged 2 commits into from
Oct 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.broadinstitute.barclay.argparser.ArgumentCollection;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.barclay.argparser.CommandLineArgumentParser;
import org.broadinstitute.hellbender.cmdline.CommandLineProgram;
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
import org.broadinstitute.hellbender.cmdline.argumentcollections.IntervalArgumentCollection;
Expand Down Expand Up @@ -76,6 +77,9 @@
* diverges the second time we suggest checking if input count or karyotype values or other inputs are abnormal
* (an example of abnormality is a count file containing mostly zeros). </p>
*
* <p> More details about the model and inference procedure can be found in the white paper
* https://github.com/broadinstitute/gatk/blob/master/docs/CNV/germline-cnv-caller-model.pdf</p>
*
* <h3>Python environment setup</h3>
*
* <p>The computation done by this tool, aside from input data parsing and validation, is performed outside of the Java
Expand Down Expand Up @@ -121,9 +125,12 @@
* <dt>CASE mode:</dt>
* <dd><p>The tool will be run in CASE mode using the argument {@code run-mode CASE}. The path to a previously
* obtained model directory must be provided via the {@code model} argument in this mode. The modeled intervals are
* then specified by a file contained in the model directory, all interval-related arguments are ignored in this
* mode, and all model intervals must be present in all of the input count files. The tool output in CASE mode
* is only the "-calls" subdirectory and is organized similarly to that in COHORT mode.</p>
* then specified by a file contained in the model directory, and all model intervals must be present in all of the
* input count files. All interval-related arguments (e.g. {@code interval-psi-scale}) are redundant in this mode
* and will trigger an exception if provided. However, an advanced user can adjust various sample-related
* (e.g. {@code sample-psi-scale}) and global (e.g. {@code p_alt}) arguments for custom applications of the tool.
* Inference-related arguments (e.g. {@code min_training_epochs}) can be adjusted as well. The tool output in CASE
* mode is only the "-calls" subdirectory and is organized similarly to that in COHORT mode.</p>
*
* <p>Note that at the moment, this tool does not automatically verify the compatibility of the provided parametrization
* with the provided count files. Model compatibility may be assessed a posteriori by inspecting the magnitude of
Expand Down Expand Up @@ -359,8 +366,9 @@ protected Object doWork() {
}

private void validateArguments() {
germlineCallingArgumentCollection.validate();
germlineDenoisingModelArgumentCollection.validate();
final CommandLineArgumentParser clpParser = (CommandLineArgumentParser) getCommandLineParser();
germlineCallingArgumentCollection.validate(clpParser, runMode);
germlineDenoisingModelArgumentCollection.validate(clpParser, runMode);
germlineCNVHybridADVIArgumentCollection.validate();

Utils.validateArg(inputReadCountPaths.size() == new HashSet<>(inputReadCountPaths).size(),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package org.broadinstitute.hellbender.tools.copynumber.arguments;

import com.google.common.collect.ImmutableList;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.CommandLineArgumentParser;
import org.broadinstitute.hellbender.tools.copynumber.GermlineCNVCaller;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.param.ParamUtils;

import java.io.Serializable;
Expand All @@ -21,6 +24,11 @@ public final class GermlineCallingArgumentCollection implements Serializable {
public static final String CLASS_COHERENCE_LENGTH_LONG_NAME = "class-coherence-length";
public static final String MAX_COPY_NUMBER_LONG_NAME = "max-copy-number";

// these model parameters will be extracted from provided model in CASE mode
private static final List<String> HIDDEN_ARGS_CASE_MODE = ImmutableList.of(
P_ACTIVE_LONG_NAME,
CLASS_COHERENCE_LENGTH_LONG_NAME);

@Argument(
doc = "Total prior probability of alternative copy-number states (the reference copy-number " +
"is set to the contig integer ploidy)",
Expand Down Expand Up @@ -78,7 +86,11 @@ public List<String> generatePythonArguments(final GermlineCNVCaller.RunMode runM
return arguments;
}

public void validate() {
public void validate(final CommandLineArgumentParser clpParser, final GermlineCNVCaller.RunMode runMode) {
if (runMode == GermlineCNVCaller.RunMode.CASE)
HIDDEN_ARGS_CASE_MODE.forEach(a -> Utils.validateArg(
!clpParser.getNamedArgumentDefinitionByAlias(a).getHasBeenSet(),
String.format("Argument '--%s' cannot be set in the CASE mode.", a)));
ParamUtils.isPositive(cnvCoherenceLength,
String.format("Coherence length of CNV events (%s) must be positive.",
CNV_COHERENCE_LENGTH_LONG_NAME));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package org.broadinstitute.hellbender.tools.copynumber.arguments;

import com.google.common.collect.ImmutableList;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.CommandLineArgumentParser;
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
import org.broadinstitute.hellbender.tools.copynumber.GermlineCNVCaller;
import org.broadinstitute.hellbender.utils.param.ParamUtils;
import org.broadinstitute.hellbender.utils.Utils;

import java.io.Serializable;
import java.util.ArrayList;
Expand All @@ -28,6 +32,16 @@ public final class GermlineDenoisingModelArgumentCollection implements Serializa
public static final String ENABLE_BIAS_FACTORS_LONG_NAME = "enable-bias-factors";
public static final String ACTIVE_CLASS_PADDING_HYBRID_MODE_LONG_NAME = "active-class-padding-hybrid-mode";

// these model parameters will be extracted from provided model in CASE mode
private static final List<String> HIDDEN_ARGS_CASE_MODE = ImmutableList.of(
MAX_BIAS_FACTORS_LONG_NAME,
INTERVAL_PSI_SCALE_LONG_NAME,
LOG_MEAN_BIAS_STANDARD_DEVIATION_LONG_NAME,
INIT_ARD_REL_UNEXPLAINED_VARIANCE_LONG_NAME,
ENABLE_BIAS_FACTORS_LONG_NAME,
NUM_GC_BINS_LONG_NAME,
GC_CURVE_STANDARD_DEVIATION_LONG_NAME);

public enum CopyNumberPosteriorExpectationMode {
MAP("map"),
EXACT("exact"),
Expand Down Expand Up @@ -168,7 +182,11 @@ public List<String> generatePythonArguments(final GermlineCNVCaller.RunMode runM
return arguments;
}

public void validate() {
public void validate(final CommandLineArgumentParser clpParser, final GermlineCNVCaller.RunMode runMode) {
if (runMode == GermlineCNVCaller.RunMode.CASE)
HIDDEN_ARGS_CASE_MODE.forEach(a -> Utils.validateArg(
!clpParser.getNamedArgumentDefinitionByAlias(a).getHasBeenSet(),
String.format("Argument '--%s' cannot be set in the CASE mode.", a)));
ParamUtils.isPositive(maxBiasFactors,
String.format("Maximum number of bias factors (%s) must be positive.",
MAX_BIAS_FACTORS_LONG_NAME));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,23 @@
# logging args
gcnvkernel.cli_commons.add_logging_args_to_argparse(parser)

hidden_denoising_args = {
"max_bias_factors",
"psi_t_scale",
"log_mean_bias_std",
"init_ard_rel_unexplained_variance",
"enable_bias_factors",
"enable_explicit_gc_bias_modeling",
"disable_bias_factors_in_active_class",
"num_gc_bins",
"gc_curve_sd"
}

hidden_calling_args = {
"p_active",
"class_coherence_length"
}

# add tool-specific args
group = parser.add_argument_group(title="Required arguments")

Expand Down Expand Up @@ -79,26 +96,13 @@
# Note: we are hiding parameters that are either set by the model or are irrelevant to the case calling task
gcnvkernel.DenoisingModelConfig.expose_args(
parser,
hide={
"--max_bias_factors",
"--psi_t_scale",
"--log_mean_bias_std",
"--init_ard_rel_unexplained_variance",
"--enable_bias_factors",
"--enable_explicit_gc_bias_modeling",
"--disable_bias_factors_in_active_class",
"--num_gc_bins",
"--gc_curve_sd",
})
hide={"--" + arg for arg in hidden_denoising_args})

# add calling config args
# Note: we are hiding parameters that are either set by the model or are irrelevant to the case calling task
gcnvkernel.CopyNumberCallingConfig.expose_args(
parser,
hide={
'--p_active',
'--class_coherence_length'
})
hide={"--" + arg for arg in hidden_calling_args})

# override some inference parameters
gcnvkernel.HybridInferenceParameters.expose_args(parser)
Expand All @@ -109,24 +113,16 @@ def update_args_dict_from_saved_model(input_model_path: str,
logging.info("Loading denoising model configuration from the provided model...")
with open(os.path.join(input_model_path, "denoising_config.json"), 'r') as fp:
loaded_denoising_config_dict = json.load(fp)

# boolean flags
_args_dict['enable_bias_factors'] = \
loaded_denoising_config_dict['enable_bias_factors']
_args_dict['enable_explicit_gc_bias_modeling'] = \
loaded_denoising_config_dict['enable_explicit_gc_bias_modeling']
_args_dict['disable_bias_factors_in_active_class'] = \
loaded_denoising_config_dict['disable_bias_factors_in_active_class']

# bias factor related
_args_dict['max_bias_factors'] = \
loaded_denoising_config_dict['max_bias_factors']

# gc-related
_args_dict['num_gc_bins'] = \
loaded_denoising_config_dict['num_gc_bins']
_args_dict['gc_curve_sd'] = \
loaded_denoising_config_dict['gc_curve_sd']
with open(os.path.join(input_model_path, "calling_config.json"), 'r') as fp:
loaded_calling_config_dict = json.load(fp)

# load arguments from the model denoising config that are hidden by the tool
for arg in hidden_denoising_args:
_args_dict[arg] = \
loaded_denoising_config_dict[arg]
for arg in hidden_calling_args:
_args_dict[arg] = \
loaded_calling_config_dict[arg]

logging.info("- bias factors enabled: "
+ repr(_args_dict['enable_bias_factors']))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.broadinstitute.hellbender.cmdline.argumentcollections.IntervalArgumentCollection;
import org.broadinstitute.hellbender.testutils.ArgumentsBuilder;
import org.broadinstitute.hellbender.tools.copynumber.arguments.CopyNumberStandardArgument;
import org.broadinstitute.hellbender.tools.copynumber.arguments.GermlineDenoisingModelArgumentCollection;
import org.broadinstitute.hellbender.utils.IntervalMergingRule;
import org.testng.annotations.Test;

Expand Down Expand Up @@ -71,6 +72,22 @@ public void testCaseWithoutModel() {
runCommandLine(argsBuilder);
}

@Test(groups = {"python"}, expectedExceptions = IllegalArgumentException.class)
public void testCaseWithHiddenArguments() {
final ArgumentsBuilder argsBuilder = new ArgumentsBuilder();
Arrays.stream(TEST_COUNT_FILES, 0, 5).forEach(argsBuilder::addInput);
argsBuilder.add(GermlineCNVCaller.RUN_MODE_LONG_NAME, GermlineCNVCaller.RunMode.CASE.name())
.add(GermlineCNVCaller.CONTIG_PLOIDY_CALLS_DIRECTORY_LONG_NAME,
CONTIG_PLOIDY_CALLS_OUTPUT_DIR.getAbsolutePath())
.add(CopyNumberStandardArgument.MODEL_LONG_NAME,
new File(OUTPUT_DIR, "test-germline-cnv-cohort-model").getAbsolutePath())
.add(StandardArgumentDefinitions.OUTPUT_LONG_NAME, OUTPUT_DIR.getAbsolutePath())
.add(CopyNumberStandardArgument.OUTPUT_PREFIX_LONG_NAME, "test-germline-cnv-case");
// add argument that is not applicable in CASE mode
argsBuilder.add(GermlineDenoisingModelArgumentCollection.INTERVAL_PSI_SCALE_LONG_NAME, 0.1);
runCommandLine(argsBuilder);
}

@Test(groups = {"python"}, enabled = false)
public void testCohortWithInputModel() {
}
Expand Down