Skip to content

Commit

Permalink
postprocess fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelklee committed Dec 8, 2023
1 parent 503bbb8 commit 4bf5286
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from . import commons
from .fancy_model import GeneralizedContinuousModel
from .pytensor_hmm import TheanoForwardBackward
from .pytensor_hmm import PytensorForwardBackward
from .. import config, types
from ..structs.interval import Interval, GCContentAnnotation
from ..structs.metadata import SampleMetadataCollection
Expand Down Expand Up @@ -980,7 +980,7 @@ def __init__(self,
self.pi_sjkc: types.TensorSharedVariable = pytensor.shared(pi_sjkc, name='pi_sjkc', borrow=config.borrow_numpy)

# compiled function for forward-backward updates of copy number posterior
self._hmm_q_copy_number = TheanoForwardBackward(
self._hmm_q_copy_number = PytensorForwardBackward(
log_posterior_probs_output_tc=None,
resolve_nans=False,
do_thermalization=True,
Expand All @@ -993,7 +993,7 @@ def __init__(self,
# Note:
# if p_active == 0, we have to deal with inf - inf expressions properly.
# setting resolve_nans = True takes care of such ambiguities.
self._hmm_q_class = TheanoForwardBackward(
self._hmm_q_class = PytensorForwardBackward(
log_posterior_probs_output_tc=shared_workspace.log_q_tau_tk,
resolve_nans=(calling_config.p_active == 0),
do_thermalization=True,
Expand All @@ -1004,7 +1004,7 @@ def __init__(self,
# compiled function for update of class log emission
self._update_log_class_emission_tk_pytensor_func = self._get_update_log_class_emission_tk_pytensor_func()
else:
self._hmm_q_class: Optional[TheanoForwardBackward] = None
self._hmm_q_class: Optional[PytensorForwardBackward] = None
self._update_log_class_emission_tk_pytensor_func = None

# compiled function for variational update of copy number HMM specs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .. import types


class TheanoForwardBackward:
class PytensorForwardBackward:
"""Implementation of the forward-backward algorithm in pytensor."""
def __init__(self,
log_posterior_probs_output_tc: Optional[types.TensorSharedVariable] = None,
Expand Down Expand Up @@ -226,9 +226,9 @@ def _get_compiled_forward_backward_pytensor_func(self) -> pytensor.compile.Funct
return pytensor.function(inputs=inputs, outputs=outputs, updates=updates)

@staticmethod
def get_symbolic_log_posterior(log_prior_c: types.TheanoVector,
log_trans_tcc: types.TheanoTensor3,
log_emission_tc: types.TheanoMatrix,
def get_symbolic_log_posterior(log_prior_c: types.PytensorVector,
log_trans_tcc: types.PytensorTensor3,
log_emission_tc: types.PytensorMatrix,
resolve_nans: bool):
"""Generates symbolic tensors representing hidden-state log posterior, log data likelihood,
forward table (alpha), and backward table (beta).
Expand All @@ -238,9 +238,9 @@ def get_symbolic_log_posterior(log_prior_c: types.TheanoVector,
"""
num_states = log_prior_c.shape[0]

def calculate_next_alpha(c_log_trans_ab: types.TheanoMatrix,
c_log_emission_b: types.TheanoVector,
p_alpha_a: types.TheanoVector):
def calculate_next_alpha(c_log_trans_ab: types.PytensorMatrix,
c_log_emission_b: types.PytensorVector,
p_alpha_a: types.PytensorVector):
"""Calculates the next entry on the forward table, alpha_{t}, from alpha_{t-1}.
Args:
Expand All @@ -259,9 +259,9 @@ def calculate_next_alpha(c_log_trans_ab: types.TheanoMatrix,
else:
return n_alpha_b

def calculate_prev_beta(n_log_trans_ab: types.TheanoMatrix,
n_log_emission_b: types.TheanoVector,
n_beta_b: types.TheanoVector):
def calculate_prev_beta(n_log_trans_ab: types.PytensorMatrix,
n_log_emission_b: types.PytensorVector,
n_beta_b: types.PytensorVector):
"""Calculates the previous entry on the backward table, beta_{t-1}, from beta_{t}.
Args:
Expand Down Expand Up @@ -313,9 +313,9 @@ def calculate_prev_beta(n_log_trans_ab: types.TheanoMatrix,
return log_posterior_probs_tc, log_data_likelihood_t.dimshuffle(0), alpha_tc, beta_tc

@staticmethod
def get_symbolic_thermal_hmm_params(log_prior_c: types.TheanoVector,
log_trans_tcc: types.TheanoTensor3,
log_emission_tc: types.TheanoMatrix,
def get_symbolic_thermal_hmm_params(log_prior_c: types.PytensorVector,
log_trans_tcc: types.PytensorTensor3,
log_emission_tc: types.PytensorMatrix,
temperature: pt.scalar):
inv_temperature = pt.reciprocal(temperature) # TODO inv to reciprocal

Expand Down Expand Up @@ -343,7 +343,7 @@ def __init__(self,
self.update_norm_t = update_norm_t


class TheanoViterbi:
class PytensorViterbi:
"""Implementation of the Viterbi algorithm in pytensor."""
def __init__(self):
self._viterbi_pytensor_func = self._get_compiled_viterbi_pytensor_func()
Expand All @@ -365,9 +365,9 @@ def _get_compiled_viterbi_pytensor_func(self) -> pytensor.compile.Function:
outputs=self._get_symbolic_viterbi_path(log_prior_c, log_trans_tcc, log_emission_tc))

@staticmethod
def _get_symbolic_viterbi_path(log_prior_c: types.TheanoVector,
log_trans_tcc: types.TheanoTensor3,
log_emission_tc: types.TheanoMatrix):
def _get_symbolic_viterbi_path(log_prior_c: types.PytensorVector,
log_trans_tcc: types.PytensorTensor3,
log_emission_tc: types.PytensorMatrix):
"""Generates a symbolic 1d integer tensor representing the most-likely chain of hidden states
(Viterbi algorithm).
Expand Down Expand Up @@ -434,6 +434,6 @@ def calculate_previous_best_state(c_psi_c, c_best_state):
go_backwards=True)

# concatenate with the terminal state
viterbi_path_t = pt.concatenate([pt.stack(last_best_state), rest_best_states_t])[::-1]
viterbi_path_t = pt.concatenate([last_best_state.dimshuffle('x'), rest_best_states_t])[::-1]

return viterbi_path_t
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from typing import Dict, List

import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt
from ..models.commons import logsumexp
from ..models import commons
from scipy.special import logsumexp

from ..utils.math import logsumexp_double_complement, logp_to_phred

Expand Down Expand Up @@ -112,7 +112,7 @@ def _get_compiled_constrained_path_logp_pytensor_func() -> pytensor.compile.Func
def update_alpha(c_log_emission_c: pt.vector,
c_log_trans_cc: pt.matrix,
p_alpha_c: pt.vector):
return c_log_emission_c + logsumexp(
return c_log_emission_c + commons.logsumexp(
p_alpha_c.dimshuffle(0, 'x') + c_log_trans_cc, axis=0).dimshuffle(1) # TODO check this

alpha_seg_iters, _ = pytensor.scan(
Expand All @@ -122,7 +122,7 @@ def update_alpha(c_log_emission_c: pt.vector,
alpha_seg_end_c = alpha_seg_iters[-1, :]

inputs = [alpha_first_c, beta_last_c, log_emission_tc, log_trans_tcc, log_data_likelihood]
output = logsumexp(alpha_seg_end_c + beta_last_c) - log_data_likelihood
output = commons.logsumexp(alpha_seg_end_c + beta_last_c) - log_data_likelihood
return pytensor.function(inputs=inputs, outputs=output)

# make a private static instance
Expand Down Expand Up @@ -166,7 +166,7 @@ def get_log_constrained_posterior_prob(self,
constrained_alpha_first_c, constrained_beta_last_c,
constrained_log_emission_tc, constrained_log_trans_tcc, self.log_data_likelihood)

return np.asscalar(logp)
return logp.item()

def get_segment_quality_some_called(self, start_index: int, end_index: int, call_state: int) -> float:
"""Calculates the phred-scaled posterior probability that one or more ("some") sites in a segment have
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..io import io_consts, io_commons, io_denoising_calling, io_intervals_and_counts, io_vcf_parsing
from ..models.model_denoising_calling import DenoisingModelConfig, CopyNumberCallingConfig, \
HHMMClassAndCopyNumberBasicCaller
from ..models.pytensor_hmm import TheanoForwardBackward, TheanoViterbi
from ..models.pytensor_hmm import PytensorForwardBackward, PytensorViterbi
from ..structs.interval import Interval
from ..structs.metadata import IntervalListMetadata
from ..structs.metadata import SampleMetadataCollection
Expand Down Expand Up @@ -98,7 +98,7 @@ def __init__(self,

# forward-backward algorithm
_logger.info("Compiling pytensor forward-backward function...")
self.pytensor_forward_backward = TheanoForwardBackward(
self.pytensor_forward_backward = PytensorForwardBackward(
log_posterior_probs_output_tc=None,
resolve_nans=False,
do_thermalization=False,
Expand All @@ -108,7 +108,7 @@ def __init__(self,

# viterbi algorithm
_logger.info("Compiling pytensor Viterbi function...")
self.pytensor_viterbi = TheanoViterbi()
self.pytensor_viterbi = PytensorViterbi()

# copy-number HMM specs generator
_logger.info("Compiling pytensor variational HHMM...")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
uint_dtypes = [np.uint8, np.uint16, np.uint32, np.uint64]

# pytensor tensor types
TheanoVector = pt.TensorType(floatX, (False,))
TheanoMatrix = pt.TensorType(floatX, (False, False))
TheanoTensor3 = pt.TensorType(floatX, (False, False, False))
PytensorVector = pt.TensorType(floatX, (False,))
PytensorMatrix = pt.TensorType(floatX, (False, False))
PytensorTensor3 = pt.TensorType(floatX, (False, False, False))
TensorSharedVariable = pytensor.tensor.sharedvar.TensorSharedVariable
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,13 @@ def logsumexp_double_complement(a: np.ndarray, rel_tol: float = 1e-3) -> float:
Returns:
a float scalar
"""
print(a)
try:
assert isinstance(a, np.ndarray)
a = np.asarray(a.copy(), dtype=np.float)
a = np.asarray(a.copy(), dtype=float)
except AssertionError:
try:
a = np.asarray(a, dtype=np.float)
a = np.asarray(a, dtype=float)
except ValueError:
raise ValueError("The input argument must be castable to a float ndarray.")
assert len(a) > 0
Expand All @@ -93,7 +94,7 @@ def logsumexp_double_complement(a: np.ndarray, rel_tol: float = 1e-3) -> float:
a[a > 0.] = 0.

if len(a) == 1:
return np.asscalar(a)
return a.item()
else:
a = np.sort(a.flatten())[::-1]
x = a[0]
Expand Down

0 comments on commit 4bf5286

Please sign in to comment.