Skip to content

Commit

Permalink
remove unused distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
jules-samaran committed Feb 26, 2024
1 parent 8b88250 commit a787320
Showing 1 changed file with 0 additions and 254 deletions.
254 changes: 0 additions & 254 deletions scconfluence/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,6 @@
# are quite auxiliary functions.


def log_nm_positive(x: torch.Tensor, r: torch.Tensor, probs: torch.Tensor, eps=1e-8):
"""
Log likelihood (scalar) of a minibatch according to a negative multinomial model.
Parameters
----------
x
Data
r
dispersion parameter (has to be positive support) (shape: 1)
probs
vector of success probabilities (must be on the simplex) (shape: minibatch x vars)
eps
numerical stability constant
"""
ll = (
torch.lgamma(r + torch.sum(x, dim=1))
+ r * torch.log(probs[:, 0] + eps)
- torch.lgamma(r)
+ torch.sum(x * torch.log(probs[:, 1:] + eps), dim=1)
)
return ll


def log_zinb_positive(
x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, pi: torch.Tensor, eps=1e-8
):
Expand Down Expand Up @@ -125,77 +101,6 @@ def log_nb_positive(x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, eps=
return res


def log_mixture_nb(
x: torch.Tensor,
mu_1: torch.Tensor,
mu_2: torch.Tensor,
theta_1: torch.Tensor,
theta_2: torch.Tensor,
pi_logits: torch.Tensor,
eps=1e-8,
):
"""
Log likelihood (scalar) of a minibatch according to a mixture nb model.
pi_logits is the probability (logits) to be in the first component.
For totalVI, the first component should be background.
Parameters
----------
x
Observed data
mu_1
Mean of the first negative binomial component (has to be positive support) (shape: minibatch x features)
mu_2
Mean of the second negative binomial (has to be positive support) (shape: minibatch x features)
theta_1
First inverse dispersion parameter (has to be positive support) (shape: minibatch x features)
theta_2
Second inverse dispersion parameter (has to be positive support) (shape: minibatch x features)
If None, assume one shared inverse dispersion parameter.
pi_logits
Probability of belonging to mixture component 1 (logits scale)
eps
Numerical stability constant
"""
if theta_2 is not None:
log_nb_1 = log_nb_positive(x, mu_1, theta_1)
log_nb_2 = log_nb_positive(x, mu_2, theta_2)
# this is intended to reduce repeated computations
else:
theta = theta_1
if theta.ndimension() == 1:
theta = theta.view(
1, theta.size(0)
) # In this case, we reshape theta for broadcasting

log_theta_mu_1_eps = torch.log(theta + mu_1 + eps)
log_theta_mu_2_eps = torch.log(theta + mu_2 + eps)
lgamma_x_theta = torch.lgamma(x + theta)
lgamma_theta = torch.lgamma(theta)
lgamma_x_plus_1 = torch.lgamma(x + 1)

log_nb_1 = (
theta * (torch.log(theta + eps) - log_theta_mu_1_eps)
+ x * (torch.log(mu_1 + eps) - log_theta_mu_1_eps)
+ lgamma_x_theta
- lgamma_theta
- lgamma_x_plus_1
)
log_nb_2 = (
theta * (torch.log(theta + eps) - log_theta_mu_2_eps)
+ x * (torch.log(mu_2 + eps) - log_theta_mu_2_eps)
+ lgamma_x_theta
- lgamma_theta
- lgamma_x_plus_1
)

logsumexp = torch.logsumexp(torch.stack((log_nb_1, log_nb_2 - pi_logits)), dim=0)
softplus_pi = F.softplus(-pi_logits)

log_mixture_nb = logsumexp - softplus_pi

return log_mixture_nb


def _convert_mean_disp_to_counts_logits(mu, theta, eps=1e-6):
r"""
NB parameterizations conversion.
Expand Down Expand Up @@ -249,61 +154,6 @@ def _gamma(theta, mu):
return gamma_d


class NegativeMultinomial(Distribution):
r"""
Negative multinomial distribution.
Parameters
----------
r
Real valued positive dispersion parameter
logits
Vector of logits
validate_args
Raise ValueError if arguments do not match constraints
"""

arg_constraints = {"r": constraints.greater_than(0)}
support = constraints.nonnegative_integer

def __init__(
self,
logits: torch.Tensor = None,
r: torch.Tensor = None,
validate_args: bool = False,
):
self._eps = 1e-8

self.r = r
self.probs = torch.softmax(
torch.cat(
(torch.zeros((logits.size(0), 1), device=logits.device), logits), dim=-1
),
dim=1,
)
super().__init__(validate_args=validate_args)

@property
def mean(self):
return (self.r / self.probs[:, 0]) * self.probs

@property
def variance(self):
raise NotImplementedError

def log_prob(self, value: torch.Tensor) -> torch.Tensor:
if self._validate_args:
try:
self._validate_sample(value)
except ValueError:
warnings.warn(
"The value argument must be within the support of the distribution",
UserWarning,
)

return log_nm_positive(value, r=self.r, probs=self.probs, eps=self._eps)


class NegativeBinomial(Distribution):
r"""
Negative binomial distribution.
Expand Down Expand Up @@ -496,107 +346,3 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
UserWarning,
)
return log_zinb_positive(value, self.mu, self.theta, self.zi_logits, eps=1e-08)


