diff --git a/pyciemss/mira_integration/distributions.py b/pyciemss/mira_integration/distributions.py index 1e0b645d..3fa830cc 100644 --- a/pyciemss/mira_integration/distributions.py +++ b/pyciemss/mira_integration/distributions.py @@ -1,20 +1,20 @@ -from typing import Callable, Dict, Union +import warnings +from typing import Dict import mira.metamodel import pyro +import torch +ParameterDict = Dict[str, torch.Tensor] -def mira_uniform_to_pyro( - parameters: Dict[str, float] -) -> pyro.distributions.Distribution: + +def mira_uniform_to_pyro(parameters: ParameterDict) -> pyro.distributions.Distribution: low = parameters["minimum"] high = parameters["maximum"] return pyro.distributions.Uniform(low=low, high=high) -def mira_normal_to_pyro( - parameters: Dict[str, float] -) -> pyro.distributions.Distribution: +def mira_normal_to_pyro(parameters: ParameterDict) -> pyro.distributions.Distribution: if "mean" in parameters.keys(): loc = parameters["mean"] if "stdev" in parameters.keys(): @@ -28,7 +28,7 @@ def mira_normal_to_pyro( def mira_lognormal_to_pyro( - parameters: Dict[str, float] + parameters: ParameterDict, ) -> pyro.distributions.Distribution: if "meanLog" in parameters.keys(): loc = parameters["meanLog"] @@ -42,7 +42,7 @@ def mira_lognormal_to_pyro( # Provide either probs or logits, not both def mira_bernoulli_to_pyro( - parameters: Dict[str, float] + parameters: ParameterDict, ) -> pyro.distributions.Distribution: if "probability" in parameters.keys(): probs = parameters["probability"] @@ -54,12 +54,12 @@ def mira_bernoulli_to_pyro( return pyro.distributions.Bernoulli(probs=probs, logits=logits) -def mira_beta_to_pyro(parameters: Dict[str, float]) -> pyro.distributions.Distribution: +def mira_beta_to_pyro(parameters: ParameterDict) -> pyro.distributions.Distribution: return pyro.distributions.Beta(alpha=parameters["alpha"], beta=parameters["beta"]) def mira_betabinomial_to_pyro( - parameters: Dict[str, Union[float, list]] + parameters: ParameterDict, ) -> pyro.distributions.Distribution: raise NotImplementedError( "Conversion from MIRA BetaBinomial distribution to Pyro distribution is not implemented." @@ -73,9 +73,7 @@ def mira_betabinomial_to_pyro( ) -def mira_binomial_to_pyro( - parameters: Dict[str, Union[float, list]] -) -> pyro.distributions.Distribution: +def mira_binomial_to_pyro(parameters: ParameterDict) -> pyro.distributions.Distribution: total_count = parameters["numberOfTrials"] if "probability" in parameters.keys(): probs = parameters["probability"] @@ -89,30 +87,28 @@ def mira_binomial_to_pyro( ) -def mira_cauchy_to_pyro( - parameters: Dict[str, float] -) -> pyro.distributions.Distribution: +def mira_cauchy_to_pyro(parameters: ParameterDict) -> pyro.distributions.Distribution: loc = parameters["location"] scale = parameters["scale"] return pyro.distributions.Cauchy(loc=loc, scale=scale) def mira_chisquared_to_pyro( - parameters: Dict[str, float] + parameters: ParameterDict, ) -> pyro.distributions.Distribution: df = parameters["degreesOfFreedom"] return pyro.distributions.Chi2(df=df) def mira_dirichlet_to_pyro( - parameters: Dict[str, list] + parameters: ParameterDict, ) -> pyro.distributions.Distribution: concentration = parameters["concentration"] return pyro.distributions.Dirichlet(concentration=concentration) def mira_exponential_to_pyro( - parameters: Dict[str, float] + parameters: ParameterDict, ) -> pyro.distributions.Distribution: if "rate" in parameters.keys(): rate = parameters["rate"] @@ -121,7 +117,7 @@ def mira_exponential_to_pyro( return pyro.distributions.Exponential(rate=rate) -def mira_gamma_to_pyro(parameters: Dict[str, float]) -> pyro.distributions.Distribution: +def mira_gamma_to_pyro(parameters: ParameterDict) -> pyro.distributions.Distribution: if "shape" in parameters.keys(): concentration = parameters["shape"] if "scale" in parameters.keys(): @@ -132,7 +128,7 @@ def mira_gamma_to_pyro(parameters: Dict[str, float]) -> pyro.distributions.Distr def mira_inversegamma_to_pyro( - parameters: Dict[str, float] + parameters: ParameterDict, ) -> pyro.distributions.Distribution: raise NotImplementedError( "Conversion from MIRA InverseGamma distribution to Pyro distribution is not implemented." @@ -140,17 +136,13 @@ def mira_inversegamma_to_pyro( # TODO: Map parameters to Pyro distribution parameters -def mira_gumbel_to_pyro( - parameters: Dict[str, float] -) -> pyro.distributions.Distribution: +def mira_gumbel_to_pyro(parameters: ParameterDict) -> pyro.distributions.Distribution: loc = parameters["location"] scale = parameters["scale"] return pyro.distributions.Gumbel(loc=loc, scale=scale) -def mira_laplace_to_pyro( - parameters: Dict[str, float] -) -> pyro.distributions.Distribution: +def mira_laplace_to_pyro(parameters: ParameterDict) -> pyro.distributions.Distribution: if "location" in parameters.keys(): loc = parameters["location"] @@ -166,47 +158,37 @@ def mira_laplace_to_pyro( def mira_paretotypeI_to_pyro( - parameters: Dict[str, float] + parameters: ParameterDict, ) -> pyro.distributions.Distribution: - raise NotImplementedError( - "Conversion from MIRA ParetoTypeI distribution to Pyro distribution is not implemented." - ) - # TODO: Confirm that parameters are mapped correctly scale = parameters["scale"] alpha = parameters["shape"] return pyro.distributions.Pareto(scale=scale, alpha=alpha) -def mira_poisson_to_pyro( - parameters: Dict[str, float] -) -> pyro.distributions.Distribution: +def mira_poisson_to_pyro(parameters: ParameterDict) -> pyro.distributions.Distribution: rate = parameters["rate"] return pyro.distributions.Poisson(rate=rate) -def mira_studentt_to_pyro( - parameters: Dict[str, float] -) -> pyro.distributions.Distribution: +def mira_studentt_to_pyro(parameters: ParameterDict) -> pyro.distributions.Distribution: if "mean" in parameters.keys(): loc = parameters["mean"] elif "location" in parameters.keys(): loc = parameters["location"] else: - loc = 0.0 + loc = torch.tensor(0.0) if "scale" in parameters.keys(): scale = parameters["scale"] else: - scale = 1.0 + scale = torch.tensor(1.0) df = parameters["degreesOfFreedom"] return pyro.distributions.StudentT(df=df, loc=loc, scale=scale) -def mira_weibull_to_pyro( - parameters: Dict[str, float] -) -> pyro.distributions.Distribution: +def mira_weibull_to_pyro(parameters: ParameterDict) -> pyro.distributions.Distribution: if "scale" in parameters.keys(): scale = parameters["scale"] @@ -219,7 +201,7 @@ def mira_weibull_to_pyro( # Key - MIRA distribution type : str -# Value - MIRA -> Pyro function : Callable[[Dict[str, float]], pyro.distributions.Distribution] +# Value - MIRA -> Pyro function : Callable[[ParameterDict], pyro.distributions.Distribution] # See https://github.com/indralab/mira/blob/main/mira/dkg/resources/probonto.json for MIRA distribution types _MIRA_TO_PYRO = { "Uniform1": mira_uniform_to_pyro, @@ -256,6 +238,15 @@ def mira_weibull_to_pyro( "Weibull2": mira_weibull_to_pyro, } +_TESTED_DISTRIBUTIONS = [ + "Uniform1", + "StandardUniform1", + "StandardNormal1", + "Normal1", + "Normal2", + "Normal3", +] + def mira_distribution_to_pyro( mira_dist: mira.metamodel.template_model.Distribution, @@ -265,4 +256,11 @@ def mira_distribution_to_pyro( f"Conversion from MIRA distribution type {mira_dist.type} to Pyro distribution is not implemented." ) - return _MIRA_TO_PYRO[mira_dist.type](mira_dist.parameters) + if mira_dist.type not in _TESTED_DISTRIBUTIONS: + warnings.warn( + f"Conversion from MIRA distribution type {mira_dist.type} to Pyro distribution has not been tested." + ) + + parameters = {param.name: torch.as_tensor(param.value) for param in mira_dist} + + return _MIRA_TO_PYRO[mira_dist.type](parameters)