class NegativeBinomialMixture(Distribution):
"""
Negative binomial mixture distribution.
See :class:`~scvi.distributions.NegativeBinomial` for further description
of parameters.
Parameters
----------
mu1
Mean of the component 1 distribution.
mu2
Mean of the component 2 distribution.
theta1
Inverse dispersion for component 1.
mixture_logits
Logits scale probability of belonging to component 1.
theta2
Inverse dispersion for component 1. If `None`, assumed to be equal to `theta1`.
validate_args
Raise ValueError if arguments do not match constraints
"""

arg_constraints = {
"mu1": constraints.greater_than_eq(0),
"mu2": constraints.greater_than_eq(0),
"theta1": constraints.greater_than_eq(0),
"mixture_probs": constraints.half_open_interval(0.0, 1.0),
"mixture_logits": constraints.real,
}
support = constraints.nonnegative_integer

def __init__(
self,
mu1: torch.Tensor,
mu2: torch.Tensor,
theta1: torch.Tensor,
mixture_logits: torch.Tensor,
theta2: Optional[torch.Tensor] = None,
validate_args: bool = False,
):

(
self.mu1,
self.theta1,
self.mu2,
self.mixture_logits,
) = broadcast_all(mu1, theta1, mu2, mixture_logits)

super().__init__(validate_args=validate_args)

if theta2 is not None:
self.theta2 = broadcast_all(mu1, theta2)
else:
self.theta2 = None

@property
def mean(self):
pi = self.mixture_probs
return pi * self.mu1 + (1 - pi) * self.mu2

@lazy_property
def mixture_probs(self) -> torch.Tensor:
return logits_to_probs(self.mixture_logits, is_binary=True)

def sample(
self, sample_shape: Union[torch.Size, Tuple] = torch.Size()
) -> torch.Tensor:
with torch.no_grad():
pi = self.mixture_probs
mixing_sample = torch.distributions.Bernoulli(pi).sample()
mu = self.mu1 * mixing_sample + self.mu2 * (1 - mixing_sample)
if self.theta2 is None:
theta = self.theta1
else:
theta = self.theta1 * mixing_sample + self.theta2 * (1 - mixing_sample)
gamma_d = _gamma(mu, theta)
p_means = gamma_d.sample(sample_shape)

# Clamping as distributions objects can have buggy behaviors when
# their parameters are too high
l_train = torch.clamp(p_means, max=1e8)
counts = Poisson(
l_train
).sample() # Shape : (n_samples, n_cells_batch, n_features)
return counts

def log_prob(self, value: torch.Tensor) -> torch.Tensor:
try:
self._validate_sample(value)
except ValueError:
warnings.warn(
"The value argument must be within the support of the distribution",
UserWarning,
)
return log_mixture_nb(
value,
self.mu1,
self.mu2,
self.theta1,
self.theta2,
self.mixture_logits,
eps=1e-08,
)

0 comments on commit a787320

Please sign in to comment.