From 38b6158059f89322566dc0ddb0411cde3b2a77a0 Mon Sep 17 00:00:00 2001 From: Raj Agrawal Date: Wed, 8 Nov 2023 15:33:20 -0500 Subject: [PATCH 01/66] added robust folder --- chirho/robust/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 chirho/robust/__init__.py diff --git a/chirho/robust/__init__.py b/chirho/robust/__init__.py new file mode 100644 index 000000000..e69de29bb From c09dcab3cab50fa29de3a553440ace31e29932cb Mon Sep 17 00:00:00 2001 From: Raj Agrawal Date: Thu, 9 Nov 2023 13:28:45 -0500 Subject: [PATCH 02/66] uncommited scratch work for log prob --- chirho/robust/internals.py | 109 +++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 chirho/robust/internals.py diff --git a/chirho/robust/internals.py b/chirho/robust/internals.py new file mode 100644 index 000000000..2be21d0ae --- /dev/null +++ b/chirho/robust/internals.py @@ -0,0 +1,109 @@ +from typing import ParamSpec, Callable, TypeVar, Optional +import torch +from pyro.infer import Predictive +from pyro.infer import Trace_ELBO +from pyro.infer.elbo import ELBOModule +from pyro.infer.importance import vectorized_importance_weights +from pyro.poutine import mask + +P = ParamSpec("P") +Q = ParamSpec("Q") +S = TypeVar("S") +T = TypeVar("T") + +Point = dict[str, T] +Guide = Callable[P, Optional[T | Point[T]]] + + +class LogProbModule: + def __init__( + self, + model: Callable[P, T], + guide: Guide[P, T], + # elbo: ELBOModule = Trace_ELBO, + theta_names_to_mask: Optional[list[str]] = None, + ): + self.theta_names_to_mask = theta_names_to_mask + # self._log_prob_from_elbo = elbo()(mask(model, ...), mask(guide, ...)) + + def log_prob(self, X, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError + + def log_prob_gradient(self, X, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError + + +class ReparametrizableLogProbModule(LogProbModule): + def __init__( + self, + model: Callable[P, T], + guide: Guide[P, T], + elbo: ELBOModule = Trace_ELBO, + theta_names_to_mask: Optional[list[str]] = None, + ): + self._log_prob_from_elbo = elbo()(mask(model, ...), mask(guide, ...)) + + # Use vmap here to get elbo at multiple points + def log_prob(self, X, *args, **kwargs) -> torch.Tensor: + elbos = [] + for x in X: + elbos.append(self._log_prob_from_elbo(X, *args, **kwargs)) + return torch.stack(elbos) + + def log_prob_gradient(self, X, *args, **kwargs) -> torch.Tensor: + return torch.functional.autograd( + partial(self.log_prob(*args, **kwargs)), X, elbo.parameters() + ) + + +# For continous latents, vectorized importance weights +# https://docs.pyro.ai/en/stable/inference_algos.html#pyro.infer.importance.vectorized_importance_weights + +# Predictive(model, guide) + + +import pyro +import pyro.distributions as dist + + +# Create simple pyro model +def model(x: torch.Tensor) -> torch.Tensor: + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.Normal(0, 1)) + with pyro.plate("data", x.shape[0]): + y = a * x + b + pyro.sample("y", dist.Normal(y, 1)) + + +# Create guide +guide_normal = pyro.infer.autoguide.AutoNormal(model) + + +def fixed_guide(x: torch.Tensor) -> None: + pyro.sample("a", dist.Delta(torch.tensor(1.0))) + pyro.sample("b", dist.Delta(torch.tensor(1.0))) + + +# Create predictive +predictive = Predictive(model, guide=fixed_guide, num_samples=1000) + +samps = predictive(torch.tensor([1.0])) + +# Create elbo loss +elbo = pyro.infer.Trace_ELBO(num_particles=10000)(model, guide=guide_normal) + + +torch.autograd(elbo(torch.tensor([1.0])), elbo.parameters()) + +torch.autograd.functional.jacobian(elbo, torch.tensor([1.0]), elbo.parameters()) + +x0 = torch.tensor([1.0, 2.0], requires_grad=True) + +elbo(x0) + +x1 = torch.tensor([[1.0, 2.0], [2.0, 3.0]], requires_grad=True) + + +vectorized_importance_weights( + model, guide_normal, x=x0, max_plate_nesting=4, num_samples=10000 +)[0].mean() From 21e31bf0579612499bbfffaa28b84e480501d0ed Mon Sep 17 00:00:00 2001 From: Raj Agrawal Date: Thu, 9 Nov 2023 15:55:29 -0500 Subject: [PATCH 03/66] untested variational log prob --- chirho/robust/internals.py | 61 ++++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 23 deletions(-) diff --git a/chirho/robust/internals.py b/chirho/robust/internals.py index 2be21d0ae..da109e427 100644 --- a/chirho/robust/internals.py +++ b/chirho/robust/internals.py @@ -4,7 +4,7 @@ from pyro.infer import Trace_ELBO from pyro.infer.elbo import ELBOModule from pyro.infer.importance import vectorized_importance_weights -from pyro.poutine import mask +from pyro.poutine import mask, replay, trace P = ParamSpec("P") Q = ParamSpec("Q") @@ -15,25 +15,23 @@ Guide = Callable[P, Optional[T | Point[T]]] -class LogProbModule: - def __init__( - self, - model: Callable[P, T], - guide: Guide[P, T], - # elbo: ELBOModule = Trace_ELBO, - theta_names_to_mask: Optional[list[str]] = None, - ): - self.theta_names_to_mask = theta_names_to_mask - # self._log_prob_from_elbo = elbo()(mask(model, ...), mask(guide, ...)) +# guide should hide obs_names sites - def log_prob(self, X, *args, **kwargs) -> torch.Tensor: - raise NotImplementedError - def log_prob_gradient(self, X, *args, **kwargs) -> torch.Tensor: - raise NotImplementedError +def vectorized_variational_log_prob( + model: Callable[P, T], guide: Guide[P, T], X: Point, *args, **kwargs +): + guide_trace = trace(guide).get_trace(*args, **kwargs) + model_trace = trace(replay(model, guide_trace)).get_trace(*args, **kwargs) + log_probs = dict() + for site_name, site_val in X.items(): + site = model_trace.nodes[site_name] + assert site["type"] == "sample" + log_probs[site_name] = site["fn"].log_prob(site_val) + return log_probs -class ReparametrizableLogProbModule(LogProbModule): +class LogProbModule: def __init__( self, model: Callable[P, T], @@ -41,21 +39,28 @@ def __init__( elbo: ELBOModule = Trace_ELBO, theta_names_to_mask: Optional[list[str]] = None, ): + self.theta_names_to_mask = theta_names_to_mask + self.model = model + self.guide = guide self._log_prob_from_elbo = elbo()(mask(model, ...), mask(guide, ...)) - # Use vmap here to get elbo at multiple points - def log_prob(self, X, *args, **kwargs) -> torch.Tensor: + def log_prob(self, X: Point, *args, **kwargs) -> torch.Tensor: elbos = [] for x in X: elbos.append(self._log_prob_from_elbo(X, *args, **kwargs)) return torch.stack(elbos) - def log_prob_gradient(self, X, *args, **kwargs) -> torch.Tensor: + def log_prob_gradient(self, X: Point, *args, **kwargs) -> torch.Tensor: return torch.functional.autograd( partial(self.log_prob(*args, **kwargs)), X, elbo.parameters() ) +class ReparametrizableLogProb(LogProbModule): + def log_prob(self, X: Point, *args, **kwargs) -> torch.Tensor: + pass + + # For continous latents, vectorized importance weights # https://docs.pyro.ai/en/stable/inference_algos.html#pyro.infer.importance.vectorized_importance_weights @@ -72,7 +77,7 @@ def model(x: torch.Tensor) -> torch.Tensor: b = pyro.sample("b", dist.Normal(0, 1)) with pyro.plate("data", x.shape[0]): y = a * x + b - pyro.sample("y", dist.Normal(y, 1)) + return pyro.sample("y", dist.Normal(y, 1)) # Create guide @@ -90,14 +95,18 @@ def fixed_guide(x: torch.Tensor) -> None: samps = predictive(torch.tensor([1.0])) # Create elbo loss -elbo = pyro.infer.Trace_ELBO(num_particles=10000)(model, guide=guide_normal) +elbo = pyro.infer.Trace_ELBO(num_particles=100)(model, guide=guide_normal) torch.autograd(elbo(torch.tensor([1.0])), elbo.parameters()) -torch.autograd.functional.jacobian(elbo, torch.tensor([1.0]), elbo.parameters()) +torch.autograd.functional.jacobian( + elbo, + torch.tensor([1.0, 2.0]), + dict(elbo.named_parameters())["guide.locs.a_unconstrained"], +) -x0 = torch.tensor([1.0, 2.0], requires_grad=True) +x0 = torch.tensor([1.0, 2.0], requires_grad=False) elbo(x0) @@ -107,3 +116,9 @@ def fixed_guide(x: torch.Tensor) -> None: vectorized_importance_weights( model, guide_normal, x=x0, max_plate_nesting=4, num_samples=10000 )[0].mean() + + +torch.stack([torch.zeros(3), torch.zeros(3)]) + + +elbo.parameters() From faed2356a8541de1731534d0874233e0f5d307eb Mon Sep 17 00:00:00 2001 From: Raj Agrawal Date: Mon, 13 Nov 2023 10:01:04 -0500 Subject: [PATCH 04/66] uncomitted changes --- chirho/robust/internals.py | 130 ++++++++++++++++++++----------------- 1 file changed, 70 insertions(+), 60 deletions(-) diff --git a/chirho/robust/internals.py b/chirho/robust/internals.py index da109e427..35ead6842 100644 --- a/chirho/robust/internals.py +++ b/chirho/robust/internals.py @@ -1,10 +1,10 @@ -from typing import ParamSpec, Callable, TypeVar, Optional +from typing import ParamSpec, Callable, TypeVar, Optional, Dict, List import torch from pyro.infer import Predictive from pyro.infer import Trace_ELBO from pyro.infer.elbo import ELBOModule from pyro.infer.importance import vectorized_importance_weights -from pyro.poutine import mask, replay, trace +from pyro.poutine import block, replay, trace, mask P = ParamSpec("P") Q = ParamSpec("Q") @@ -15,19 +15,47 @@ Guide = Callable[P, Optional[T | Point[T]]] -# guide should hide obs_names sites +def _shuffle_dict(d: dict[str, T]): + """ + Shuffle values of a dictionary in first batch dimension + """ + return {k: v[torch.randperm(v.shape[0])] for k, v in d.items()} +# Need to add vectorize function from vectorized_importance_weights + + +# Issue: gradients detached in predictives def vectorized_variational_log_prob( - model: Callable[P, T], guide: Guide[P, T], X: Point, *args, **kwargs + model: Callable[P, T], + guide: Guide[P, T], + trace_predictive: Dict, + obs_names: List[str], + # num_particles: int = 1, # TODO: support this next + *args, + **kwargs ): - guide_trace = trace(guide).get_trace(*args, **kwargs) - model_trace = trace(replay(model, guide_trace)).get_trace(*args, **kwargs) - log_probs = dict() - for site_name, site_val in X.items(): + """ + See eq. 3 in http://approximateinference.org/2017/accepted/TangRanganath2017.pdf + """ + latent_params_trace = _shuffle_dict( + {k: v.clone() for k, v in trace_predictive.items() if k not in obs_names} + ) + obs_vars_trace = { + k: v.clone().detach() for k, v in trace_predictive.items() if k in obs_names + } + import pdb + + pdb.set_trace() + model_trace = trace(replay(model, latent_params_trace)).get_trace(*args, **kwargs) + + N_samples = next(iter(latent_params_trace.values())).shape[0] + + log_probs = torch.zeros(N_samples) + for site_name, site_val in obs_vars_trace.items(): site = model_trace.nodes[site_name] assert site["type"] == "sample" - log_probs[site_name] = site["fn"].log_prob(site_val) + log_probs += site["fn"].log_prob(site_val) return log_probs @@ -61,64 +89,46 @@ def log_prob(self, X: Point, *args, **kwargs) -> torch.Tensor: pass +def log_likelihood_fn(flat_theta: torch.tensor, X: Dict[str, torch.Tensor]): + n_monte_carlo = X[next(iter(X))].shape[0] + theta = _unflatten_dict(flat_theta, theta_hat) + model_at_theta = condition(data=theta)(DataConditionedModel(model)) + log_like_trace = pyro.poutine.trace(model_at_theta).get_trace(X) + log_like_trace.compute_log_prob() + log_prob_at_datapoints = torch.zeros(n_monte_carlo) + for name in obs_names: + log_prob_at_datapoints += log_like_trace.nodes[name]["log_prob"] + return log_prob_at_datapoints + + # For continous latents, vectorized importance weights # https://docs.pyro.ai/en/stable/inference_algos.html#pyro.infer.importance.vectorized_importance_weights # Predictive(model, guide) +if __name__ == "__main__": + import pyro + import pyro.distributions as dist -import pyro -import pyro.distributions as dist - - -# Create simple pyro model -def model(x: torch.Tensor) -> torch.Tensor: - a = pyro.sample("a", dist.Normal(0, 1)) - b = pyro.sample("b", dist.Normal(0, 1)) - with pyro.plate("data", x.shape[0]): - y = a * x + b - return pyro.sample("y", dist.Normal(y, 1)) - - -# Create guide -guide_normal = pyro.infer.autoguide.AutoNormal(model) - - -def fixed_guide(x: torch.Tensor) -> None: - pyro.sample("a", dist.Delta(torch.tensor(1.0))) - pyro.sample("b", dist.Delta(torch.tensor(1.0))) - - -# Create predictive -predictive = Predictive(model, guide=fixed_guide, num_samples=1000) - -samps = predictive(torch.tensor([1.0])) - -# Create elbo loss -elbo = pyro.infer.Trace_ELBO(num_particles=100)(model, guide=guide_normal) - - -torch.autograd(elbo(torch.tensor([1.0])), elbo.parameters()) - -torch.autograd.functional.jacobian( - elbo, - torch.tensor([1.0, 2.0]), - dict(elbo.named_parameters())["guide.locs.a_unconstrained"], -) - -x0 = torch.tensor([1.0, 2.0], requires_grad=False) - -elbo(x0) - -x1 = torch.tensor([[1.0, 2.0], [2.0, 3.0]], requires_grad=True) - - -vectorized_importance_weights( - model, guide_normal, x=x0, max_plate_nesting=4, num_samples=10000 -)[0].mean() + # Create simple pyro model + def model(): + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.Normal(0, 1)) + return pyro.sample("y", dist.Normal(a + b, 1)) + # Create guide on latents a and b + guide = pyro.infer.autoguide.AutoNormal(block(model, hide=["y"])) + # with pyro.poutine.trace() as tr: + # guide() + # print(tr.trace.nodes.keys()) + # Create predictive + predictive = Predictive(model, guide=guide, num_samples=100) + # with pyro.poutine.trace() as tr: + X = predictive() -torch.stack([torch.zeros(3), torch.zeros(3)]) + vectorized_variational_log_prob(model, guide, X, ["y"]) + # print(X) + # import pdb -elbo.parameters() + # pdb.set_trace() From fac98cdf0aa1f1df36139fbfe713798fd6789329 Mon Sep 17 00:00:00 2001 From: Raj Agrawal Date: Thu, 16 Nov 2023 09:35:29 -0500 Subject: [PATCH 05/66] uncomitted changes --- chirho/robust/internals.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/chirho/robust/internals.py b/chirho/robust/internals.py index 35ead6842..a19e429c7 100644 --- a/chirho/robust/internals.py +++ b/chirho/robust/internals.py @@ -101,6 +101,12 @@ def log_likelihood_fn(flat_theta: torch.tensor, X: Dict[str, torch.Tensor]): return log_prob_at_datapoints +def stochastic_variational_log_likelihood_fn( + flat_theta: torch.tensor, X: Dict[str, torch.Tensor] +): + pass + + # For continous latents, vectorized importance weights # https://docs.pyro.ai/en/stable/inference_algos.html#pyro.infer.importance.vectorized_importance_weights From 4edcb5eaf56dd1cf176c5a3b192f8280626cf488 Mon Sep 17 00:00:00 2001 From: Raj Agrawal Date: Thu, 16 Nov 2023 17:40:36 -0500 Subject: [PATCH 06/66] pair coding w/ eli --- chirho/robust/internals.py | 231 ++++++++++++++++++++++--------------- chirho/robust/utils.py | 110 ++++++++++++++++++ 2 files changed, 245 insertions(+), 96 deletions(-) create mode 100644 chirho/robust/utils.py diff --git a/chirho/robust/internals.py b/chirho/robust/internals.py index a19e429c7..82b369f93 100644 --- a/chirho/robust/internals.py +++ b/chirho/robust/internals.py @@ -1,140 +1,179 @@ -from typing import ParamSpec, Callable, TypeVar, Optional, Dict, List +import math +import collections +import pyro +from typing import Container, ParamSpec, Callable, Tuple, TypeVar, Optional, Dict, List, Protocol import torch from pyro.infer import Predictive from pyro.infer import Trace_ELBO from pyro.infer.elbo import ELBOModule from pyro.infer.importance import vectorized_importance_weights from pyro.poutine import block, replay, trace, mask +from chirho.indexed.handlers import DependentMaskMessenger +from chirho.observational.handlers import condition +from torch.func import functional_call +from functools import partial +from utils import conjugate_gradient_solve + +pyro.settings.set(module_local_params=True) P = ParamSpec("P") Q = ParamSpec("Q") S = TypeVar("S") -T = TypeVar("T") +T = TypeVar("T") # This will be a torch.Tensor usually Point = dict[str, T] Guide = Callable[P, Optional[T | Point[T]]] -def _shuffle_dict(d: dict[str, T]): - """ - Shuffle values of a dictionary in first batch dimension - """ - return {k: v[torch.randperm(v.shape[0])] for k, v in d.items()} +def make_empirical_inverse_fisher_vp( + log_prob: torch.nn.Module, + **solver_kwargs, +) -> Callable: + fvp = make_empirical_fisher_vp(log_prob) + return lambda v: conjugate_gradient_solve(fvp, v, **solver_kwargs) -# Need to add vectorize function from vectorized_importance_weights +def make_empirical_fisher_vp( + log_prob: torch.nn.Module, +) -> Callable: -# Issue: gradients detached in predictives -def vectorized_variational_log_prob( - model: Callable[P, T], - guide: Guide[P, T], - trace_predictive: Dict, - obs_names: List[str], - # num_particles: int = 1, # TODO: support this next - *args, - **kwargs -): - """ - See eq. 3 in http://approximateinference.org/2017/accepted/TangRanganath2017.pdf - """ - latent_params_trace = _shuffle_dict( - {k: v.clone() for k, v in trace_predictive.items() if k not in obs_names} - ) - obs_vars_trace = { - k: v.clone().detach() for k, v in trace_predictive.items() if k in obs_names - } - import pdb + def _empirical_fisher_vp(v: T) -> T: + params = dict(log_prob.named_parameters()) + vnew = torch.func.jvp(partial(torch.func.functional_call, log_prob), params, v) + (_, vjp_fn) = torch.func.vjp(partial(torch.func.functional_call, log_prob), params) + return vjp_fn(vnew) + + return _empirical_fisher_vp + + +class UnmaskNamedSites(DependentMaskMessenger): + names: Container[str] - pdb.set_trace() - model_trace = trace(replay(model, latent_params_trace)).get_trace(*args, **kwargs) + def __init__(self, names: Container[str]): + self.names = names - N_samples = next(iter(latent_params_trace.values())).shape[0] + def get_mask( + self, dist: pyro.distributions.Distribution, value: Optional[torch.Tensor], device: torch.device, name: str + ) -> torch.Tensor: + return torch.tensor(name in self.names, device=device) - log_probs = torch.zeros(N_samples) - for site_name, site_val in obs_vars_trace.items(): - site = model_trace.nodes[site_name] - assert site["type"] == "sample" - log_probs += site["fn"].log_prob(site_val) - return log_probs +class NMCLogLikelihood(torch.nn.Module): -class LogProbModule: def __init__( self, - model: Callable[P, T], - guide: Guide[P, T], - elbo: ELBOModule = Trace_ELBO, - theta_names_to_mask: Optional[list[str]] = None, + model: pyro.nn.PyroModule, + guide: pyro.nn.PyroModule, + num_samples: int, ): - self.theta_names_to_mask = theta_names_to_mask + super().__init__() self.model = model self.guide = guide - self._log_prob_from_elbo = elbo()(mask(model, ...), mask(guide, ...)) + self.num_samples = num_samples + + def forward(self, data: Point[torch.Tensor], *args, **kwargs) -> torch.Tensor: + num_monte_carlo_outer = data[next(iter(data))].shape[0] + # if num_monte_inner is None: + # # Optimal scaling for inner expectation: + # # see https://arxiv.org/pdf/1709.06181.pdf + # num_monte_inner = num_monte_carlo_outer ** 2 + + log_weights = [] + for i in range(num_monte_carlo_outer): + log_weights_i = [] + for j in range(self.num_samples): + masked_guide = pyro.poutine.mask(mask=False)(self.guide) # Mask all sites in guide + masked_model = UnmaskNamedSites(names=set(data.keys()))(condition(data={k: v[i] for k, v in data.items()})(self.model)) + log_weight_ij = pyro.infer.Trace_ELBO().differentiable_loss(masked_model, masked_guide, *args, **kwargs) + log_weights_i.append(log_weight_ij) + log_weight_i = torch.logsumexp(torch.stack(log_weights_i), dim=0) - math.log(self.num_samples) + log_weights.append(log_weight_i) + + log_weights = torch.stack(log_weights) + assert log_weights.shape == (num_monte_carlo_outer,) + return log_weights / (num_monte_carlo_outer ** 0.5) + + +class NMCLogLikelihoodSingle(NMCLogLikelihood): + def forward(self, data: Point[torch.Tensor], *args, **kwargs) -> torch.Tensor: + masked_guide = pyro.poutine.mask(mask=False)(self.guide) # Mask all sites in guide + masked_model = UnmaskNamedSites(names=set(data.keys()))(condition(data=data)(self.model)) + log_weights = pyro.infer.importance.vectorized_importance_weights(masked_model, masked_guide, *args, num_samples=self.num_samples, max_plate_nesting=1, **kwargs)[0] + return torch.logsumexp(log_weights * self.guide.zzz.w, dim=0) - math.log(self.num_samples) + + +class DummyAutoNormal(pyro.infer.autoguide.AutoNormal): + + def __getattr__(self, name): + # PyroParams trigger pyro.param statements. + if "_pyro_params" in self.__dict__: + _pyro_params = self.__dict__["_pyro_params"] + if name in _pyro_params: + constraint, event_dim = _pyro_params[name] + unconstrained_value = getattr(self, name + "_unconstrained") + import weakref + constrained_value = torch.distributions.transform_to(constraint)(unconstrained_value) + constrained_value.unconstrained = weakref.ref(unconstrained_value) + return constrained_value + return super().__getattr__(name) - def log_prob(self, X: Point, *args, **kwargs) -> torch.Tensor: - elbos = [] - for x in X: - elbos.append(self._log_prob_from_elbo(X, *args, **kwargs)) - return torch.stack(elbos) - def log_prob_gradient(self, X: Point, *args, **kwargs) -> torch.Tensor: - return torch.functional.autograd( - partial(self.log_prob(*args, **kwargs)), X, elbo.parameters() - ) - - -class ReparametrizableLogProb(LogProbModule): - def log_prob(self, X: Point, *args, **kwargs) -> torch.Tensor: - pass +if __name__ == "__main__": + import pyro + import pyro.distributions as dist + # Create simple pyro model + class SimpleModel(pyro.nn.PyroModule): + def forward(self): + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.Normal(0, 1)) + return pyro.sample("y", dist.Normal(a + b, 1)) -def log_likelihood_fn(flat_theta: torch.tensor, X: Dict[str, torch.Tensor]): - n_monte_carlo = X[next(iter(X))].shape[0] - theta = _unflatten_dict(flat_theta, theta_hat) - model_at_theta = condition(data=theta)(DataConditionedModel(model)) - log_like_trace = pyro.poutine.trace(model_at_theta).get_trace(X) - log_like_trace.compute_log_prob() - log_prob_at_datapoints = torch.zeros(n_monte_carlo) - for name in obs_names: - log_prob_at_datapoints += log_like_trace.nodes[name]["log_prob"] - return log_prob_at_datapoints + model = SimpleModel() + # Create guide on latents a and b + num_monte_carlo_outer = 100 + guide = DummyAutoNormal(block(model, hide=["y"])) + zzz = pyro.nn.PyroModule() + zzz.w = pyro.nn.PyroParam(torch.rand(10), dist.constraints.positive) + guide() + guide.zzz = zzz + print(dict(guide.named_parameters())) + data = Predictive(model, guide=guide, num_samples=num_monte_carlo_outer, return_sites=["y"])() + + # Create log likelihood function + log_prob = NMCLogLikelihoodSingle(model, guide, num_samples=10) + + log_prob_func = torch.func.vmap( + torch.func.functionalize(pyro.validation_enabled(False)(partial(torch.func.functional_call, log_prob))), + # pyro.validation_enabled(False)(partial(torch.func.functional_call, log_prob)), + in_dims=(None, 0), + randomness='different' + ) -def stochastic_variational_log_likelihood_fn( - flat_theta: torch.tensor, X: Dict[str, torch.Tensor] -): - pass + print(log_prob_func(dict(log_prob.named_parameters()), data)[0]) + # func + grad_log_prob = torch.func.vjp(log_prob_func, dict(log_prob.named_parameters()), data)[1] + print(grad_log_prob(torch.ones(num_monte_carlo_outer))[0]) -# For continous latents, vectorized importance weights -# https://docs.pyro.ai/en/stable/inference_algos.html#pyro.infer.importance.vectorized_importance_weights + # autograd.functional + param_dict = collections.OrderedDict(log_prob.named_parameters()) + print(dict(zip(param_dict.keys(), torch.autograd.functional.vjp( + lambda *params: log_prob_func(dict(zip(param_dict.keys(), params)), data), + tuple(param_dict.values()), + torch.ones(num_monte_carlo_outer) + )[1]))) -# Predictive(model, guide) + # print(torch.autograd.grad(partial(torch.func.functional_call, log_prob)(dict(log_prob.named_parameters()), data), tuple(log_prob.parameters()))) + # fvp = make_empirical_fisher_vp(log_prob) -if __name__ == "__main__": - import pyro - import pyro.distributions as dist + # v = tuple(torch.ones_like(p) for p in guide.parameters()) - # Create simple pyro model - def model(): - a = pyro.sample("a", dist.Normal(0, 1)) - b = pyro.sample("b", dist.Normal(0, 1)) - return pyro.sample("y", dist.Normal(a + b, 1)) + # print(v, fvp(v)) + - # Create guide on latents a and b - guide = pyro.infer.autoguide.AutoNormal(block(model, hide=["y"])) - # with pyro.poutine.trace() as tr: - # guide() - # print(tr.trace.nodes.keys()) - # Create predictive - predictive = Predictive(model, guide=guide, num_samples=100) - # with pyro.poutine.trace() as tr: - X = predictive() - vectorized_variational_log_prob(model, guide, X, ["y"]) - # print(X) - # import pdb - # pdb.set_trace() diff --git a/chirho/robust/utils.py b/chirho/robust/utils.py new file mode 100644 index 000000000..129d133ee --- /dev/null +++ b/chirho/robust/utils.py @@ -0,0 +1,110 @@ +import inspect +import collections +import functools +from typing import Callable, Dict, Optional, Tuple, TypeVar +import torch + + +T = TypeVar("T") + + +@functools.singledispatch +def make_flatten_unflatten(v) -> Tuple[Callable[[T], torch.Tensor], Callable[[torch.Tensor], T]]: + raise NotImplementedError + + +@make_flatten_unflatten.register(tuple) +def _make_flatten_unflatten_tuple(v: Tuple[torch.Tensor, ...]): + sizes = [x.size() for x in v] + + def flatten(xs: Tuple[torch.Tensor, ...]) -> torch.Tensor: + return torch.cat([x.reshape(-1) for x in xs], dim=0) + + def unflatten(x: torch.Tensor) -> Tuple[torch.Tensor, ...]: + tensors = [] + i = 0 + for size in sizes: + num_elements = torch.prod(torch.tensor(size)) + tensors.append(x[i : i + num_elements].view(size)) + i += num_elements + return tuple(tensors) + + return flatten, unflatten + + +@make_flatten_unflatten.register(dict) +def _make_flatten_unflatten_dict(d: Dict[str, torch.Tensor]): + + def flatten(d: Dict[str, torch.Tensor]) -> torch.Tensor: + r""" + Flatten a dictionary of tensors into a single vector. + """ + return torch.cat([v.flatten() for k, v in d.items()]) + + + def unflatten(x: torch.Tensor) -> Dict[str, torch.Tensor]: + r""" + Unflatten a vector into a dictionary of tensors. + """ + return collections.OrderedDict( + zip( + d.keys(), + [ + v_flat.reshape(v.shape) + for v, v_flat in zip( + d.values(), torch.split(x, [v.numel() for k, v in d.items()]) + ) + ], + ) + ) + + return flatten, unflatten + + +def conjugate_gradient_solve(f_Ax: Callable[[T], T], b: T, **kwargs) -> T: + flatten, unflatten = make_flatten_unflatten(b) + f_Ax_flat = lambda v: flatten(f_Ax(flatten(v))) + b_flat = flatten(b) + return unflatten(_flat_conjugate_gradient_solve(f_Ax_flat, b_flat, **kwargs)) + + +def _flat_conjugate_gradient_solve( + f_Ax: Callable[[torch.Tensor], torch.Tensor], b: torch.Tensor, *, cg_iters: Optional[int] = None, residual_tol: float = 1e-10 +) -> torch.Tensor: + r"""Use Conjugate Gradient iteration to solve Ax = b. Demmel p 312. + + Args: + f_Ax (callable): A function to compute matrix vector product. + b (torch.Tensor): Right hand side of the equation to solve. + cg_iters (int): Number of iterations to run conjugate gradient + algorithm. + residual_tol (float): Tolerence for convergence. + + Returns: + torch.Tensor: Solution x* for equation Ax = b. + + Notes: This code is copied from https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py + """ + if cg_iters is None: + cg_iters = b.numel() + + p = b.clone() + r = b.clone() + x = torch.zeros_like(b) + rdotr = torch.dot(r, r) + + for _ in range(cg_iters): + z = f_Ax(p) + v = rdotr / torch.dot(p, z) + x += v * p + r -= v * z + newrdotr = torch.dot(r, r) + mu = newrdotr / rdotr + p = r + mu * p + + # Still executes loop but effectively stops update (can't break loop since we're using vmap) + # rdotr = torch.where(rdotr < residual_tol, rdotr, newrdotr) + # rdotr = newrdotr + # if rdotr < residual_tol: + # break + return x From fe1740362ea8b0a2eb67d15951292b8d203574c5 Mon Sep 17 00:00:00 2001 From: Raj Agrawal Date: Fri, 17 Nov 2023 16:31:17 -0500 Subject: [PATCH 07/66] added tests w/ Eli --- tests/robust/test_autograd.py | 281 +++++++++++++++++++++++++++ tests/robust/test_dice_correction.py | 140 +++++++++++++ tests/robust/test_internals.py | 39 ++++ 3 files changed, 460 insertions(+) create mode 100644 tests/robust/test_autograd.py create mode 100644 tests/robust/test_dice_correction.py create mode 100644 tests/robust/test_internals.py diff --git a/tests/robust/test_autograd.py b/tests/robust/test_autograd.py new file mode 100644 index 000000000..734d38d5b --- /dev/null +++ b/tests/robust/test_autograd.py @@ -0,0 +1,281 @@ +import math +import collections +import functools +import pyro +from typing import Concatenate, Container, ParamSpec, Callable, Tuple, TypeVar, Optional, Mapping, Dict, List, Protocol +import torch +import pyro.distributions as dist +from pyro.infer import Predictive +from pyro.infer import Trace_ELBO +from pyro.infer.elbo import ELBOModule +from pyro.infer.importance import vectorized_importance_weights +from pyro.poutine import block, replay, trace, mask +from chirho.indexed.handlers import DependentMaskMessenger +from chirho.observational.handlers import condition +from torch.func import functional_call +from functools import partial + +pyro.settings.set(module_local_params=True) + +P = ParamSpec("P") +Q = ParamSpec("Q") +S = TypeVar("S") +T = TypeVar("T") # This will be a torch.Tensor usually + +Point = dict[str, T] +Guide = Callable[P, Optional[T | Point[T]]] + + +@functools.singledispatch +def make_flatten_unflatten(v) -> Tuple[Callable[[T], torch.Tensor], Callable[[torch.Tensor], T]]: + raise NotImplementedError + + +@make_flatten_unflatten.register(tuple) +def _make_flatten_unflatten_tuple(v: Tuple[torch.Tensor, ...]): + sizes = [x.size() for x in v] + + def flatten(xs: Tuple[torch.Tensor, ...]) -> torch.Tensor: + return torch.cat([x.reshape(-1) for x in xs], dim=0) + + def unflatten(x: torch.Tensor) -> Tuple[torch.Tensor, ...]: + tensors = [] + i = 0 + for size in sizes: + num_elements = torch.prod(torch.tensor(size)) + tensors.append(x[i : i + num_elements].view(size)) + i += num_elements + return tuple(tensors) + + return flatten, unflatten + + +@make_flatten_unflatten.register(dict) +def _make_flatten_unflatten_dict(d: Dict[str, torch.Tensor]): + + def flatten(d: Dict[str, torch.Tensor]) -> torch.Tensor: + r""" + Flatten a dictionary of tensors into a single vector. + """ + return torch.cat([v.flatten() for k, v in d.items()]) + + def unflatten(x: torch.Tensor) -> Dict[str, torch.Tensor]: + r""" + Unflatten a vector into a dictionary of tensors. + """ + return dict( + zip( + d.keys(), + [ + v_flat.reshape(v.shape) + for v, v_flat in zip( + d.values(), torch.split(x, [v.numel() for k, v in d.items()]) + ) + ], + ) + ) + + return flatten, unflatten + + +def _flat_conjugate_gradient_solve( + f_Ax: Callable[[torch.Tensor], torch.Tensor], b: torch.Tensor, *, cg_iters: Optional[int] = None, residual_tol: float = 1e-10 +) -> torch.Tensor: + r"""Use Conjugate Gradient iteration to solve Ax = b. Demmel p 312. + + Args: + f_Ax (callable): A function to compute matrix vector product. + b (torch.Tensor): Right hand side of the equation to solve. + cg_iters (int): Number of iterations to run conjugate gradient + algorithm. + residual_tol (float): Tolerence for convergence. + + Returns: + torch.Tensor: Solution x* for equation Ax = b. + + Notes: This code is copied from https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py + """ + if cg_iters is None: + cg_iters = b.numel() + + p = b.clone() + r = b.clone() + x = torch.zeros_like(b) + rdotr = torch.dot(r, r) + + import pdb; pdb.set_trace() + + for _ in range(cg_iters): + z = f_Ax(p) + import pdb; pdb.set_trace() + v = rdotr / torch.dot(p, z) + x += v * p + r -= v * z + newrdotr = torch.dot(r, r) + mu = newrdotr / rdotr + p = r + mu * p + import pdb; pdb.set_trace() + + # Still executes loop but effectively stops update (can't break loop since we're using vmap) + # rdotr = torch.where(rdotr < residual_tol, rdotr, newrdotr) + # rdotr = newrdotr + # if rdotr < residual_tol: + # break + import pdb; pdb.set_trace() + return x + + +def conjugate_gradient_solve(f_Ax: Callable[[T], T], b: T, **kwargs) -> T: + flatten, unflatten = make_flatten_unflatten(b) + + def f_Ax_flat(v: torch.Tensor) -> torch.Tensor: + v_unflattened = unflatten(v) + result_unflattened = f_Ax(v_unflattened) + return flatten(result_unflattened) + + return unflatten(_flat_conjugate_gradient_solve(f_Ax_flat, flatten(b), **kwargs)) + + +def make_functional_call( + mod: Callable[P, T] +) -> Tuple[ + Mapping[str, torch.Tensor], + Callable[Concatenate[Mapping[str, torch.Tensor], P], T] +]: + assert isinstance(mod, torch.nn.Module) + return dict(mod.named_parameters()), torch.func.functionalize(pyro.validation_enabled(False)(functools.partial(torch.func.functional_call, mod))) + + +def make_bound_batched_func_log_prob( + log_prob: Callable[[T], torch.Tensor], + data: T +) -> Tuple[Mapping[str, torch.Tensor], Callable[[Mapping[str, torch.Tensor]], torch.Tensor]]: + + assert isinstance(log_prob, torch.nn.Module) + log_prob_params_and_fn = make_functional_call(log_prob) + log_prob_params: Mapping[str, torch.Tensor] = log_prob_params_and_fn[0] + func_log_prob: Callable[[Mapping[str, torch.Tensor], T], torch.Tensor] = log_prob_params_and_fn[1] + + batched_func_log_prob: Callable[[Mapping[str, torch.Tensor], T], torch.Tensor] = torch.vmap( + func_log_prob, + in_dims=(None, 0), + randomness='different' + ) + + def bound_batched_func_log_prob(params: Mapping[str, torch.Tensor]) -> torch.Tensor: + return batched_func_log_prob(params, data) + + return log_prob_params, bound_batched_func_log_prob + + +def make_empirical_fisher_vp( + log_prob: Callable[[T], torch.Tensor], + data: T +) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, torch.Tensor]]: + + log_prob_params, bound_batched_func_log_prob = make_bound_batched_func_log_prob(log_prob, data) + + def _empirical_fisher_vp(v: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]: + (_, vnew) = torch.func.jvp(bound_batched_func_log_prob, (log_prob_params,), (v,)) + (_, vjp_fn) = torch.func.vjp(bound_batched_func_log_prob, log_prob_params) + result: Mapping[str, torch.Tensor] = vjp_fn(vnew / vnew.shape[0])[0] + import pdb; pdb.set_trace() + # result is batched over datapoints (via vmap), so we must sum out the batch dimension 0? + # return {k: torch.sum(v, dim=0) for k, v in result.items()} + assert result.keys() == v.keys() and all(result[k].shape == v[k].shape for k in result.keys()) + return result + + return _empirical_fisher_vp + + +def make_empirical_inverse_fisher_vp( + log_prob: Callable[[T], torch.Tensor], + data: T, + **solver_kwargs, +) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, torch.Tensor]]: + + assert isinstance(log_prob, torch.nn.Module) + fvp = make_empirical_fisher_vp(log_prob, data) + return functools.partial(conjugate_gradient_solve, fvp, **solver_kwargs) + + +class UnmaskNamedSites(DependentMaskMessenger): + names: Container[str] + + def __init__(self, names: Container[str]): + self.names = names + + def get_mask( + self, dist: pyro.distributions.Distribution, value: Optional[torch.Tensor], device: torch.device, name: str + ) -> torch.Tensor: + return torch.tensor(name in self.names, device=device) + + +class NMCLogPredictiveLikelihood(torch.nn.Module): + + def __init__( + self, + model: torch.nn.Module, + guide: torch.nn.Module, + *, + num_samples: int = 1, + max_plate_nesting: int = 1, + ): + super().__init__() + self.model = model + self.guide = guide + self.num_samples = num_samples + self.max_plate_nesting = max_plate_nesting + + def forward(self, data: Point[torch.Tensor], *args, **kwargs) -> torch.Tensor: + masked_guide = pyro.poutine.mask(mask=False)(self.guide) + masked_model = UnmaskNamedSites(names=set(data.keys()))(condition(data=data)(self.model)) + log_weights = pyro.infer.importance.vectorized_importance_weights( + masked_model, + masked_guide, + *args, + num_samples=self.num_samples, + max_plate_nesting=self.max_plate_nesting, + **kwargs + )[0] + return torch.logsumexp(log_weights, dim=0) - math.log(self.num_samples) + + +def test_nmc_log_likelihood(): + + # Create simple pyro model + class SimpleModel(pyro.nn.PyroModule): + def forward(self): + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.Normal(0, 1)) + return pyro.sample("y", dist.Normal(a + b, 1)) + + class SimpleGuide(torch.nn.Module): + def __init__(self): + super().__init__() + self.loc_a = torch.nn.Parameter(torch.rand(())) + self.loc_b = torch.nn.Parameter(torch.rand(())) + + def forward(self): + a = pyro.sample("a", dist.Normal(self.loc_a, 1.)) + b = pyro.sample("b", dist.Normal(self.loc_b, 1.)) + + model = SimpleModel() + guide = SimpleGuide() + + # Create guide on latents a and b + num_samples_outer = 100 + data = pyro.infer.Predictive(model, guide=guide, num_samples=num_samples_outer, return_sites=["y"], parallel=True)() + + # Create log likelihood function + log_prob = NMCLogPredictiveLikelihood(model, guide, num_samples=10000, max_plate_nesting=1) + + v = {k: torch.ones_like(v) for k, v in log_prob.named_parameters()} + + # fvp = make_empirical_fisher_vp(log_prob, data) + # print(v, fvp(v)) + + flatten_v, unflatten_v = make_flatten_unflatten(v) + assert unflatten_v(flatten_v(v)) == v + fivp = make_empirical_inverse_fisher_vp(log_prob, data, cg_iters = 1) + print(v, fivp(v)) diff --git a/tests/robust/test_dice_correction.py b/tests/robust/test_dice_correction.py new file mode 100644 index 000000000..bc5265d8c --- /dev/null +++ b/tests/robust/test_dice_correction.py @@ -0,0 +1,140 @@ +import pyro +import math +import collections +from functools import partial +import pyro.distributions as dist +import torch +from typing import Callable, Dict, List, Optional + +from chirho.robust import one_step_correction +from chirho.robust.functionals import dice_correction, average_treatment_effect +from chirho.robust.utils import _flatten_dict, _unflatten_dict + + +class HighDimLinearModel(pyro.nn.PyroModule): + def __init__( + self, + p: int, + link_fn: Callable[..., dist.Distribution] = lambda mu: dist.Normal(mu, 1.0), + prior_scale: float = None, + ): + super().__init__() + self.p = p + self.link_fn = link_fn + if prior_scale is None: + self.prior_scale = 1 / math.sqrt(self.p) + else: + self.prior_scale = prior_scale + + def sample_outcome_weights(self): + return pyro.sample( + "outcome_weights", + dist.Normal(0.0, self.prior_scale).expand((self.p,)).to_event(1), + ) + + def sample_intercept(self): + return pyro.sample("intercept", dist.Normal(0.0, 1.0)) + + def sample_propensity_weights(self): + return pyro.sample( + "propensity_weights", + dist.Normal(0.0, self.prior_scale).expand((self.p,)).to_event(1), + ) + + def sample_treatment_weight(self): + return pyro.sample("treatment_weight", dist.Normal(0.0, 1.0)) + + def sample_covariate_loc_scale(self): + loc = pyro.sample( + "covariate_loc", dist.Normal(0.0, 1.0).expand((self.p,)).to_event(1) + ) + scale = pyro.sample( + "covariate_scale", dist.LogNormal(0, 1).expand((self.p,)).to_event(1) + ) + return loc, scale + + def forward(self, N: int): + intercept = self.sample_intercept() + outcome_weights = self.sample_outcome_weights() + propensity_weights = self.sample_propensity_weights() + tau = self.sample_treatment_weight() + x_loc, x_scale = self.sample_covariate_loc_scale() + with pyro.plate("obs", N, dim=-1): + X = pyro.sample("X", dist.Normal(x_loc, x_scale).to_event(1)) + A = pyro.sample( + "A", + dist.Bernoulli( + logits=torch.einsum("...np,...p->...n", X, propensity_weights) + ).mask(False), + ) + return pyro.sample( + "Y", + self.link_fn( + torch.einsum("...np,...p->...n", X, outcome_weights) + + A * tau + + intercept + ), + ) + + +# Internal structure of the model (that's given) +# and then outer monte carlo samples + + +class KnownCovariateDistModel(HighDimLinearModel): + def sample_covariate_loc_scale(self): + return torch.zeros(self.p), torch.ones(self.p) + + +class FakeNormal(dist.Normal): + has_rsample = False + + +def test_bernoulli_model(): + p = 1 + n_monte_carlo_outer = 1000 + avg_plug_in_grads = torch.zeros(4) + for _ in range(n_monte_carlo_outer): + n_monte_carlo_inner = 100 + target_functional = partial( + dice_correction(average_treatment_effect), n_monte_carlo=n_monte_carlo_inner + ) + # bernoulli_link = lambda mu: dist.Bernoulli(logits=mu) + link = lambda mu: FakeNormal(mu, 1.0) + # link = lambda mu: dist.Normal(mu, 1.0) + model = KnownCovariateDistModel(p, link) + theta_hat = { + "intercept": torch.tensor(0.0).requires_grad_(True), + "outcome_weights": torch.tensor([1.0]).requires_grad_(True), + "propensity_weights": torch.tensor([1.0]).requires_grad_(True), + "treatment_weight": torch.tensor(1.0).requires_grad_(True), + } + + # Canonical ordering of parameters when flattening and unflattening + theta_hat = collections.OrderedDict( + (k, theta_hat[k]) for k in sorted(theta_hat.keys()) + ) + flat_theta = _flatten_dict(theta_hat) + + # Compute gradient of plug-in functional + plug_in = target_functional(model, theta_hat) + plug_in += ( + 0 * flat_theta.sum() + ) # hack for full gradient (maintain flattened shape) + + avg_plug_in_grads += ( + _flatten_dict( + collections.OrderedDict( + zip( + theta_hat.keys(), + torch.autograd.grad(plug_in, theta_hat.values()), + ) + ) + ) + / n_monte_carlo_outer + ) + + correct_grad = torch.tensor([0, 0, 0, 1.0]) + # assert (avg_plug_in_grads - correct_grad).abs().sum() < 1 / torch.sqrt( + # torch.tensor(n_monte_carlo_outer) + # ) diff --git a/tests/robust/test_internals.py b/tests/robust/test_internals.py new file mode 100644 index 000000000..c7049a1ef --- /dev/null +++ b/tests/robust/test_internals.py @@ -0,0 +1,39 @@ + +import pyro +from chirho.robust.internals import * +from pyro.infer import Predictive +import pyro.distributions as dist + + +class SimpleModel(pyro.nn.PyroModule): + def forward(self): + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.Normal(0, 1)) + return pyro.sample("y", dist.Normal(a + b, 1)) + + +class SimpleGuide(torch.nn.Module): + def __init__(self): + super().__init__() + self.a_loc = torch.nn.Parameter(torch.tensor(0.)) + self.b_loc = torch.nn.Parameter(torch.tensor(0.)) + + def forward(self): + pyro.sample("a", dist.Delta(self.a_loc)) + pyro.sample("b", dist.Delta(self.b_loc)) + + +def test_nmc_log_likelihood(): + model = SimpleModel() + guide = SimpleGuide() + num_monte_carlo_outer = 100 + data = Predictive(model, guide=guide, num_samples=num_monte_carlo_outer, return_sites=["y"])() + nmc_ll = NMCLogLikelihood(model, guide, num_samples=100) + ll_at_data = nmc_ll(data) + print(ll_at_data) + + nmc_ll_single = NMCLogLikelihoodSingle(model, guide, num_samples=10) + nmc_ll_single._vectorized_log_prob({'y': torch.tensor(1.)}) + nmc_ll({'y': torch.tensor([1.])}) + ll_at_data_single = nmc_ll_single.vectorized_log_prob(data) + From b15968761052128cc7ac9802a7e95fac81beae83 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 17 Nov 2023 18:22:55 -0500 Subject: [PATCH 08/66] eif --- tests/robust/test_autograd.py | 60 ++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/tests/robust/test_autograd.py b/tests/robust/test_autograd.py index 734d38d5b..e5d071c39 100644 --- a/tests/robust/test_autograd.py +++ b/tests/robust/test_autograd.py @@ -93,7 +93,7 @@ def _flat_conjugate_gradient_solve( Returns: torch.Tensor: Solution x* for equation Ax = b. - Notes: This code is copied from https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py + Notes: This code is adapted from https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py """ if cg_iters is None: cg_iters = b.numel() @@ -101,27 +101,30 @@ def _flat_conjugate_gradient_solve( p = b.clone() r = b.clone() x = torch.zeros_like(b) + z = f_Ax(p) rdotr = torch.dot(r, r) + v = rdotr / torch.dot(p, z) + newrdotr = rdotr + mu = newrdotr / rdotr - import pdb; pdb.set_trace() + zeros_x = torch.zeros_like(x) + zeros_r = torch.zeros_like(r) for _ in range(cg_iters): - z = f_Ax(p) - import pdb; pdb.set_trace() - v = rdotr / torch.dot(p, z) - x += v * p - r -= v * z - newrdotr = torch.dot(r, r) - mu = newrdotr / rdotr - p = r + mu * p - import pdb; pdb.set_trace() - - # Still executes loop but effectively stops update (can't break loop since we're using vmap) - # rdotr = torch.where(rdotr < residual_tol, rdotr, newrdotr) + not_converged = rdotr > residual_tol + + z = torch.where(not_converged, f_Ax(p), z) + v = torch.where(not_converged, rdotr / torch.dot(p, z), v) + x += torch.where(not_converged, v * p, zeros_x) + r -= torch.where(not_converged, v * z, zeros_r) + newrdotr = torch.where(not_converged, torch.dot(r, r), newrdotr) + mu = torch.where(not_converged, newrdotr / rdotr, mu) + p = torch.where(not_converged, r + mu * p, p) + rdotr = torch.where(not_converged, newrdotr, rdotr) + # rdotr = newrdotr # if rdotr < residual_tol: # break - import pdb; pdb.set_trace() return x @@ -179,10 +182,6 @@ def _empirical_fisher_vp(v: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Te (_, vnew) = torch.func.jvp(bound_batched_func_log_prob, (log_prob_params,), (v,)) (_, vjp_fn) = torch.func.vjp(bound_batched_func_log_prob, log_prob_params) result: Mapping[str, torch.Tensor] = vjp_fn(vnew / vnew.shape[0])[0] - import pdb; pdb.set_trace() - # result is batched over datapoints (via vmap), so we must sum out the batch dimension 0? - # return {k: torch.sum(v, dim=0) for k, v in result.items()} - assert result.keys() == v.keys() and all(result[k].shape == v[k].shape for k in result.keys()) return result return _empirical_fisher_vp @@ -247,8 +246,9 @@ def test_nmc_log_likelihood(): class SimpleModel(pyro.nn.PyroModule): def forward(self): a = pyro.sample("a", dist.Normal(0, 1)) - b = pyro.sample("b", dist.Normal(0, 1)) - return pyro.sample("y", dist.Normal(a + b, 1)) + with pyro.plate("data", 3, dim=-1): + b = pyro.sample("b", dist.Normal(0, 1)) + return pyro.sample("y", dist.Normal(a + b, 1)) class SimpleGuide(torch.nn.Module): def __init__(self): @@ -258,17 +258,19 @@ def __init__(self): def forward(self): a = pyro.sample("a", dist.Normal(self.loc_a, 1.)) - b = pyro.sample("b", dist.Normal(self.loc_b, 1.)) + with pyro.plate("data", 3, dim=-1): + b = pyro.sample("b", dist.Normal(self.loc_b, 1.)) + return {"a": a, "b": b} model = SimpleModel() guide = SimpleGuide() # Create guide on latents a and b - num_samples_outer = 100 + num_samples_outer = 10 data = pyro.infer.Predictive(model, guide=guide, num_samples=num_samples_outer, return_sites=["y"], parallel=True)() # Create log likelihood function - log_prob = NMCLogPredictiveLikelihood(model, guide, num_samples=10000, max_plate_nesting=1) + log_prob = NMCLogPredictiveLikelihood(model, guide, num_samples=100, max_plate_nesting=1) v = {k: torch.ones_like(v) for k, v in log_prob.named_parameters()} @@ -277,5 +279,13 @@ def forward(self): flatten_v, unflatten_v = make_flatten_unflatten(v) assert unflatten_v(flatten_v(v)) == v - fivp = make_empirical_inverse_fisher_vp(log_prob, data, cg_iters = 1) + fivp = make_empirical_inverse_fisher_vp(log_prob, data, cg_iters = 10) print(v, fivp(v)) + + d2 = pyro.infer.Predictive(model, num_samples=30, return_sites=["y"], parallel=True)() + log_prob_params, func_log_prob = make_functional_call(log_prob) + + def eif(d: Point[torch.Tensor]) -> Mapping[str, torch.Tensor]: + return fivp(torch.func.grad(lambda params: func_log_prob(params, d))(log_prob_params)) + + print(torch.vmap(eif)(d2)) \ No newline at end of file From 33f4811344fb63e363750fc26ea543c9aa667a93 Mon Sep 17 00:00:00 2001 From: Raj Agrawal Date: Fri, 17 Nov 2023 21:40:19 -0500 Subject: [PATCH 09/66] linting --- chirho/robust/internals.py | 35 +++++-- chirho/robust/utils.py | 16 ++-- tests/robust/test_autograd.py | 134 +++++++++++++++++---------- tests/robust/test_dice_correction.py | 9 +- tests/robust/test_internals.py | 21 +++-- 5 files changed, 140 insertions(+), 75 deletions(-) diff --git a/chirho/robust/internals.py b/chirho/robust/internals.py index 82b369f93..6fedaf060 100644 --- a/chirho/robust/internals.py +++ b/chirho/robust/internals.py @@ -12,7 +12,7 @@ from chirho.observational.handlers import condition from torch.func import functional_call from functools import partial -from utils import conjugate_gradient_solve +from chirho.robust.utils import conjugate_gradient_solve pyro.settings.set(module_local_params=True) @@ -40,6 +40,7 @@ def make_empirical_fisher_vp( def _empirical_fisher_vp(v: T) -> T: params = dict(log_prob.named_parameters()) + # TODO: I think vnew = RHS[1] vnew = torch.func.jvp(partial(torch.func.functional_call, log_prob), params, v) (_, vjp_fn) = torch.func.vjp(partial(torch.func.functional_call, log_prob), params) return vjp_fn(vnew) @@ -71,21 +72,39 @@ def __init__( self.model = model self.guide = guide self.num_samples = num_samples + + def _vectorized_log_prob(self, single_datapoint: Point[torch.Tensor], *args, **kwargs) -> torch.Tensor: + masked_guide = pyro.poutine.mask(mask=False)(self.guide) # Mask all sites in guide + masked_model = UnmaskNamedSites(names=set(single_datapoint.keys()))(condition(data=single_datapoint)(self.model)) + log_weights = pyro.infer.importance.vectorized_importance_weights(masked_model, masked_guide, *args, num_samples=self.num_samples, max_plate_nesting=1, **kwargs)[0] + return torch.logsumexp(log_weights, dim=0) - math.log(self.num_samples) + + def vectorized_log_prob(self, data: Point[torch.Tensor], *args, **kwargs) -> torch.Tensor: + num_monte_carlo_outer = data[next(iter(data))].shape[0] + log_prob_func = torch.func.vmap( + torch.func.functionalize(pyro.validation_enabled(False)(partial(torch.func.functional_call, self._vectorized_log_prob))), + in_dims=(None, 0), + randomness='different' + ) + # log_prob_func = torch.func.vmap( + # pyro.validation_enabled(False)(self._vectorized_log_prob), + # in_dims=(0,), + # randomness='different' + # ) + import pdb; pdb.set_trace() + pyro.validation_enabled(False)(self._vectorized_log_prob) + # return log_prob_func(dict(self.named_parameters()), data)[0] / (num_monte_carlo_outer ** 0.5) + return log_prob_func(data)[0] / (num_monte_carlo_outer ** 0.5) def forward(self, data: Point[torch.Tensor], *args, **kwargs) -> torch.Tensor: - num_monte_carlo_outer = data[next(iter(data))].shape[0] - # if num_monte_inner is None: - # # Optimal scaling for inner expectation: - # # see https://arxiv.org/pdf/1709.06181.pdf - # num_monte_inner = num_monte_carlo_outer ** 2 - + num_monte_carlo_outer = data[next(iter(data))].shape[0] log_weights = [] for i in range(num_monte_carlo_outer): log_weights_i = [] for j in range(self.num_samples): masked_guide = pyro.poutine.mask(mask=False)(self.guide) # Mask all sites in guide masked_model = UnmaskNamedSites(names=set(data.keys()))(condition(data={k: v[i] for k, v in data.items()})(self.model)) - log_weight_ij = pyro.infer.Trace_ELBO().differentiable_loss(masked_model, masked_guide, *args, **kwargs) + log_weight_ij = -1. * pyro.infer.Trace_ELBO().differentiable_loss(masked_model, masked_guide, *args, **kwargs) # -1 since negative elbo here log_weights_i.append(log_weight_ij) log_weight_i = torch.logsumexp(torch.stack(log_weights_i), dim=0) - math.log(self.num_samples) log_weights.append(log_weight_i) diff --git a/chirho/robust/utils.py b/chirho/robust/utils.py index 129d133ee..68367a76a 100644 --- a/chirho/robust/utils.py +++ b/chirho/robust/utils.py @@ -1,15 +1,17 @@ -import inspect import collections import functools +import inspect from typing import Callable, Dict, Optional, Tuple, TypeVar -import torch +import torch T = TypeVar("T") @functools.singledispatch -def make_flatten_unflatten(v) -> Tuple[Callable[[T], torch.Tensor], Callable[[torch.Tensor], T]]: +def make_flatten_unflatten( + v, +) -> Tuple[Callable[[T], torch.Tensor], Callable[[torch.Tensor], T]]: raise NotImplementedError @@ -34,14 +36,12 @@ def unflatten(x: torch.Tensor) -> Tuple[torch.Tensor, ...]: @make_flatten_unflatten.register(dict) def _make_flatten_unflatten_dict(d: Dict[str, torch.Tensor]): - def flatten(d: Dict[str, torch.Tensor]) -> torch.Tensor: r""" Flatten a dictionary of tensors into a single vector. """ return torch.cat([v.flatten() for k, v in d.items()]) - def unflatten(x: torch.Tensor) -> Dict[str, torch.Tensor]: r""" Unflatten a vector into a dictionary of tensors. @@ -69,7 +69,11 @@ def conjugate_gradient_solve(f_Ax: Callable[[T], T], b: T, **kwargs) -> T: def _flat_conjugate_gradient_solve( - f_Ax: Callable[[torch.Tensor], torch.Tensor], b: torch.Tensor, *, cg_iters: Optional[int] = None, residual_tol: float = 1e-10 + f_Ax: Callable[[torch.Tensor], torch.Tensor], + b: torch.Tensor, + *, + cg_iters: Optional[int] = None, + residual_tol: float = 1e-10 ) -> torch.Tensor: r"""Use Conjugate Gradient iteration to solve Ax = b. Demmel p 312. diff --git a/tests/robust/test_autograd.py b/tests/robust/test_autograd.py index e5d071c39..a921d8da9 100644 --- a/tests/robust/test_autograd.py +++ b/tests/robust/test_autograd.py @@ -1,33 +1,48 @@ -import math import collections import functools +import math +from functools import partial +from typing import ( + Callable, + Concatenate, + Container, + Dict, + List, + Mapping, + Optional, + ParamSpec, + Protocol, + Tuple, + TypeVar, +) + import pyro -from typing import Concatenate, Container, ParamSpec, Callable, Tuple, TypeVar, Optional, Mapping, Dict, List, Protocol -import torch import pyro.distributions as dist -from pyro.infer import Predictive -from pyro.infer import Trace_ELBO +import torch +from pyro.infer import Predictive, Trace_ELBO from pyro.infer.elbo import ELBOModule from pyro.infer.importance import vectorized_importance_weights -from pyro.poutine import block, replay, trace, mask +from pyro.poutine import block, mask, replay, trace +from torch.func import functional_call + from chirho.indexed.handlers import DependentMaskMessenger from chirho.observational.handlers import condition -from torch.func import functional_call -from functools import partial pyro.settings.set(module_local_params=True) P = ParamSpec("P") Q = ParamSpec("Q") S = TypeVar("S") -T = TypeVar("T") # This will be a torch.Tensor usually +T = TypeVar("T") # This will be a torch.Tensor usually Point = dict[str, T] Guide = Callable[P, Optional[T | Point[T]]] @functools.singledispatch -def make_flatten_unflatten(v) -> Tuple[Callable[[T], torch.Tensor], Callable[[torch.Tensor], T]]: +def make_flatten_unflatten( + v, +) -> Tuple[Callable[[T], torch.Tensor], Callable[[torch.Tensor], T]]: raise NotImplementedError @@ -52,7 +67,6 @@ def unflatten(x: torch.Tensor) -> Tuple[torch.Tensor, ...]: @make_flatten_unflatten.register(dict) def _make_flatten_unflatten_dict(d: Dict[str, torch.Tensor]): - def flatten(d: Dict[str, torch.Tensor]) -> torch.Tensor: r""" Flatten a dictionary of tensors into a single vector. @@ -79,7 +93,11 @@ def unflatten(x: torch.Tensor) -> Dict[str, torch.Tensor]: def _flat_conjugate_gradient_solve( - f_Ax: Callable[[torch.Tensor], torch.Tensor], b: torch.Tensor, *, cg_iters: Optional[int] = None, residual_tol: float = 1e-10 + f_Ax: Callable[[torch.Tensor], torch.Tensor], + b: torch.Tensor, + *, + cg_iters: Optional[int] = None, + residual_tol: float = 1e-10, ) -> torch.Tensor: r"""Use Conjugate Gradient iteration to solve Ax = b. Demmel p 312. @@ -142,28 +160,31 @@ def f_Ax_flat(v: torch.Tensor) -> torch.Tensor: def make_functional_call( mod: Callable[P, T] ) -> Tuple[ - Mapping[str, torch.Tensor], - Callable[Concatenate[Mapping[str, torch.Tensor], P], T] + Mapping[str, torch.Tensor], Callable[Concatenate[Mapping[str, torch.Tensor], P], T] ]: assert isinstance(mod, torch.nn.Module) - return dict(mod.named_parameters()), torch.func.functionalize(pyro.validation_enabled(False)(functools.partial(torch.func.functional_call, mod))) + return dict(mod.named_parameters()), torch.func.functionalize( + pyro.validation_enabled(False)( + functools.partial(torch.func.functional_call, mod) + ) + ) def make_bound_batched_func_log_prob( - log_prob: Callable[[T], torch.Tensor], - data: T -) -> Tuple[Mapping[str, torch.Tensor], Callable[[Mapping[str, torch.Tensor]], torch.Tensor]]: - + log_prob: Callable[[T], torch.Tensor], data: T +) -> Tuple[ + Mapping[str, torch.Tensor], Callable[[Mapping[str, torch.Tensor]], torch.Tensor] +]: assert isinstance(log_prob, torch.nn.Module) log_prob_params_and_fn = make_functional_call(log_prob) log_prob_params: Mapping[str, torch.Tensor] = log_prob_params_and_fn[0] - func_log_prob: Callable[[Mapping[str, torch.Tensor], T], torch.Tensor] = log_prob_params_and_fn[1] + func_log_prob: Callable[ + [Mapping[str, torch.Tensor], T], torch.Tensor + ] = log_prob_params_and_fn[1] - batched_func_log_prob: Callable[[Mapping[str, torch.Tensor], T], torch.Tensor] = torch.vmap( - func_log_prob, - in_dims=(None, 0), - randomness='different' - ) + batched_func_log_prob: Callable[ + [Mapping[str, torch.Tensor], T], torch.Tensor + ] = torch.vmap(func_log_prob, in_dims=(None, 0), randomness="different") def bound_batched_func_log_prob(params: Mapping[str, torch.Tensor]) -> torch.Tensor: return batched_func_log_prob(params, data) @@ -172,14 +193,18 @@ def bound_batched_func_log_prob(params: Mapping[str, torch.Tensor]) -> torch.Ten def make_empirical_fisher_vp( - log_prob: Callable[[T], torch.Tensor], - data: T + log_prob: Callable[[T], torch.Tensor], data: T ) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, torch.Tensor]]: + log_prob_params, bound_batched_func_log_prob = make_bound_batched_func_log_prob( + log_prob, data + ) - log_prob_params, bound_batched_func_log_prob = make_bound_batched_func_log_prob(log_prob, data) - - def _empirical_fisher_vp(v: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]: - (_, vnew) = torch.func.jvp(bound_batched_func_log_prob, (log_prob_params,), (v,)) + def _empirical_fisher_vp( + v: Mapping[str, torch.Tensor] + ) -> Mapping[str, torch.Tensor]: + (_, vnew) = torch.func.jvp( + bound_batched_func_log_prob, (log_prob_params,), (v,) + ) (_, vjp_fn) = torch.func.vjp(bound_batched_func_log_prob, log_prob_params) result: Mapping[str, torch.Tensor] = vjp_fn(vnew / vnew.shape[0])[0] return result @@ -192,7 +217,6 @@ def make_empirical_inverse_fisher_vp( data: T, **solver_kwargs, ) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, torch.Tensor]]: - assert isinstance(log_prob, torch.nn.Module) fvp = make_empirical_fisher_vp(log_prob, data) return functools.partial(conjugate_gradient_solve, fvp, **solver_kwargs) @@ -205,13 +229,16 @@ def __init__(self, names: Container[str]): self.names = names def get_mask( - self, dist: pyro.distributions.Distribution, value: Optional[torch.Tensor], device: torch.device, name: str + self, + dist: pyro.distributions.Distribution, + value: Optional[torch.Tensor], + device: torch.device, + name: str, ) -> torch.Tensor: return torch.tensor(name in self.names, device=device) class NMCLogPredictiveLikelihood(torch.nn.Module): - def __init__( self, model: torch.nn.Module, @@ -228,20 +255,21 @@ def __init__( def forward(self, data: Point[torch.Tensor], *args, **kwargs) -> torch.Tensor: masked_guide = pyro.poutine.mask(mask=False)(self.guide) - masked_model = UnmaskNamedSites(names=set(data.keys()))(condition(data=data)(self.model)) + masked_model = UnmaskNamedSites(names=set(data.keys()))( + condition(data=data)(self.model) + ) log_weights = pyro.infer.importance.vectorized_importance_weights( masked_model, masked_guide, *args, num_samples=self.num_samples, max_plate_nesting=self.max_plate_nesting, - **kwargs + **kwargs, )[0] return torch.logsumexp(log_weights, dim=0) - math.log(self.num_samples) def test_nmc_log_likelihood(): - # Create simple pyro model class SimpleModel(pyro.nn.PyroModule): def forward(self): @@ -257,35 +285,47 @@ def __init__(self): self.loc_b = torch.nn.Parameter(torch.rand(())) def forward(self): - a = pyro.sample("a", dist.Normal(self.loc_a, 1.)) + a = pyro.sample("a", dist.Normal(self.loc_a, 1.0)) with pyro.plate("data", 3, dim=-1): - b = pyro.sample("b", dist.Normal(self.loc_b, 1.)) + b = pyro.sample("b", dist.Normal(self.loc_b, 1.0)) return {"a": a, "b": b} model = SimpleModel() guide = SimpleGuide() # Create guide on latents a and b - num_samples_outer = 10 - data = pyro.infer.Predictive(model, guide=guide, num_samples=num_samples_outer, return_sites=["y"], parallel=True)() + num_samples_outer = 100000 + data = pyro.infer.Predictive( + model, + guide=guide, + num_samples=num_samples_outer, + return_sites=["y"], + parallel=True, + )() # Create log likelihood function - log_prob = NMCLogPredictiveLikelihood(model, guide, num_samples=100, max_plate_nesting=1) + log_prob = NMCLogPredictiveLikelihood( + model, guide, num_samples=1, max_plate_nesting=1 + ) v = {k: torch.ones_like(v) for k, v in log_prob.named_parameters()} - # fvp = make_empirical_fisher_vp(log_prob, data) + # fvp = make_empirical_fisher_vp(log_prob, data) # print(v, fvp(v)) flatten_v, unflatten_v = make_flatten_unflatten(v) assert unflatten_v(flatten_v(v)) == v - fivp = make_empirical_inverse_fisher_vp(log_prob, data, cg_iters = 10) + fivp = make_empirical_inverse_fisher_vp(log_prob, data, cg_iters=10) print(v, fivp(v)) - d2 = pyro.infer.Predictive(model, num_samples=30, return_sites=["y"], parallel=True)() + d2 = pyro.infer.Predictive( + model, num_samples=30, return_sites=["y"], parallel=True + )() log_prob_params, func_log_prob = make_functional_call(log_prob) def eif(d: Point[torch.Tensor]) -> Mapping[str, torch.Tensor]: - return fivp(torch.func.grad(lambda params: func_log_prob(params, d))(log_prob_params)) + return fivp( + torch.func.grad(lambda params: func_log_prob(params, d))(log_prob_params) + ) - print(torch.vmap(eif)(d2)) \ No newline at end of file + print(torch.vmap(eif)(d2)) diff --git a/tests/robust/test_dice_correction.py b/tests/robust/test_dice_correction.py index bc5265d8c..e968e9660 100644 --- a/tests/robust/test_dice_correction.py +++ b/tests/robust/test_dice_correction.py @@ -1,13 +1,14 @@ -import pyro -import math import collections +import math from functools import partial +from typing import Callable, Dict, List, Optional + +import pyro import pyro.distributions as dist import torch -from typing import Callable, Dict, List, Optional from chirho.robust import one_step_correction -from chirho.robust.functionals import dice_correction, average_treatment_effect +from chirho.robust.functionals import average_treatment_effect, dice_correction from chirho.robust.utils import _flatten_dict, _unflatten_dict diff --git a/tests/robust/test_internals.py b/tests/robust/test_internals.py index c7049a1ef..30725429e 100644 --- a/tests/robust/test_internals.py +++ b/tests/robust/test_internals.py @@ -1,8 +1,8 @@ - import pyro -from chirho.robust.internals import * -from pyro.infer import Predictive import pyro.distributions as dist +from pyro.infer import Predictive + +from chirho.robust.internals import * class SimpleModel(pyro.nn.PyroModule): @@ -15,9 +15,9 @@ def forward(self): class SimpleGuide(torch.nn.Module): def __init__(self): super().__init__() - self.a_loc = torch.nn.Parameter(torch.tensor(0.)) - self.b_loc = torch.nn.Parameter(torch.tensor(0.)) - + self.a_loc = torch.nn.Parameter(torch.tensor(0.0)) + self.b_loc = torch.nn.Parameter(torch.tensor(0.0)) + def forward(self): pyro.sample("a", dist.Delta(self.a_loc)) pyro.sample("b", dist.Delta(self.b_loc)) @@ -27,13 +27,14 @@ def test_nmc_log_likelihood(): model = SimpleModel() guide = SimpleGuide() num_monte_carlo_outer = 100 - data = Predictive(model, guide=guide, num_samples=num_monte_carlo_outer, return_sites=["y"])() + data = Predictive( + model, guide=guide, num_samples=num_monte_carlo_outer, return_sites=["y"] + )() nmc_ll = NMCLogLikelihood(model, guide, num_samples=100) ll_at_data = nmc_ll(data) print(ll_at_data) nmc_ll_single = NMCLogLikelihoodSingle(model, guide, num_samples=10) - nmc_ll_single._vectorized_log_prob({'y': torch.tensor(1.)}) - nmc_ll({'y': torch.tensor([1.])}) + nmc_ll_single._vectorized_log_prob({"y": torch.tensor(1.0)}) + nmc_ll({"y": torch.tensor([1.0])}) ll_at_data_single = nmc_ll_single.vectorized_log_prob(data) - From 8e171f481a955fca81db5b8d74265cbd601283fd Mon Sep 17 00:00:00 2001 From: Raj Agrawal Date: Mon, 20 Nov 2023 11:40:50 -0500 Subject: [PATCH 10/66] moving test autograd to internals and deleted old utils file --- chirho/robust/internals.py | 366 ++++++++++++++++++++-------------- chirho/robust/utils.py | 114 ----------- tests/robust/test_autograd.py | 273 +------------------------ 3 files changed, 222 insertions(+), 531 deletions(-) delete mode 100644 chirho/robust/utils.py diff --git a/chirho/robust/internals.py b/chirho/robust/internals.py index 6fedaf060..268a4bf45 100644 --- a/chirho/robust/internals.py +++ b/chirho/robust/internals.py @@ -1,53 +1,216 @@ +import functools import math -import collections +from typing import ( + Callable, + Concatenate, + Container, + Dict, + Mapping, + Optional, + ParamSpec, + Tuple, + TypeVar, +) + import pyro -from typing import Container, ParamSpec, Callable, Tuple, TypeVar, Optional, Dict, List, Protocol import torch -from pyro.infer import Predictive -from pyro.infer import Trace_ELBO -from pyro.infer.elbo import ELBOModule -from pyro.infer.importance import vectorized_importance_weights -from pyro.poutine import block, replay, trace, mask from chirho.indexed.handlers import DependentMaskMessenger from chirho.observational.handlers import condition -from torch.func import functional_call -from functools import partial -from chirho.robust.utils import conjugate_gradient_solve pyro.settings.set(module_local_params=True) P = ParamSpec("P") Q = ParamSpec("Q") S = TypeVar("S") -T = TypeVar("T") # This will be a torch.Tensor usually +T = TypeVar("T") # This will be a torch.Tensor usually Point = dict[str, T] Guide = Callable[P, Optional[T | Point[T]]] -def make_empirical_inverse_fisher_vp( - log_prob: torch.nn.Module, - **solver_kwargs, -) -> Callable: +@functools.singledispatch +def make_flatten_unflatten( + v, +) -> Tuple[Callable[[T], torch.Tensor], Callable[[torch.Tensor], T]]: + raise NotImplementedError + + +@make_flatten_unflatten.register(tuple) +def _make_flatten_unflatten_tuple(v: Tuple[torch.Tensor, ...]): + sizes = [x.size() for x in v] + + def flatten(xs: Tuple[torch.Tensor, ...]) -> torch.Tensor: + return torch.cat([x.reshape(-1) for x in xs], dim=0) + + def unflatten(x: torch.Tensor) -> Tuple[torch.Tensor, ...]: + tensors = [] + i = 0 + for size in sizes: + num_elements = torch.prod(torch.tensor(size)) + tensors.append(x[i : i + num_elements].view(size)) + i += num_elements + return tuple(tensors) + + return flatten, unflatten + + +@make_flatten_unflatten.register(dict) +def _make_flatten_unflatten_dict(d: Dict[str, torch.Tensor]): + def flatten(d: Dict[str, torch.Tensor]) -> torch.Tensor: + r""" + Flatten a dictionary of tensors into a single vector. + """ + return torch.cat([v.flatten() for k, v in d.items()]) + + def unflatten(x: torch.Tensor) -> Dict[str, torch.Tensor]: + r""" + Unflatten a vector into a dictionary of tensors. + """ + return dict( + zip( + d.keys(), + [ + v_flat.reshape(v.shape) + for v, v_flat in zip( + d.values(), torch.split(x, [v.numel() for k, v in d.items()]) + ) + ], + ) + ) + + return flatten, unflatten + + +def _flat_conjugate_gradient_solve( + f_Ax: Callable[[torch.Tensor], torch.Tensor], + b: torch.Tensor, + *, + cg_iters: Optional[int] = None, + residual_tol: float = 1e-10, +) -> torch.Tensor: + r"""Use Conjugate Gradient iteration to solve Ax = b. Demmel p 312. + + Args: + f_Ax (callable): A function to compute matrix vector product. + b (torch.Tensor): Right hand side of the equation to solve. + cg_iters (int): Number of iterations to run conjugate gradient + algorithm. + residual_tol (float): Tolerence for convergence. + + Returns: + torch.Tensor: Solution x* for equation Ax = b. + + Notes: This code is adapted from https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py + """ + if cg_iters is None: + cg_iters = b.numel() + + p = b.clone() + r = b.clone() + x = torch.zeros_like(b) + z = f_Ax(p) + rdotr = torch.dot(r, r) + v = rdotr / torch.dot(p, z) + newrdotr = rdotr + mu = newrdotr / rdotr + + zeros_x = torch.zeros_like(x) + zeros_r = torch.zeros_like(r) + + for _ in range(cg_iters): + not_converged = rdotr > residual_tol + + z = torch.where(not_converged, f_Ax(p), z) + v = torch.where(not_converged, rdotr / torch.dot(p, z), v) + x += torch.where(not_converged, v * p, zeros_x) + r -= torch.where(not_converged, v * z, zeros_r) + newrdotr = torch.where(not_converged, torch.dot(r, r), newrdotr) + mu = torch.where(not_converged, newrdotr / rdotr, mu) + p = torch.where(not_converged, r + mu * p, p) + rdotr = torch.where(not_converged, newrdotr, rdotr) + + # rdotr = newrdotr + # if rdotr < residual_tol: + # break + return x + + +def conjugate_gradient_solve(f_Ax: Callable[[T], T], b: T, **kwargs) -> T: + flatten, unflatten = make_flatten_unflatten(b) + + def f_Ax_flat(v: torch.Tensor) -> torch.Tensor: + v_unflattened = unflatten(v) + result_unflattened = f_Ax(v_unflattened) + return flatten(result_unflattened) + + return unflatten(_flat_conjugate_gradient_solve(f_Ax_flat, flatten(b), **kwargs)) + + +def make_functional_call( + mod: Callable[P, T] +) -> Tuple[ + Mapping[str, torch.Tensor], Callable[Concatenate[Mapping[str, torch.Tensor], P], T] +]: + assert isinstance(mod, torch.nn.Module) + return dict(mod.named_parameters()), torch.func.functionalize( + pyro.validation_enabled(False)( + functools.partial(torch.func.functional_call, mod) + ) + ) + + +def make_bound_batched_func_log_prob( + log_prob: Callable[[T], torch.Tensor], data: T +) -> Tuple[ + Mapping[str, torch.Tensor], Callable[[Mapping[str, torch.Tensor]], torch.Tensor] +]: + assert isinstance(log_prob, torch.nn.Module) + log_prob_params_and_fn = make_functional_call(log_prob) + log_prob_params: Mapping[str, torch.Tensor] = log_prob_params_and_fn[0] + func_log_prob: Callable[ + [Mapping[str, torch.Tensor], T], torch.Tensor + ] = log_prob_params_and_fn[1] - fvp = make_empirical_fisher_vp(log_prob) - return lambda v: conjugate_gradient_solve(fvp, v, **solver_kwargs) + batched_func_log_prob: Callable[ + [Mapping[str, torch.Tensor], T], torch.Tensor + ] = torch.vmap(func_log_prob, in_dims=(None, 0), randomness="different") + + def bound_batched_func_log_prob(params: Mapping[str, torch.Tensor]) -> torch.Tensor: + return batched_func_log_prob(params, data) + + return log_prob_params, bound_batched_func_log_prob def make_empirical_fisher_vp( - log_prob: torch.nn.Module, -) -> Callable: + log_prob: Callable[[T], torch.Tensor], data: T +) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, torch.Tensor]]: + log_prob_params, bound_batched_func_log_prob = make_bound_batched_func_log_prob( + log_prob, data + ) - def _empirical_fisher_vp(v: T) -> T: - params = dict(log_prob.named_parameters()) - # TODO: I think vnew = RHS[1] - vnew = torch.func.jvp(partial(torch.func.functional_call, log_prob), params, v) - (_, vjp_fn) = torch.func.vjp(partial(torch.func.functional_call, log_prob), params) - return vjp_fn(vnew) + def _empirical_fisher_vp( + v: Mapping[str, torch.Tensor] + ) -> Mapping[str, torch.Tensor]: + (_, vnew) = torch.func.jvp( + bound_batched_func_log_prob, (log_prob_params,), (v,) + ) + (_, vjp_fn) = torch.func.vjp(bound_batched_func_log_prob, log_prob_params) + result: Mapping[str, torch.Tensor] = vjp_fn(vnew / vnew.shape[0])[0] + return result return _empirical_fisher_vp +def make_empirical_inverse_fisher_vp( + log_prob: Callable[[T], torch.Tensor], + data: T, + **solver_kwargs, +) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, torch.Tensor]]: + assert isinstance(log_prob, torch.nn.Module) + fvp = make_empirical_fisher_vp(log_prob, data) + return functools.partial(conjugate_gradient_solve, fvp, **solver_kwargs) + + class UnmaskNamedSites(DependentMaskMessenger): names: Container[str] @@ -55,144 +218,41 @@ def __init__(self, names: Container[str]): self.names = names def get_mask( - self, dist: pyro.distributions.Distribution, value: Optional[torch.Tensor], device: torch.device, name: str + self, + dist: pyro.distributions.Distribution, + value: Optional[torch.Tensor], + device: torch.device, + name: str, ) -> torch.Tensor: return torch.tensor(name in self.names, device=device) -class NMCLogLikelihood(torch.nn.Module): - +class NMCLogPredictiveLikelihood(torch.nn.Module): def __init__( self, - model: pyro.nn.PyroModule, - guide: pyro.nn.PyroModule, - num_samples: int, + model: torch.nn.Module, + guide: torch.nn.Module, + *, + num_samples: int = 1, + max_plate_nesting: int = 1, ): super().__init__() self.model = model self.guide = guide self.num_samples = num_samples - - def _vectorized_log_prob(self, single_datapoint: Point[torch.Tensor], *args, **kwargs) -> torch.Tensor: - masked_guide = pyro.poutine.mask(mask=False)(self.guide) # Mask all sites in guide - masked_model = UnmaskNamedSites(names=set(single_datapoint.keys()))(condition(data=single_datapoint)(self.model)) - log_weights = pyro.infer.importance.vectorized_importance_weights(masked_model, masked_guide, *args, num_samples=self.num_samples, max_plate_nesting=1, **kwargs)[0] - return torch.logsumexp(log_weights, dim=0) - math.log(self.num_samples) - - def vectorized_log_prob(self, data: Point[torch.Tensor], *args, **kwargs) -> torch.Tensor: - num_monte_carlo_outer = data[next(iter(data))].shape[0] - log_prob_func = torch.func.vmap( - torch.func.functionalize(pyro.validation_enabled(False)(partial(torch.func.functional_call, self._vectorized_log_prob))), - in_dims=(None, 0), - randomness='different' - ) - # log_prob_func = torch.func.vmap( - # pyro.validation_enabled(False)(self._vectorized_log_prob), - # in_dims=(0,), - # randomness='different' - # ) - import pdb; pdb.set_trace() - pyro.validation_enabled(False)(self._vectorized_log_prob) - # return log_prob_func(dict(self.named_parameters()), data)[0] / (num_monte_carlo_outer ** 0.5) - return log_prob_func(data)[0] / (num_monte_carlo_outer ** 0.5) + self.max_plate_nesting = max_plate_nesting def forward(self, data: Point[torch.Tensor], *args, **kwargs) -> torch.Tensor: - num_monte_carlo_outer = data[next(iter(data))].shape[0] - log_weights = [] - for i in range(num_monte_carlo_outer): - log_weights_i = [] - for j in range(self.num_samples): - masked_guide = pyro.poutine.mask(mask=False)(self.guide) # Mask all sites in guide - masked_model = UnmaskNamedSites(names=set(data.keys()))(condition(data={k: v[i] for k, v in data.items()})(self.model)) - log_weight_ij = -1. * pyro.infer.Trace_ELBO().differentiable_loss(masked_model, masked_guide, *args, **kwargs) # -1 since negative elbo here - log_weights_i.append(log_weight_ij) - log_weight_i = torch.logsumexp(torch.stack(log_weights_i), dim=0) - math.log(self.num_samples) - log_weights.append(log_weight_i) - - log_weights = torch.stack(log_weights) - assert log_weights.shape == (num_monte_carlo_outer,) - return log_weights / (num_monte_carlo_outer ** 0.5) - - -class NMCLogLikelihoodSingle(NMCLogLikelihood): - def forward(self, data: Point[torch.Tensor], *args, **kwargs) -> torch.Tensor: - masked_guide = pyro.poutine.mask(mask=False)(self.guide) # Mask all sites in guide - masked_model = UnmaskNamedSites(names=set(data.keys()))(condition(data=data)(self.model)) - log_weights = pyro.infer.importance.vectorized_importance_weights(masked_model, masked_guide, *args, num_samples=self.num_samples, max_plate_nesting=1, **kwargs)[0] - return torch.logsumexp(log_weights * self.guide.zzz.w, dim=0) - math.log(self.num_samples) - - -class DummyAutoNormal(pyro.infer.autoguide.AutoNormal): - - def __getattr__(self, name): - # PyroParams trigger pyro.param statements. - if "_pyro_params" in self.__dict__: - _pyro_params = self.__dict__["_pyro_params"] - if name in _pyro_params: - constraint, event_dim = _pyro_params[name] - unconstrained_value = getattr(self, name + "_unconstrained") - import weakref - constrained_value = torch.distributions.transform_to(constraint)(unconstrained_value) - constrained_value.unconstrained = weakref.ref(unconstrained_value) - return constrained_value - return super().__getattr__(name) - - -if __name__ == "__main__": - import pyro - import pyro.distributions as dist - - # Create simple pyro model - class SimpleModel(pyro.nn.PyroModule): - def forward(self): - a = pyro.sample("a", dist.Normal(0, 1)) - b = pyro.sample("b", dist.Normal(0, 1)) - return pyro.sample("y", dist.Normal(a + b, 1)) - - model = SimpleModel() - - # Create guide on latents a and b - num_monte_carlo_outer = 100 - guide = DummyAutoNormal(block(model, hide=["y"])) - zzz = pyro.nn.PyroModule() - zzz.w = pyro.nn.PyroParam(torch.rand(10), dist.constraints.positive) - guide() - guide.zzz = zzz - print(dict(guide.named_parameters())) - data = Predictive(model, guide=guide, num_samples=num_monte_carlo_outer, return_sites=["y"])() - - # Create log likelihood function - log_prob = NMCLogLikelihoodSingle(model, guide, num_samples=10) - - log_prob_func = torch.func.vmap( - torch.func.functionalize(pyro.validation_enabled(False)(partial(torch.func.functional_call, log_prob))), - # pyro.validation_enabled(False)(partial(torch.func.functional_call, log_prob)), - in_dims=(None, 0), - randomness='different' - ) - - print(log_prob_func(dict(log_prob.named_parameters()), data)[0]) - - # func - grad_log_prob = torch.func.vjp(log_prob_func, dict(log_prob.named_parameters()), data)[1] - print(grad_log_prob(torch.ones(num_monte_carlo_outer))[0]) - - # autograd.functional - param_dict = collections.OrderedDict(log_prob.named_parameters()) - print(dict(zip(param_dict.keys(), torch.autograd.functional.vjp( - lambda *params: log_prob_func(dict(zip(param_dict.keys(), params)), data), - tuple(param_dict.values()), - torch.ones(num_monte_carlo_outer) - )[1]))) - - # print(torch.autograd.grad(partial(torch.func.functional_call, log_prob)(dict(log_prob.named_parameters()), data), tuple(log_prob.parameters()))) - # fvp = make_empirical_fisher_vp(log_prob) - - # v = tuple(torch.ones_like(p) for p in guide.parameters()) - - # print(v, fvp(v)) - - - - - + masked_guide = pyro.poutine.mask(mask=False)(self.guide) + masked_model = UnmaskNamedSites(names=set(data.keys()))( + condition(data=data)(self.model) + ) + log_weights = pyro.infer.importance.vectorized_importance_weights( + masked_model, + masked_guide, + *args, + num_samples=self.num_samples, + max_plate_nesting=self.max_plate_nesting, + **kwargs, + )[0] + return torch.logsumexp(log_weights, dim=0) - math.log(self.num_samples) diff --git a/chirho/robust/utils.py b/chirho/robust/utils.py deleted file mode 100644 index 68367a76a..000000000 --- a/chirho/robust/utils.py +++ /dev/null @@ -1,114 +0,0 @@ -import collections -import functools -import inspect -from typing import Callable, Dict, Optional, Tuple, TypeVar - -import torch - -T = TypeVar("T") - - -@functools.singledispatch -def make_flatten_unflatten( - v, -) -> Tuple[Callable[[T], torch.Tensor], Callable[[torch.Tensor], T]]: - raise NotImplementedError - - -@make_flatten_unflatten.register(tuple) -def _make_flatten_unflatten_tuple(v: Tuple[torch.Tensor, ...]): - sizes = [x.size() for x in v] - - def flatten(xs: Tuple[torch.Tensor, ...]) -> torch.Tensor: - return torch.cat([x.reshape(-1) for x in xs], dim=0) - - def unflatten(x: torch.Tensor) -> Tuple[torch.Tensor, ...]: - tensors = [] - i = 0 - for size in sizes: - num_elements = torch.prod(torch.tensor(size)) - tensors.append(x[i : i + num_elements].view(size)) - i += num_elements - return tuple(tensors) - - return flatten, unflatten - - -@make_flatten_unflatten.register(dict) -def _make_flatten_unflatten_dict(d: Dict[str, torch.Tensor]): - def flatten(d: Dict[str, torch.Tensor]) -> torch.Tensor: - r""" - Flatten a dictionary of tensors into a single vector. - """ - return torch.cat([v.flatten() for k, v in d.items()]) - - def unflatten(x: torch.Tensor) -> Dict[str, torch.Tensor]: - r""" - Unflatten a vector into a dictionary of tensors. - """ - return collections.OrderedDict( - zip( - d.keys(), - [ - v_flat.reshape(v.shape) - for v, v_flat in zip( - d.values(), torch.split(x, [v.numel() for k, v in d.items()]) - ) - ], - ) - ) - - return flatten, unflatten - - -def conjugate_gradient_solve(f_Ax: Callable[[T], T], b: T, **kwargs) -> T: - flatten, unflatten = make_flatten_unflatten(b) - f_Ax_flat = lambda v: flatten(f_Ax(flatten(v))) - b_flat = flatten(b) - return unflatten(_flat_conjugate_gradient_solve(f_Ax_flat, b_flat, **kwargs)) - - -def _flat_conjugate_gradient_solve( - f_Ax: Callable[[torch.Tensor], torch.Tensor], - b: torch.Tensor, - *, - cg_iters: Optional[int] = None, - residual_tol: float = 1e-10 -) -> torch.Tensor: - r"""Use Conjugate Gradient iteration to solve Ax = b. Demmel p 312. - - Args: - f_Ax (callable): A function to compute matrix vector product. - b (torch.Tensor): Right hand side of the equation to solve. - cg_iters (int): Number of iterations to run conjugate gradient - algorithm. - residual_tol (float): Tolerence for convergence. - - Returns: - torch.Tensor: Solution x* for equation Ax = b. - - Notes: This code is copied from https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py - """ - if cg_iters is None: - cg_iters = b.numel() - - p = b.clone() - r = b.clone() - x = torch.zeros_like(b) - rdotr = torch.dot(r, r) - - for _ in range(cg_iters): - z = f_Ax(p) - v = rdotr / torch.dot(p, z) - x += v * p - r -= v * z - newrdotr = torch.dot(r, r) - mu = newrdotr / rdotr - p = r + mu * p - - # Still executes loop but effectively stops update (can't break loop since we're using vmap) - # rdotr = torch.where(rdotr < residual_tol, rdotr, newrdotr) - # rdotr = newrdotr - # if rdotr < residual_tol: - # break - return x diff --git a/tests/robust/test_autograd.py b/tests/robust/test_autograd.py index a921d8da9..729b3bec3 100644 --- a/tests/robust/test_autograd.py +++ b/tests/robust/test_autograd.py @@ -1,273 +1,18 @@ -import collections -import functools -import math -from functools import partial -from typing import ( - Callable, - Concatenate, - Container, - Dict, - List, - Mapping, - Optional, - ParamSpec, - Protocol, - Tuple, - TypeVar, -) - +from typing import Mapping import pyro import pyro.distributions as dist import torch -from pyro.infer import Predictive, Trace_ELBO -from pyro.infer.elbo import ELBOModule -from pyro.infer.importance import vectorized_importance_weights -from pyro.poutine import block, mask, replay, trace -from torch.func import functional_call - -from chirho.indexed.handlers import DependentMaskMessenger from chirho.observational.handlers import condition +from chirho.robust.internals import ( + NMCLogPredictiveLikelihood, + make_empirical_inverse_fisher_vp, + make_flatten_unflatten, + make_functional_call, + Point, +) pyro.settings.set(module_local_params=True) -P = ParamSpec("P") -Q = ParamSpec("Q") -S = TypeVar("S") -T = TypeVar("T") # This will be a torch.Tensor usually - -Point = dict[str, T] -Guide = Callable[P, Optional[T | Point[T]]] - - -@functools.singledispatch -def make_flatten_unflatten( - v, -) -> Tuple[Callable[[T], torch.Tensor], Callable[[torch.Tensor], T]]: - raise NotImplementedError - - -@make_flatten_unflatten.register(tuple) -def _make_flatten_unflatten_tuple(v: Tuple[torch.Tensor, ...]): - sizes = [x.size() for x in v] - - def flatten(xs: Tuple[torch.Tensor, ...]) -> torch.Tensor: - return torch.cat([x.reshape(-1) for x in xs], dim=0) - - def unflatten(x: torch.Tensor) -> Tuple[torch.Tensor, ...]: - tensors = [] - i = 0 - for size in sizes: - num_elements = torch.prod(torch.tensor(size)) - tensors.append(x[i : i + num_elements].view(size)) - i += num_elements - return tuple(tensors) - - return flatten, unflatten - - -@make_flatten_unflatten.register(dict) -def _make_flatten_unflatten_dict(d: Dict[str, torch.Tensor]): - def flatten(d: Dict[str, torch.Tensor]) -> torch.Tensor: - r""" - Flatten a dictionary of tensors into a single vector. - """ - return torch.cat([v.flatten() for k, v in d.items()]) - - def unflatten(x: torch.Tensor) -> Dict[str, torch.Tensor]: - r""" - Unflatten a vector into a dictionary of tensors. - """ - return dict( - zip( - d.keys(), - [ - v_flat.reshape(v.shape) - for v, v_flat in zip( - d.values(), torch.split(x, [v.numel() for k, v in d.items()]) - ) - ], - ) - ) - - return flatten, unflatten - - -def _flat_conjugate_gradient_solve( - f_Ax: Callable[[torch.Tensor], torch.Tensor], - b: torch.Tensor, - *, - cg_iters: Optional[int] = None, - residual_tol: float = 1e-10, -) -> torch.Tensor: - r"""Use Conjugate Gradient iteration to solve Ax = b. Demmel p 312. - - Args: - f_Ax (callable): A function to compute matrix vector product. - b (torch.Tensor): Right hand side of the equation to solve. - cg_iters (int): Number of iterations to run conjugate gradient - algorithm. - residual_tol (float): Tolerence for convergence. - - Returns: - torch.Tensor: Solution x* for equation Ax = b. - - Notes: This code is adapted from https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py - """ - if cg_iters is None: - cg_iters = b.numel() - - p = b.clone() - r = b.clone() - x = torch.zeros_like(b) - z = f_Ax(p) - rdotr = torch.dot(r, r) - v = rdotr / torch.dot(p, z) - newrdotr = rdotr - mu = newrdotr / rdotr - - zeros_x = torch.zeros_like(x) - zeros_r = torch.zeros_like(r) - - for _ in range(cg_iters): - not_converged = rdotr > residual_tol - - z = torch.where(not_converged, f_Ax(p), z) - v = torch.where(not_converged, rdotr / torch.dot(p, z), v) - x += torch.where(not_converged, v * p, zeros_x) - r -= torch.where(not_converged, v * z, zeros_r) - newrdotr = torch.where(not_converged, torch.dot(r, r), newrdotr) - mu = torch.where(not_converged, newrdotr / rdotr, mu) - p = torch.where(not_converged, r + mu * p, p) - rdotr = torch.where(not_converged, newrdotr, rdotr) - - # rdotr = newrdotr - # if rdotr < residual_tol: - # break - return x - - -def conjugate_gradient_solve(f_Ax: Callable[[T], T], b: T, **kwargs) -> T: - flatten, unflatten = make_flatten_unflatten(b) - - def f_Ax_flat(v: torch.Tensor) -> torch.Tensor: - v_unflattened = unflatten(v) - result_unflattened = f_Ax(v_unflattened) - return flatten(result_unflattened) - - return unflatten(_flat_conjugate_gradient_solve(f_Ax_flat, flatten(b), **kwargs)) - - -def make_functional_call( - mod: Callable[P, T] -) -> Tuple[ - Mapping[str, torch.Tensor], Callable[Concatenate[Mapping[str, torch.Tensor], P], T] -]: - assert isinstance(mod, torch.nn.Module) - return dict(mod.named_parameters()), torch.func.functionalize( - pyro.validation_enabled(False)( - functools.partial(torch.func.functional_call, mod) - ) - ) - - -def make_bound_batched_func_log_prob( - log_prob: Callable[[T], torch.Tensor], data: T -) -> Tuple[ - Mapping[str, torch.Tensor], Callable[[Mapping[str, torch.Tensor]], torch.Tensor] -]: - assert isinstance(log_prob, torch.nn.Module) - log_prob_params_and_fn = make_functional_call(log_prob) - log_prob_params: Mapping[str, torch.Tensor] = log_prob_params_and_fn[0] - func_log_prob: Callable[ - [Mapping[str, torch.Tensor], T], torch.Tensor - ] = log_prob_params_and_fn[1] - - batched_func_log_prob: Callable[ - [Mapping[str, torch.Tensor], T], torch.Tensor - ] = torch.vmap(func_log_prob, in_dims=(None, 0), randomness="different") - - def bound_batched_func_log_prob(params: Mapping[str, torch.Tensor]) -> torch.Tensor: - return batched_func_log_prob(params, data) - - return log_prob_params, bound_batched_func_log_prob - - -def make_empirical_fisher_vp( - log_prob: Callable[[T], torch.Tensor], data: T -) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, torch.Tensor]]: - log_prob_params, bound_batched_func_log_prob = make_bound_batched_func_log_prob( - log_prob, data - ) - - def _empirical_fisher_vp( - v: Mapping[str, torch.Tensor] - ) -> Mapping[str, torch.Tensor]: - (_, vnew) = torch.func.jvp( - bound_batched_func_log_prob, (log_prob_params,), (v,) - ) - (_, vjp_fn) = torch.func.vjp(bound_batched_func_log_prob, log_prob_params) - result: Mapping[str, torch.Tensor] = vjp_fn(vnew / vnew.shape[0])[0] - return result - - return _empirical_fisher_vp - - -def make_empirical_inverse_fisher_vp( - log_prob: Callable[[T], torch.Tensor], - data: T, - **solver_kwargs, -) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, torch.Tensor]]: - assert isinstance(log_prob, torch.nn.Module) - fvp = make_empirical_fisher_vp(log_prob, data) - return functools.partial(conjugate_gradient_solve, fvp, **solver_kwargs) - - -class UnmaskNamedSites(DependentMaskMessenger): - names: Container[str] - - def __init__(self, names: Container[str]): - self.names = names - - def get_mask( - self, - dist: pyro.distributions.Distribution, - value: Optional[torch.Tensor], - device: torch.device, - name: str, - ) -> torch.Tensor: - return torch.tensor(name in self.names, device=device) - - -class NMCLogPredictiveLikelihood(torch.nn.Module): - def __init__( - self, - model: torch.nn.Module, - guide: torch.nn.Module, - *, - num_samples: int = 1, - max_plate_nesting: int = 1, - ): - super().__init__() - self.model = model - self.guide = guide - self.num_samples = num_samples - self.max_plate_nesting = max_plate_nesting - - def forward(self, data: Point[torch.Tensor], *args, **kwargs) -> torch.Tensor: - masked_guide = pyro.poutine.mask(mask=False)(self.guide) - masked_model = UnmaskNamedSites(names=set(data.keys()))( - condition(data=data)(self.model) - ) - log_weights = pyro.infer.importance.vectorized_importance_weights( - masked_model, - masked_guide, - *args, - num_samples=self.num_samples, - max_plate_nesting=self.max_plate_nesting, - **kwargs, - )[0] - return torch.logsumexp(log_weights, dim=0) - math.log(self.num_samples) - def test_nmc_log_likelihood(): # Create simple pyro model @@ -294,7 +39,7 @@ def forward(self): guide = SimpleGuide() # Create guide on latents a and b - num_samples_outer = 100000 + num_samples_outer = 10000 data = pyro.infer.Predictive( model, guide=guide, From 93cc014bb3ef87d4031369800a883759015c553e Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 21 Nov 2023 12:22:28 -0500 Subject: [PATCH 11/66] sketch influence implementation --- chirho/robust/internals.py | 79 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 75 insertions(+), 4 deletions(-) diff --git a/chirho/robust/internals.py b/chirho/robust/internals.py index 268a4bf45..67f7e237b 100644 --- a/chirho/robust/internals.py +++ b/chirho/robust/internals.py @@ -1,6 +1,7 @@ import functools import math from typing import ( + Any, Callable, Concatenate, Container, @@ -191,10 +192,8 @@ def make_empirical_fisher_vp( def _empirical_fisher_vp( v: Mapping[str, torch.Tensor] ) -> Mapping[str, torch.Tensor]: - (_, vnew) = torch.func.jvp( - bound_batched_func_log_prob, (log_prob_params,), (v,) - ) - (_, vjp_fn) = torch.func.vjp(bound_batched_func_log_prob, log_prob_params) + vnew = torch.func.jvp(bound_batched_func_log_prob, (log_prob_params,), (v,))[1] + vjp_fn = torch.func.vjp(bound_batched_func_log_prob, log_prob_params)[1] result: Mapping[str, torch.Tensor] = vjp_fn(vnew / vnew.shape[0])[0] return result @@ -256,3 +255,75 @@ def forward(self, data: Point[torch.Tensor], *args, **kwargs) -> torch.Tensor: **kwargs, )[0] return torch.logsumexp(log_weights, dim=0) - math.log(self.num_samples) + + +def linearize( + model: Callable[P, Any], + guide: Callable[P, Any], + *, + max_plate_nesting: int, + num_samples_outer: int, + num_samples_inner: Optional[int] = None, + cg_iters: Optional[int] = None, + cg_tol: float = 1e-10, +) -> Callable[Concatenate[Point[T], P], Mapping[str, torch.Tensor]]: + + assert isinstance(model, torch.nn.Module) + assert isinstance(guide, torch.nn.Module) + if num_samples_inner is None: + num_samples_inner = num_samples_outer ** 2 + + predictive = pyro.infer.Predictive( + model, + guide=guide, + num_samples=num_samples_outer, + parallel=True, + ) + + log_prob = NMCLogPredictiveLikelihood( + model, + guide, + num_samples=num_samples_inner, + max_plate_nesting=max_plate_nesting + ) + log_prob_params, log_prob_func = make_functional_call(log_prob) + score_fn = torch.func.grad(log_prob_func) + + cg_solver = functools.partial(conjugate_gradient_solve, cg_iters=cg_iters, cg_tol=cg_tol) + + @functools.wraps(score_fn) + def _fn(point: Point[T], *args: P.args, **kwargs: P.kwargs) -> Mapping[str, torch.Tensor]: + data = predictive(*args, **kwargs) + fvp = make_empirical_fisher_vp(log_prob, data) + point_score = score_fn(log_prob_params, point, *args, **kwargs) + return cg_solver(fvp, point_score) + + return _fn + + +def influence_fn( + model: Callable[P, Any], + guide: Callable[P, Any], + functional: Optional[Callable[[Callable[P, Any], Callable[P, Any]], Callable[P, S]]] = None, + **linearize_kwargs +) -> Callable[Concatenate[Point[T], P], S]: + + linearized = linearize(model, guide, **linearize_kwargs) + + if functional is None: + return linearized + + target = functional(model, guide) + assert isinstance(target, torch.nn.Module) + target_params, func_target = make_functional_call(target) + + @functools.wraps(target) + def _fn(point: Point[T], *args: P.args, **kwargs: P.kwargs) -> S: + param_eif = linearized(point, *args, **kwargs) + return torch.func.jvp( + lambda p: func_target(p, *args, **kwargs), + (target_params,), + (param_eif,) + )[1] + + return _fn From 9bc704ce1072293d34354d56081eee6f2e2764d1 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 21 Nov 2023 12:58:33 -0500 Subject: [PATCH 12/66] fix more args --- chirho/robust/internals.py | 104 ++++++++++++++++--------------------- 1 file changed, 44 insertions(+), 60 deletions(-) diff --git a/chirho/robust/internals.py b/chirho/robust/internals.py index 67f7e237b..7bd6f9466 100644 --- a/chirho/robust/internals.py +++ b/chirho/robust/internals.py @@ -6,6 +6,7 @@ Concatenate, Container, Dict, + Generic, Mapping, Optional, ParamSpec, @@ -17,6 +18,7 @@ import torch from chirho.indexed.handlers import DependentMaskMessenger from chirho.observational.handlers import condition +from chirho.observational.ops import Observation pyro.settings.set(module_local_params=True) @@ -25,8 +27,10 @@ S = TypeVar("S") T = TypeVar("T") # This will be a torch.Tensor usually -Point = dict[str, T] -Guide = Callable[P, Optional[T | Point[T]]] +Model = Callable[P, Any] +Point = Mapping[str, Observation[T]] +Functional = Callable[[Model[P], Model[P]], Callable[P, S]] +ParamDict = Mapping[str, torch.Tensor] @functools.singledispatch @@ -149,67 +153,45 @@ def f_Ax_flat(v: torch.Tensor) -> torch.Tensor: def make_functional_call( mod: Callable[P, T] -) -> Tuple[ - Mapping[str, torch.Tensor], Callable[Concatenate[Mapping[str, torch.Tensor], P], T] -]: +) -> Tuple[ParamDict, Callable[Concatenate[ParamDict, P], T]]: assert isinstance(mod, torch.nn.Module) - return dict(mod.named_parameters()), torch.func.functionalize( + param_dict: ParamDict = dict(mod.named_parameters()) + return param_dict, torch.func.functionalize( pyro.validation_enabled(False)( functools.partial(torch.func.functional_call, mod) ) ) -def make_bound_batched_func_log_prob( - log_prob: Callable[[T], torch.Tensor], data: T -) -> Tuple[ - Mapping[str, torch.Tensor], Callable[[Mapping[str, torch.Tensor]], torch.Tensor] -]: - assert isinstance(log_prob, torch.nn.Module) - log_prob_params_and_fn = make_functional_call(log_prob) - log_prob_params: Mapping[str, torch.Tensor] = log_prob_params_and_fn[0] - func_log_prob: Callable[ - [Mapping[str, torch.Tensor], T], torch.Tensor - ] = log_prob_params_and_fn[1] +def make_empirical_fisher_vp( + func_log_prob: Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor], + log_prob_params: ParamDict, + data: Point[T], + *args: P.args, + **kwargs: P.kwargs +) -> Callable[[ParamDict], ParamDict]: batched_func_log_prob: Callable[ - [Mapping[str, torch.Tensor], T], torch.Tensor - ] = torch.vmap(func_log_prob, in_dims=(None, 0), randomness="different") + [ParamDict, Point[T]], torch.Tensor + ] = torch.vmap( + lambda p, data: func_log_prob(p, data, *args, **kwargs), + in_dims=(None, 0), + randomness="different" + ) - def bound_batched_func_log_prob(params: Mapping[str, torch.Tensor]) -> torch.Tensor: + def bound_batched_func_log_prob(params: ParamDict) -> torch.Tensor: return batched_func_log_prob(params, data) - return log_prob_params, bound_batched_func_log_prob - + jvp_fn = functools.partial(torch.func.jvp, bound_batched_func_log_prob, (log_prob_params,)) + vjp_fn = torch.func.vjp(bound_batched_func_log_prob, log_prob_params)[1] -def make_empirical_fisher_vp( - log_prob: Callable[[T], torch.Tensor], data: T -) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, torch.Tensor]]: - log_prob_params, bound_batched_func_log_prob = make_bound_batched_func_log_prob( - log_prob, data - ) - - def _empirical_fisher_vp( - v: Mapping[str, torch.Tensor] - ) -> Mapping[str, torch.Tensor]: - vnew = torch.func.jvp(bound_batched_func_log_prob, (log_prob_params,), (v,))[1] - vjp_fn = torch.func.vjp(bound_batched_func_log_prob, log_prob_params)[1] - result: Mapping[str, torch.Tensor] = vjp_fn(vnew / vnew.shape[0])[0] - return result + def _empirical_fisher_vp(v: ParamDict) -> ParamDict: + jvp_log_prob_v = jvp_fn((v,))[1] + return vjp_fn(jvp_log_prob_v / jvp_log_prob_v.shape[0])[0] return _empirical_fisher_vp -def make_empirical_inverse_fisher_vp( - log_prob: Callable[[T], torch.Tensor], - data: T, - **solver_kwargs, -) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, torch.Tensor]]: - assert isinstance(log_prob, torch.nn.Module) - fvp = make_empirical_fisher_vp(log_prob, data) - return functools.partial(conjugate_gradient_solve, fvp, **solver_kwargs) - - class UnmaskNamedSites(DependentMaskMessenger): names: Container[str] @@ -226,7 +208,8 @@ def get_mask( return torch.tensor(name in self.names, device=device) -class NMCLogPredictiveLikelihood(torch.nn.Module): +class NMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module): + def __init__( self, model: torch.nn.Module, @@ -241,7 +224,7 @@ def __init__( self.num_samples = num_samples self.max_plate_nesting = max_plate_nesting - def forward(self, data: Point[torch.Tensor], *args, **kwargs) -> torch.Tensor: + def forward(self, data: Point[T], *args: P.args, **kwargs: P.kwargs) -> torch.Tensor: masked_guide = pyro.poutine.mask(mask=False)(self.guide) masked_model = UnmaskNamedSites(names=set(data.keys()))( condition(data=data)(self.model) @@ -258,15 +241,15 @@ def forward(self, data: Point[torch.Tensor], *args, **kwargs) -> torch.Tensor: def linearize( - model: Callable[P, Any], - guide: Callable[P, Any], + model: Model[P], + guide: Model[P], *, max_plate_nesting: int, num_samples_outer: int, num_samples_inner: Optional[int] = None, cg_iters: Optional[int] = None, cg_tol: float = 1e-10, -) -> Callable[Concatenate[Point[T], P], Mapping[str, torch.Tensor]]: +) -> Callable[Concatenate[Point[T], P], ParamDict]: assert isinstance(model, torch.nn.Module) assert isinstance(guide, torch.nn.Module) @@ -286,25 +269,26 @@ def linearize( num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting ) - log_prob_params, log_prob_func = make_functional_call(log_prob) - score_fn = torch.func.grad(log_prob_func) + log_prob_params, func_log_prob = make_functional_call(log_prob) + score_fn = torch.func.grad(func_log_prob) cg_solver = functools.partial(conjugate_gradient_solve, cg_iters=cg_iters, cg_tol=cg_tol) @functools.wraps(score_fn) - def _fn(point: Point[T], *args: P.args, **kwargs: P.kwargs) -> Mapping[str, torch.Tensor]: - data = predictive(*args, **kwargs) - fvp = make_empirical_fisher_vp(log_prob, data) - point_score = score_fn(log_prob_params, point, *args, **kwargs) + def _fn(point: Point[T], *args: P.args, **kwargs: P.kwargs) -> ParamDict: + with torch.no_grad(): + data: Point[T] = predictive(*args, **kwargs) + fvp = make_empirical_fisher_vp(func_log_prob, log_prob_params, data, *args, **kwargs) + point_score: ParamDict = score_fn(log_prob_params, point, *args, **kwargs) return cg_solver(fvp, point_score) return _fn def influence_fn( - model: Callable[P, Any], - guide: Callable[P, Any], - functional: Optional[Callable[[Callable[P, Any], Callable[P, Any]], Callable[P, S]]] = None, + model: Model[P], + guide: Model[P], + functional: Optional[Functional[P, S]] = None, **linearize_kwargs ) -> Callable[Concatenate[Point[T], P], S]: From cedb818a2bef5a4e113141bfd1b95d80a34ae004 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 21 Nov 2023 13:11:18 -0500 Subject: [PATCH 13/66] ops file --- chirho/robust/internals.py | 52 ++--------------------------- chirho/robust/ops.py | 67 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 50 deletions(-) create mode 100644 chirho/robust/ops.py diff --git a/chirho/robust/internals.py b/chirho/robust/internals.py index 7bd6f9466..1eda01878 100644 --- a/chirho/robust/internals.py +++ b/chirho/robust/internals.py @@ -1,13 +1,11 @@ import functools import math from typing import ( - Any, Callable, Concatenate, Container, Dict, Generic, - Mapping, Optional, ParamSpec, Tuple, @@ -18,20 +16,14 @@ import torch from chirho.indexed.handlers import DependentMaskMessenger from chirho.observational.handlers import condition -from chirho.observational.ops import Observation +from chirho.robust.ops import Model, Point, ParamDict, make_functional_call pyro.settings.set(module_local_params=True) P = ParamSpec("P") Q = ParamSpec("Q") S = TypeVar("S") -T = TypeVar("T") # This will be a torch.Tensor usually - -Model = Callable[P, Any] -Point = Mapping[str, Observation[T]] -Functional = Callable[[Model[P], Model[P]], Callable[P, S]] -ParamDict = Mapping[str, torch.Tensor] - +T = TypeVar("T") @functools.singledispatch def make_flatten_unflatten( @@ -151,18 +143,6 @@ def f_Ax_flat(v: torch.Tensor) -> torch.Tensor: return unflatten(_flat_conjugate_gradient_solve(f_Ax_flat, flatten(b), **kwargs)) -def make_functional_call( - mod: Callable[P, T] -) -> Tuple[ParamDict, Callable[Concatenate[ParamDict, P], T]]: - assert isinstance(mod, torch.nn.Module) - param_dict: ParamDict = dict(mod.named_parameters()) - return param_dict, torch.func.functionalize( - pyro.validation_enabled(False)( - functools.partial(torch.func.functional_call, mod) - ) - ) - - def make_empirical_fisher_vp( func_log_prob: Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor], log_prob_params: ParamDict, @@ -283,31 +263,3 @@ def _fn(point: Point[T], *args: P.args, **kwargs: P.kwargs) -> ParamDict: return cg_solver(fvp, point_score) return _fn - - -def influence_fn( - model: Model[P], - guide: Model[P], - functional: Optional[Functional[P, S]] = None, - **linearize_kwargs -) -> Callable[Concatenate[Point[T], P], S]: - - linearized = linearize(model, guide, **linearize_kwargs) - - if functional is None: - return linearized - - target = functional(model, guide) - assert isinstance(target, torch.nn.Module) - target_params, func_target = make_functional_call(target) - - @functools.wraps(target) - def _fn(point: Point[T], *args: P.args, **kwargs: P.kwargs) -> S: - param_eif = linearized(point, *args, **kwargs) - return torch.func.jvp( - lambda p: func_target(p, *args, **kwargs), - (target_params,), - (param_eif,) - )[1] - - return _fn diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py new file mode 100644 index 000000000..8642f8bdd --- /dev/null +++ b/chirho/robust/ops.py @@ -0,0 +1,67 @@ +import functools +from typing import ( + Any, + Callable, + Concatenate, + Mapping, + Optional, + ParamSpec, + Tuple, + TypeVar, +) + +import pyro +import torch +from chirho.observational.ops import Observation + +P = ParamSpec("P") +Q = ParamSpec("Q") +S = TypeVar("S") +T = TypeVar("T") + +Model = Callable[P, Any] +Point = Mapping[str, Observation[T]] +Functional = Callable[[Model[P], Model[P]], Callable[P, S]] +ParamDict = Mapping[str, torch.Tensor] + + +def make_functional_call( + mod: Callable[P, T] +) -> Tuple[ParamDict, Callable[Concatenate[ParamDict, P], T]]: + assert isinstance(mod, torch.nn.Module) + param_dict: ParamDict = dict(mod.named_parameters()) + mod_func: Callable[Concatenate[ParamDict, P], T] = functools.partial(torch.func.functional_call, mod) + functionalized_mod_func: Callable[Concatenate[ParamDict, P], T] = torch.func.functionalize( + pyro.validation_enabled(False)(mod_func) + ) + return param_dict, functionalized_mod_func + + +def influence_fn( + model: Model[P], + guide: Model[P], + functional: Optional[Functional[P, S]] = None, + **linearize_kwargs +) -> Callable[Concatenate[Point[T], P], S]: + + from chirho.robust.internals import linearize + + linearized = linearize(model, guide, **linearize_kwargs) + + if functional is None: + return linearized + + target = functional(model, guide) + assert isinstance(target, torch.nn.Module) + target_params, func_target = make_functional_call(target) + + @functools.wraps(target) + def _fn(point: Point[T], *args: P.args, **kwargs: P.kwargs) -> S: + param_eif = linearized(point, *args, **kwargs) + return torch.func.jvp( + lambda p: func_target(p, *args, **kwargs), + (target_params,), + (param_eif,) + )[1] + + return _fn From 418f79287613892d28809e9af7b042781be5ccfc Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 21 Nov 2023 13:13:29 -0500 Subject: [PATCH 14/66] file --- chirho/robust/internals.py | 14 +++++++++++++- chirho/robust/ops.py | 16 +--------------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/chirho/robust/internals.py b/chirho/robust/internals.py index 1eda01878..3089935e8 100644 --- a/chirho/robust/internals.py +++ b/chirho/robust/internals.py @@ -16,7 +16,7 @@ import torch from chirho.indexed.handlers import DependentMaskMessenger from chirho.observational.handlers import condition -from chirho.robust.ops import Model, Point, ParamDict, make_functional_call +from chirho.robust.ops import Model, Point, ParamDict pyro.settings.set(module_local_params=True) @@ -143,6 +143,18 @@ def f_Ax_flat(v: torch.Tensor) -> torch.Tensor: return unflatten(_flat_conjugate_gradient_solve(f_Ax_flat, flatten(b), **kwargs)) +def make_functional_call( + mod: Callable[P, T] +) -> Tuple[ParamDict, Callable[Concatenate[ParamDict, P], T]]: + assert isinstance(mod, torch.nn.Module) + param_dict: ParamDict = dict(mod.named_parameters()) + mod_func: Callable[Concatenate[ParamDict, P], T] = functools.partial(torch.func.functional_call, mod) + functionalized_mod_func: Callable[Concatenate[ParamDict, P], T] = torch.func.functionalize( + pyro.validation_enabled(False)(mod_func) + ) + return param_dict, functionalized_mod_func + + def make_empirical_fisher_vp( func_log_prob: Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor], log_prob_params: ParamDict, diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py index 8642f8bdd..d779eefcd 100644 --- a/chirho/robust/ops.py +++ b/chirho/robust/ops.py @@ -6,11 +6,9 @@ Mapping, Optional, ParamSpec, - Tuple, TypeVar, ) -import pyro import torch from chirho.observational.ops import Observation @@ -25,18 +23,6 @@ ParamDict = Mapping[str, torch.Tensor] -def make_functional_call( - mod: Callable[P, T] -) -> Tuple[ParamDict, Callable[Concatenate[ParamDict, P], T]]: - assert isinstance(mod, torch.nn.Module) - param_dict: ParamDict = dict(mod.named_parameters()) - mod_func: Callable[Concatenate[ParamDict, P], T] = functools.partial(torch.func.functional_call, mod) - functionalized_mod_func: Callable[Concatenate[ParamDict, P], T] = torch.func.functionalize( - pyro.validation_enabled(False)(mod_func) - ) - return param_dict, functionalized_mod_func - - def influence_fn( model: Model[P], guide: Model[P], @@ -44,7 +30,7 @@ def influence_fn( **linearize_kwargs ) -> Callable[Concatenate[Point[T], P], S]: - from chirho.robust.internals import linearize + from chirho.robust.internals import linearize, make_functional_call linearized = linearize(model, guide, **linearize_kwargs) From f792ddffb85189b1a99cfb43c132bab2c5120a87 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 21 Nov 2023 13:14:06 -0500 Subject: [PATCH 15/66] format --- chirho/robust/internals.py | 46 +++++++++++++++++++---------------- chirho/robust/ops.py | 16 +++--------- tests/robust/test_autograd.py | 4 ++- 3 files changed, 31 insertions(+), 35 deletions(-) diff --git a/chirho/robust/internals.py b/chirho/robust/internals.py index 3089935e8..8ce55b827 100644 --- a/chirho/robust/internals.py +++ b/chirho/robust/internals.py @@ -14,9 +14,10 @@ import pyro import torch + from chirho.indexed.handlers import DependentMaskMessenger from chirho.observational.handlers import condition -from chirho.robust.ops import Model, Point, ParamDict +from chirho.robust.ops import Model, ParamDict, Point pyro.settings.set(module_local_params=True) @@ -25,6 +26,7 @@ S = TypeVar("S") T = TypeVar("T") + @functools.singledispatch def make_flatten_unflatten( v, @@ -148,10 +150,12 @@ def make_functional_call( ) -> Tuple[ParamDict, Callable[Concatenate[ParamDict, P], T]]: assert isinstance(mod, torch.nn.Module) param_dict: ParamDict = dict(mod.named_parameters()) - mod_func: Callable[Concatenate[ParamDict, P], T] = functools.partial(torch.func.functional_call, mod) - functionalized_mod_func: Callable[Concatenate[ParamDict, P], T] = torch.func.functionalize( - pyro.validation_enabled(False)(mod_func) + mod_func: Callable[Concatenate[ParamDict, P], T] = functools.partial( + torch.func.functional_call, mod ) + functionalized_mod_func: Callable[ + Concatenate[ParamDict, P], T + ] = torch.func.functionalize(pyro.validation_enabled(False)(mod_func)) return param_dict, functionalized_mod_func @@ -160,21 +164,20 @@ def make_empirical_fisher_vp( log_prob_params: ParamDict, data: Point[T], *args: P.args, - **kwargs: P.kwargs + **kwargs: P.kwargs, ) -> Callable[[ParamDict], ParamDict]: - - batched_func_log_prob: Callable[ - [ParamDict, Point[T]], torch.Tensor - ] = torch.vmap( + batched_func_log_prob: Callable[[ParamDict, Point[T]], torch.Tensor] = torch.vmap( lambda p, data: func_log_prob(p, data, *args, **kwargs), in_dims=(None, 0), - randomness="different" + randomness="different", ) def bound_batched_func_log_prob(params: ParamDict) -> torch.Tensor: return batched_func_log_prob(params, data) - jvp_fn = functools.partial(torch.func.jvp, bound_batched_func_log_prob, (log_prob_params,)) + jvp_fn = functools.partial( + torch.func.jvp, bound_batched_func_log_prob, (log_prob_params,) + ) vjp_fn = torch.func.vjp(bound_batched_func_log_prob, log_prob_params)[1] def _empirical_fisher_vp(v: ParamDict) -> ParamDict: @@ -201,7 +204,6 @@ def get_mask( class NMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module): - def __init__( self, model: torch.nn.Module, @@ -216,7 +218,9 @@ def __init__( self.num_samples = num_samples self.max_plate_nesting = max_plate_nesting - def forward(self, data: Point[T], *args: P.args, **kwargs: P.kwargs) -> torch.Tensor: + def forward( + self, data: Point[T], *args: P.args, **kwargs: P.kwargs + ) -> torch.Tensor: masked_guide = pyro.poutine.mask(mask=False)(self.guide) masked_model = UnmaskNamedSites(names=set(data.keys()))( condition(data=data)(self.model) @@ -242,11 +246,10 @@ def linearize( cg_iters: Optional[int] = None, cg_tol: float = 1e-10, ) -> Callable[Concatenate[Point[T], P], ParamDict]: - assert isinstance(model, torch.nn.Module) assert isinstance(guide, torch.nn.Module) if num_samples_inner is None: - num_samples_inner = num_samples_outer ** 2 + num_samples_inner = num_samples_outer**2 predictive = pyro.infer.Predictive( model, @@ -256,21 +259,22 @@ def linearize( ) log_prob = NMCLogPredictiveLikelihood( - model, - guide, - num_samples=num_samples_inner, - max_plate_nesting=max_plate_nesting + model, guide, num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting ) log_prob_params, func_log_prob = make_functional_call(log_prob) score_fn = torch.func.grad(func_log_prob) - cg_solver = functools.partial(conjugate_gradient_solve, cg_iters=cg_iters, cg_tol=cg_tol) + cg_solver = functools.partial( + conjugate_gradient_solve, cg_iters=cg_iters, cg_tol=cg_tol + ) @functools.wraps(score_fn) def _fn(point: Point[T], *args: P.args, **kwargs: P.kwargs) -> ParamDict: with torch.no_grad(): data: Point[T] = predictive(*args, **kwargs) - fvp = make_empirical_fisher_vp(func_log_prob, log_prob_params, data, *args, **kwargs) + fvp = make_empirical_fisher_vp( + func_log_prob, log_prob_params, data, *args, **kwargs + ) point_score: ParamDict = score_fn(log_prob_params, point, *args, **kwargs) return cg_solver(fvp, point_score) diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py index d779eefcd..ce219666b 100644 --- a/chirho/robust/ops.py +++ b/chirho/robust/ops.py @@ -1,15 +1,8 @@ import functools -from typing import ( - Any, - Callable, - Concatenate, - Mapping, - Optional, - ParamSpec, - TypeVar, -) +from typing import Any, Callable, Concatenate, Mapping, Optional, ParamSpec, TypeVar import torch + from chirho.observational.ops import Observation P = ParamSpec("P") @@ -29,7 +22,6 @@ def influence_fn( functional: Optional[Functional[P, S]] = None, **linearize_kwargs ) -> Callable[Concatenate[Point[T], P], S]: - from chirho.robust.internals import linearize, make_functional_call linearized = linearize(model, guide, **linearize_kwargs) @@ -45,9 +37,7 @@ def influence_fn( def _fn(point: Point[T], *args: P.args, **kwargs: P.kwargs) -> S: param_eif = linearized(point, *args, **kwargs) return torch.func.jvp( - lambda p: func_target(p, *args, **kwargs), - (target_params,), - (param_eif,) + lambda p: func_target(p, *args, **kwargs), (target_params,), (param_eif,) )[1] return _fn diff --git a/tests/robust/test_autograd.py b/tests/robust/test_autograd.py index 729b3bec3..808ae101a 100644 --- a/tests/robust/test_autograd.py +++ b/tests/robust/test_autograd.py @@ -1,14 +1,16 @@ from typing import Mapping + import pyro import pyro.distributions as dist import torch + from chirho.observational.handlers import condition from chirho.robust.internals import ( NMCLogPredictiveLikelihood, + Point, make_empirical_inverse_fisher_vp, make_flatten_unflatten, make_functional_call, - Point, ) pyro.settings.set(module_local_params=True) From 88a100bd211fdb289a1bfff98b8d63df5b6e7330 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 21 Nov 2023 13:25:20 -0500 Subject: [PATCH 16/66] lint --- chirho/robust/internals.py | 72 +++++++++++++------------------------- 1 file changed, 25 insertions(+), 47 deletions(-) diff --git a/chirho/robust/internals.py b/chirho/robust/internals.py index 8ce55b827..2cae5d515 100644 --- a/chirho/robust/internals.py +++ b/chirho/robust/internals.py @@ -14,10 +14,9 @@ import pyro import torch - from chirho.indexed.handlers import DependentMaskMessenger from chirho.observational.handlers import condition -from chirho.robust.ops import Model, ParamDict, Point +from chirho.robust.ops import Model, Point, ParamDict pyro.settings.set(module_local_params=True) @@ -34,25 +33,6 @@ def make_flatten_unflatten( raise NotImplementedError -@make_flatten_unflatten.register(tuple) -def _make_flatten_unflatten_tuple(v: Tuple[torch.Tensor, ...]): - sizes = [x.size() for x in v] - - def flatten(xs: Tuple[torch.Tensor, ...]) -> torch.Tensor: - return torch.cat([x.reshape(-1) for x in xs], dim=0) - - def unflatten(x: torch.Tensor) -> Tuple[torch.Tensor, ...]: - tensors = [] - i = 0 - for size in sizes: - num_elements = torch.prod(torch.tensor(size)) - tensors.append(x[i : i + num_elements].view(size)) - i += num_elements - return tuple(tensors) - - return flatten, unflatten - - @make_flatten_unflatten.register(dict) def _make_flatten_unflatten_dict(d: Dict[str, torch.Tensor]): def flatten(d: Dict[str, torch.Tensor]) -> torch.Tensor: @@ -138,7 +118,7 @@ def conjugate_gradient_solve(f_Ax: Callable[[T], T], b: T, **kwargs) -> T: flatten, unflatten = make_flatten_unflatten(b) def f_Ax_flat(v: torch.Tensor) -> torch.Tensor: - v_unflattened = unflatten(v) + v_unflattened: T = unflatten(v) result_unflattened = f_Ax(v_unflattened) return flatten(result_unflattened) @@ -150,12 +130,10 @@ def make_functional_call( ) -> Tuple[ParamDict, Callable[Concatenate[ParamDict, P], T]]: assert isinstance(mod, torch.nn.Module) param_dict: ParamDict = dict(mod.named_parameters()) - mod_func: Callable[Concatenate[ParamDict, P], T] = functools.partial( - torch.func.functional_call, mod + mod_func: Callable[Concatenate[ParamDict, P], T] = functools.partial(torch.func.functional_call, mod) + functionalized_mod_func: Callable[Concatenate[ParamDict, P], T] = torch.func.functionalize( + pyro.validation_enabled(False)(mod_func) ) - functionalized_mod_func: Callable[ - Concatenate[ParamDict, P], T - ] = torch.func.functionalize(pyro.validation_enabled(False)(mod_func)) return param_dict, functionalized_mod_func @@ -164,20 +142,21 @@ def make_empirical_fisher_vp( log_prob_params: ParamDict, data: Point[T], *args: P.args, - **kwargs: P.kwargs, + **kwargs: P.kwargs ) -> Callable[[ParamDict], ParamDict]: - batched_func_log_prob: Callable[[ParamDict, Point[T]], torch.Tensor] = torch.vmap( + + batched_func_log_prob: Callable[ + [ParamDict, Point[T]], torch.Tensor + ] = torch.vmap( lambda p, data: func_log_prob(p, data, *args, **kwargs), in_dims=(None, 0), - randomness="different", + randomness="different" ) def bound_batched_func_log_prob(params: ParamDict) -> torch.Tensor: return batched_func_log_prob(params, data) - jvp_fn = functools.partial( - torch.func.jvp, bound_batched_func_log_prob, (log_prob_params,) - ) + jvp_fn = functools.partial(torch.func.jvp, bound_batched_func_log_prob, (log_prob_params,)) vjp_fn = torch.func.vjp(bound_batched_func_log_prob, log_prob_params)[1] def _empirical_fisher_vp(v: ParamDict) -> ParamDict: @@ -197,13 +176,14 @@ def get_mask( self, dist: pyro.distributions.Distribution, value: Optional[torch.Tensor], - device: torch.device, - name: str, + device: torch.device = torch.device("cpu"), + name: Optional[str] = None, ) -> torch.Tensor: - return torch.tensor(name in self.names, device=device) + return torch.tensor(name is None or name in self.names, device=device) class NMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module): + def __init__( self, model: torch.nn.Module, @@ -218,9 +198,7 @@ def __init__( self.num_samples = num_samples self.max_plate_nesting = max_plate_nesting - def forward( - self, data: Point[T], *args: P.args, **kwargs: P.kwargs - ) -> torch.Tensor: + def forward(self, data: Point[T], *args: P.args, **kwargs: P.kwargs) -> torch.Tensor: masked_guide = pyro.poutine.mask(mask=False)(self.guide) masked_model = UnmaskNamedSites(names=set(data.keys()))( condition(data=data)(self.model) @@ -246,10 +224,11 @@ def linearize( cg_iters: Optional[int] = None, cg_tol: float = 1e-10, ) -> Callable[Concatenate[Point[T], P], ParamDict]: + assert isinstance(model, torch.nn.Module) assert isinstance(guide, torch.nn.Module) if num_samples_inner is None: - num_samples_inner = num_samples_outer**2 + num_samples_inner = num_samples_outer ** 2 predictive = pyro.infer.Predictive( model, @@ -259,22 +238,21 @@ def linearize( ) log_prob = NMCLogPredictiveLikelihood( - model, guide, num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting + model, + guide, + num_samples=num_samples_inner, + max_plate_nesting=max_plate_nesting ) log_prob_params, func_log_prob = make_functional_call(log_prob) score_fn = torch.func.grad(func_log_prob) - cg_solver = functools.partial( - conjugate_gradient_solve, cg_iters=cg_iters, cg_tol=cg_tol - ) + cg_solver = functools.partial(conjugate_gradient_solve, cg_iters=cg_iters, cg_tol=cg_tol) @functools.wraps(score_fn) def _fn(point: Point[T], *args: P.args, **kwargs: P.kwargs) -> ParamDict: with torch.no_grad(): data: Point[T] = predictive(*args, **kwargs) - fvp = make_empirical_fisher_vp( - func_log_prob, log_prob_params, data, *args, **kwargs - ) + fvp = make_empirical_fisher_vp(func_log_prob, log_prob_params, data, *args, **kwargs) point_score: ParamDict = score_fn(log_prob_params, point, *args, **kwargs) return cg_solver(fvp, point_score) From 94c2fc6cafeddce0ffa165044fa791d713875ca0 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 21 Nov 2023 17:32:47 -0500 Subject: [PATCH 17/66] clean up influence and tests --- chirho/robust/internals.py | 63 +++++++++--------- chirho/robust/ops.py | 1 + tests/robust/test_autograd.py | 120 +++++++++++++++++----------------- 3 files changed, 93 insertions(+), 91 deletions(-) diff --git a/chirho/robust/internals.py b/chirho/robust/internals.py index 2cae5d515..9a3150ff4 100644 --- a/chirho/robust/internals.py +++ b/chirho/robust/internals.py @@ -14,9 +14,10 @@ import pyro import torch + from chirho.indexed.handlers import DependentMaskMessenger from chirho.observational.handlers import condition -from chirho.robust.ops import Model, Point, ParamDict +from chirho.robust.ops import Model, ParamDict, Point pyro.settings.set(module_local_params=True) @@ -93,16 +94,14 @@ def _flat_conjugate_gradient_solve( newrdotr = rdotr mu = newrdotr / rdotr - zeros_x = torch.zeros_like(x) - zeros_r = torch.zeros_like(r) + zeros_xr = torch.zeros_like(x) for _ in range(cg_iters): not_converged = rdotr > residual_tol - z = torch.where(not_converged, f_Ax(p), z) v = torch.where(not_converged, rdotr / torch.dot(p, z), v) - x += torch.where(not_converged, v * p, zeros_x) - r -= torch.where(not_converged, v * z, zeros_r) + x += torch.where(not_converged, v * p, zeros_xr) + r -= torch.where(not_converged, v * z, zeros_xr) newrdotr = torch.where(not_converged, torch.dot(r, r), newrdotr) mu = torch.where(not_converged, newrdotr / rdotr, mu) p = torch.where(not_converged, r + mu * p, p) @@ -130,11 +129,13 @@ def make_functional_call( ) -> Tuple[ParamDict, Callable[Concatenate[ParamDict, P], T]]: assert isinstance(mod, torch.nn.Module) param_dict: ParamDict = dict(mod.named_parameters()) - mod_func: Callable[Concatenate[ParamDict, P], T] = functools.partial(torch.func.functional_call, mod) - functionalized_mod_func: Callable[Concatenate[ParamDict, P], T] = torch.func.functionalize( - pyro.validation_enabled(False)(mod_func) - ) - return param_dict, functionalized_mod_func + + @torch.func.functionalize + def mod_func(params: ParamDict, *args: P.args, **kwargs: P.kwargs) -> T: + with pyro.validation_enabled(False): + return torch.func.functional_call(mod, params, args, dict(**kwargs)) + + return param_dict, mod_func def make_empirical_fisher_vp( @@ -142,21 +143,20 @@ def make_empirical_fisher_vp( log_prob_params: ParamDict, data: Point[T], *args: P.args, - **kwargs: P.kwargs + **kwargs: P.kwargs, ) -> Callable[[ParamDict], ParamDict]: - - batched_func_log_prob: Callable[ - [ParamDict, Point[T]], torch.Tensor - ] = torch.vmap( + batched_func_log_prob: Callable[[ParamDict, Point[T]], torch.Tensor] = torch.vmap( lambda p, data: func_log_prob(p, data, *args, **kwargs), in_dims=(None, 0), - randomness="different" + randomness="different", ) def bound_batched_func_log_prob(params: ParamDict) -> torch.Tensor: return batched_func_log_prob(params, data) - jvp_fn = functools.partial(torch.func.jvp, bound_batched_func_log_prob, (log_prob_params,)) + jvp_fn = functools.partial( + torch.func.jvp, bound_batched_func_log_prob, (log_prob_params,) + ) vjp_fn = torch.func.vjp(bound_batched_func_log_prob, log_prob_params)[1] def _empirical_fisher_vp(v: ParamDict) -> ParamDict: @@ -183,7 +183,6 @@ def get_mask( class NMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module): - def __init__( self, model: torch.nn.Module, @@ -198,7 +197,9 @@ def __init__( self.num_samples = num_samples self.max_plate_nesting = max_plate_nesting - def forward(self, data: Point[T], *args: P.args, **kwargs: P.kwargs) -> torch.Tensor: + def forward( + self, data: Point[T], *args: P.args, **kwargs: P.kwargs + ) -> torch.Tensor: masked_guide = pyro.poutine.mask(mask=False)(self.guide) masked_model = UnmaskNamedSites(names=set(data.keys()))( condition(data=data)(self.model) @@ -222,13 +223,12 @@ def linearize( num_samples_outer: int, num_samples_inner: Optional[int] = None, cg_iters: Optional[int] = None, - cg_tol: float = 1e-10, + residual_tol: float = 1e-10, ) -> Callable[Concatenate[Point[T], P], ParamDict]: - assert isinstance(model, torch.nn.Module) assert isinstance(guide, torch.nn.Module) if num_samples_inner is None: - num_samples_inner = num_samples_outer ** 2 + num_samples_inner = num_samples_outer**2 predictive = pyro.infer.Predictive( model, @@ -236,23 +236,26 @@ def linearize( num_samples=num_samples_outer, parallel=True, ) + predictive_params, func_predictive = make_functional_call(predictive) log_prob = NMCLogPredictiveLikelihood( - model, - guide, - num_samples=num_samples_inner, - max_plate_nesting=max_plate_nesting + model, guide, num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting ) log_prob_params, func_log_prob = make_functional_call(log_prob) score_fn = torch.func.grad(func_log_prob) - cg_solver = functools.partial(conjugate_gradient_solve, cg_iters=cg_iters, cg_tol=cg_tol) + cg_solver = functools.partial( + conjugate_gradient_solve, cg_iters=cg_iters, residual_tol=residual_tol + ) @functools.wraps(score_fn) def _fn(point: Point[T], *args: P.args, **kwargs: P.kwargs) -> ParamDict: with torch.no_grad(): - data: Point[T] = predictive(*args, **kwargs) - fvp = make_empirical_fisher_vp(func_log_prob, log_prob_params, data, *args, **kwargs) + data: Point[T] = func_predictive(predictive_params, *args, **kwargs) + data = {k: data[k] for k in point.keys()} + fvp = make_empirical_fisher_vp( + func_log_prob, log_prob_params, data, *args, **kwargs + ) point_score: ParamDict = score_fn(log_prob_params, point, *args, **kwargs) return cg_solver(fvp, point_score) diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py index ce219666b..2cfa48dc4 100644 --- a/chirho/robust/ops.py +++ b/chirho/robust/ops.py @@ -30,6 +30,7 @@ def influence_fn( return linearized target = functional(model, guide) + # TODO check that target_params == model_params | guide_params assert isinstance(target, torch.nn.Module) target_params, func_target = make_functional_call(target) diff --git a/tests/robust/test_autograd.py b/tests/robust/test_autograd.py index 808ae101a..9d43c86ce 100644 --- a/tests/robust/test_autograd.py +++ b/tests/robust/test_autograd.py @@ -1,78 +1,76 @@ -from typing import Mapping +from typing import ParamSpec, TypeVar import pyro import pyro.distributions as dist +import pytest import torch -from chirho.observational.handlers import condition -from chirho.robust.internals import ( - NMCLogPredictiveLikelihood, - Point, - make_empirical_inverse_fisher_vp, - make_flatten_unflatten, - make_functional_call, -) +from chirho.robust.ops import influence_fn pyro.settings.set(module_local_params=True) +P = ParamSpec("P") +Q = ParamSpec("Q") +S = TypeVar("S") +T = TypeVar("T") -def test_nmc_log_likelihood(): - # Create simple pyro model - class SimpleModel(pyro.nn.PyroModule): - def forward(self): - a = pyro.sample("a", dist.Normal(0, 1)) - with pyro.plate("data", 3, dim=-1): - b = pyro.sample("b", dist.Normal(0, 1)) - return pyro.sample("y", dist.Normal(a + b, 1)) - - class SimpleGuide(torch.nn.Module): - def __init__(self): - super().__init__() - self.loc_a = torch.nn.Parameter(torch.rand(())) - self.loc_b = torch.nn.Parameter(torch.rand(())) - - def forward(self): - a = pyro.sample("a", dist.Normal(self.loc_a, 1.0)) - with pyro.plate("data", 3, dim=-1): - b = pyro.sample("b", dist.Normal(self.loc_b, 1.0)) - return {"a": a, "b": b} - - model = SimpleModel() - guide = SimpleGuide() - - # Create guide on latents a and b - num_samples_outer = 10000 - data = pyro.infer.Predictive( + +class SimpleModel(pyro.nn.PyroModule): + def forward(self): + a = pyro.sample("a", dist.Normal(0, 1)) + with pyro.plate("data", 3, dim=-1): + b = pyro.sample("b", dist.Normal(a, 1)) + return pyro.sample("y", dist.Normal(a + b, 1)) + + +class SimpleGuide(torch.nn.Module): + def __init__(self): + super().__init__() + self.loc_a = torch.nn.Parameter(torch.rand(())) + self.loc_b = torch.nn.Parameter(torch.rand((3,))) + + def forward(self): + a = pyro.sample("a", dist.Normal(self.loc_a, 1.0)) + with pyro.plate("data", 3, dim=-1): + b = pyro.sample("b", dist.Normal(self.loc_b, 1.0)) + return {"a": a, "b": b} + + +@pytest.mark.parametrize("model,guide", [(SimpleModel(), SimpleGuide())]) +def test_nmc_influence_smoke(model, guide): + num_samples_outer = 100 + param_eif = influence_fn( model, - guide=guide, - num_samples=num_samples_outer, - return_sites=["y"], - parallel=True, - )() - - # Create log likelihood function - log_prob = NMCLogPredictiveLikelihood( - model, guide, num_samples=1, max_plate_nesting=1 + guide, + max_plate_nesting=1, + num_samples_outer=num_samples_outer, ) - v = {k: torch.ones_like(v) for k, v in log_prob.named_parameters()} + with torch.no_grad(): + test_datum = { + k: v[0] + for k, v in pyro.infer.Predictive( + model, num_samples=2, return_sites=["y"], parallel=True + )().items() + } - # fvp = make_empirical_fisher_vp(log_prob, data) - # print(v, fvp(v)) + print(test_datum, param_eif(test_datum)) - flatten_v, unflatten_v = make_flatten_unflatten(v) - assert unflatten_v(flatten_v(v)) == v - fivp = make_empirical_inverse_fisher_vp(log_prob, data, cg_iters=10) - print(v, fivp(v)) - d2 = pyro.infer.Predictive( - model, num_samples=30, return_sites=["y"], parallel=True - )() - log_prob_params, func_log_prob = make_functional_call(log_prob) +@pytest.mark.parametrize("model,guide", [(SimpleModel(), SimpleGuide())]) +def test_nmc_influence_vmap_smoke(model, guide): + num_samples_outer = 100 + param_eif = influence_fn( + model, + guide, + max_plate_nesting=1, + num_samples_outer=num_samples_outer, + ) - def eif(d: Point[torch.Tensor]) -> Mapping[str, torch.Tensor]: - return fivp( - torch.func.grad(lambda params: func_log_prob(params, d))(log_prob_params) - ) + with torch.no_grad(): + test_data = pyro.infer.Predictive( + model, num_samples=4, return_sites=["y"], parallel=True + )() - print(torch.vmap(eif)(d2)) + batch_param_eif = torch.vmap(param_eif, randomness="different") + print(test_data, batch_param_eif(test_data)) From da0bc5ce69e99d9d1e03af9815e21a27ecdd3294 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 22 Nov 2023 09:53:45 -0500 Subject: [PATCH 18/66] make tests more generic --- tests/robust/test_autograd.py | 67 +++++++++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 15 deletions(-) diff --git a/tests/robust/test_autograd.py b/tests/robust/test_autograd.py index 9d43c86ce..ee53ef866 100644 --- a/tests/robust/test_autograd.py +++ b/tests/robust/test_autograd.py @@ -1,4 +1,4 @@ -from typing import ParamSpec, TypeVar +from typing import Callable, List, Mapping, Optional, ParamSpec, Set, Tuple, TypeVar import pyro import pyro.distributions as dist @@ -36,41 +36,78 @@ def forward(self): return {"a": a, "b": b} -@pytest.mark.parametrize("model,guide", [(SimpleModel(), SimpleGuide())]) -def test_nmc_influence_smoke(model, guide): - num_samples_outer = 100 +TEST_CASES: List[Tuple[Callable, Callable, Set[str], Optional[int]]] = [ + (SimpleModel(), SimpleGuide(), {"y"}, 1), +] + + +@pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", TEST_CASES) +@pytest.mark.parametrize( + "num_samples_outer,num_samples_inner", [(100, None), (10, 100)] +) +@pytest.mark.parametrize("cg_iters", [None, 1, 10]) +def test_nmc_param_influence_smoke( + model, + guide, + obs_names, + max_plate_nesting, + num_samples_outer, + num_samples_inner, + cg_iters, +): param_eif = influence_fn( model, guide, - max_plate_nesting=1, + max_plate_nesting=max_plate_nesting, num_samples_outer=num_samples_outer, + num_samples_inner=num_samples_inner, + cg_iters=cg_iters, ) with torch.no_grad(): test_datum = { k: v[0] for k, v in pyro.infer.Predictive( - model, num_samples=2, return_sites=["y"], parallel=True + model, num_samples=2, return_sites=obs_names, parallel=True )().items() } - print(test_datum, param_eif(test_datum)) - - -@pytest.mark.parametrize("model,guide", [(SimpleModel(), SimpleGuide())]) -def test_nmc_influence_vmap_smoke(model, guide): - num_samples_outer = 100 + test_datum_eif: Mapping[str, torch.Tensor] = param_eif(test_datum) + for k, v in test_datum_eif.items(): + assert not torch.isnan(v).any() + assert not torch.isinf(v).any() + + +@pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", TEST_CASES) +@pytest.mark.parametrize( + "num_samples_outer,num_samples_inner", [(100, None), (10, 100)] +) +@pytest.mark.parametrize("cg_iters", [None, 1, 10]) +def test_nmc_param_influence_vmap_smoke( + model, + guide, + obs_names, + max_plate_nesting, + num_samples_outer, + num_samples_inner, + cg_iters, +): param_eif = influence_fn( model, guide, - max_plate_nesting=1, + max_plate_nesting=max_plate_nesting, num_samples_outer=num_samples_outer, + num_samples_inner=num_samples_inner, + cg_iters=cg_iters, ) with torch.no_grad(): test_data = pyro.infer.Predictive( - model, num_samples=4, return_sites=["y"], parallel=True + model, num_samples=4, return_sites=obs_names, parallel=True )() batch_param_eif = torch.vmap(param_eif, randomness="different") - print(test_data, batch_param_eif(test_data)) + test_data_eif: Mapping[str, torch.Tensor] = batch_param_eif(test_data) + for k, v in test_data_eif.items(): + assert not torch.isnan(v).any() + assert not torch.isinf(v).any() From 4d027e478c69ff789ddec0f5d6e32e73b6573b99 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 22 Nov 2023 10:44:44 -0500 Subject: [PATCH 19/66] guess max plate nesting --- chirho/robust/internals.py | 20 ++++++++++++++++++-- tests/robust/test_autograd.py | 28 ++++++++++++++++++++++------ 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/chirho/robust/internals.py b/chirho/robust/internals.py index 9a3150ff4..54502169a 100644 --- a/chirho/robust/internals.py +++ b/chirho/robust/internals.py @@ -182,6 +182,17 @@ def get_mask( return torch.tensor(name is None or name in self.names, device=device) +@pyro.poutine.block() +@pyro.validation_enabled(False) +@torch.no_grad() +def _guess_max_plate_nesting( + model: Model[P], guide: Model[P], *args: P.args, **kwargs: P.kwargs +) -> int: + elbo = pyro.infer.Trace_ELBO() + elbo._guess_max_plate_nesting(model, guide, args, kwargs) + return elbo.max_plate_nesting + + class NMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module): def __init__( self, @@ -189,7 +200,7 @@ def __init__( guide: torch.nn.Module, *, num_samples: int = 1, - max_plate_nesting: int = 1, + max_plate_nesting: Optional[int] = None, ): super().__init__() self.model = model @@ -200,6 +211,11 @@ def __init__( def forward( self, data: Point[T], *args: P.args, **kwargs: P.kwargs ) -> torch.Tensor: + if self.max_plate_nesting is None: + self.max_plate_nesting = _guess_max_plate_nesting( + self.model, self.guide, *args, **kwargs + ) + masked_guide = pyro.poutine.mask(mask=False)(self.guide) masked_model = UnmaskNamedSites(names=set(data.keys()))( condition(data=data)(self.model) @@ -219,9 +235,9 @@ def linearize( model: Model[P], guide: Model[P], *, - max_plate_nesting: int, num_samples_outer: int, num_samples_inner: Optional[int] = None, + max_plate_nesting: Optional[int] = None, cg_iters: Optional[int] = None, residual_tol: float = 1e-10, ) -> Callable[Concatenate[Point[T], P], ParamDict]: diff --git a/tests/robust/test_autograd.py b/tests/robust/test_autograd.py index ee53ef866..9091c872f 100644 --- a/tests/robust/test_autograd.py +++ b/tests/robust/test_autograd.py @@ -30,14 +30,24 @@ def __init__(self): self.loc_b = torch.nn.Parameter(torch.rand((3,))) def forward(self): - a = pyro.sample("a", dist.Normal(self.loc_a, 1.0)) + a = pyro.sample("a", dist.Normal(self.loc_a, 1)) with pyro.plate("data", 3, dim=-1): - b = pyro.sample("b", dist.Normal(self.loc_b, 1.0)) + b = pyro.sample("b", dist.Normal(self.loc_b, 1)) return {"a": a, "b": b} TEST_CASES: List[Tuple[Callable, Callable, Set[str], Optional[int]]] = [ (SimpleModel(), SimpleGuide(), {"y"}, 1), + (SimpleModel(), SimpleGuide(), {"y"}, None), + pytest.param( + (m := SimpleModel()), + pyro.infer.autoguide.AutoNormal(m), + {"y"}, + 1, + marks=pytest.mark.xfail( + reason="torch.func autograd doesnt work with PyroParam" + ), + ), ] @@ -55,6 +65,8 @@ def test_nmc_param_influence_smoke( num_samples_inner, cg_iters, ): + model(), guide() # initialize + param_eif = influence_fn( model, guide, @@ -74,8 +86,9 @@ def test_nmc_param_influence_smoke( test_datum_eif: Mapping[str, torch.Tensor] = param_eif(test_datum) for k, v in test_datum_eif.items(): - assert not torch.isnan(v).any() - assert not torch.isinf(v).any() + assert not torch.isnan(v).any(), f"eif for {k} had nans" + assert not torch.isinf(v).any(), f"eif for {k} had infs" + assert not torch.isclose(v, torch.zeros_like(v)).all(), f"eif for {k} was zero" @pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", TEST_CASES) @@ -92,6 +105,8 @@ def test_nmc_param_influence_vmap_smoke( num_samples_inner, cg_iters, ): + model(), guide() # initialize + param_eif = influence_fn( model, guide, @@ -109,5 +124,6 @@ def test_nmc_param_influence_vmap_smoke( batch_param_eif = torch.vmap(param_eif, randomness="different") test_data_eif: Mapping[str, torch.Tensor] = batch_param_eif(test_data) for k, v in test_data_eif.items(): - assert not torch.isnan(v).any() - assert not torch.isinf(v).any() + assert not torch.isnan(v).any(), f"eif for {k} had nans" + assert not torch.isinf(v).any(), f"eif for {k} had infs" + assert not torch.isclose(v, torch.zeros_like(v)).all(), f"eif for {k} was zero" From e85e33f652fee75428aeeccc2d60b4ff7807322e Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 22 Nov 2023 10:53:09 -0500 Subject: [PATCH 20/66] linearize --- tests/robust/test_autograd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/robust/test_autograd.py b/tests/robust/test_autograd.py index 9091c872f..08d562725 100644 --- a/tests/robust/test_autograd.py +++ b/tests/robust/test_autograd.py @@ -5,7 +5,7 @@ import pytest import torch -from chirho.robust.ops import influence_fn +from chirho.robust.internals import linearize pyro.settings.set(module_local_params=True) @@ -67,7 +67,7 @@ def test_nmc_param_influence_smoke( ): model(), guide() # initialize - param_eif = influence_fn( + param_eif = linearize( model, guide, max_plate_nesting=max_plate_nesting, @@ -107,7 +107,7 @@ def test_nmc_param_influence_vmap_smoke( ): model(), guide() # initialize - param_eif = influence_fn( + param_eif = linearize( model, guide, max_plate_nesting=max_plate_nesting, From 1734191633a44f251937e5380e56d739818fc5c5 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 22 Nov 2023 10:53:56 -0500 Subject: [PATCH 21/66] rename file --- tests/robust/{test_autograd.py => test_internals_linearize.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/robust/{test_autograd.py => test_internals_linearize.py} (100%) diff --git a/tests/robust/test_autograd.py b/tests/robust/test_internals_linearize.py similarity index 100% rename from tests/robust/test_autograd.py rename to tests/robust/test_internals_linearize.py From f46556bc41a16894910f88ec3c52c3e76f494373 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 22 Nov 2023 11:04:31 -0500 Subject: [PATCH 22/66] tensor flatten --- chirho/robust/internals.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/chirho/robust/internals.py b/chirho/robust/internals.py index 54502169a..cb980e214 100644 --- a/chirho/robust/internals.py +++ b/chirho/robust/internals.py @@ -34,6 +34,23 @@ def make_flatten_unflatten( raise NotImplementedError +@make_flatten_unflatten.register(torch.Tensor) +def _make_flatten_unflatten_tensor(v: torch.Tensor): + def flatten(v: torch.Tensor) -> torch.Tensor: + r""" + Flatten a tensor into a single vector. + """ + return v.flatten() + + def unflatten(x: torch.Tensor) -> torch.Tensor: + r""" + Unflatten a vector into a tensor. + """ + return x.reshape(v.shape) + + return flatten, unflatten + + @make_flatten_unflatten.register(dict) def _make_flatten_unflatten_dict(d: Dict[str, torch.Tensor]): def flatten(d: Dict[str, torch.Tensor]) -> torch.Tensor: From 1abc5e0b9bcee4c094b352c7db68ceb869ee0fca Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 22 Nov 2023 13:47:41 -0500 Subject: [PATCH 23/66] predictive eif --- chirho/robust/internals.py | 71 +++++++++++++++ chirho/robust/ops.py | 7 +- tests/robust/test_internals_linearize.py | 105 ++++++++++++++++++++++- 3 files changed, 180 insertions(+), 3 deletions(-) diff --git a/chirho/robust/internals.py b/chirho/robust/internals.py index cb980e214..b21a8ec55 100644 --- a/chirho/robust/internals.py +++ b/chirho/robust/internals.py @@ -1,11 +1,14 @@ +import contextlib import functools import math from typing import ( + Any, Callable, Concatenate, Container, Dict, Generic, + Mapping, Optional, ParamSpec, Tuple, @@ -211,6 +214,11 @@ def _guess_max_plate_nesting( class NMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module): + model: Model[P] + guide: Model[P] + num_samples: int + max_plate_nesting: Optional[int] + def __init__( self, model: torch.nn.Module, @@ -248,6 +256,69 @@ def forward( return torch.logsumexp(log_weights, dim=0) - math.log(self.num_samples) +class PredictiveFunctional(Generic[P, T], torch.nn.Module): + model: Model[P] + guide: Model[P] + num_samples: int + max_plate_nesting: Optional[int] + + def __init__( + self, + model: Model[P], + guide: Model[P], + *, + num_samples: int = 1, + max_plate_nesting: Optional[int] = None, + ): + super().__init__() + assert isinstance(model, torch.nn.Module) + assert isinstance(guide, torch.nn.Module) + self.model = model + self.guide = guide + self.num_samples = num_samples + self.max_plate_nesting = max_plate_nesting + + def forward(self, *args: P.args, **kwargs: P.kwargs) -> Point[T]: + if self.max_plate_nesting is None: + self.max_plate_nesting = _guess_max_plate_nesting( + self.model, self.guide, *args, **kwargs + ) + + particles_plate = ( + contextlib.nullcontext() + if self.num_samples == 1 + else pyro.plate( + "__predictive_particles", + self.num_samples, + dim=-self.max_plate_nesting - 1, + ) + ) + + with pyro.poutine.trace() as guide_tr, particles_plate: + self.guide(*args, **kwargs) + + block_guide_sample_sites = pyro.poutine.block( + hide=[ + name + for name, node in guide_tr.trace.nodes.items() + if node["type"] == "sample" + and not pyro.poutine.util.site_is_subsample(node) + ] + ) + + with pyro.poutine.trace() as model_tr: + with block_guide_sample_sites: + with pyro.poutine.replay(trace=guide_tr.trace), particles_plate: + self.model(*args, **kwargs) + + return { + name: node["value"] + for name, node in model_tr.trace.nodes.items() + if node["type"] == "sample" + and not pyro.poutine.util.site_is_subsample(node) + } + + def linearize( model: Model[P], guide: Model[P], diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py index 2cfa48dc4..562ae8b96 100644 --- a/chirho/robust/ops.py +++ b/chirho/robust/ops.py @@ -27,9 +27,12 @@ def influence_fn( linearized = linearize(model, guide, **linearize_kwargs) if functional is None: - return linearized + from chirho.robust.internals import PredictiveFunctional + + target = PredictiveFunctional(model, guide) + else: + target = functional(model, guide) - target = functional(model, guide) # TODO check that target_params == model_params | guide_params assert isinstance(target, torch.nn.Module) target_params, func_target = make_functional_call(target) diff --git a/tests/robust/test_internals_linearize.py b/tests/robust/test_internals_linearize.py index 08d562725..9347b293f 100644 --- a/tests/robust/test_internals_linearize.py +++ b/tests/robust/test_internals_linearize.py @@ -1,3 +1,4 @@ +import functools from typing import Callable, List, Mapping, Optional, ParamSpec, Set, Tuple, TypeVar import pyro @@ -5,7 +6,8 @@ import pytest import torch -from chirho.robust.internals import linearize +from chirho.robust.internals import PredictiveFunctional, linearize +from chirho.robust.ops import influence_fn pyro.settings.set(module_local_params=True) @@ -48,6 +50,15 @@ def forward(self): reason="torch.func autograd doesnt work with PyroParam" ), ), + pytest.param( + (m := SimpleModel()), + pyro.infer.autoguide.AutoDelta(m), + {"y"}, + 1, + marks=pytest.mark.xfail( + reason="torch.func autograd doesnt work with PyroParam" + ), + ), ] @@ -85,6 +96,7 @@ def test_nmc_param_influence_smoke( } test_datum_eif: Mapping[str, torch.Tensor] = param_eif(test_datum) + assert len(test_datum_eif) > 0 for k, v in test_datum_eif.items(): assert not torch.isnan(v).any(), f"eif for {k} had nans" assert not torch.isinf(v).any(), f"eif for {k} had infs" @@ -123,6 +135,97 @@ def test_nmc_param_influence_vmap_smoke( batch_param_eif = torch.vmap(param_eif, randomness="different") test_data_eif: Mapping[str, torch.Tensor] = batch_param_eif(test_data) + assert len(test_data_eif) > 0 + for k, v in test_data_eif.items(): + assert not torch.isnan(v).any(), f"eif for {k} had nans" + assert not torch.isinf(v).any(), f"eif for {k} had infs" + assert not torch.isclose(v, torch.zeros_like(v)).all(), f"eif for {k} was zero" + + +@pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", TEST_CASES) +@pytest.mark.parametrize( + "num_samples_outer,num_samples_inner", [(100, None), (10, 100)] +) +@pytest.mark.parametrize("cg_iters", [None, 1, 10]) +@pytest.mark.parametrize("num_predictive_samples", [1, 5]) +def test_nmc_predictive_influence_smoke( + model, + guide, + obs_names, + max_plate_nesting, + num_samples_outer, + num_samples_inner, + cg_iters, + num_predictive_samples, +): + model(), guide() # initialize + + predictive_eif = influence_fn( + model, + guide, + functional=functools.partial( + PredictiveFunctional, num_samples=num_predictive_samples + ), + max_plate_nesting=max_plate_nesting, + num_samples_outer=num_samples_outer, + num_samples_inner=num_samples_inner, + cg_iters=cg_iters, + ) + + with torch.no_grad(): + test_datum = { + k: v[0] + for k, v in pyro.infer.Predictive( + model, num_samples=2, return_sites=obs_names, parallel=True + )().items() + } + + test_datum_eif: Mapping[str, torch.Tensor] = predictive_eif(test_datum) + assert len(test_datum_eif) > 0 + for k, v in test_datum_eif.items(): + assert not torch.isnan(v).any(), f"eif for {k} had nans" + assert not torch.isinf(v).any(), f"eif for {k} had infs" + assert not torch.isclose(v, torch.zeros_like(v)).all(), f"eif for {k} was zero" + + +@pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", TEST_CASES) +@pytest.mark.parametrize( + "num_samples_outer,num_samples_inner", [(100, None), (10, 100)] +) +@pytest.mark.parametrize("cg_iters", [None, 1, 10]) +@pytest.mark.parametrize("num_predictive_samples", [1, 5]) +def test_nmc_predictive_influence_vmap_smoke( + model, + guide, + obs_names, + max_plate_nesting, + num_samples_outer, + num_samples_inner, + cg_iters, + num_predictive_samples, +): + model(), guide() # initialize + + predictive_eif = influence_fn( + model, + guide, + functional=functools.partial( + PredictiveFunctional, num_samples=num_predictive_samples + ), + max_plate_nesting=max_plate_nesting, + num_samples_outer=num_samples_outer, + num_samples_inner=num_samples_inner, + cg_iters=cg_iters, + ) + + with torch.no_grad(): + test_data = pyro.infer.Predictive( + model, num_samples=4, return_sites=obs_names, parallel=True + )() + + batch_predictive_eif = torch.vmap(predictive_eif, randomness="different") + test_data_eif: Mapping[str, torch.Tensor] = batch_predictive_eif(test_data) + assert len(test_data_eif) > 0 for k, v in test_data_eif.items(): assert not torch.isnan(v).any(), f"eif for {k} had nans" assert not torch.isinf(v).any(), f"eif for {k} had infs" From 9c80b602e767f2870755ebf63787db8ecabdd03c Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 22 Nov 2023 14:11:43 -0500 Subject: [PATCH 24/66] jvp type --- chirho/robust/internals.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/chirho/robust/internals.py b/chirho/robust/internals.py index b21a8ec55..376385d4b 100644 --- a/chirho/robust/internals.py +++ b/chirho/robust/internals.py @@ -174,13 +174,13 @@ def make_empirical_fisher_vp( def bound_batched_func_log_prob(params: ParamDict) -> torch.Tensor: return batched_func_log_prob(params, data) - jvp_fn = functools.partial( - torch.func.jvp, bound_batched_func_log_prob, (log_prob_params,) - ) + def jvp_fn(v: ParamDict) -> torch.Tensor: + return torch.func.jvp(bound_batched_func_log_prob, (log_prob_params,), (v,))[1] + vjp_fn = torch.func.vjp(bound_batched_func_log_prob, log_prob_params)[1] def _empirical_fisher_vp(v: ParamDict) -> ParamDict: - jvp_log_prob_v = jvp_fn((v,))[1] + jvp_log_prob_v = jvp_fn(v) return vjp_fn(jvp_log_prob_v / jvp_log_prob_v.shape[0])[0] return _empirical_fisher_vp From 931da4f6e09a79b85ef44190f0ccda65e960c402 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 22 Nov 2023 16:21:47 -0500 Subject: [PATCH 25/66] reorganize files --- chirho/robust/internals.py | 366 ----------------------- chirho/robust/internals/__init__.py | 0 chirho/robust/internals/linearize.py | 91 ++++++ chirho/robust/internals/predictive.py | 150 ++++++++++ chirho/robust/internals/utils.py | 152 ++++++++++ chirho/robust/ops.py | 6 +- tests/robust/test_dice_correction.py | 2 +- tests/robust/test_internals_linearize.py | 3 +- 8 files changed, 399 insertions(+), 371 deletions(-) delete mode 100644 chirho/robust/internals.py create mode 100644 chirho/robust/internals/__init__.py create mode 100644 chirho/robust/internals/linearize.py create mode 100644 chirho/robust/internals/predictive.py create mode 100644 chirho/robust/internals/utils.py diff --git a/chirho/robust/internals.py b/chirho/robust/internals.py deleted file mode 100644 index 376385d4b..000000000 --- a/chirho/robust/internals.py +++ /dev/null @@ -1,366 +0,0 @@ -import contextlib -import functools -import math -from typing import ( - Any, - Callable, - Concatenate, - Container, - Dict, - Generic, - Mapping, - Optional, - ParamSpec, - Tuple, - TypeVar, -) - -import pyro -import torch - -from chirho.indexed.handlers import DependentMaskMessenger -from chirho.observational.handlers import condition -from chirho.robust.ops import Model, ParamDict, Point - -pyro.settings.set(module_local_params=True) - -P = ParamSpec("P") -Q = ParamSpec("Q") -S = TypeVar("S") -T = TypeVar("T") - - -@functools.singledispatch -def make_flatten_unflatten( - v, -) -> Tuple[Callable[[T], torch.Tensor], Callable[[torch.Tensor], T]]: - raise NotImplementedError - - -@make_flatten_unflatten.register(torch.Tensor) -def _make_flatten_unflatten_tensor(v: torch.Tensor): - def flatten(v: torch.Tensor) -> torch.Tensor: - r""" - Flatten a tensor into a single vector. - """ - return v.flatten() - - def unflatten(x: torch.Tensor) -> torch.Tensor: - r""" - Unflatten a vector into a tensor. - """ - return x.reshape(v.shape) - - return flatten, unflatten - - -@make_flatten_unflatten.register(dict) -def _make_flatten_unflatten_dict(d: Dict[str, torch.Tensor]): - def flatten(d: Dict[str, torch.Tensor]) -> torch.Tensor: - r""" - Flatten a dictionary of tensors into a single vector. - """ - return torch.cat([v.flatten() for k, v in d.items()]) - - def unflatten(x: torch.Tensor) -> Dict[str, torch.Tensor]: - r""" - Unflatten a vector into a dictionary of tensors. - """ - return dict( - zip( - d.keys(), - [ - v_flat.reshape(v.shape) - for v, v_flat in zip( - d.values(), torch.split(x, [v.numel() for k, v in d.items()]) - ) - ], - ) - ) - - return flatten, unflatten - - -def _flat_conjugate_gradient_solve( - f_Ax: Callable[[torch.Tensor], torch.Tensor], - b: torch.Tensor, - *, - cg_iters: Optional[int] = None, - residual_tol: float = 1e-10, -) -> torch.Tensor: - r"""Use Conjugate Gradient iteration to solve Ax = b. Demmel p 312. - - Args: - f_Ax (callable): A function to compute matrix vector product. - b (torch.Tensor): Right hand side of the equation to solve. - cg_iters (int): Number of iterations to run conjugate gradient - algorithm. - residual_tol (float): Tolerence for convergence. - - Returns: - torch.Tensor: Solution x* for equation Ax = b. - - Notes: This code is adapted from https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py - """ - if cg_iters is None: - cg_iters = b.numel() - - p = b.clone() - r = b.clone() - x = torch.zeros_like(b) - z = f_Ax(p) - rdotr = torch.dot(r, r) - v = rdotr / torch.dot(p, z) - newrdotr = rdotr - mu = newrdotr / rdotr - - zeros_xr = torch.zeros_like(x) - - for _ in range(cg_iters): - not_converged = rdotr > residual_tol - z = torch.where(not_converged, f_Ax(p), z) - v = torch.where(not_converged, rdotr / torch.dot(p, z), v) - x += torch.where(not_converged, v * p, zeros_xr) - r -= torch.where(not_converged, v * z, zeros_xr) - newrdotr = torch.where(not_converged, torch.dot(r, r), newrdotr) - mu = torch.where(not_converged, newrdotr / rdotr, mu) - p = torch.where(not_converged, r + mu * p, p) - rdotr = torch.where(not_converged, newrdotr, rdotr) - - # rdotr = newrdotr - # if rdotr < residual_tol: - # break - return x - - -def conjugate_gradient_solve(f_Ax: Callable[[T], T], b: T, **kwargs) -> T: - flatten, unflatten = make_flatten_unflatten(b) - - def f_Ax_flat(v: torch.Tensor) -> torch.Tensor: - v_unflattened: T = unflatten(v) - result_unflattened = f_Ax(v_unflattened) - return flatten(result_unflattened) - - return unflatten(_flat_conjugate_gradient_solve(f_Ax_flat, flatten(b), **kwargs)) - - -def make_functional_call( - mod: Callable[P, T] -) -> Tuple[ParamDict, Callable[Concatenate[ParamDict, P], T]]: - assert isinstance(mod, torch.nn.Module) - param_dict: ParamDict = dict(mod.named_parameters()) - - @torch.func.functionalize - def mod_func(params: ParamDict, *args: P.args, **kwargs: P.kwargs) -> T: - with pyro.validation_enabled(False): - return torch.func.functional_call(mod, params, args, dict(**kwargs)) - - return param_dict, mod_func - - -def make_empirical_fisher_vp( - func_log_prob: Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor], - log_prob_params: ParamDict, - data: Point[T], - *args: P.args, - **kwargs: P.kwargs, -) -> Callable[[ParamDict], ParamDict]: - batched_func_log_prob: Callable[[ParamDict, Point[T]], torch.Tensor] = torch.vmap( - lambda p, data: func_log_prob(p, data, *args, **kwargs), - in_dims=(None, 0), - randomness="different", - ) - - def bound_batched_func_log_prob(params: ParamDict) -> torch.Tensor: - return batched_func_log_prob(params, data) - - def jvp_fn(v: ParamDict) -> torch.Tensor: - return torch.func.jvp(bound_batched_func_log_prob, (log_prob_params,), (v,))[1] - - vjp_fn = torch.func.vjp(bound_batched_func_log_prob, log_prob_params)[1] - - def _empirical_fisher_vp(v: ParamDict) -> ParamDict: - jvp_log_prob_v = jvp_fn(v) - return vjp_fn(jvp_log_prob_v / jvp_log_prob_v.shape[0])[0] - - return _empirical_fisher_vp - - -class UnmaskNamedSites(DependentMaskMessenger): - names: Container[str] - - def __init__(self, names: Container[str]): - self.names = names - - def get_mask( - self, - dist: pyro.distributions.Distribution, - value: Optional[torch.Tensor], - device: torch.device = torch.device("cpu"), - name: Optional[str] = None, - ) -> torch.Tensor: - return torch.tensor(name is None or name in self.names, device=device) - - -@pyro.poutine.block() -@pyro.validation_enabled(False) -@torch.no_grad() -def _guess_max_plate_nesting( - model: Model[P], guide: Model[P], *args: P.args, **kwargs: P.kwargs -) -> int: - elbo = pyro.infer.Trace_ELBO() - elbo._guess_max_plate_nesting(model, guide, args, kwargs) - return elbo.max_plate_nesting - - -class NMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module): - model: Model[P] - guide: Model[P] - num_samples: int - max_plate_nesting: Optional[int] - - def __init__( - self, - model: torch.nn.Module, - guide: torch.nn.Module, - *, - num_samples: int = 1, - max_plate_nesting: Optional[int] = None, - ): - super().__init__() - self.model = model - self.guide = guide - self.num_samples = num_samples - self.max_plate_nesting = max_plate_nesting - - def forward( - self, data: Point[T], *args: P.args, **kwargs: P.kwargs - ) -> torch.Tensor: - if self.max_plate_nesting is None: - self.max_plate_nesting = _guess_max_plate_nesting( - self.model, self.guide, *args, **kwargs - ) - - masked_guide = pyro.poutine.mask(mask=False)(self.guide) - masked_model = UnmaskNamedSites(names=set(data.keys()))( - condition(data=data)(self.model) - ) - log_weights = pyro.infer.importance.vectorized_importance_weights( - masked_model, - masked_guide, - *args, - num_samples=self.num_samples, - max_plate_nesting=self.max_plate_nesting, - **kwargs, - )[0] - return torch.logsumexp(log_weights, dim=0) - math.log(self.num_samples) - - -class PredictiveFunctional(Generic[P, T], torch.nn.Module): - model: Model[P] - guide: Model[P] - num_samples: int - max_plate_nesting: Optional[int] - - def __init__( - self, - model: Model[P], - guide: Model[P], - *, - num_samples: int = 1, - max_plate_nesting: Optional[int] = None, - ): - super().__init__() - assert isinstance(model, torch.nn.Module) - assert isinstance(guide, torch.nn.Module) - self.model = model - self.guide = guide - self.num_samples = num_samples - self.max_plate_nesting = max_plate_nesting - - def forward(self, *args: P.args, **kwargs: P.kwargs) -> Point[T]: - if self.max_plate_nesting is None: - self.max_plate_nesting = _guess_max_plate_nesting( - self.model, self.guide, *args, **kwargs - ) - - particles_plate = ( - contextlib.nullcontext() - if self.num_samples == 1 - else pyro.plate( - "__predictive_particles", - self.num_samples, - dim=-self.max_plate_nesting - 1, - ) - ) - - with pyro.poutine.trace() as guide_tr, particles_plate: - self.guide(*args, **kwargs) - - block_guide_sample_sites = pyro.poutine.block( - hide=[ - name - for name, node in guide_tr.trace.nodes.items() - if node["type"] == "sample" - and not pyro.poutine.util.site_is_subsample(node) - ] - ) - - with pyro.poutine.trace() as model_tr: - with block_guide_sample_sites: - with pyro.poutine.replay(trace=guide_tr.trace), particles_plate: - self.model(*args, **kwargs) - - return { - name: node["value"] - for name, node in model_tr.trace.nodes.items() - if node["type"] == "sample" - and not pyro.poutine.util.site_is_subsample(node) - } - - -def linearize( - model: Model[P], - guide: Model[P], - *, - num_samples_outer: int, - num_samples_inner: Optional[int] = None, - max_plate_nesting: Optional[int] = None, - cg_iters: Optional[int] = None, - residual_tol: float = 1e-10, -) -> Callable[Concatenate[Point[T], P], ParamDict]: - assert isinstance(model, torch.nn.Module) - assert isinstance(guide, torch.nn.Module) - if num_samples_inner is None: - num_samples_inner = num_samples_outer**2 - - predictive = pyro.infer.Predictive( - model, - guide=guide, - num_samples=num_samples_outer, - parallel=True, - ) - predictive_params, func_predictive = make_functional_call(predictive) - - log_prob = NMCLogPredictiveLikelihood( - model, guide, num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting - ) - log_prob_params, func_log_prob = make_functional_call(log_prob) - score_fn = torch.func.grad(func_log_prob) - - cg_solver = functools.partial( - conjugate_gradient_solve, cg_iters=cg_iters, residual_tol=residual_tol - ) - - @functools.wraps(score_fn) - def _fn(point: Point[T], *args: P.args, **kwargs: P.kwargs) -> ParamDict: - with torch.no_grad(): - data: Point[T] = func_predictive(predictive_params, *args, **kwargs) - data = {k: data[k] for k in point.keys()} - fvp = make_empirical_fisher_vp( - func_log_prob, log_prob_params, data, *args, **kwargs - ) - point_score: ParamDict = score_fn(log_prob_params, point, *args, **kwargs) - return cg_solver(fvp, point_score) - - return _fn diff --git a/chirho/robust/internals/__init__.py b/chirho/robust/internals/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/chirho/robust/internals/linearize.py b/chirho/robust/internals/linearize.py new file mode 100644 index 000000000..863084017 --- /dev/null +++ b/chirho/robust/internals/linearize.py @@ -0,0 +1,91 @@ +import functools +from typing import Callable, Concatenate, Optional, ParamSpec, TypeVar + +import pyro +import torch + +from chirho.robust.internals.predictive import NMCLogPredictiveLikelihood +from chirho.robust.internals.utils import conjugate_gradient_solve, make_functional_call +from chirho.robust.ops import Model, ParamDict, Point + +pyro.settings.set(module_local_params=True) + +P = ParamSpec("P") +Q = ParamSpec("Q") +S = TypeVar("S") +T = TypeVar("T") + + +def make_empirical_fisher_vp( + func_log_prob: Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor], + log_prob_params: ParamDict, + data: Point[T], + *args: P.args, + **kwargs: P.kwargs, +) -> Callable[[ParamDict], ParamDict]: + batched_func_log_prob: Callable[[ParamDict, Point[T]], torch.Tensor] = torch.vmap( + lambda p, data: func_log_prob(p, data, *args, **kwargs), + in_dims=(None, 0), + randomness="different", + ) + + def bound_batched_func_log_prob(params: ParamDict) -> torch.Tensor: + return batched_func_log_prob(params, data) + + def jvp_fn(v: ParamDict) -> torch.Tensor: + return torch.func.jvp(bound_batched_func_log_prob, (log_prob_params,), (v,))[1] + + vjp_fn = torch.func.vjp(bound_batched_func_log_prob, log_prob_params)[1] + + def _empirical_fisher_vp(v: ParamDict) -> ParamDict: + jvp_log_prob_v = jvp_fn(v) + return vjp_fn(jvp_log_prob_v / jvp_log_prob_v.shape[0])[0] + + return _empirical_fisher_vp + + +def linearize( + model: Model[P], + guide: Model[P], + *, + num_samples_outer: int, + num_samples_inner: Optional[int] = None, + max_plate_nesting: Optional[int] = None, + cg_iters: Optional[int] = None, + residual_tol: float = 1e-10, +) -> Callable[Concatenate[Point[T], P], ParamDict]: + assert isinstance(model, torch.nn.Module) + assert isinstance(guide, torch.nn.Module) + if num_samples_inner is None: + num_samples_inner = num_samples_outer**2 + + predictive = pyro.infer.Predictive( + model, + guide=guide, + num_samples=num_samples_outer, + parallel=True, + ) + predictive_params, func_predictive = make_functional_call(predictive) + + log_prob = NMCLogPredictiveLikelihood( + model, guide, num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting + ) + log_prob_params, func_log_prob = make_functional_call(log_prob) + score_fn = torch.func.grad(func_log_prob) + + cg_solver = functools.partial( + conjugate_gradient_solve, cg_iters=cg_iters, residual_tol=residual_tol + ) + + @functools.wraps(score_fn) + def _fn(point: Point[T], *args: P.args, **kwargs: P.kwargs) -> ParamDict: + with torch.no_grad(): + data: Point[T] = func_predictive(predictive_params, *args, **kwargs) + data = {k: data[k] for k in point.keys()} + fvp = make_empirical_fisher_vp( + func_log_prob, log_prob_params, data, *args, **kwargs + ) + point_score: ParamDict = score_fn(log_prob_params, point, *args, **kwargs) + return cg_solver(fvp, point_score) + + return _fn diff --git a/chirho/robust/internals/predictive.py b/chirho/robust/internals/predictive.py new file mode 100644 index 000000000..7c26622c7 --- /dev/null +++ b/chirho/robust/internals/predictive.py @@ -0,0 +1,150 @@ +import contextlib +import math +from typing import Container, Generic, Optional, ParamSpec, TypeVar + +import pyro +import torch + +from chirho.indexed.handlers import DependentMaskMessenger +from chirho.observational.handlers import condition +from chirho.robust.ops import Model, Point + +pyro.settings.set(module_local_params=True) + +P = ParamSpec("P") +Q = ParamSpec("Q") +S = TypeVar("S") +T = TypeVar("T") + + +@pyro.poutine.block() +@pyro.validation_enabled(False) +@torch.no_grad() +def _guess_max_plate_nesting( + model: Model[P], guide: Model[P], *args: P.args, **kwargs: P.kwargs +) -> int: + elbo = pyro.infer.Trace_ELBO() + elbo._guess_max_plate_nesting(model, guide, args, kwargs) + return elbo.max_plate_nesting + + +class UnmaskNamedSites(DependentMaskMessenger): + names: Container[str] + + def __init__(self, names: Container[str]): + self.names = names + + def get_mask( + self, + dist: pyro.distributions.Distribution, + value: Optional[torch.Tensor], + device: torch.device = torch.device("cpu"), + name: Optional[str] = None, + ) -> torch.Tensor: + return torch.tensor(name is None or name in self.names, device=device) + + +class PredictiveFunctional(Generic[P, T], torch.nn.Module): + model: Model[P] + guide: Model[P] + num_samples: int + max_plate_nesting: Optional[int] + + def __init__( + self, + model: Model[P], + guide: Model[P], + *, + num_samples: int = 1, + max_plate_nesting: Optional[int] = None, + ): + super().__init__() + assert isinstance(model, torch.nn.Module) + assert isinstance(guide, torch.nn.Module) + self.model = model + self.guide = guide + self.num_samples = num_samples + self.max_plate_nesting = max_plate_nesting + + def forward(self, *args: P.args, **kwargs: P.kwargs) -> Point[T]: + if self.max_plate_nesting is None: + self.max_plate_nesting = _guess_max_plate_nesting( + self.model, self.guide, *args, **kwargs + ) + + particles_plate = ( + contextlib.nullcontext() + if self.num_samples == 1 + else pyro.plate( + "__predictive_particles", + self.num_samples, + dim=-self.max_plate_nesting - 1, + ) + ) + + with pyro.poutine.trace() as guide_tr, particles_plate: + self.guide(*args, **kwargs) + + block_guide_sample_sites = pyro.poutine.block( + hide=[ + name + for name, node in guide_tr.trace.nodes.items() + if node["type"] == "sample" + and not pyro.poutine.util.site_is_subsample(node) + ] + ) + + with pyro.poutine.trace() as model_tr: + with block_guide_sample_sites: + with pyro.poutine.replay(trace=guide_tr.trace), particles_plate: + self.model(*args, **kwargs) + + return { + name: node["value"] + for name, node in model_tr.trace.nodes.items() + if node["type"] == "sample" + and not pyro.poutine.util.site_is_subsample(node) + } + + +class NMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module): + model: Model[P] + guide: Model[P] + num_samples: int + max_plate_nesting: Optional[int] + + def __init__( + self, + model: torch.nn.Module, + guide: torch.nn.Module, + *, + num_samples: int = 1, + max_plate_nesting: Optional[int] = None, + ): + super().__init__() + self.model = model + self.guide = guide + self.num_samples = num_samples + self.max_plate_nesting = max_plate_nesting + + def forward( + self, data: Point[T], *args: P.args, **kwargs: P.kwargs + ) -> torch.Tensor: + if self.max_plate_nesting is None: + self.max_plate_nesting = _guess_max_plate_nesting( + self.model, self.guide, *args, **kwargs + ) + + masked_guide = pyro.poutine.mask(mask=False)(self.guide) + masked_model = UnmaskNamedSites(names=set(data.keys()))( + condition(data=data)(self.model) + ) + log_weights = pyro.infer.importance.vectorized_importance_weights( + masked_model, + masked_guide, + *args, + num_samples=self.num_samples, + max_plate_nesting=self.max_plate_nesting, + **kwargs, + )[0] + return torch.logsumexp(log_weights, dim=0) - math.log(self.num_samples) diff --git a/chirho/robust/internals/utils.py b/chirho/robust/internals/utils.py new file mode 100644 index 000000000..2b0107533 --- /dev/null +++ b/chirho/robust/internals/utils.py @@ -0,0 +1,152 @@ +import functools +from typing import ( + Any, + Callable, + Concatenate, + Dict, + Mapping, + Optional, + ParamSpec, + Tuple, + TypeVar, +) + +import pyro +import torch + +from chirho.robust.ops import ParamDict + +pyro.settings.set(module_local_params=True) + +P = ParamSpec("P") +Q = ParamSpec("Q") +S = TypeVar("S") +T = TypeVar("T") + + +@functools.singledispatch +def make_flatten_unflatten( + v, +) -> Tuple[Callable[[T], torch.Tensor], Callable[[torch.Tensor], T]]: + raise NotImplementedError + + +@make_flatten_unflatten.register(torch.Tensor) +def _make_flatten_unflatten_tensor(v: torch.Tensor): + def flatten(v: torch.Tensor) -> torch.Tensor: + r""" + Flatten a tensor into a single vector. + """ + return v.flatten() + + def unflatten(x: torch.Tensor) -> torch.Tensor: + r""" + Unflatten a vector into a tensor. + """ + return x.reshape(v.shape) + + return flatten, unflatten + + +@make_flatten_unflatten.register(dict) +def _make_flatten_unflatten_dict(d: Dict[str, torch.Tensor]): + def flatten(d: Dict[str, torch.Tensor]) -> torch.Tensor: + r""" + Flatten a dictionary of tensors into a single vector. + """ + return torch.cat([v.flatten() for k, v in d.items()]) + + def unflatten(x: torch.Tensor) -> Dict[str, torch.Tensor]: + r""" + Unflatten a vector into a dictionary of tensors. + """ + return dict( + zip( + d.keys(), + [ + v_flat.reshape(v.shape) + for v, v_flat in zip( + d.values(), torch.split(x, [v.numel() for k, v in d.items()]) + ) + ], + ) + ) + + return flatten, unflatten + + +def _flat_conjugate_gradient_solve( + f_Ax: Callable[[torch.Tensor], torch.Tensor], + b: torch.Tensor, + *, + cg_iters: Optional[int] = None, + residual_tol: float = 1e-10, +) -> torch.Tensor: + r"""Use Conjugate Gradient iteration to solve Ax = b. Demmel p 312. + + Args: + f_Ax (callable): A function to compute matrix vector product. + b (torch.Tensor): Right hand side of the equation to solve. + cg_iters (int): Number of iterations to run conjugate gradient + algorithm. + residual_tol (float): Tolerence for convergence. + + Returns: + torch.Tensor: Solution x* for equation Ax = b. + + Notes: This code is adapted from https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py + """ + if cg_iters is None: + cg_iters = b.numel() + + p = b.clone() + r = b.clone() + x = torch.zeros_like(b) + z = f_Ax(p) + rdotr = torch.dot(r, r) + v = rdotr / torch.dot(p, z) + newrdotr = rdotr + mu = newrdotr / rdotr + + zeros_xr = torch.zeros_like(x) + + for _ in range(cg_iters): + not_converged = rdotr > residual_tol + z = torch.where(not_converged, f_Ax(p), z) + v = torch.where(not_converged, rdotr / torch.dot(p, z), v) + x += torch.where(not_converged, v * p, zeros_xr) + r -= torch.where(not_converged, v * z, zeros_xr) + newrdotr = torch.where(not_converged, torch.dot(r, r), newrdotr) + mu = torch.where(not_converged, newrdotr / rdotr, mu) + p = torch.where(not_converged, r + mu * p, p) + rdotr = torch.where(not_converged, newrdotr, rdotr) + + # rdotr = newrdotr + # if rdotr < residual_tol: + # break + return x + + +def conjugate_gradient_solve(f_Ax: Callable[[T], T], b: T, **kwargs) -> T: + flatten, unflatten = make_flatten_unflatten(b) + + def f_Ax_flat(v: torch.Tensor) -> torch.Tensor: + v_unflattened: T = unflatten(v) + result_unflattened = f_Ax(v_unflattened) + return flatten(result_unflattened) + + return unflatten(_flat_conjugate_gradient_solve(f_Ax_flat, flatten(b), **kwargs)) + + +def make_functional_call( + mod: Callable[P, T] +) -> Tuple[ParamDict, Callable[Concatenate[ParamDict, P], T]]: + assert isinstance(mod, torch.nn.Module) + param_dict: ParamDict = dict(mod.named_parameters()) + + @torch.func.functionalize + def mod_func(params: ParamDict, *args: P.args, **kwargs: P.kwargs) -> T: + with pyro.validation_enabled(False): + return torch.func.functional_call(mod, params, args, dict(**kwargs)) + + return param_dict, mod_func diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py index 562ae8b96..6ae7d27be 100644 --- a/chirho/robust/ops.py +++ b/chirho/robust/ops.py @@ -22,13 +22,13 @@ def influence_fn( functional: Optional[Functional[P, S]] = None, **linearize_kwargs ) -> Callable[Concatenate[Point[T], P], S]: - from chirho.robust.internals import linearize, make_functional_call + from chirho.robust.internals.linearize import linearize + from chirho.robust.internals.predictive import PredictiveFunctional + from chirho.robust.internals.utils import make_functional_call linearized = linearize(model, guide, **linearize_kwargs) if functional is None: - from chirho.robust.internals import PredictiveFunctional - target = PredictiveFunctional(model, guide) else: target = functional(model, guide) diff --git a/tests/robust/test_dice_correction.py b/tests/robust/test_dice_correction.py index e968e9660..77c5b17b8 100644 --- a/tests/robust/test_dice_correction.py +++ b/tests/robust/test_dice_correction.py @@ -9,7 +9,7 @@ from chirho.robust import one_step_correction from chirho.robust.functionals import average_treatment_effect, dice_correction -from chirho.robust.utils import _flatten_dict, _unflatten_dict +from chirho.robust.internals.utils import _flatten_dict, _unflatten_dict class HighDimLinearModel(pyro.nn.PyroModule): diff --git a/tests/robust/test_internals_linearize.py b/tests/robust/test_internals_linearize.py index 9347b293f..1de2041ff 100644 --- a/tests/robust/test_internals_linearize.py +++ b/tests/robust/test_internals_linearize.py @@ -6,7 +6,8 @@ import pytest import torch -from chirho.robust.internals import PredictiveFunctional, linearize +from chirho.robust.internals.linearize import linearize +from chirho.robust.internals.predictive import PredictiveFunctional from chirho.robust.ops import influence_fn pyro.settings.set(module_local_params=True) From dc63f31cb8c3cc645fadb0d108fde098a9a3bfa1 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 22 Nov 2023 16:23:47 -0500 Subject: [PATCH 26/66] shrink test case --- tests/robust/test_internals_linearize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/robust/test_internals_linearize.py b/tests/robust/test_internals_linearize.py index 1de2041ff..5d9d34056 100644 --- a/tests/robust/test_internals_linearize.py +++ b/tests/robust/test_internals_linearize.py @@ -65,7 +65,7 @@ def forward(self): @pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", TEST_CASES) @pytest.mark.parametrize( - "num_samples_outer,num_samples_inner", [(100, None), (10, 100)] + "num_samples_outer,num_samples_inner", [(10, None), (10, 100)] ) @pytest.mark.parametrize("cg_iters", [None, 1, 10]) def test_nmc_param_influence_smoke( From be3bc8de43610e3d43e843cefd42e7a41ebf42fd Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 22 Nov 2023 17:03:00 -0500 Subject: [PATCH 27/66] move guess_max_plate_nesting --- chirho/robust/internals/linearize.py | 2 -- chirho/robust/internals/predictive.py | 16 +++------------- chirho/robust/internals/utils.py | 15 ++++++++++++--- tests/robust/test_internals_linearize.py | 4 +--- 4 files changed, 16 insertions(+), 21 deletions(-) diff --git a/chirho/robust/internals/linearize.py b/chirho/robust/internals/linearize.py index 863084017..e75c3be08 100644 --- a/chirho/robust/internals/linearize.py +++ b/chirho/robust/internals/linearize.py @@ -8,8 +8,6 @@ from chirho.robust.internals.utils import conjugate_gradient_solve, make_functional_call from chirho.robust.ops import Model, ParamDict, Point -pyro.settings.set(module_local_params=True) - P = ParamSpec("P") Q = ParamSpec("Q") S = TypeVar("S") diff --git a/chirho/robust/internals/predictive.py b/chirho/robust/internals/predictive.py index 7c26622c7..17866f0f9 100644 --- a/chirho/robust/internals/predictive.py +++ b/chirho/robust/internals/predictive.py @@ -7,6 +7,7 @@ from chirho.indexed.handlers import DependentMaskMessenger from chirho.observational.handlers import condition +from chirho.robust.internals.utils import guess_max_plate_nesting from chirho.robust.ops import Model, Point pyro.settings.set(module_local_params=True) @@ -17,17 +18,6 @@ T = TypeVar("T") -@pyro.poutine.block() -@pyro.validation_enabled(False) -@torch.no_grad() -def _guess_max_plate_nesting( - model: Model[P], guide: Model[P], *args: P.args, **kwargs: P.kwargs -) -> int: - elbo = pyro.infer.Trace_ELBO() - elbo._guess_max_plate_nesting(model, guide, args, kwargs) - return elbo.max_plate_nesting - - class UnmaskNamedSites(DependentMaskMessenger): names: Container[str] @@ -68,7 +58,7 @@ def __init__( def forward(self, *args: P.args, **kwargs: P.kwargs) -> Point[T]: if self.max_plate_nesting is None: - self.max_plate_nesting = _guess_max_plate_nesting( + self.max_plate_nesting = guess_max_plate_nesting( self.model, self.guide, *args, **kwargs ) @@ -131,7 +121,7 @@ def forward( self, data: Point[T], *args: P.args, **kwargs: P.kwargs ) -> torch.Tensor: if self.max_plate_nesting is None: - self.max_plate_nesting = _guess_max_plate_nesting( + self.max_plate_nesting = guess_max_plate_nesting( self.model, self.guide, *args, **kwargs ) diff --git a/chirho/robust/internals/utils.py b/chirho/robust/internals/utils.py index 2b0107533..cc6a13a7f 100644 --- a/chirho/robust/internals/utils.py +++ b/chirho/robust/internals/utils.py @@ -14,9 +14,7 @@ import pyro import torch -from chirho.robust.ops import ParamDict - -pyro.settings.set(module_local_params=True) +from chirho.robust.ops import Model, ParamDict P = ParamSpec("P") Q = ParamSpec("Q") @@ -150,3 +148,14 @@ def mod_func(params: ParamDict, *args: P.args, **kwargs: P.kwargs) -> T: return torch.func.functional_call(mod, params, args, dict(**kwargs)) return param_dict, mod_func + + +@pyro.poutine.block() +@pyro.validation_enabled(False) +@torch.no_grad() +def guess_max_plate_nesting( + model: Model[P], guide: Model[P], *args: P.args, **kwargs: P.kwargs +) -> int: + elbo = pyro.infer.Trace_ELBO() + elbo._guess_max_plate_nesting(model, guide, args, kwargs) + return elbo.max_plate_nesting diff --git a/tests/robust/test_internals_linearize.py b/tests/robust/test_internals_linearize.py index 5d9d34056..2cdf0842b 100644 --- a/tests/robust/test_internals_linearize.py +++ b/tests/robust/test_internals_linearize.py @@ -64,9 +64,7 @@ def forward(self): @pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", TEST_CASES) -@pytest.mark.parametrize( - "num_samples_outer,num_samples_inner", [(10, None), (10, 100)] -) +@pytest.mark.parametrize("num_samples_outer,num_samples_inner", [(10, None), (10, 100)]) @pytest.mark.parametrize("cg_iters", [None, 1, 10]) def test_nmc_param_influence_smoke( model, From 9ce164a94689dd47ee05fd06e74ed13ab5cbcd8b Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 22 Nov 2023 17:07:05 -0500 Subject: [PATCH 28/66] move cg solver to linearze --- chirho/robust/internals/linearize.py | 65 ++++++++++++++++++++++++++- chirho/robust/internals/utils.py | 66 ---------------------------- 2 files changed, 64 insertions(+), 67 deletions(-) diff --git a/chirho/robust/internals/linearize.py b/chirho/robust/internals/linearize.py index e75c3be08..c8d4f8b06 100644 --- a/chirho/robust/internals/linearize.py +++ b/chirho/robust/internals/linearize.py @@ -5,7 +5,7 @@ import torch from chirho.robust.internals.predictive import NMCLogPredictiveLikelihood -from chirho.robust.internals.utils import conjugate_gradient_solve, make_functional_call +from chirho.robust.internals.utils import make_flatten_unflatten, make_functional_call from chirho.robust.ops import Model, ParamDict, Point P = ParamSpec("P") @@ -14,6 +14,69 @@ T = TypeVar("T") +def _flat_conjugate_gradient_solve( + f_Ax: Callable[[torch.Tensor], torch.Tensor], + b: torch.Tensor, + *, + cg_iters: Optional[int] = None, + residual_tol: float = 1e-10, +) -> torch.Tensor: + r"""Use Conjugate Gradient iteration to solve Ax = b. Demmel p 312. + + Args: + f_Ax (callable): A function to compute matrix vector product. + b (torch.Tensor): Right hand side of the equation to solve. + cg_iters (int): Number of iterations to run conjugate gradient + algorithm. + residual_tol (float): Tolerence for convergence. + + Returns: + torch.Tensor: Solution x* for equation Ax = b. + + Notes: This code is adapted from https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py + """ + if cg_iters is None: + cg_iters = b.numel() + + p = b.clone() + r = b.clone() + x = torch.zeros_like(b) + z = f_Ax(p) + rdotr = torch.dot(r, r) + v = rdotr / torch.dot(p, z) + newrdotr = rdotr + mu = newrdotr / rdotr + + zeros_xr = torch.zeros_like(x) + + for _ in range(cg_iters): + not_converged = rdotr > residual_tol + z = torch.where(not_converged, f_Ax(p), z) + v = torch.where(not_converged, rdotr / torch.dot(p, z), v) + x += torch.where(not_converged, v * p, zeros_xr) + r -= torch.where(not_converged, v * z, zeros_xr) + newrdotr = torch.where(not_converged, torch.dot(r, r), newrdotr) + mu = torch.where(not_converged, newrdotr / rdotr, mu) + p = torch.where(not_converged, r + mu * p, p) + rdotr = torch.where(not_converged, newrdotr, rdotr) + + # rdotr = newrdotr + # if rdotr < residual_tol: + # break + return x + + +def conjugate_gradient_solve(f_Ax: Callable[[T], T], b: T, **kwargs) -> T: + flatten, unflatten = make_flatten_unflatten(b) + + def f_Ax_flat(v: torch.Tensor) -> torch.Tensor: + v_unflattened: T = unflatten(v) + result_unflattened = f_Ax(v_unflattened) + return flatten(result_unflattened) + + return unflatten(_flat_conjugate_gradient_solve(f_Ax_flat, flatten(b), **kwargs)) + + def make_empirical_fisher_vp( func_log_prob: Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor], log_prob_params: ParamDict, diff --git a/chirho/robust/internals/utils.py b/chirho/robust/internals/utils.py index cc6a13a7f..3f7acbe71 100644 --- a/chirho/robust/internals/utils.py +++ b/chirho/robust/internals/utils.py @@ -1,11 +1,8 @@ import functools from typing import ( - Any, Callable, Concatenate, Dict, - Mapping, - Optional, ParamSpec, Tuple, TypeVar, @@ -73,69 +70,6 @@ def unflatten(x: torch.Tensor) -> Dict[str, torch.Tensor]: return flatten, unflatten -def _flat_conjugate_gradient_solve( - f_Ax: Callable[[torch.Tensor], torch.Tensor], - b: torch.Tensor, - *, - cg_iters: Optional[int] = None, - residual_tol: float = 1e-10, -) -> torch.Tensor: - r"""Use Conjugate Gradient iteration to solve Ax = b. Demmel p 312. - - Args: - f_Ax (callable): A function to compute matrix vector product. - b (torch.Tensor): Right hand side of the equation to solve. - cg_iters (int): Number of iterations to run conjugate gradient - algorithm. - residual_tol (float): Tolerence for convergence. - - Returns: - torch.Tensor: Solution x* for equation Ax = b. - - Notes: This code is adapted from https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py - """ - if cg_iters is None: - cg_iters = b.numel() - - p = b.clone() - r = b.clone() - x = torch.zeros_like(b) - z = f_Ax(p) - rdotr = torch.dot(r, r) - v = rdotr / torch.dot(p, z) - newrdotr = rdotr - mu = newrdotr / rdotr - - zeros_xr = torch.zeros_like(x) - - for _ in range(cg_iters): - not_converged = rdotr > residual_tol - z = torch.where(not_converged, f_Ax(p), z) - v = torch.where(not_converged, rdotr / torch.dot(p, z), v) - x += torch.where(not_converged, v * p, zeros_xr) - r -= torch.where(not_converged, v * z, zeros_xr) - newrdotr = torch.where(not_converged, torch.dot(r, r), newrdotr) - mu = torch.where(not_converged, newrdotr / rdotr, mu) - p = torch.where(not_converged, r + mu * p, p) - rdotr = torch.where(not_converged, newrdotr, rdotr) - - # rdotr = newrdotr - # if rdotr < residual_tol: - # break - return x - - -def conjugate_gradient_solve(f_Ax: Callable[[T], T], b: T, **kwargs) -> T: - flatten, unflatten = make_flatten_unflatten(b) - - def f_Ax_flat(v: torch.Tensor) -> torch.Tensor: - v_unflattened: T = unflatten(v) - result_unflattened = f_Ax(v_unflattened) - return flatten(result_unflattened) - - return unflatten(_flat_conjugate_gradient_solve(f_Ax_flat, flatten(b), **kwargs)) - - def make_functional_call( mod: Callable[P, T] ) -> Tuple[ParamDict, Callable[Concatenate[ParamDict, P], T]]: From 81196d4b7a88399c56241c05611cf737c512e882 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 22 Nov 2023 17:16:25 -0500 Subject: [PATCH 29/66] type alias --- chirho/robust/internals/linearize.py | 14 +++++++++----- chirho/robust/internals/predictive.py | 18 ++++++++---------- chirho/robust/internals/utils.py | 16 +++++----------- chirho/robust/ops.py | 10 +++++----- 4 files changed, 27 insertions(+), 31 deletions(-) diff --git a/chirho/robust/internals/linearize.py b/chirho/robust/internals/linearize.py index c8d4f8b06..67b20a9d3 100644 --- a/chirho/robust/internals/linearize.py +++ b/chirho/robust/internals/linearize.py @@ -1,12 +1,16 @@ import functools -from typing import Callable, Concatenate, Optional, ParamSpec, TypeVar +from typing import Any, Callable, Concatenate, Optional, ParamSpec, TypeVar import pyro import torch from chirho.robust.internals.predictive import NMCLogPredictiveLikelihood -from chirho.robust.internals.utils import make_flatten_unflatten, make_functional_call -from chirho.robust.ops import Model, ParamDict, Point +from chirho.robust.internals.utils import ( + ParamDict, + make_flatten_unflatten, + make_functional_call, +) +from chirho.robust.ops import Point P = ParamSpec("P") Q = ParamSpec("Q") @@ -106,8 +110,8 @@ def _empirical_fisher_vp(v: ParamDict) -> ParamDict: def linearize( - model: Model[P], - guide: Model[P], + model: Callable[P, Any], + guide: Callable[P, Any], *, num_samples_outer: int, num_samples_inner: Optional[int] = None, diff --git a/chirho/robust/internals/predictive.py b/chirho/robust/internals/predictive.py index 17866f0f9..d7bd4a4db 100644 --- a/chirho/robust/internals/predictive.py +++ b/chirho/robust/internals/predictive.py @@ -1,6 +1,6 @@ import contextlib import math -from typing import Container, Generic, Optional, ParamSpec, TypeVar +from typing import Any, Callable, Container, Generic, Optional, ParamSpec, TypeVar import pyro import torch @@ -8,7 +8,7 @@ from chirho.indexed.handlers import DependentMaskMessenger from chirho.observational.handlers import condition from chirho.robust.internals.utils import guess_max_plate_nesting -from chirho.robust.ops import Model, Point +from chirho.robust.ops import Point pyro.settings.set(module_local_params=True) @@ -35,22 +35,20 @@ def get_mask( class PredictiveFunctional(Generic[P, T], torch.nn.Module): - model: Model[P] - guide: Model[P] + model: Callable[P, Any] + guide: Callable[P, Any] num_samples: int max_plate_nesting: Optional[int] def __init__( self, - model: Model[P], - guide: Model[P], + model: torch.nn.Module, + guide: torch.nn.Module, *, num_samples: int = 1, max_plate_nesting: Optional[int] = None, ): super().__init__() - assert isinstance(model, torch.nn.Module) - assert isinstance(guide, torch.nn.Module) self.model = model self.guide = guide self.num_samples = num_samples @@ -98,8 +96,8 @@ def forward(self, *args: P.args, **kwargs: P.kwargs) -> Point[T]: class NMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module): - model: Model[P] - guide: Model[P] + model: Callable[P, Any] + guide: Callable[P, Any] num_samples: int max_plate_nesting: Optional[int] diff --git a/chirho/robust/internals/utils.py b/chirho/robust/internals/utils.py index 3f7acbe71..7de1c3b91 100644 --- a/chirho/robust/internals/utils.py +++ b/chirho/robust/internals/utils.py @@ -1,18 +1,9 @@ import functools -from typing import ( - Callable, - Concatenate, - Dict, - ParamSpec, - Tuple, - TypeVar, -) +from typing import Any, Callable, Concatenate, Dict, Mapping, ParamSpec, Tuple, TypeVar import pyro import torch -from chirho.robust.ops import Model, ParamDict - P = ParamSpec("P") Q = ParamSpec("Q") S = TypeVar("S") @@ -70,6 +61,9 @@ def unflatten(x: torch.Tensor) -> Dict[str, torch.Tensor]: return flatten, unflatten +ParamDict = Mapping[str, torch.Tensor] + + def make_functional_call( mod: Callable[P, T] ) -> Tuple[ParamDict, Callable[Concatenate[ParamDict, P], T]]: @@ -88,7 +82,7 @@ def mod_func(params: ParamDict, *args: P.args, **kwargs: P.kwargs) -> T: @pyro.validation_enabled(False) @torch.no_grad() def guess_max_plate_nesting( - model: Model[P], guide: Model[P], *args: P.args, **kwargs: P.kwargs + model: Callable[P, Any], guide: Callable[P, Any], *args: P.args, **kwargs: P.kwargs ) -> int: elbo = pyro.infer.Trace_ELBO() elbo._guess_max_plate_nesting(model, guide, args, kwargs) diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py index 6ae7d27be..b03ba9935 100644 --- a/chirho/robust/ops.py +++ b/chirho/robust/ops.py @@ -10,15 +10,13 @@ S = TypeVar("S") T = TypeVar("T") -Model = Callable[P, Any] Point = Mapping[str, Observation[T]] -Functional = Callable[[Model[P], Model[P]], Callable[P, S]] -ParamDict = Mapping[str, torch.Tensor] +Functional = Callable[[Callable[P, Any], Callable[P, Any]], Callable[P, S]] def influence_fn( - model: Model[P], - guide: Model[P], + model: Callable[P, Any], + guide: Callable[P, Any], functional: Optional[Functional[P, S]] = None, **linearize_kwargs ) -> Callable[Concatenate[Point[T], P], S]: @@ -29,6 +27,8 @@ def influence_fn( linearized = linearize(model, guide, **linearize_kwargs) if functional is None: + assert isinstance(model, torch.nn.Module) + assert isinstance(guide, torch.nn.Module) target = PredictiveFunctional(model, guide) else: target = functional(model, guide) From 30cb2e7d8a16b68b1dc207298606c0c33dee5cf5 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 22 Nov 2023 17:24:30 -0500 Subject: [PATCH 30/66] test_ops --- tests/robust/test_internals_linearize.py | 93 -------------- tests/robust/test_ops.py | 152 +++++++++++++++++++++++ 2 files changed, 152 insertions(+), 93 deletions(-) create mode 100644 tests/robust/test_ops.py diff --git a/tests/robust/test_internals_linearize.py b/tests/robust/test_internals_linearize.py index 2cdf0842b..77615f82b 100644 --- a/tests/robust/test_internals_linearize.py +++ b/tests/robust/test_internals_linearize.py @@ -1,4 +1,3 @@ -import functools from typing import Callable, List, Mapping, Optional, ParamSpec, Set, Tuple, TypeVar import pyro @@ -7,8 +6,6 @@ import torch from chirho.robust.internals.linearize import linearize -from chirho.robust.internals.predictive import PredictiveFunctional -from chirho.robust.ops import influence_fn pyro.settings.set(module_local_params=True) @@ -139,93 +136,3 @@ def test_nmc_param_influence_vmap_smoke( assert not torch.isnan(v).any(), f"eif for {k} had nans" assert not torch.isinf(v).any(), f"eif for {k} had infs" assert not torch.isclose(v, torch.zeros_like(v)).all(), f"eif for {k} was zero" - - -@pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", TEST_CASES) -@pytest.mark.parametrize( - "num_samples_outer,num_samples_inner", [(100, None), (10, 100)] -) -@pytest.mark.parametrize("cg_iters", [None, 1, 10]) -@pytest.mark.parametrize("num_predictive_samples", [1, 5]) -def test_nmc_predictive_influence_smoke( - model, - guide, - obs_names, - max_plate_nesting, - num_samples_outer, - num_samples_inner, - cg_iters, - num_predictive_samples, -): - model(), guide() # initialize - - predictive_eif = influence_fn( - model, - guide, - functional=functools.partial( - PredictiveFunctional, num_samples=num_predictive_samples - ), - max_plate_nesting=max_plate_nesting, - num_samples_outer=num_samples_outer, - num_samples_inner=num_samples_inner, - cg_iters=cg_iters, - ) - - with torch.no_grad(): - test_datum = { - k: v[0] - for k, v in pyro.infer.Predictive( - model, num_samples=2, return_sites=obs_names, parallel=True - )().items() - } - - test_datum_eif: Mapping[str, torch.Tensor] = predictive_eif(test_datum) - assert len(test_datum_eif) > 0 - for k, v in test_datum_eif.items(): - assert not torch.isnan(v).any(), f"eif for {k} had nans" - assert not torch.isinf(v).any(), f"eif for {k} had infs" - assert not torch.isclose(v, torch.zeros_like(v)).all(), f"eif for {k} was zero" - - -@pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", TEST_CASES) -@pytest.mark.parametrize( - "num_samples_outer,num_samples_inner", [(100, None), (10, 100)] -) -@pytest.mark.parametrize("cg_iters", [None, 1, 10]) -@pytest.mark.parametrize("num_predictive_samples", [1, 5]) -def test_nmc_predictive_influence_vmap_smoke( - model, - guide, - obs_names, - max_plate_nesting, - num_samples_outer, - num_samples_inner, - cg_iters, - num_predictive_samples, -): - model(), guide() # initialize - - predictive_eif = influence_fn( - model, - guide, - functional=functools.partial( - PredictiveFunctional, num_samples=num_predictive_samples - ), - max_plate_nesting=max_plate_nesting, - num_samples_outer=num_samples_outer, - num_samples_inner=num_samples_inner, - cg_iters=cg_iters, - ) - - with torch.no_grad(): - test_data = pyro.infer.Predictive( - model, num_samples=4, return_sites=obs_names, parallel=True - )() - - batch_predictive_eif = torch.vmap(predictive_eif, randomness="different") - test_data_eif: Mapping[str, torch.Tensor] = batch_predictive_eif(test_data) - assert len(test_data_eif) > 0 - for k, v in test_data_eif.items(): - assert not torch.isnan(v).any(), f"eif for {k} had nans" - assert not torch.isinf(v).any(), f"eif for {k} had infs" - assert not torch.isclose(v, torch.zeros_like(v)).all(), f"eif for {k} was zero" diff --git a/tests/robust/test_ops.py b/tests/robust/test_ops.py new file mode 100644 index 000000000..d1dbb6616 --- /dev/null +++ b/tests/robust/test_ops.py @@ -0,0 +1,152 @@ +import functools +from typing import Callable, List, Mapping, Optional, ParamSpec, Set, Tuple, TypeVar + +import pyro +import pyro.distributions as dist +import pytest +import torch + +from chirho.robust.internals.predictive import PredictiveFunctional +from chirho.robust.ops import influence_fn + +pyro.settings.set(module_local_params=True) + +P = ParamSpec("P") +Q = ParamSpec("Q") +S = TypeVar("S") +T = TypeVar("T") + + +class SimpleModel(pyro.nn.PyroModule): + def forward(self): + a = pyro.sample("a", dist.Normal(0, 1)) + with pyro.plate("data", 3, dim=-1): + b = pyro.sample("b", dist.Normal(a, 1)) + return pyro.sample("y", dist.Normal(a + b, 1)) + + +class SimpleGuide(torch.nn.Module): + def __init__(self): + super().__init__() + self.loc_a = torch.nn.Parameter(torch.rand(())) + self.loc_b = torch.nn.Parameter(torch.rand((3,))) + + def forward(self): + a = pyro.sample("a", dist.Normal(self.loc_a, 1)) + with pyro.plate("data", 3, dim=-1): + b = pyro.sample("b", dist.Normal(self.loc_b, 1)) + return {"a": a, "b": b} + + +TEST_CASES: List[Tuple[Callable, Callable, Set[str], Optional[int]]] = [ + (SimpleModel(), SimpleGuide(), {"y"}, 1), + (SimpleModel(), SimpleGuide(), {"y"}, None), + pytest.param( + (m := SimpleModel()), + pyro.infer.autoguide.AutoNormal(m), + {"y"}, + 1, + marks=pytest.mark.xfail( + reason="torch.func autograd doesnt work with PyroParam" + ), + ), + pytest.param( + (m := SimpleModel()), + pyro.infer.autoguide.AutoDelta(m), + {"y"}, + 1, + marks=pytest.mark.xfail( + reason="torch.func autograd doesnt work with PyroParam" + ), + ), +] + + +@pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", TEST_CASES) +@pytest.mark.parametrize( + "num_samples_outer,num_samples_inner", [(100, None), (10, 100)] +) +@pytest.mark.parametrize("cg_iters", [None, 1, 10]) +@pytest.mark.parametrize("num_predictive_samples", [1, 5]) +def test_nmc_predictive_influence_smoke( + model, + guide, + obs_names, + max_plate_nesting, + num_samples_outer, + num_samples_inner, + cg_iters, + num_predictive_samples, +): + model(), guide() # initialize + + predictive_eif = influence_fn( + model, + guide, + functional=functools.partial( + PredictiveFunctional, num_samples=num_predictive_samples + ), + max_plate_nesting=max_plate_nesting, + num_samples_outer=num_samples_outer, + num_samples_inner=num_samples_inner, + cg_iters=cg_iters, + ) + + with torch.no_grad(): + test_datum = { + k: v[0] + for k, v in pyro.infer.Predictive( + model, num_samples=2, return_sites=obs_names, parallel=True + )().items() + } + + test_datum_eif: Mapping[str, torch.Tensor] = predictive_eif(test_datum) + assert len(test_datum_eif) > 0 + for k, v in test_datum_eif.items(): + assert not torch.isnan(v).any(), f"eif for {k} had nans" + assert not torch.isinf(v).any(), f"eif for {k} had infs" + assert not torch.isclose(v, torch.zeros_like(v)).all(), f"eif for {k} was zero" + + +@pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", TEST_CASES) +@pytest.mark.parametrize( + "num_samples_outer,num_samples_inner", [(100, None), (10, 100)] +) +@pytest.mark.parametrize("cg_iters", [None, 1, 10]) +@pytest.mark.parametrize("num_predictive_samples", [1, 5]) +def test_nmc_predictive_influence_vmap_smoke( + model, + guide, + obs_names, + max_plate_nesting, + num_samples_outer, + num_samples_inner, + cg_iters, + num_predictive_samples, +): + model(), guide() # initialize + + predictive_eif = influence_fn( + model, + guide, + functional=functools.partial( + PredictiveFunctional, num_samples=num_predictive_samples + ), + max_plate_nesting=max_plate_nesting, + num_samples_outer=num_samples_outer, + num_samples_inner=num_samples_inner, + cg_iters=cg_iters, + ) + + with torch.no_grad(): + test_data = pyro.infer.Predictive( + model, num_samples=4, return_sites=obs_names, parallel=True + )() + + batch_predictive_eif = torch.vmap(predictive_eif, randomness="different") + test_data_eif: Mapping[str, torch.Tensor] = batch_predictive_eif(test_data) + assert len(test_data_eif) > 0 + for k, v in test_data_eif.items(): + assert not torch.isnan(v).any(), f"eif for {k} had nans" + assert not torch.isinf(v).any(), f"eif for {k} had infs" + assert not torch.isclose(v, torch.zeros_like(v)).all(), f"eif for {k} was zero" From 21cf2d719f341d8ebfa373fbf731d859bcbeef68 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 22 Nov 2023 17:57:07 -0500 Subject: [PATCH 31/66] basic cg tests --- tests/robust/test_internals_linearize.py | 50 ++++++++++++++++++++++-- tests/robust/test_ops.py | 6 +-- 2 files changed, 49 insertions(+), 7 deletions(-) diff --git a/tests/robust/test_internals_linearize.py b/tests/robust/test_internals_linearize.py index 77615f82b..49bd6daff 100644 --- a/tests/robust/test_internals_linearize.py +++ b/tests/robust/test_internals_linearize.py @@ -1,3 +1,4 @@ +import functools from typing import Callable, List, Mapping, Optional, ParamSpec, Set, Tuple, TypeVar import pyro @@ -5,7 +6,7 @@ import pytest import torch -from chirho.robust.internals.linearize import linearize +from chirho.robust.internals.linearize import conjugate_gradient_solve, linearize pyro.settings.set(module_local_params=True) @@ -15,6 +16,47 @@ T = TypeVar("T") +@pytest.mark.parametrize("ndim", [1, 2, 3, 10]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_cg_solve(ndim: int, dtype: torch.dtype): + cg_iters = None + residual_tol = 1e-10 + + A = torch.eye(ndim, dtype=dtype) + 0.1 * torch.rand(ndim, ndim, dtype=dtype) + expected_x = torch.randn(ndim, dtype=dtype) + b = A @ expected_x + + actual_x = conjugate_gradient_solve( + lambda v: A @ v, b, cg_iters=cg_iters, residual_tol=residual_tol + ) + assert torch.sum((actual_x - expected_x) ** 2) < 1e-4 + + +@pytest.mark.parametrize("ndim", [1, 2, 3, 10]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +@pytest.mark.parametrize("num_particles", [1, 4]) +def test_batch_cg_solve(ndim: int, dtype: torch.dtype, num_particles: int): + cg_iters = None + residual_tol = 1e-10 + + A = torch.eye(ndim, dtype=dtype) + 0.1 * torch.rand(ndim, ndim, dtype=dtype) + expected_x = torch.randn(num_particles, ndim, dtype=dtype) + b = torch.einsum("ij,nj->ni", A, expected_x) + assert b.shape == (num_particles, ndim) + + batch_solve = torch.vmap( + functools.partial( + conjugate_gradient_solve, + lambda v: A @ v, + cg_iters=cg_iters, + residual_tol=residual_tol, + ), + ) + actual_x = batch_solve(b) + + assert torch.all(torch.sum((actual_x - expected_x) ** 2, dim=1) < 1e-4) + + class SimpleModel(pyro.nn.PyroModule): def forward(self): a = pyro.sample("a", dist.Normal(0, 1)) @@ -36,7 +78,7 @@ def forward(self): return {"a": a, "b": b} -TEST_CASES: List[Tuple[Callable, Callable, Set[str], Optional[int]]] = [ +MODEL_TEST_CASES: List[Tuple[Callable, Callable, Set[str], Optional[int]]] = [ (SimpleModel(), SimpleGuide(), {"y"}, 1), (SimpleModel(), SimpleGuide(), {"y"}, None), pytest.param( @@ -60,7 +102,7 @@ def forward(self): ] -@pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", TEST_CASES) +@pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", MODEL_TEST_CASES) @pytest.mark.parametrize("num_samples_outer,num_samples_inner", [(10, None), (10, 100)]) @pytest.mark.parametrize("cg_iters", [None, 1, 10]) def test_nmc_param_influence_smoke( @@ -99,7 +141,7 @@ def test_nmc_param_influence_smoke( assert not torch.isclose(v, torch.zeros_like(v)).all(), f"eif for {k} was zero" -@pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", TEST_CASES) +@pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", MODEL_TEST_CASES) @pytest.mark.parametrize( "num_samples_outer,num_samples_inner", [(100, None), (10, 100)] ) diff --git a/tests/robust/test_ops.py b/tests/robust/test_ops.py index d1dbb6616..6a5e169a8 100644 --- a/tests/robust/test_ops.py +++ b/tests/robust/test_ops.py @@ -38,7 +38,7 @@ def forward(self): return {"a": a, "b": b} -TEST_CASES: List[Tuple[Callable, Callable, Set[str], Optional[int]]] = [ +MODEL_TEST_CASES: List[Tuple[Callable, Callable, Set[str], Optional[int]]] = [ (SimpleModel(), SimpleGuide(), {"y"}, 1), (SimpleModel(), SimpleGuide(), {"y"}, None), pytest.param( @@ -62,7 +62,7 @@ def forward(self): ] -@pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", TEST_CASES) +@pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", MODEL_TEST_CASES) @pytest.mark.parametrize( "num_samples_outer,num_samples_inner", [(100, None), (10, 100)] ) @@ -108,7 +108,7 @@ def test_nmc_predictive_influence_smoke( assert not torch.isclose(v, torch.zeros_like(v)).all(), f"eif for {k} was zero" -@pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", TEST_CASES) +@pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", MODEL_TEST_CASES) @pytest.mark.parametrize( "num_samples_outer,num_samples_inner", [(100, None), (10, 100)] ) From 720661f394c810e3bfb5a623acab561110ce9b19 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 22 Nov 2023 17:59:32 -0500 Subject: [PATCH 32/66] remove failing test case --- tests/robust/test_internals_linearize.py | 9 --------- tests/robust/test_ops.py | 9 --------- 2 files changed, 18 deletions(-) diff --git a/tests/robust/test_internals_linearize.py b/tests/robust/test_internals_linearize.py index 49bd6daff..9b3ad9afb 100644 --- a/tests/robust/test_internals_linearize.py +++ b/tests/robust/test_internals_linearize.py @@ -90,15 +90,6 @@ def forward(self): reason="torch.func autograd doesnt work with PyroParam" ), ), - pytest.param( - (m := SimpleModel()), - pyro.infer.autoguide.AutoDelta(m), - {"y"}, - 1, - marks=pytest.mark.xfail( - reason="torch.func autograd doesnt work with PyroParam" - ), - ), ] diff --git a/tests/robust/test_ops.py b/tests/robust/test_ops.py index 6a5e169a8..e4acf0950 100644 --- a/tests/robust/test_ops.py +++ b/tests/robust/test_ops.py @@ -50,15 +50,6 @@ def forward(self): reason="torch.func autograd doesnt work with PyroParam" ), ), - pytest.param( - (m := SimpleModel()), - pyro.infer.autoguide.AutoDelta(m), - {"y"}, - 1, - marks=pytest.mark.xfail( - reason="torch.func autograd doesnt work with PyroParam" - ), - ), ] From 91833dac0f2c19678988a2c73468e32e69a41bf3 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 22 Nov 2023 18:03:02 -0500 Subject: [PATCH 33/66] format --- tests/robust/test_internals_linearize.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/robust/test_internals_linearize.py b/tests/robust/test_internals_linearize.py index 9b3ad9afb..e30ca604c 100644 --- a/tests/robust/test_internals_linearize.py +++ b/tests/robust/test_internals_linearize.py @@ -133,9 +133,7 @@ def test_nmc_param_influence_smoke( @pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", MODEL_TEST_CASES) -@pytest.mark.parametrize( - "num_samples_outer,num_samples_inner", [(100, None), (10, 100)] -) +@pytest.mark.parametrize("num_samples_outer,num_samples_inner", [(10, None), (10, 100)]) @pytest.mark.parametrize("cg_iters", [None, 1, 10]) def test_nmc_param_influence_vmap_smoke( model, From 548069ad0c4a1cc755b8572ec5dd1cc8db659352 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 22 Nov 2023 18:04:39 -0500 Subject: [PATCH 34/66] move paramdict up --- chirho/robust/internals/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/chirho/robust/internals/utils.py b/chirho/robust/internals/utils.py index 7de1c3b91..5d9a8c5d7 100644 --- a/chirho/robust/internals/utils.py +++ b/chirho/robust/internals/utils.py @@ -9,6 +9,8 @@ S = TypeVar("S") T = TypeVar("T") +ParamDict = Mapping[str, torch.Tensor] + @functools.singledispatch def make_flatten_unflatten( @@ -61,9 +63,6 @@ def unflatten(x: torch.Tensor) -> Dict[str, torch.Tensor]: return flatten, unflatten -ParamDict = Mapping[str, torch.Tensor] - - def make_functional_call( mod: Callable[P, T] ) -> Tuple[ParamDict, Callable[Concatenate[ParamDict, P], T]]: From 12b22c09b47db69f36e663f59f74da19b05227a3 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 22 Nov 2023 18:09:13 -0500 Subject: [PATCH 35/66] remove obsolete test files --- chirho/robust/internals/predictive.py | 4 +- tests/robust/test_dice_correction.py | 141 -------------------------- tests/robust/test_internals.py | 40 -------- 3 files changed, 2 insertions(+), 183 deletions(-) delete mode 100644 tests/robust/test_dice_correction.py delete mode 100644 tests/robust/test_internals.py diff --git a/chirho/robust/internals/predictive.py b/chirho/robust/internals/predictive.py index d7bd4a4db..79fd16c9e 100644 --- a/chirho/robust/internals/predictive.py +++ b/chirho/robust/internals/predictive.py @@ -18,7 +18,7 @@ T = TypeVar("T") -class UnmaskNamedSites(DependentMaskMessenger): +class _UnmaskNamedSites(DependentMaskMessenger): names: Container[str] def __init__(self, names: Container[str]): @@ -124,7 +124,7 @@ def forward( ) masked_guide = pyro.poutine.mask(mask=False)(self.guide) - masked_model = UnmaskNamedSites(names=set(data.keys()))( + masked_model = _UnmaskNamedSites(names=set(data.keys()))( condition(data=data)(self.model) ) log_weights = pyro.infer.importance.vectorized_importance_weights( diff --git a/tests/robust/test_dice_correction.py b/tests/robust/test_dice_correction.py deleted file mode 100644 index 77c5b17b8..000000000 --- a/tests/robust/test_dice_correction.py +++ /dev/null @@ -1,141 +0,0 @@ -import collections -import math -from functools import partial -from typing import Callable, Dict, List, Optional - -import pyro -import pyro.distributions as dist -import torch - -from chirho.robust import one_step_correction -from chirho.robust.functionals import average_treatment_effect, dice_correction -from chirho.robust.internals.utils import _flatten_dict, _unflatten_dict - - -class HighDimLinearModel(pyro.nn.PyroModule): - def __init__( - self, - p: int, - link_fn: Callable[..., dist.Distribution] = lambda mu: dist.Normal(mu, 1.0), - prior_scale: float = None, - ): - super().__init__() - self.p = p - self.link_fn = link_fn - if prior_scale is None: - self.prior_scale = 1 / math.sqrt(self.p) - else: - self.prior_scale = prior_scale - - def sample_outcome_weights(self): - return pyro.sample( - "outcome_weights", - dist.Normal(0.0, self.prior_scale).expand((self.p,)).to_event(1), - ) - - def sample_intercept(self): - return pyro.sample("intercept", dist.Normal(0.0, 1.0)) - - def sample_propensity_weights(self): - return pyro.sample( - "propensity_weights", - dist.Normal(0.0, self.prior_scale).expand((self.p,)).to_event(1), - ) - - def sample_treatment_weight(self): - return pyro.sample("treatment_weight", dist.Normal(0.0, 1.0)) - - def sample_covariate_loc_scale(self): - loc = pyro.sample( - "covariate_loc", dist.Normal(0.0, 1.0).expand((self.p,)).to_event(1) - ) - scale = pyro.sample( - "covariate_scale", dist.LogNormal(0, 1).expand((self.p,)).to_event(1) - ) - return loc, scale - - def forward(self, N: int): - intercept = self.sample_intercept() - outcome_weights = self.sample_outcome_weights() - propensity_weights = self.sample_propensity_weights() - tau = self.sample_treatment_weight() - x_loc, x_scale = self.sample_covariate_loc_scale() - with pyro.plate("obs", N, dim=-1): - X = pyro.sample("X", dist.Normal(x_loc, x_scale).to_event(1)) - A = pyro.sample( - "A", - dist.Bernoulli( - logits=torch.einsum("...np,...p->...n", X, propensity_weights) - ).mask(False), - ) - return pyro.sample( - "Y", - self.link_fn( - torch.einsum("...np,...p->...n", X, outcome_weights) - + A * tau - + intercept - ), - ) - - -# Internal structure of the model (that's given) -# and then outer monte carlo samples - - -class KnownCovariateDistModel(HighDimLinearModel): - def sample_covariate_loc_scale(self): - return torch.zeros(self.p), torch.ones(self.p) - - -class FakeNormal(dist.Normal): - has_rsample = False - - -def test_bernoulli_model(): - p = 1 - n_monte_carlo_outer = 1000 - avg_plug_in_grads = torch.zeros(4) - for _ in range(n_monte_carlo_outer): - n_monte_carlo_inner = 100 - target_functional = partial( - dice_correction(average_treatment_effect), n_monte_carlo=n_monte_carlo_inner - ) - # bernoulli_link = lambda mu: dist.Bernoulli(logits=mu) - link = lambda mu: FakeNormal(mu, 1.0) - # link = lambda mu: dist.Normal(mu, 1.0) - model = KnownCovariateDistModel(p, link) - theta_hat = { - "intercept": torch.tensor(0.0).requires_grad_(True), - "outcome_weights": torch.tensor([1.0]).requires_grad_(True), - "propensity_weights": torch.tensor([1.0]).requires_grad_(True), - "treatment_weight": torch.tensor(1.0).requires_grad_(True), - } - - # Canonical ordering of parameters when flattening and unflattening - theta_hat = collections.OrderedDict( - (k, theta_hat[k]) for k in sorted(theta_hat.keys()) - ) - flat_theta = _flatten_dict(theta_hat) - - # Compute gradient of plug-in functional - plug_in = target_functional(model, theta_hat) - plug_in += ( - 0 * flat_theta.sum() - ) # hack for full gradient (maintain flattened shape) - - avg_plug_in_grads += ( - _flatten_dict( - collections.OrderedDict( - zip( - theta_hat.keys(), - torch.autograd.grad(plug_in, theta_hat.values()), - ) - ) - ) - / n_monte_carlo_outer - ) - - correct_grad = torch.tensor([0, 0, 0, 1.0]) - # assert (avg_plug_in_grads - correct_grad).abs().sum() < 1 / torch.sqrt( - # torch.tensor(n_monte_carlo_outer) - # ) diff --git a/tests/robust/test_internals.py b/tests/robust/test_internals.py deleted file mode 100644 index 30725429e..000000000 --- a/tests/robust/test_internals.py +++ /dev/null @@ -1,40 +0,0 @@ -import pyro -import pyro.distributions as dist -from pyro.infer import Predictive - -from chirho.robust.internals import * - - -class SimpleModel(pyro.nn.PyroModule): - def forward(self): - a = pyro.sample("a", dist.Normal(0, 1)) - b = pyro.sample("b", dist.Normal(0, 1)) - return pyro.sample("y", dist.Normal(a + b, 1)) - - -class SimpleGuide(torch.nn.Module): - def __init__(self): - super().__init__() - self.a_loc = torch.nn.Parameter(torch.tensor(0.0)) - self.b_loc = torch.nn.Parameter(torch.tensor(0.0)) - - def forward(self): - pyro.sample("a", dist.Delta(self.a_loc)) - pyro.sample("b", dist.Delta(self.b_loc)) - - -def test_nmc_log_likelihood(): - model = SimpleModel() - guide = SimpleGuide() - num_monte_carlo_outer = 100 - data = Predictive( - model, guide=guide, num_samples=num_monte_carlo_outer, return_sites=["y"] - )() - nmc_ll = NMCLogLikelihood(model, guide, num_samples=100) - ll_at_data = nmc_ll(data) - print(ll_at_data) - - nmc_ll_single = NMCLogLikelihoodSingle(model, guide, num_samples=10) - nmc_ll_single._vectorized_log_prob({"y": torch.tensor(1.0)}) - nmc_ll({"y": torch.tensor([1.0])}) - ll_at_data_single = nmc_ll_single.vectorized_log_prob(data) From 3b72bb01cd27ae03be8fbc963873b89dad84ce9b Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 22 Nov 2023 18:39:46 -0500 Subject: [PATCH 36/66] add empty handlers --- chirho/robust/handlers/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 chirho/robust/handlers/__init__.py diff --git a/chirho/robust/handlers/__init__.py b/chirho/robust/handlers/__init__.py new file mode 100644 index 000000000..e69de29bb From 89d9f6b773ebac38a4e5b50effac2b7e69e29cd0 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 22 Nov 2023 18:42:07 -0500 Subject: [PATCH 37/66] add chirho.robust to docs --- docs/source/index.rst | 1 + docs/source/robust.rst | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 docs/source/robust.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 0eb51ab2f..43a3c51e4 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -41,6 +41,7 @@ Table of Contents observational indexed dynamical + robust .. toctree:: :maxdepth: 2 diff --git a/docs/source/robust.rst b/docs/source/robust.rst new file mode 100644 index 000000000..38ed8dc0e --- /dev/null +++ b/docs/source/robust.rst @@ -0,0 +1,39 @@ +Robust +====== + +.. automodule:: chirho.robust + :members: + :undoc-members: + +Operations +---------- + +.. automodule:: chirho.robust.ops + :members: + :undoc-members: + +Handlers +-------- + +.. automodule:: chirho.robust.handlers + :members: + :undoc-members: + +Internals +--------- + +.. automodule:: chirho.robust.internals + :members: + :undoc-members: + +.. automodule:: chirho.robust.internals.linearize + :members: + :undoc-members: + +.. automodule:: chirho.robust.internals.predictive + :members: + :undoc-members: + +.. automodule:: chirho.robust.internals.utils + :members: + :undoc-members: From 7582c221a648387db38fcc4e7977f88d84cbbf5f Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 27 Nov 2023 09:58:01 -0500 Subject: [PATCH 38/66] fix memory leak in tests --- tests/robust/test_internals_linearize.py | 20 +++++++++++++----- tests/robust/test_ops.py | 27 ++++++++++++++---------- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/tests/robust/test_internals_linearize.py b/tests/robust/test_internals_linearize.py index e30ca604c..0ef9a0b78 100644 --- a/tests/robust/test_internals_linearize.py +++ b/tests/robust/test_internals_linearize.py @@ -78,12 +78,16 @@ def forward(self): return {"a": a, "b": b} -MODEL_TEST_CASES: List[Tuple[Callable, Callable, Set[str], Optional[int]]] = [ - (SimpleModel(), SimpleGuide(), {"y"}, 1), - (SimpleModel(), SimpleGuide(), {"y"}, None), +ModelTestCase = Tuple[ + Callable[[], Callable], Callable[[Callable], Callable], Set[str], Optional[int] +] + +MODEL_TEST_CASES: List[ModelTestCase] = [ + (SimpleModel, lambda _: SimpleGuide(), {"y"}, 1), + (SimpleModel, lambda _: SimpleGuide(), {"y"}, None), pytest.param( - (m := SimpleModel()), - pyro.infer.autoguide.AutoNormal(m), + SimpleModel, + pyro.infer.autoguide.AutoNormal, {"y"}, 1, marks=pytest.mark.xfail( @@ -105,6 +109,9 @@ def test_nmc_param_influence_smoke( num_samples_inner, cg_iters, ): + model = model() + guide = guide(model) + model(), guide() # initialize param_eif = linearize( @@ -144,6 +151,9 @@ def test_nmc_param_influence_vmap_smoke( num_samples_inner, cg_iters, ): + model = model() + guide = guide(model) + model(), guide() # initialize param_eif = linearize( diff --git a/tests/robust/test_ops.py b/tests/robust/test_ops.py index e4acf0950..6304a49f5 100644 --- a/tests/robust/test_ops.py +++ b/tests/robust/test_ops.py @@ -38,12 +38,16 @@ def forward(self): return {"a": a, "b": b} -MODEL_TEST_CASES: List[Tuple[Callable, Callable, Set[str], Optional[int]]] = [ - (SimpleModel(), SimpleGuide(), {"y"}, 1), - (SimpleModel(), SimpleGuide(), {"y"}, None), +ModelTestCase = Tuple[ + Callable[[], Callable], Callable[[Callable], Callable], Set[str], Optional[int] +] + +MODEL_TEST_CASES: List[ModelTestCase] = [ + (SimpleModel, lambda _: SimpleGuide(), {"y"}, 1), + (SimpleModel, lambda _: SimpleGuide(), {"y"}, None), pytest.param( - (m := SimpleModel()), - pyro.infer.autoguide.AutoNormal(m), + SimpleModel, + pyro.infer.autoguide.AutoNormal, {"y"}, 1, marks=pytest.mark.xfail( @@ -54,9 +58,7 @@ def forward(self): @pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", MODEL_TEST_CASES) -@pytest.mark.parametrize( - "num_samples_outer,num_samples_inner", [(100, None), (10, 100)] -) +@pytest.mark.parametrize("num_samples_outer,num_samples_inner", [(10, None), (10, 100)]) @pytest.mark.parametrize("cg_iters", [None, 1, 10]) @pytest.mark.parametrize("num_predictive_samples", [1, 5]) def test_nmc_predictive_influence_smoke( @@ -69,6 +71,8 @@ def test_nmc_predictive_influence_smoke( cg_iters, num_predictive_samples, ): + model = model() + guide = guide(model) model(), guide() # initialize predictive_eif = influence_fn( @@ -100,9 +104,7 @@ def test_nmc_predictive_influence_smoke( @pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", MODEL_TEST_CASES) -@pytest.mark.parametrize( - "num_samples_outer,num_samples_inner", [(100, None), (10, 100)] -) +@pytest.mark.parametrize("num_samples_outer,num_samples_inner", [(10, None), (10, 100)]) @pytest.mark.parametrize("cg_iters", [None, 1, 10]) @pytest.mark.parametrize("num_predictive_samples", [1, 5]) def test_nmc_predictive_influence_vmap_smoke( @@ -115,6 +117,9 @@ def test_nmc_predictive_influence_vmap_smoke( cg_iters, num_predictive_samples, ): + model = model() + guide = guide(model) + model(), guide() # initialize predictive_eif = influence_fn( From 82c23e8b89cd87db68cb00519f373afc8410d575 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 27 Nov 2023 10:08:58 -0500 Subject: [PATCH 39/66] make typing compatible with python 3.8 --- chirho/robust/internals/linearize.py | 6 ++++-- chirho/robust/internals/utils.py | 3 ++- chirho/robust/ops.py | 5 +++-- tests/robust/test_internals_linearize.py | 3 ++- tests/robust/test_ops.py | 3 ++- 5 files changed, 13 insertions(+), 7 deletions(-) diff --git a/chirho/robust/internals/linearize.py b/chirho/robust/internals/linearize.py index 67b20a9d3..2d9615ef1 100644 --- a/chirho/robust/internals/linearize.py +++ b/chirho/robust/internals/linearize.py @@ -1,8 +1,9 @@ import functools -from typing import Any, Callable, Concatenate, Optional, ParamSpec, TypeVar +from typing import Any, Callable, Optional, TypeVar import pyro import torch +from typing_extensions import Concatenate, ParamSpec from chirho.robust.internals.predictive import NMCLogPredictiveLikelihood from chirho.robust.internals.utils import ( @@ -37,7 +38,8 @@ def _flat_conjugate_gradient_solve( Returns: torch.Tensor: Solution x* for equation Ax = b. - Notes: This code is adapted from https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py + Notes: This code is adapted from + https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py """ if cg_iters is None: cg_iters = b.numel() diff --git a/chirho/robust/internals/utils.py b/chirho/robust/internals/utils.py index 5d9a8c5d7..fe0bcf77b 100644 --- a/chirho/robust/internals/utils.py +++ b/chirho/robust/internals/utils.py @@ -1,8 +1,9 @@ import functools -from typing import Any, Callable, Concatenate, Dict, Mapping, ParamSpec, Tuple, TypeVar +from typing import Any, Callable, Dict, Mapping, Tuple, TypeVar import pyro import torch +from typing_extensions import Concatenate, ParamSpec P = ParamSpec("P") Q = ParamSpec("Q") diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py index b03ba9935..2d6d95793 100644 --- a/chirho/robust/ops.py +++ b/chirho/robust/ops.py @@ -1,7 +1,8 @@ import functools -from typing import Any, Callable, Concatenate, Mapping, Optional, ParamSpec, TypeVar +from typing import Any, Callable, Mapping, Optional, TypeVar import torch +from typing_extensions import Concatenate, ParamSpec from chirho.observational.ops import Observation @@ -24,7 +25,7 @@ def influence_fn( from chirho.robust.internals.predictive import PredictiveFunctional from chirho.robust.internals.utils import make_functional_call - linearized = linearize(model, guide, **linearize_kwargs) + linearized: Callable = linearize(model, guide, **linearize_kwargs) if functional is None: assert isinstance(model, torch.nn.Module) diff --git a/tests/robust/test_internals_linearize.py b/tests/robust/test_internals_linearize.py index 0ef9a0b78..05d9bddea 100644 --- a/tests/robust/test_internals_linearize.py +++ b/tests/robust/test_internals_linearize.py @@ -1,10 +1,11 @@ import functools -from typing import Callable, List, Mapping, Optional, ParamSpec, Set, Tuple, TypeVar +from typing import Callable, List, Mapping, Optional, Set, Tuple, TypeVar import pyro import pyro.distributions as dist import pytest import torch +from typing_extensions import ParamSpec from chirho.robust.internals.linearize import conjugate_gradient_solve, linearize diff --git a/tests/robust/test_ops.py b/tests/robust/test_ops.py index 6304a49f5..48081cd1f 100644 --- a/tests/robust/test_ops.py +++ b/tests/robust/test_ops.py @@ -1,10 +1,11 @@ import functools -from typing import Callable, List, Mapping, Optional, ParamSpec, Set, Tuple, TypeVar +from typing import Callable, List, Mapping, Optional, Set, Tuple, TypeVar import pyro import pyro.distributions as dist import pytest import torch +from typing_extensions import ParamSpec from chirho.robust.internals.predictive import PredictiveFunctional from chirho.robust.ops import influence_fn From e08d9d61f859ba28fb5ff3b63b8bfaccbec37711 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 27 Nov 2023 10:09:59 -0500 Subject: [PATCH 40/66] typing_extensions --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index a982d3158..fc7079309 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,7 @@ "pytorch-lightning", "scikit-image", "tensorboard", + "typing_extensions", ] DYNAMICAL_REQUIRE = ["torchdiffeq"] From 22eae09effe641ca722b612799971405d23443c5 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 27 Nov 2023 10:12:18 -0500 Subject: [PATCH 41/66] add branch to ci --- .github/workflows/lint.yml | 2 +- .github/workflows/test.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 39640f3b5..126b74bca 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -4,7 +4,7 @@ on: push: branches: [ master ] pull_request: - branches: [ master ] + branches: [ master, staging-robust ] jobs: build: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6c01b89df..578ab04c4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -4,7 +4,7 @@ on: push: branches: [ master ] pull_request: - branches: [ master ] + branches: [ master, staging-robust ] jobs: build: From d0014db513d4ee6094d37fc91eab051f60369647 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 27 Nov 2023 10:13:57 -0500 Subject: [PATCH 42/66] predictive --- chirho/robust/internals/predictive.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/chirho/robust/internals/predictive.py b/chirho/robust/internals/predictive.py index 79fd16c9e..ba529edf7 100644 --- a/chirho/robust/internals/predictive.py +++ b/chirho/robust/internals/predictive.py @@ -1,9 +1,10 @@ import contextlib import math -from typing import Any, Callable, Container, Generic, Optional, ParamSpec, TypeVar +from typing import Any, Callable, Container, Generic, Optional, TypeVar import pyro import torch +from typing_extensions import ParamSpec from chirho.indexed.handlers import DependentMaskMessenger from chirho.observational.handlers import condition From e5342dc75943ccf90fe3226c8d839ba876d89b36 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 27 Nov 2023 10:33:24 -0500 Subject: [PATCH 43/66] remove imprecise annotation --- chirho/robust/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py index 2d6d95793..d3ea8259a 100644 --- a/chirho/robust/ops.py +++ b/chirho/robust/ops.py @@ -25,7 +25,7 @@ def influence_fn( from chirho.robust.internals.predictive import PredictiveFunctional from chirho.robust.internals.utils import make_functional_call - linearized: Callable = linearize(model, guide, **linearize_kwargs) + linearized = linearize(model, guide, **linearize_kwargs) if functional is None: assert isinstance(model, torch.nn.Module) From c5fe64b24b1d9549ac4f6af5757bb2c61d84ac9f Mon Sep 17 00:00:00 2001 From: Raj Agrawal Date: Wed, 6 Dec 2023 12:54:58 -0800 Subject: [PATCH 44/66] Added more tests for `linearize` and `make_empirical_fisher_vp` (#405) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * removed missing import * fixed failing test with seeding * addressing Eli's comments --- tests/robust/__init__.py | 0 tests/robust/robust_fixtures.py | 223 ++++++++++++++++++++ tests/robust/test_internals_compositions.py | 61 ++++++ tests/robust/test_internals_linearize.py | 212 +++++++++++++++++-- tests/robust/test_ops.py | 24 +-- 5 files changed, 476 insertions(+), 44 deletions(-) create mode 100644 tests/robust/__init__.py create mode 100644 tests/robust/robust_fixtures.py create mode 100644 tests/robust/test_internals_compositions.py diff --git a/tests/robust/__init__.py b/tests/robust/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/robust/robust_fixtures.py b/tests/robust/robust_fixtures.py new file mode 100644 index 000000000..a74c94ebc --- /dev/null +++ b/tests/robust/robust_fixtures.py @@ -0,0 +1,223 @@ +import math +from typing import Callable, Optional, Tuple, TypedDict, TypeVar + +import pyro +import pyro.distributions as dist +import torch +from pyro.nn import PyroModule + +from chirho.observational.handlers import condition +from chirho.robust.internals.utils import ParamDict +from chirho.robust.ops import Point + +pyro.settings.set(module_local_params=True) +T = TypeVar("T") + + +class SimpleModel(PyroModule): + def forward(self): + a = pyro.sample("a", dist.Normal(0, 1)) + with pyro.plate("data", 3, dim=-1): + b = pyro.sample("b", dist.Normal(a, 1)) + return pyro.sample("y", dist.Normal(a + b, 1)) + + +class SimpleGuide(torch.nn.Module): + def __init__(self): + super().__init__() + self.loc_a = torch.nn.Parameter(torch.rand(())) + self.loc_b = torch.nn.Parameter(torch.rand((3,))) + + def forward(self): + a = pyro.sample("a", dist.Normal(self.loc_a, 1)) + with pyro.plate("data", 3, dim=-1): + b = pyro.sample("b", dist.Normal(self.loc_b, 1)) + return {"a": a, "b": b} + + +class GaussianModel(PyroModule): + def __init__(self, cov_mat: torch.Tensor): + super().__init__() + self.register_buffer("cov_mat", cov_mat) + + def forward(self, loc): + pyro.sample( + "x", dist.MultivariateNormal(loc=loc, covariance_matrix=self.cov_mat) + ) + + +# Note: `gaussian_log_prob` is separate from the GaussianModel above because of upstream obstacles +# in the interaction between `pyro.nn.PyroModule` and `torch.func`. +# See https://github.com/BasisResearch/chirho/issues/393 +def gaussian_log_prob(params: ParamDict, data_point: Point[T], cov_mat) -> T: + with pyro.validation_enabled(False): + return dist.MultivariateNormal( + loc=params["loc"], covariance_matrix=cov_mat + ).log_prob(data_point["x"]) + + +class DataConditionedModel(PyroModule): + r""" + Helper class for conditioning on data. + """ + + def __init__(self, model: PyroModule): + super().__init__() + self.model = model + + def forward(self, D: Point[torch.Tensor]): + with condition(data=D): + # Assume first dimension corresponds to # of datapoints + N = D[next(iter(D))].shape[0] + return self.model.forward(N=N) + + +class HighDimLinearModel(pyro.nn.PyroModule): + def __init__( + self, + p: int, + link_fn: Callable[..., dist.Distribution] = lambda mu: dist.Normal(mu, 1.0), + prior_scale: Optional[float] = None, + ): + super().__init__() + self.p = p + self.link_fn = link_fn + if prior_scale is None: + self.prior_scale = 1 / math.sqrt(self.p) + else: + self.prior_scale = prior_scale + + def sample_outcome_weights(self): + return pyro.sample( + "outcome_weights", + dist.Normal(0.0, self.prior_scale).expand((self.p,)).to_event(1), + ) + + def sample_intercept(self): + return pyro.sample("intercept", dist.Normal(0.0, 1.0)) + + def sample_propensity_weights(self): + return pyro.sample( + "propensity_weights", + dist.Normal(0.0, self.prior_scale).expand((self.p,)).to_event(1), + ) + + def sample_treatment_weight(self): + return pyro.sample("treatment_weight", dist.Normal(0.0, 1.0)) + + def sample_covariate_loc_scale(self): + loc = pyro.sample( + "covariate_loc", dist.Normal(0.0, 1.0).expand((self.p,)).to_event(1) + ) + scale = pyro.sample( + "covariate_scale", dist.LogNormal(0, 1).expand((self.p,)).to_event(1) + ) + return loc, scale + + def forward(self, N: int = 1): + intercept = self.sample_intercept() + outcome_weights = self.sample_outcome_weights() + propensity_weights = self.sample_propensity_weights() + tau = self.sample_treatment_weight() + x_loc, x_scale = self.sample_covariate_loc_scale() + with pyro.plate("obs", N, dim=-1): + X = pyro.sample("X", dist.Normal(x_loc, x_scale).to_event(1)) + A = pyro.sample( + "A", + dist.Bernoulli( + logits=torch.einsum("...np,...p->...n", X, propensity_weights) + ), + ) + return pyro.sample( + "Y", + self.link_fn( + torch.einsum("...np,...p->...n", X, outcome_weights) + + A * tau + + intercept + ), + ) + + +class KnownCovariateDistModel(HighDimLinearModel): + def sample_covariate_loc_scale(self): + return torch.zeros(self.p), torch.ones(self.p) + + +class BenchmarkLinearModel(HighDimLinearModel): + def __init__( + self, + p: int, + link_fn: Callable[..., dist.Distribution], + alpha: int, + beta: int, + treatment_weight: float = 0.0, + ): + super().__init__(p, link_fn) + self.alpha = alpha # sparsity of propensity weights + self.beta = beta # sparisty of outcome weights + self.treatment_weight = treatment_weight + + def sample_outcome_weights(self): + outcome_weights = 1 / math.sqrt(self.beta) * torch.ones(self.p) + outcome_weights[self.beta :] = 0.0 + return outcome_weights + + def sample_treatment_null_weight(self): + return torch.tensor(0.0) + + def sample_propensity_weights(self): + propensity_weights = 1 / math.sqrt(self.alpha) * torch.ones(self.p) + propensity_weights[self.alpha :] = 0.0 + return propensity_weights + + def sample_treatment_weight(self): + return torch.tensor(self.treatment_weight) + + def sample_intercept(self): + return torch.tensor(0.0) + + def sample_covariate_loc_scale(self): + return torch.zeros(self.p), torch.ones(self.p) + + +class MLEGuide(torch.nn.Module): + def __init__(self, mle_est: ParamDict): + super().__init__() + self.names = list(mle_est.keys()) + for name, value in mle_est.items(): + setattr(self, name + "_param", torch.nn.Parameter(value)) + + def forward(self, *args, **kwargs): + for name in self.names: + value = getattr(self, name + "_param") + pyro.sample(name, dist.Delta(value)) + + +class ATETestPoint(TypedDict): + X: torch.Tensor + A: torch.Tensor + Y: torch.Tensor + + +class ATEParamDict(TypedDict): + propensity_weights: torch.Tensor + outcome_weights: torch.Tensor + treatment_weight: torch.Tensor + intercept: torch.Tensor + + +def closed_form_ate_correction( + X_test: ATETestPoint, theta: ATEParamDict +) -> Tuple[torch.Tensor, torch.Tensor]: + X = X_test["X"] + A = X_test["A"] + Y = X_test["Y"] + pi_X = torch.sigmoid(X.mv(theta["propensity_weights"])) + mu_X = ( + X.mv(theta["outcome_weights"]) + + A * theta["treatment_weight"] + + theta["intercept"] + ) + analytic_eif_at_test_pts = (A / pi_X - (1 - A) / (1 - pi_X)) * (Y - mu_X) + analytic_correction = analytic_eif_at_test_pts.mean() + return analytic_correction, analytic_eif_at_test_pts diff --git a/tests/robust/test_internals_compositions.py b/tests/robust/test_internals_compositions.py new file mode 100644 index 000000000..a559560a9 --- /dev/null +++ b/tests/robust/test_internals_compositions.py @@ -0,0 +1,61 @@ +import functools +import warnings + +import pyro +import torch +from pyro.poutine.seed_messenger import SeedMessenger + +from chirho.robust.internals.linearize import ( + conjugate_gradient_solve, + make_empirical_fisher_vp, +) +from chirho.robust.internals.predictive import NMCLogPredictiveLikelihood +from chirho.robust.internals.utils import make_functional_call + +from .robust_fixtures import SimpleGuide, SimpleModel + +pyro.settings.set(module_local_params=True) + + +def test_empirical_fisher_vp_nmclikelihood_cg_composition(): + model = SimpleModel() + guide = SimpleGuide() + model(), guide() # initialize + log_prob = NMCLogPredictiveLikelihood(model, guide, num_samples=100) + log_prob_params, func_log_prob = make_functional_call(log_prob) + func_log_prob = SeedMessenger(123)(func_log_prob) + + predictive = pyro.infer.Predictive( + model, guide=guide, num_samples=1000, parallel=True, return_sites=["y"] + ) + predictive_params, func_predictive = make_functional_call(predictive) + + cg_solver = functools.partial(conjugate_gradient_solve, cg_iters=10) + + with torch.no_grad(): + data = func_predictive(predictive_params) + fvp = make_empirical_fisher_vp(func_log_prob, log_prob_params, data) + + v = {k: torch.ones_like(v) for k, v in log_prob_params.items()} + + assert fvp(v)["guide.loc_a"].abs().max() > 0 # sanity check for non-zero fvp + + solve_one = cg_solver(fvp, v) + solve_two = cg_solver(fvp, v) + + if solve_one["guide.loc_a"].abs().max() > 1e6: + warnings.warn( + "solve_one['guide.loc_a'] is large (max entry={}).".format( + solve_one["guide.loc_a"].abs().max() + ) + ) + + if solve_one["guide.loc_b"].abs().max() > 1e6: + warnings.warn( + "solve_one['guide.loc_b'] is large (max entry={}).".format( + solve_one["guide.loc_b"].abs().max() + ) + ) + + assert torch.allclose(solve_one["guide.loc_a"], solve_two["guide.loc_a"]) + assert torch.allclose(solve_one["guide.loc_b"], solve_two["guide.loc_b"]) diff --git a/tests/robust/test_internals_linearize.py b/tests/robust/test_internals_linearize.py index 05d9bddea..cb4e13af7 100644 --- a/tests/robust/test_internals_linearize.py +++ b/tests/robust/test_internals_linearize.py @@ -5,9 +5,26 @@ import pyro.distributions as dist import pytest import torch +from pyro.infer.predictive import Predictive from typing_extensions import ParamSpec -from chirho.robust.internals.linearize import conjugate_gradient_solve, linearize +from chirho.robust.internals.linearize import ( + conjugate_gradient_solve, + linearize, + make_empirical_fisher_vp, +) + +from .robust_fixtures import ( + BenchmarkLinearModel, + DataConditionedModel, + GaussianModel, + KnownCovariateDistModel, + MLEGuide, + SimpleGuide, + SimpleModel, + closed_form_ate_correction, + gaussian_log_prob, +) pyro.settings.set(module_local_params=True) @@ -58,27 +75,6 @@ def test_batch_cg_solve(ndim: int, dtype: torch.dtype, num_particles: int): assert torch.all(torch.sum((actual_x - expected_x) ** 2, dim=1) < 1e-4) -class SimpleModel(pyro.nn.PyroModule): - def forward(self): - a = pyro.sample("a", dist.Normal(0, 1)) - with pyro.plate("data", 3, dim=-1): - b = pyro.sample("b", dist.Normal(a, 1)) - return pyro.sample("y", dist.Normal(a + b, 1)) - - -class SimpleGuide(torch.nn.Module): - def __init__(self): - super().__init__() - self.loc_a = torch.nn.Parameter(torch.rand(())) - self.loc_b = torch.nn.Parameter(torch.rand((3,))) - - def forward(self): - a = pyro.sample("a", dist.Normal(self.loc_a, 1)) - with pyro.plate("data", 3, dim=-1): - b = pyro.sample("b", dist.Normal(self.loc_b, 1)) - return {"a": a, "b": b} - - ModelTestCase = Tuple[ Callable[[], Callable], Callable[[Callable], Callable], Set[str], Optional[int] ] @@ -178,3 +174,175 @@ def test_nmc_param_influence_vmap_smoke( assert not torch.isnan(v).any(), f"eif for {k} had nans" assert not torch.isinf(v).any(), f"eif for {k} had infs" assert not torch.isclose(v, torch.zeros_like(v)).all(), f"eif for {k} was zero" + + +@pytest.mark.parametrize( + "loc", [torch.zeros(2, requires_grad=True), torch.ones(2, requires_grad=True)] +) +@pytest.mark.parametrize( + "cov_mat", + [ + torch.eye(2, requires_grad=False), + torch.tensor(torch.ones(2, 2) + torch.eye(2), requires_grad=False), + ], +) +@pytest.mark.parametrize( + "v", + [ + torch.tensor([1.0, 0.0], requires_grad=False), + torch.tensor([0.0, 1.0], requires_grad=False), + torch.tensor([1.0, 1.0], requires_grad=False), + torch.tensor([0.0, 0.0], requires_grad=False), + ], +) +def test_empirical_fisher_vp_against_analytical( + loc: torch.Tensor, cov_mat: torch.Tensor, v: torch.Tensor +): + func_log_prob = gaussian_log_prob + log_prob_params = {"loc": loc} + N_monte_carlo = 10000 + data = Predictive(GaussianModel(cov_mat), num_samples=N_monte_carlo)(loc) + empirical_fisher_vp_func = make_empirical_fisher_vp( + func_log_prob, log_prob_params, data, cov_mat=cov_mat + ) + + empirical_fisher_vp = empirical_fisher_vp_func({"loc": v})["loc"] + + prec_matrix = torch.linalg.inv(cov_mat) + true_vp = prec_matrix.mv(v) + + assert torch.all(torch.isclose(empirical_fisher_vp, true_vp, atol=0.1)) + + +@pytest.mark.parametrize( + "data_config", + [ + (torch.zeros(1, requires_grad=True), torch.eye(1)), + (torch.ones(2, requires_grad=True), torch.eye(2)), + ], +) +def test_fisher_vmap(data_config): + loc, cov_mat = data_config + func_log_prob = gaussian_log_prob + log_prob_params = {"loc": loc} + N_monte_carlo = 10000 + data = Predictive(GaussianModel(cov_mat), num_samples=N_monte_carlo)(loc) + empirical_fisher_vp_func = make_empirical_fisher_vp( + func_log_prob, log_prob_params, data, cov_mat=cov_mat + ) + v_single_one = torch.ones(cov_mat.shape[1]) + v_single_two = 0.4 * torch.ones(cov_mat.shape[1]) + v_batch = torch.stack([v_single_one, v_single_two], axis=0) + empirical_fisher_vp_func_batched = torch.func.vmap(empirical_fisher_vp_func) + + # Check if fisher vector product works on a single vector and a batch of vectors + single_one_out = empirical_fisher_vp_func({"loc": v_single_one}) + single_two_out = empirical_fisher_vp_func({"loc": v_single_two}) + batch_out = empirical_fisher_vp_func_batched({"loc": v_batch}) + + assert torch.allclose(batch_out["loc"][0], single_one_out["loc"]) + assert torch.allclose(batch_out["loc"][1], single_two_out["loc"]) + + with pytest.raises(RuntimeError): + # Fisher vector product should not work on a batch of vectors + empirical_fisher_vp_func({"loc": v_batch}) + with pytest.raises(RuntimeError): + # Batched Fisher vector product should not work on a single vector + empirical_fisher_vp_func_batched({"loc": v_single_one}) + + +@pytest.mark.parametrize( + "data_config", + [ + (torch.zeros(1, requires_grad=True), torch.eye(1)), + (torch.ones(2, requires_grad=True), torch.eye(2)), + ], +) +def test_fisher_grad_smoke(data_config): + loc, cov_mat = data_config + func_log_prob = gaussian_log_prob + log_prob_params = {"loc": loc} + N_monte_carlo = 10000 + data = Predictive(GaussianModel(cov_mat), num_samples=N_monte_carlo)(loc) + empirical_fisher_vp_func = make_empirical_fisher_vp( + func_log_prob, log_prob_params, data, cov_mat=cov_mat + ) + + v = 0.5 * torch.ones(cov_mat.shape[1], requires_grad=True) + + def f(x): + return empirical_fisher_vp_func({"loc": x})["loc"].sum() + + # Check using `torch.func.grad` + assert ( + torch.func.grad(f)(v).sum() != 0 + ), "Zero gradients but expected non-zero gradients" + + # Check using autograd + assert torch.autograd.gradcheck( + f, v, atol=0.2 + ), "Finite difference gradients do not match autograd gradients" + + +def test_linearize_against_analytic_ate(): + p = 1 + alpha = 1 + beta = 1 + N_train = 100 + N_test = 100 + + def link(mu): + return dist.Normal(mu, 1.0) + + # Generate data + benchmark_model = BenchmarkLinearModel(p, link, alpha, beta) + D_train = Predictive( + benchmark_model, num_samples=N_train, return_sites=["X", "A", "Y"] + )() + D_train = {k: v.squeeze(-1) for k, v in D_train.items()} + D_test = Predictive( + benchmark_model, num_samples=N_test, return_sites=["X", "A", "Y"] + )() + D_test_flat = {k: v.squeeze(-1) for k, v in D_test.items()} + + model = KnownCovariateDistModel(p, link) + conditioned_model = DataConditionedModel(model) + guide_train = pyro.infer.autoguide.AutoDelta(conditioned_model) + elbo = pyro.infer.Trace_ELBO()(conditioned_model, guide_train) + + # initialize parameters + elbo(D_train) + + adam = torch.optim.Adam(elbo.parameters(), lr=0.03) + + # Do gradient steps + for _ in range(500): + adam.zero_grad() + loss = elbo(D_train) + loss.backward() + adam.step() + + theta_hat = { + k: v.clone().detach().requires_grad_(True) for k, v in guide_train().items() + } + _, analytic_eif_at_test_pts = closed_form_ate_correction(D_test_flat, theta_hat) + + mle_guide = MLEGuide(theta_hat) + param_eif = linearize( + model, + mle_guide, + num_samples_outer=10000, + num_samples_inner=1, + cg_iters=4, # dimension of params = 4 + ) + + batch_param_eif = torch.vmap(param_eif, randomness="different") + test_data_eif = batch_param_eif(D_test) + median_abs_error = torch.abs( + test_data_eif["guide.treatment_weight_param"] - analytic_eif_at_test_pts + ).median() + median_scale = torch.abs(analytic_eif_at_test_pts).median() + if median_scale > 1: + assert median_abs_error / median_scale < 0.5 + else: + assert median_abs_error < 0.5 diff --git a/tests/robust/test_ops.py b/tests/robust/test_ops.py index 48081cd1f..7eeb99641 100644 --- a/tests/robust/test_ops.py +++ b/tests/robust/test_ops.py @@ -2,7 +2,6 @@ from typing import Callable, List, Mapping, Optional, Set, Tuple, TypeVar import pyro -import pyro.distributions as dist import pytest import torch from typing_extensions import ParamSpec @@ -10,6 +9,8 @@ from chirho.robust.internals.predictive import PredictiveFunctional from chirho.robust.ops import influence_fn +from .robust_fixtures import SimpleGuide, SimpleModel + pyro.settings.set(module_local_params=True) P = ParamSpec("P") @@ -18,27 +19,6 @@ T = TypeVar("T") -class SimpleModel(pyro.nn.PyroModule): - def forward(self): - a = pyro.sample("a", dist.Normal(0, 1)) - with pyro.plate("data", 3, dim=-1): - b = pyro.sample("b", dist.Normal(a, 1)) - return pyro.sample("y", dist.Normal(a + b, 1)) - - -class SimpleGuide(torch.nn.Module): - def __init__(self): - super().__init__() - self.loc_a = torch.nn.Parameter(torch.rand(())) - self.loc_b = torch.nn.Parameter(torch.rand((3,))) - - def forward(self): - a = pyro.sample("a", dist.Normal(self.loc_a, 1)) - with pyro.plate("data", 3, dim=-1): - b = pyro.sample("b", dist.Normal(self.loc_b, 1)) - return {"a": a, "b": b} - - ModelTestCase = Tuple[ Callable[[], Callable], Callable[[Callable], Callable], Set[str], Optional[int] ] From 117d6455c25355e5e3667340c14c2519473fcceb Mon Sep 17 00:00:00 2001 From: eb8680 Date: Thu, 7 Dec 2023 14:09:42 -0500 Subject: [PATCH 45/66] Add upper bound on number of CG steps (#404) * upper bound on cg_iters * address comment --- chirho/robust/internals/linearize.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/chirho/robust/internals/linearize.py b/chirho/robust/internals/linearize.py index 2d9615ef1..8a13dc2c7 100644 --- a/chirho/robust/internals/linearize.py +++ b/chirho/robust/internals/linearize.py @@ -66,9 +66,6 @@ def _flat_conjugate_gradient_solve( p = torch.where(not_converged, r + mu * p, p) rdotr = torch.where(not_converged, newrdotr, rdotr) - # rdotr = newrdotr - # if rdotr < residual_tol: - # break return x @@ -140,6 +137,11 @@ def linearize( log_prob_params, func_log_prob = make_functional_call(log_prob) score_fn = torch.func.grad(func_log_prob) + log_prob_params_numel: int = sum(p.numel() for p in log_prob_params.values()) + if cg_iters is None: + cg_iters = log_prob_params_numel + else: + cg_iters = min(cg_iters, log_prob_params_numel) cg_solver = functools.partial( conjugate_gradient_solve, cg_iters=cg_iters, residual_tol=residual_tol ) From 8fe1b25777cce8e04a26a3499540718d3dc13737 Mon Sep 17 00:00:00 2001 From: Raj Agrawal Date: Thu, 7 Dec 2023 12:25:41 -0800 Subject: [PATCH 46/66] fixed test for non-symmetric matrix (#437) --- tests/robust/test_internals_linearize.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/robust/test_internals_linearize.py b/tests/robust/test_internals_linearize.py index cb4e13af7..dddb104af 100644 --- a/tests/robust/test_internals_linearize.py +++ b/tests/robust/test_internals_linearize.py @@ -39,8 +39,8 @@ def test_cg_solve(ndim: int, dtype: torch.dtype): cg_iters = None residual_tol = 1e-10 - - A = torch.eye(ndim, dtype=dtype) + 0.1 * torch.rand(ndim, ndim, dtype=dtype) + U = torch.rand(ndim, ndim, dtype=dtype) + A = torch.eye(ndim, dtype=dtype) + 0.1 * U.mm(U.t()) expected_x = torch.randn(ndim, dtype=dtype) b = A @ expected_x @@ -57,7 +57,8 @@ def test_batch_cg_solve(ndim: int, dtype: torch.dtype, num_particles: int): cg_iters = None residual_tol = 1e-10 - A = torch.eye(ndim, dtype=dtype) + 0.1 * torch.rand(ndim, ndim, dtype=dtype) + U = torch.rand(ndim, ndim, dtype=dtype) + A = torch.eye(ndim, dtype=dtype) + 0.1 * U.mm(U.t()) expected_x = torch.randn(num_particles, ndim, dtype=dtype) b = torch.einsum("ij,nj->ni", A, expected_x) assert b.shape == (num_particles, ndim) From 3f0c83d50ac9081f3414d79102b84df557d2bd63 Mon Sep 17 00:00:00 2001 From: Raj Agrawal Date: Fri, 8 Dec 2023 09:12:08 -0800 Subject: [PATCH 47/66] Make `NMCLogPredictiveLikelihood` seeded (#408) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * switched back to different --- chirho/robust/internals/linearize.py | 5 +- chirho/robust/internals/predictive.py | 6 +++ chirho/robust/internals/utils.py | 10 ++++ tests/robust/test_internals_compositions.py | 54 +++++++++++++++++++-- 4 files changed, 70 insertions(+), 5 deletions(-) diff --git a/chirho/robust/internals/linearize.py b/chirho/robust/internals/linearize.py index 8a13dc2c7..439440bad 100644 --- a/chirho/robust/internals/linearize.py +++ b/chirho/robust/internals/linearize.py @@ -10,6 +10,7 @@ ParamDict, make_flatten_unflatten, make_functional_call, + reset_rng_state, ) from chirho.robust.ops import Point @@ -154,7 +155,9 @@ def _fn(point: Point[T], *args: P.args, **kwargs: P.kwargs) -> ParamDict: fvp = make_empirical_fisher_vp( func_log_prob, log_prob_params, data, *args, **kwargs ) + + pinned_fvp = reset_rng_state(pyro.util.get_rng_state())(fvp) point_score: ParamDict = score_fn(log_prob_params, point, *args, **kwargs) - return cg_solver(fvp, point_score) + return cg_solver(pinned_fvp, point_score) return _fn diff --git a/chirho/robust/internals/predictive.py b/chirho/robust/internals/predictive.py index ba529edf7..6e011de93 100644 --- a/chirho/robust/internals/predictive.py +++ b/chirho/robust/internals/predictive.py @@ -1,5 +1,6 @@ import contextlib import math +import warnings from typing import Any, Callable, Container, Generic, Optional, TypeVar import pyro @@ -123,6 +124,11 @@ def forward( self.max_plate_nesting = guess_max_plate_nesting( self.model, self.guide, *args, **kwargs ) + warnings.warn( + "Since max_plate_nesting is not specified, \ + the first call to NMCLogPredictiveLikelihood will not be seeded properly. \ + See https://github.com/BasisResearch/chirho/pull/408" + ) masked_guide = pyro.poutine.mask(mask=False)(self.guide) masked_model = _UnmaskNamedSites(names=set(data.keys()))( diff --git a/chirho/robust/internals/utils.py b/chirho/robust/internals/utils.py index fe0bcf77b..57c7df3e4 100644 --- a/chirho/robust/internals/utils.py +++ b/chirho/robust/internals/utils.py @@ -1,3 +1,4 @@ +import contextlib import functools from typing import Any, Callable, Dict, Mapping, Tuple, TypeVar @@ -87,3 +88,12 @@ def guess_max_plate_nesting( elbo = pyro.infer.Trace_ELBO() elbo._guess_max_plate_nesting(model, guide, args, kwargs) return elbo.max_plate_nesting + + +@contextlib.contextmanager +def reset_rng_state(rng_state: T): + try: + prev_rng_state: T = pyro.util.get_rng_state() + yield pyro.util.set_rng_state(rng_state) + finally: + pyro.util.set_rng_state(prev_rng_state) diff --git a/tests/robust/test_internals_compositions.py b/tests/robust/test_internals_compositions.py index a559560a9..0bf8b38ca 100644 --- a/tests/robust/test_internals_compositions.py +++ b/tests/robust/test_internals_compositions.py @@ -3,14 +3,13 @@ import pyro import torch -from pyro.poutine.seed_messenger import SeedMessenger from chirho.robust.internals.linearize import ( conjugate_gradient_solve, make_empirical_fisher_vp, ) from chirho.robust.internals.predictive import NMCLogPredictiveLikelihood -from chirho.robust.internals.utils import make_functional_call +from chirho.robust.internals.utils import make_functional_call, reset_rng_state from .robust_fixtures import SimpleGuide, SimpleModel @@ -23,14 +22,14 @@ def test_empirical_fisher_vp_nmclikelihood_cg_composition(): model(), guide() # initialize log_prob = NMCLogPredictiveLikelihood(model, guide, num_samples=100) log_prob_params, func_log_prob = make_functional_call(log_prob) - func_log_prob = SeedMessenger(123)(func_log_prob) + func_log_prob = reset_rng_state(pyro.util.get_rng_state())(func_log_prob) predictive = pyro.infer.Predictive( model, guide=guide, num_samples=1000, parallel=True, return_sites=["y"] ) predictive_params, func_predictive = make_functional_call(predictive) - cg_solver = functools.partial(conjugate_gradient_solve, cg_iters=10) + cg_solver = functools.partial(conjugate_gradient_solve, cg_iters=2) with torch.no_grad(): data = func_predictive(predictive_params) @@ -59,3 +58,50 @@ def test_empirical_fisher_vp_nmclikelihood_cg_composition(): assert torch.allclose(solve_one["guide.loc_a"], solve_two["guide.loc_a"]) assert torch.allclose(solve_one["guide.loc_b"], solve_two["guide.loc_b"]) + + +def test_nmc_likelihood_seeded(): + model = SimpleModel() + guide = SimpleGuide() + model(), guide() # initialize + + log_prob = NMCLogPredictiveLikelihood( + model, guide, num_samples=3, max_plate_nesting=3 + ) + log_prob_params, func_log_prob = make_functional_call(log_prob) + + func_log_prob = reset_rng_state(pyro.util.get_rng_state())(func_log_prob) + + datapoint = {"y": torch.tensor([1.0, 2.0, 3.0])} + prob_call_one = func_log_prob(log_prob_params, datapoint) + prob_call_two = func_log_prob(log_prob_params, datapoint) + prob_call_three = func_log_prob(log_prob_params, datapoint) + assert torch.allclose(prob_call_two, prob_call_three) + assert torch.allclose(prob_call_one, prob_call_two) + + data = {"y": torch.tensor([[0.3665, 1.5440, 2.2210], [0.3665, 1.5440, 2.2210]])} + + fvp = make_empirical_fisher_vp(func_log_prob, log_prob_params, data) + + v = {k: torch.ones_like(v) for k, v in log_prob_params.items()} + + assert fvp(v)["guide.loc_a"].abs().max() > 0 + assert fvp(v)["guide.loc_b"].abs().max() > 0 + + # Check if fvp agrees across multiple calls of same `fvp` object + assert torch.allclose(fvp(v)["guide.loc_a"], fvp(v)["guide.loc_a"]) + assert torch.allclose(fvp(v)["guide.loc_b"], fvp(v)["guide.loc_b"]) + + # Check if fvp agrees across different `fvp` objects with same inputs + fvp_one = make_empirical_fisher_vp(func_log_prob, log_prob_params, data) + fvp_two = make_empirical_fisher_vp(func_log_prob, log_prob_params, data) + assert torch.allclose(fvp_one(v)["guide.loc_a"], fvp_two(v)["guide.loc_a"]) + assert torch.allclose(fvp_one(v)["guide.loc_b"], fvp_two(v)["guide.loc_b"]) + + # Since `data` has same datapoint twice, fvp with redundant + # data should agree when only the single datapoint is used if the + # seeding works correctly + data_two = {"y": torch.tensor([[0.3665, 1.5440, 2.2210]])} + fvp_two = make_empirical_fisher_vp(func_log_prob, log_prob_params, data_two) + assert torch.allclose(fvp_one(v)["guide.loc_a"], fvp_two(v)["guide.loc_a"]) + assert torch.allclose(fvp_one(v)["guide.loc_b"], fvp_two(v)["guide.loc_b"]) From 4d418073122f02c79b2bbdba1c16874988ccc54d Mon Sep 17 00:00:00 2001 From: Raj Agrawal Date: Fri, 8 Dec 2023 10:31:36 -0800 Subject: [PATCH 48/66] Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430) * hessian vector product formulation for fisher * ignoring small type error * fixed linting error --- chirho/robust/internals/linearize.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/chirho/robust/internals/linearize.py b/chirho/robust/internals/linearize.py index 439440bad..09e8e5216 100644 --- a/chirho/robust/internals/linearize.py +++ b/chirho/robust/internals/linearize.py @@ -94,17 +94,21 @@ def make_empirical_fisher_vp( randomness="different", ) + N = data[next(iter(data))].shape[0] # type: ignore + mean_vector = 1 / N * torch.ones(N) + def bound_batched_func_log_prob(params: ParamDict) -> torch.Tensor: return batched_func_log_prob(params, data) - def jvp_fn(v: ParamDict) -> torch.Tensor: - return torch.func.jvp(bound_batched_func_log_prob, (log_prob_params,), (v,))[1] - - vjp_fn = torch.func.vjp(bound_batched_func_log_prob, log_prob_params)[1] - def _empirical_fisher_vp(v: ParamDict) -> ParamDict: - jvp_log_prob_v = jvp_fn(v) - return vjp_fn(jvp_log_prob_v / jvp_log_prob_v.shape[0])[0] + def jvp_fn(log_prob_params: ParamDict) -> torch.Tensor: + return torch.func.jvp( + bound_batched_func_log_prob, (log_prob_params,), (v,) + )[1] + + # Perlmutter's trick + vjp_fn = torch.func.vjp(jvp_fn, log_prob_params)[1] + return vjp_fn(-1 * mean_vector)[0] # Fisher = -E[Hessian] return _empirical_fisher_vp From 2e01b7b457e2902951f1766c58d18242a7b22a30 Mon Sep 17 00:00:00 2001 From: Raj Agrawal Date: Fri, 8 Dec 2023 10:36:20 -0800 Subject: [PATCH 49/66] Add new `SimpleModel` and `SimpleGuide` (#440) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * uncomitted change before branch switch * switched back to different * added revised simple model and guide * added multiple link functions in test * linting --- tests/robust/robust_fixtures.py | 9 ++++- tests/robust/test_internals_compositions.py | 43 +++++++++++---------- tests/robust/test_internals_linearize.py | 18 ++++++++- 3 files changed, 47 insertions(+), 23 deletions(-) diff --git a/tests/robust/robust_fixtures.py b/tests/robust/robust_fixtures.py index a74c94ebc..4496e7da6 100644 --- a/tests/robust/robust_fixtures.py +++ b/tests/robust/robust_fixtures.py @@ -15,11 +15,18 @@ class SimpleModel(PyroModule): + def __init__( + self, + link_fn: Callable[..., dist.Distribution] = lambda mu: dist.Normal(mu, 1.0), + ): + super().__init__() + self.link_fn = link_fn + def forward(self): a = pyro.sample("a", dist.Normal(0, 1)) with pyro.plate("data", 3, dim=-1): b = pyro.sample("b", dist.Normal(a, 1)) - return pyro.sample("y", dist.Normal(a + b, 1)) + return pyro.sample("y", dist.Normal(b, 1)) class SimpleGuide(torch.nn.Module): diff --git a/tests/robust/test_internals_compositions.py b/tests/robust/test_internals_compositions.py index 0bf8b38ca..f4bc79505 100644 --- a/tests/robust/test_internals_compositions.py +++ b/tests/robust/test_internals_compositions.py @@ -2,6 +2,7 @@ import warnings import pyro +import pytest import torch from chirho.robust.internals.linearize import ( @@ -35,9 +36,14 @@ def test_empirical_fisher_vp_nmclikelihood_cg_composition(): data = func_predictive(predictive_params) fvp = make_empirical_fisher_vp(func_log_prob, log_prob_params, data) - v = {k: torch.ones_like(v) for k, v in log_prob_params.items()} + v = { + k: torch.ones_like(v) if k != "guide.loc_a" else torch.zeros_like(v) + for k, v in log_prob_params.items() + } - assert fvp(v)["guide.loc_a"].abs().max() > 0 # sanity check for non-zero fvp + # For this model, fvp for loc_a is zero. See + # https://github.com/BasisResearch/chirho/issues/427 + assert fvp(v)["guide.loc_a"].abs().max() == 0 solve_one = cg_solver(fvp, v) solve_two = cg_solver(fvp, v) @@ -56,12 +62,24 @@ def test_empirical_fisher_vp_nmclikelihood_cg_composition(): ) ) + assert torch.allclose( + solve_one["guide.loc_a"], torch.zeros_like(log_prob_params["guide.loc_a"]) + ) assert torch.allclose(solve_one["guide.loc_a"], solve_two["guide.loc_a"]) assert torch.allclose(solve_one["guide.loc_b"], solve_two["guide.loc_b"]) -def test_nmc_likelihood_seeded(): - model = SimpleModel() +link_functions = [ + lambda mu: pyro.distributions.Normal(mu, 1.0), + lambda mu: pyro.distributions.Bernoulli(logits=mu), + lambda mu: pyro.distributions.Beta(concentration1=mu, concentration0=1.0), + lambda mu: pyro.distributions.Exponential(rate=mu), +] + + +@pytest.mark.parametrize("link_fn", link_functions) +def test_nmc_likelihood_seeded(link_fn): + model = SimpleModel(link_fn=link_fn) guide = SimpleGuide() model(), guide() # initialize @@ -85,23 +103,8 @@ def test_nmc_likelihood_seeded(): v = {k: torch.ones_like(v) for k, v in log_prob_params.items()} - assert fvp(v)["guide.loc_a"].abs().max() > 0 - assert fvp(v)["guide.loc_b"].abs().max() > 0 + assert (fvp(v)["guide.loc_a"].abs().max() + fvp(v)["guide.loc_b"].abs().max()) > 0 # Check if fvp agrees across multiple calls of same `fvp` object assert torch.allclose(fvp(v)["guide.loc_a"], fvp(v)["guide.loc_a"]) assert torch.allclose(fvp(v)["guide.loc_b"], fvp(v)["guide.loc_b"]) - - # Check if fvp agrees across different `fvp` objects with same inputs - fvp_one = make_empirical_fisher_vp(func_log_prob, log_prob_params, data) - fvp_two = make_empirical_fisher_vp(func_log_prob, log_prob_params, data) - assert torch.allclose(fvp_one(v)["guide.loc_a"], fvp_two(v)["guide.loc_a"]) - assert torch.allclose(fvp_one(v)["guide.loc_b"], fvp_two(v)["guide.loc_b"]) - - # Since `data` has same datapoint twice, fvp with redundant - # data should agree when only the single datapoint is used if the - # seeding works correctly - data_two = {"y": torch.tensor([[0.3665, 1.5440, 2.2210]])} - fvp_two = make_empirical_fisher_vp(func_log_prob, log_prob_params, data_two) - assert torch.allclose(fvp_one(v)["guide.loc_a"], fvp_two(v)["guide.loc_a"]) - assert torch.allclose(fvp_one(v)["guide.loc_b"], fvp_two(v)["guide.loc_b"]) diff --git a/tests/robust/test_internals_linearize.py b/tests/robust/test_internals_linearize.py index dddb104af..16f869885 100644 --- a/tests/robust/test_internals_linearize.py +++ b/tests/robust/test_internals_linearize.py @@ -134,7 +134,14 @@ def test_nmc_param_influence_smoke( for k, v in test_datum_eif.items(): assert not torch.isnan(v).any(), f"eif for {k} had nans" assert not torch.isinf(v).any(), f"eif for {k} had infs" - assert not torch.isclose(v, torch.zeros_like(v)).all(), f"eif for {k} was zero" + if k != "guide.loc_a": + assert not torch.isclose( + v, torch.zeros_like(v) + ).all(), f"eif for {k} was zero" + else: + assert torch.isclose( + v, torch.zeros_like(v) + ).all(), f"eif for {k} should be zero" @pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", MODEL_TEST_CASES) @@ -174,7 +181,14 @@ def test_nmc_param_influence_vmap_smoke( for k, v in test_data_eif.items(): assert not torch.isnan(v).any(), f"eif for {k} had nans" assert not torch.isinf(v).any(), f"eif for {k} had infs" - assert not torch.isclose(v, torch.zeros_like(v)).all(), f"eif for {k} was zero" + if k != "guide.loc_a": + assert not torch.isclose( + v, torch.zeros_like(v) + ).all(), f"eif for {k} was zero" + else: + assert torch.isclose( + v, torch.zeros_like(v) + ).all(), f"eif for {k} should be zero" @pytest.mark.parametrize( From 538cef8e368010c704de8e12e92897a687600f09 Mon Sep 17 00:00:00 2001 From: Raj Agrawal Date: Fri, 22 Dec 2023 13:41:46 -0800 Subject: [PATCH 50/66] Batching in `linearize` and `influence` (#465) * batching in linearize and influence * addressing eli's review * added optimization for pointwise false case * fixing lint error --- chirho/robust/internals/linearize.py | 35 ++++++++++++++++++------ chirho/robust/ops.py | 14 ++++++---- tests/robust/test_internals_linearize.py | 24 +++++++++++++--- tests/robust/test_ops.py | 3 +- 4 files changed, 57 insertions(+), 19 deletions(-) diff --git a/chirho/robust/internals/linearize.py b/chirho/robust/internals/linearize.py index 09e8e5216..8d008942d 100644 --- a/chirho/robust/internals/linearize.py +++ b/chirho/robust/internals/linearize.py @@ -122,6 +122,7 @@ def linearize( max_plate_nesting: Optional[int] = None, cg_iters: Optional[int] = None, residual_tol: float = 1e-10, + pointwise_influence: bool = True, ) -> Callable[Concatenate[Point[T], P], ParamDict]: assert isinstance(model, torch.nn.Module) assert isinstance(guide, torch.nn.Module) @@ -140,8 +141,6 @@ def linearize( model, guide, num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting ) log_prob_params, func_log_prob = make_functional_call(log_prob) - score_fn = torch.func.grad(func_log_prob) - log_prob_params_numel: int = sum(p.numel() for p in log_prob_params.values()) if cg_iters is None: cg_iters = log_prob_params_numel @@ -151,17 +150,37 @@ def linearize( conjugate_gradient_solve, cg_iters=cg_iters, residual_tol=residual_tol ) - @functools.wraps(score_fn) - def _fn(point: Point[T], *args: P.args, **kwargs: P.kwargs) -> ParamDict: + def _fn( + points: Point[T], + *args: P.args, + **kwargs: P.kwargs, + ) -> ParamDict: with torch.no_grad(): data: Point[T] = func_predictive(predictive_params, *args, **kwargs) - data = {k: data[k] for k in point.keys()} + data = {k: data[k] for k in points.keys()} fvp = make_empirical_fisher_vp( func_log_prob, log_prob_params, data, *args, **kwargs ) - pinned_fvp = reset_rng_state(pyro.util.get_rng_state())(fvp) - point_score: ParamDict = score_fn(log_prob_params, point, *args, **kwargs) - return cg_solver(pinned_fvp, point_score) + batched_func_log_prob = torch.vmap( + lambda p, data: func_log_prob(p, data, *args, **kwargs), + in_dims=(None, 0), + randomness="different", + ) + + def bound_batched_func_log_prob(p: ParamDict) -> torch.Tensor: + return batched_func_log_prob(p, points) + + if pointwise_influence: + score_fn = torch.func.jacrev(bound_batched_func_log_prob) + point_scores = score_fn(log_prob_params) + else: + score_fn = torch.func.vjp(bound_batched_func_log_prob, log_prob_params)[1] + N_pts = points[next(iter(points))].shape[0] # type: ignore + point_scores = score_fn(1 / N_pts * torch.ones(N_pts))[0] + point_scores = {k: v.unsqueeze(0) for k, v in point_scores.items()} + return torch.func.vmap( + lambda v: cg_solver(pinned_fvp, v), randomness="different" + )(point_scores) return _fn diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py index d3ea8259a..806f7b4d9 100644 --- a/chirho/robust/ops.py +++ b/chirho/robust/ops.py @@ -39,10 +39,14 @@ def influence_fn( target_params, func_target = make_functional_call(target) @functools.wraps(target) - def _fn(point: Point[T], *args: P.args, **kwargs: P.kwargs) -> S: - param_eif = linearized(point, *args, **kwargs) - return torch.func.jvp( - lambda p: func_target(p, *args, **kwargs), (target_params,), (param_eif,) - )[1] + def _fn(points: Point[T], *args: P.args, **kwargs: P.kwargs) -> S: + param_eif = linearized(points, *args, **kwargs) + return torch.vmap( + lambda d: torch.func.jvp( + lambda p: func_target(p, *args, **kwargs), (target_params,), (d,) + )[1], + in_dims=0, + randomness="different", + )(param_eif) return _fn diff --git a/tests/robust/test_internals_linearize.py b/tests/robust/test_internals_linearize.py index 16f869885..c1e6e4d6d 100644 --- a/tests/robust/test_internals_linearize.py +++ b/tests/robust/test_internals_linearize.py @@ -175,8 +175,7 @@ def test_nmc_param_influence_vmap_smoke( model, num_samples=4, return_sites=obs_names, parallel=True )() - batch_param_eif = torch.vmap(param_eif, randomness="different") - test_data_eif: Mapping[str, torch.Tensor] = batch_param_eif(test_data) + test_data_eif: Mapping[str, torch.Tensor] = param_eif(test_data) assert len(test_data_eif) > 0 for k, v in test_data_eif.items(): assert not torch.isnan(v).any(), f"eif for {k} had nans" @@ -349,10 +348,10 @@ def link(mu): num_samples_outer=10000, num_samples_inner=1, cg_iters=4, # dimension of params = 4 + pointwise_influence=True, ) - batch_param_eif = torch.vmap(param_eif, randomness="different") - test_data_eif = batch_param_eif(D_test) + test_data_eif = param_eif(D_test) median_abs_error = torch.abs( test_data_eif["guide.treatment_weight_param"] - analytic_eif_at_test_pts ).median() @@ -361,3 +360,20 @@ def link(mu): assert median_abs_error / median_scale < 0.5 else: assert median_abs_error < 0.5 + + # Test w/ pointwise_influence=False + param_eif = linearize( + model, + mle_guide, + num_samples_outer=10000, + num_samples_inner=1, + cg_iters=4, # dimension of params = 4 + pointwise_influence=False, + ) + + test_data_eif = param_eif(D_test) + assert torch.allclose( + test_data_eif["guide.treatment_weight_param"][0], + analytic_eif_at_test_pts.mean(), + atol=1e-1, + ) diff --git a/tests/robust/test_ops.py b/tests/robust/test_ops.py index 7eeb99641..881faa63c 100644 --- a/tests/robust/test_ops.py +++ b/tests/robust/test_ops.py @@ -120,8 +120,7 @@ def test_nmc_predictive_influence_vmap_smoke( model, num_samples=4, return_sites=obs_names, parallel=True )() - batch_predictive_eif = torch.vmap(predictive_eif, randomness="different") - test_data_eif: Mapping[str, torch.Tensor] = batch_predictive_eif(test_data) + test_data_eif: Mapping[str, torch.Tensor] = predictive_eif(test_data) assert len(test_data_eif) > 0 for k, v in test_data_eif.items(): assert not torch.isnan(v).any(), f"eif for {k} had nans" From 6bba70b3361834d6bc07961043c510c0a5f8155a Mon Sep 17 00:00:00 2001 From: Raj Agrawal Date: Fri, 22 Dec 2023 14:08:55 -0800 Subject: [PATCH 51/66] batched cg (#466) --- chirho/robust/internals/linearize.py | 43 +++++++++++++-------- chirho/robust/internals/utils.py | 15 +++++-- tests/robust/test_internals_compositions.py | 9 ++++- tests/robust/test_internals_linearize.py | 29 +++----------- 4 files changed, 52 insertions(+), 44 deletions(-) diff --git a/chirho/robust/internals/linearize.py b/chirho/robust/internals/linearize.py index 8d008942d..436ae033b 100644 --- a/chirho/robust/internals/linearize.py +++ b/chirho/robust/internals/linearize.py @@ -25,7 +25,7 @@ def _flat_conjugate_gradient_solve( b: torch.Tensor, *, cg_iters: Optional[int] = None, - residual_tol: float = 1e-10, + residual_tol: float = 1e-3, ) -> torch.Tensor: r"""Use Conjugate Gradient iteration to solve Ax = b. Demmel p 312. @@ -42,31 +42,41 @@ def _flat_conjugate_gradient_solve( Notes: This code is adapted from https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py """ + assert len(b.shape), "b must be a 2D matrix" + if cg_iters is None: - cg_iters = b.numel() + cg_iters = b.shape[1] + else: + cg_iters = min(cg_iters, b.shape[1]) + + def _batched_dot(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + return (x1 * x2).sum(axis=-1) # type: ignore + + def _batched_product(a: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + return a.unsqueeze(0).t() * B p = b.clone() r = b.clone() x = torch.zeros_like(b) z = f_Ax(p) - rdotr = torch.dot(r, r) - v = rdotr / torch.dot(p, z) + rdotr = _batched_dot(r, r) + v = rdotr / _batched_dot(p, z) newrdotr = rdotr mu = newrdotr / rdotr - zeros_xr = torch.zeros_like(x) - for _ in range(cg_iters): not_converged = rdotr > residual_tol - z = torch.where(not_converged, f_Ax(p), z) - v = torch.where(not_converged, rdotr / torch.dot(p, z), v) - x += torch.where(not_converged, v * p, zeros_xr) - r -= torch.where(not_converged, v * z, zeros_xr) - newrdotr = torch.where(not_converged, torch.dot(r, r), newrdotr) + not_converged_broadcasted = not_converged.unsqueeze(0).t() + z = torch.where(not_converged_broadcasted, f_Ax(p), z) + v = torch.where(not_converged, rdotr / _batched_dot(p, z), v) + x += torch.where(not_converged_broadcasted, _batched_product(v, p), zeros_xr) + r -= torch.where(not_converged_broadcasted, _batched_product(v, z), zeros_xr) + newrdotr = torch.where(not_converged, _batched_dot(r, r), newrdotr) mu = torch.where(not_converged, newrdotr / rdotr, mu) - p = torch.where(not_converged, r + mu * p, p) + p = torch.where(not_converged_broadcasted, r + _batched_product(mu, p), p) rdotr = torch.where(not_converged, newrdotr, rdotr) - + if torch.all(~not_converged): + return x return x @@ -162,6 +172,9 @@ def _fn( func_log_prob, log_prob_params, data, *args, **kwargs ) pinned_fvp = reset_rng_state(pyro.util.get_rng_state())(fvp) + pinned_fvp_batched = torch.func.vmap( + lambda v: pinned_fvp(v), randomness="different" + ) batched_func_log_prob = torch.vmap( lambda p, data: func_log_prob(p, data, *args, **kwargs), in_dims=(None, 0), @@ -179,8 +192,6 @@ def bound_batched_func_log_prob(p: ParamDict) -> torch.Tensor: N_pts = points[next(iter(points))].shape[0] # type: ignore point_scores = score_fn(1 / N_pts * torch.ones(N_pts))[0] point_scores = {k: v.unsqueeze(0) for k, v in point_scores.items()} - return torch.func.vmap( - lambda v: cg_solver(pinned_fvp, v), randomness="different" - )(point_scores) + return cg_solver(pinned_fvp_batched, point_scores) return _fn diff --git a/chirho/robust/internals/utils.py b/chirho/robust/internals/utils.py index 57c7df3e4..7af0af4e0 100644 --- a/chirho/robust/internals/utils.py +++ b/chirho/robust/internals/utils.py @@ -23,11 +23,13 @@ def make_flatten_unflatten( @make_flatten_unflatten.register(torch.Tensor) def _make_flatten_unflatten_tensor(v: torch.Tensor): + batch_size = v.shape[0] + def flatten(v: torch.Tensor) -> torch.Tensor: r""" Flatten a tensor into a single vector. """ - return v.flatten() + return v.reshape((batch_size, -1)) def unflatten(x: torch.Tensor) -> torch.Tensor: r""" @@ -40,11 +42,13 @@ def unflatten(x: torch.Tensor) -> torch.Tensor: @make_flatten_unflatten.register(dict) def _make_flatten_unflatten_dict(d: Dict[str, torch.Tensor]): + batch_size = next(iter(d.values())).shape[0] + def flatten(d: Dict[str, torch.Tensor]) -> torch.Tensor: r""" Flatten a dictionary of tensors into a single vector. """ - return torch.cat([v.flatten() for k, v in d.items()]) + return torch.hstack([v.reshape((batch_size, -1)) for k, v in d.items()]) def unflatten(x: torch.Tensor) -> Dict[str, torch.Tensor]: r""" @@ -56,7 +60,12 @@ def unflatten(x: torch.Tensor) -> Dict[str, torch.Tensor]: [ v_flat.reshape(v.shape) for v, v_flat in zip( - d.values(), torch.split(x, [v.numel() for k, v in d.items()]) + d.values(), + torch.split( + x, + [int(v.numel() / batch_size) for k, v in d.items()], + dim=1, + ), ) ], ) diff --git a/tests/robust/test_internals_compositions.py b/tests/robust/test_internals_compositions.py index f4bc79505..1e7aea907 100644 --- a/tests/robust/test_internals_compositions.py +++ b/tests/robust/test_internals_compositions.py @@ -34,10 +34,15 @@ def test_empirical_fisher_vp_nmclikelihood_cg_composition(): with torch.no_grad(): data = func_predictive(predictive_params) - fvp = make_empirical_fisher_vp(func_log_prob, log_prob_params, data) + + fvp = torch.func.vmap( + make_empirical_fisher_vp(func_log_prob, log_prob_params, data) + ) v = { - k: torch.ones_like(v) if k != "guide.loc_a" else torch.zeros_like(v) + k: torch.ones_like(v).unsqueeze(0) + if k != "guide.loc_a" + else torch.zeros_like(v).unsqueeze(0) for k, v in log_prob_params.items() } diff --git a/tests/robust/test_internals_linearize.py b/tests/robust/test_internals_linearize.py index c1e6e4d6d..afdca254f 100644 --- a/tests/robust/test_internals_linearize.py +++ b/tests/robust/test_internals_linearize.py @@ -34,22 +34,6 @@ T = TypeVar("T") -@pytest.mark.parametrize("ndim", [1, 2, 3, 10]) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) -def test_cg_solve(ndim: int, dtype: torch.dtype): - cg_iters = None - residual_tol = 1e-10 - U = torch.rand(ndim, ndim, dtype=dtype) - A = torch.eye(ndim, dtype=dtype) + 0.1 * U.mm(U.t()) - expected_x = torch.randn(ndim, dtype=dtype) - b = A @ expected_x - - actual_x = conjugate_gradient_solve( - lambda v: A @ v, b, cg_iters=cg_iters, residual_tol=residual_tol - ) - assert torch.sum((actual_x - expected_x) ** 2) < 1e-4 - - @pytest.mark.parametrize("ndim", [1, 2, 3, 10]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) @pytest.mark.parametrize("num_particles", [1, 4]) @@ -63,14 +47,13 @@ def test_batch_cg_solve(ndim: int, dtype: torch.dtype, num_particles: int): b = torch.einsum("ij,nj->ni", A, expected_x) assert b.shape == (num_particles, ndim) - batch_solve = torch.vmap( - functools.partial( - conjugate_gradient_solve, - lambda v: A @ v, - cg_iters=cg_iters, - residual_tol=residual_tol, - ), + batch_solve = functools.partial( + conjugate_gradient_solve, + lambda v: torch.einsum("ij,nj->ni", A, v), + cg_iters=cg_iters, + residual_tol=residual_tol, ) + actual_x = batch_solve(b) assert torch.all(torch.sum((actual_x - expected_x) ** 2, dim=1) < 1e-4) From f143d3a95331933cddf91eed90db5f1f36fbf3b2 Mon Sep 17 00:00:00 2001 From: Raj Agrawal Date: Fri, 22 Dec 2023 15:05:57 -0800 Subject: [PATCH 52/66] One step correction implemented (#467) * one step correction * increased tolerance * fixing lint issue --- chirho/robust/handlers/estimators.py | 21 ++++++ tests/robust/test_handlers.py | 86 ++++++++++++++++++++++++ tests/robust/test_internals_linearize.py | 2 +- 3 files changed, 108 insertions(+), 1 deletion(-) create mode 100644 chirho/robust/handlers/estimators.py create mode 100644 tests/robust/test_handlers.py diff --git a/chirho/robust/handlers/estimators.py b/chirho/robust/handlers/estimators.py new file mode 100644 index 000000000..146282445 --- /dev/null +++ b/chirho/robust/handlers/estimators.py @@ -0,0 +1,21 @@ +from typing import Any, Callable, Optional + +from typing_extensions import Concatenate + +from chirho.robust.ops import Functional, P, Point, S, T, influence_fn + + +def one_step_correction( + model: Callable[P, Any], + guide: Callable[P, Any], + functional: Optional[Functional[P, S]] = None, + **influence_kwargs, +) -> Callable[Concatenate[Point[T], P], S]: + influence_kwargs_one_step = influence_kwargs.copy() + influence_kwargs_one_step["pointwise_influence"] = False + eif_fn = influence_fn(model, guide, functional, **influence_kwargs_one_step) + + def _one_step(test_data: Point[T], *args, **kwargs) -> S: + return eif_fn(test_data, *args, **kwargs) + + return _one_step diff --git a/tests/robust/test_handlers.py b/tests/robust/test_handlers.py new file mode 100644 index 000000000..fc36d0b49 --- /dev/null +++ b/tests/robust/test_handlers.py @@ -0,0 +1,86 @@ +import functools +from typing import Callable, List, Mapping, Optional, Set, Tuple, TypeVar + +import pyro +import pytest +import torch +from typing_extensions import ParamSpec + +from chirho.robust.handlers.estimators import one_step_correction +from chirho.robust.internals.predictive import PredictiveFunctional + +from .robust_fixtures import SimpleGuide, SimpleModel + +pyro.settings.set(module_local_params=True) + +P = ParamSpec("P") +Q = ParamSpec("Q") +S = TypeVar("S") +T = TypeVar("T") + + +ModelTestCase = Tuple[ + Callable[[], Callable], Callable[[Callable], Callable], Set[str], Optional[int] +] + +MODEL_TEST_CASES: List[ModelTestCase] = [ + (SimpleModel, lambda _: SimpleGuide(), {"y"}, 1), + (SimpleModel, lambda _: SimpleGuide(), {"y"}, None), + pytest.param( + SimpleModel, + pyro.infer.autoguide.AutoNormal, + {"y"}, + 1, + marks=pytest.mark.xfail( + reason="torch.func autograd doesnt work with PyroParam" + ), + ), +] + + +@pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", MODEL_TEST_CASES) +@pytest.mark.parametrize("num_samples_outer,num_samples_inner", [(10, None), (10, 100)]) +@pytest.mark.parametrize("cg_iters", [None, 1, 10]) +@pytest.mark.parametrize("num_predictive_samples", [1, 5]) +def test_one_step_correction_smoke( + model, + guide, + obs_names, + max_plate_nesting, + num_samples_outer, + num_samples_inner, + cg_iters, + num_predictive_samples, +): + model = model() + guide = guide(model) + model(), guide() # initialize + + one_step = one_step_correction( + model, + guide, + functional=functools.partial( + PredictiveFunctional, num_samples=num_predictive_samples + ), + max_plate_nesting=max_plate_nesting, + num_samples_outer=num_samples_outer, + num_samples_inner=num_samples_inner, + cg_iters=cg_iters, + ) + + with torch.no_grad(): + test_datum = { + k: v[0] + for k, v in pyro.infer.Predictive( + model, num_samples=2, return_sites=obs_names, parallel=True + )().items() + } + + one_step_on_test: Mapping[str, torch.Tensor] = one_step(test_datum) + assert len(one_step_on_test) > 0 + for k, v in one_step_on_test.items(): + assert not torch.isnan(v).any(), f"one_step for {k} had nans" + assert not torch.isinf(v).any(), f"one_step for {k} had infs" + assert not torch.isclose( + v, torch.zeros_like(v) + ).all(), f"one_step for {k} was zero" diff --git a/tests/robust/test_internals_linearize.py b/tests/robust/test_internals_linearize.py index afdca254f..a8a80a536 100644 --- a/tests/robust/test_internals_linearize.py +++ b/tests/robust/test_internals_linearize.py @@ -358,5 +358,5 @@ def link(mu): assert torch.allclose( test_data_eif["guide.treatment_weight_param"][0], analytic_eif_at_test_pts.mean(), - atol=1e-1, + atol=0.5, ) From 878eb0d502322fd5d24d5a12fe5bc43ba5c9f824 Mon Sep 17 00:00:00 2001 From: eb8680 Date: Tue, 2 Jan 2024 14:38:14 -0500 Subject: [PATCH 53/66] Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473) * sketch batched nmc lpd * nits * fix type * format * comment * comment * comment * typo * typo * add condition to help guarantee idempotence * simplify edge case * simplify plate_name * simplify batchedobservation logic * factorize * simplify batched * reorder * comment * remove plate_names * types * formatting and type * move unbind to utils * remove max_plate_nesting arg from get_traces * comment * nit * move get_importance_traces to utils * fix types * generic obs type * lint * format * handle observe in batchedobservations * event dim * move batching handlers to utils * replace 2/3 vmaps, tests pass * remove dead code * format * name args * lint * shuffle code * try an extra optimization in batchedlatents * add another optimization * undo changes to test * remove inplace adds * add performance test showing speedup * document internal helpers * batch latents test * move batch handlers to predictive * add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel * use bind_leftmost_dim in log prob --- chirho/observational/handlers/condition.py | 8 +- chirho/robust/internals/linearize.py | 28 +- chirho/robust/internals/predictive.py | 303 +++++++++++++++----- chirho/robust/internals/utils.py | 129 ++++++++- tests/robust/test_internals_compositions.py | 106 ++++++- tests/robust/test_performance.py | 178 ++++++++++++ 6 files changed, 647 insertions(+), 105 deletions(-) create mode 100644 tests/robust/test_performance.py diff --git a/chirho/observational/handlers/condition.py b/chirho/observational/handlers/condition.py index a097743c6..ebdd43bc0 100644 --- a/chirho/observational/handlers/condition.py +++ b/chirho/observational/handlers/condition.py @@ -1,10 +1,10 @@ -from typing import Callable, Generic, Hashable, Mapping, TypeVar, Union +from typing import Callable, Generic, Mapping, TypeVar, Union import pyro import torch from chirho.observational.internals import ObserveNameMessenger -from chirho.observational.ops import AtomicObservation, observe +from chirho.observational.ops import Observation, observe T = TypeVar("T") R = Union[float, torch.Tensor] @@ -62,7 +62,9 @@ class Observations(Generic[T], ObserveNameMessenger): a richer set of observational data types and enables counterfactual inference. """ - def __init__(self, data: Mapping[Hashable, AtomicObservation[T]]): + data: Mapping[str, Observation[T]] + + def __init__(self, data: Mapping[str, Observation[T]]): self.data = data super().__init__() diff --git a/chirho/robust/internals/linearize.py b/chirho/robust/internals/linearize.py index 436ae033b..d02ef1207 100644 --- a/chirho/robust/internals/linearize.py +++ b/chirho/robust/internals/linearize.py @@ -5,7 +5,7 @@ import torch from typing_extensions import Concatenate, ParamSpec -from chirho.robust.internals.predictive import NMCLogPredictiveLikelihood +from chirho.robust.internals.predictive import BatchedNMCLogPredictiveLikelihood from chirho.robust.internals.utils import ( ParamDict, make_flatten_unflatten, @@ -92,23 +92,17 @@ def f_Ax_flat(v: torch.Tensor) -> torch.Tensor: def make_empirical_fisher_vp( - func_log_prob: Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor], + batched_func_log_prob: Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor], log_prob_params: ParamDict, data: Point[T], *args: P.args, **kwargs: P.kwargs, ) -> Callable[[ParamDict], ParamDict]: - batched_func_log_prob: Callable[[ParamDict, Point[T]], torch.Tensor] = torch.vmap( - lambda p, data: func_log_prob(p, data, *args, **kwargs), - in_dims=(None, 0), - randomness="different", - ) - N = data[next(iter(data))].shape[0] # type: ignore mean_vector = 1 / N * torch.ones(N) def bound_batched_func_log_prob(params: ParamDict) -> torch.Tensor: - return batched_func_log_prob(params, data) + return batched_func_log_prob(params, data, *args, **kwargs) def _empirical_fisher_vp(v: ParamDict) -> ParamDict: def jvp_fn(log_prob_params: ParamDict) -> torch.Tensor: @@ -145,12 +139,11 @@ def linearize( num_samples=num_samples_outer, parallel=True, ) - predictive_params, func_predictive = make_functional_call(predictive) - log_prob = NMCLogPredictiveLikelihood( + batched_log_prob = BatchedNMCLogPredictiveLikelihood( model, guide, num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting ) - log_prob_params, func_log_prob = make_functional_call(log_prob) + log_prob_params, batched_func_log_prob = make_functional_call(batched_log_prob) log_prob_params_numel: int = sum(p.numel() for p in log_prob_params.values()) if cg_iters is None: cg_iters = log_prob_params_numel @@ -166,23 +159,18 @@ def _fn( **kwargs: P.kwargs, ) -> ParamDict: with torch.no_grad(): - data: Point[T] = func_predictive(predictive_params, *args, **kwargs) + data: Point[T] = predictive(*args, **kwargs) data = {k: data[k] for k in points.keys()} fvp = make_empirical_fisher_vp( - func_log_prob, log_prob_params, data, *args, **kwargs + batched_func_log_prob, log_prob_params, data, *args, **kwargs ) pinned_fvp = reset_rng_state(pyro.util.get_rng_state())(fvp) pinned_fvp_batched = torch.func.vmap( lambda v: pinned_fvp(v), randomness="different" ) - batched_func_log_prob = torch.vmap( - lambda p, data: func_log_prob(p, data, *args, **kwargs), - in_dims=(None, 0), - randomness="different", - ) def bound_batched_func_log_prob(p: ParamDict) -> torch.Tensor: - return batched_func_log_prob(p, points) + return batched_func_log_prob(p, points, *args, **kwargs) if pointwise_influence: score_fn = torch.func.jacrev(bound_batched_func_log_prob) diff --git a/chirho/robust/internals/predictive.py b/chirho/robust/internals/predictive.py index 6e011de93..19369924f 100644 --- a/chirho/robust/internals/predictive.py +++ b/chirho/robust/internals/predictive.py @@ -1,15 +1,21 @@ -import contextlib +import collections import math -import warnings -from typing import Any, Callable, Container, Generic, Optional, TypeVar +import typing +from typing import Any, Callable, Generic, Optional, TypeVar import pyro import torch from typing_extensions import ParamSpec -from chirho.indexed.handlers import DependentMaskMessenger -from chirho.observational.handlers import condition -from chirho.robust.internals.utils import guess_max_plate_nesting +from chirho.indexed.handlers import IndexPlatesMessenger +from chirho.indexed.ops import get_index_plates, indices_of +from chirho.observational.handlers.condition import Observations +from chirho.robust.internals.utils import ( + bind_leftmost_dim, + get_importance_traces, + site_is_delta, + unbind_leftmost_dim, +) from chirho.robust.ops import Point pyro.settings.set(module_local_params=True) @@ -20,27 +26,151 @@ T = TypeVar("T") -class _UnmaskNamedSites(DependentMaskMessenger): - names: Container[str] +class BatchedLatents(pyro.poutine.messenger.Messenger): + """ + Effect handler that adds a fresh batch dimension to all latent ``sample`` sites. + Similar to wrapping a Pyro model in a ``pyro.plate`` context, but uses the machinery + in ``chirho.indexed`` to automatically allocate and track the fresh batch dimension + based on the ``name`` argument to ``BatchedLatents`` . - def __init__(self, names: Container[str]): - self.names = names + .. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . - def get_mask( + :param int num_particles: Number of particles to use for parallelization. + :param str name: Name of the fresh batch dimension. + """ + + num_particles: int + name: str + + def __init__(self, num_particles: int, *, name: str = "__particles_mc"): + assert num_particles > 0 + assert len(name) > 0 + self.num_particles = num_particles + self.name = name + super().__init__() + + def _pyro_sample(self, msg: dict) -> None: + if ( + self.num_particles > 1 + and msg["value"] is None + and not pyro.poutine.util.site_is_factor(msg) + and not pyro.poutine.util.site_is_subsample(msg) + and not site_is_delta(msg) + and self.name not in indices_of(msg["fn"]) + ): + msg["fn"] = unbind_leftmost_dim( + msg["fn"].expand((1,) + msg["fn"].batch_shape), + self.name, + size=self.num_particles, + ) + + +class BatchedObservations(Generic[T], Observations[T]): + """ + Effect handler that takes a dictionary of observation values for ``sample`` sites + that are assumed to be batched along their leftmost dimension, adds a fresh named + dimension using the machinery in ``chirho.indexed``, and reshapes the observation + values so that the new ``chirho.observational.observe`` sites are batched along + the fresh named dimension. + + Useful in combination with ``pyro.infer.Predictive`` which returns a dictionary + of values whose leftmost dimension is a batch dimension over independent samples. + + .. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . + + :param Point[T] data: Dictionary of observation values. + :param str name: Name of the fresh batch dimension. + """ + + name: str + + def __init__(self, data: Point[T], *, name: str = "__particles_data"): + assert len(name) > 0 + self.name = name + super().__init__(data) + + def _pyro_observe(self, msg: dict) -> None: + super()._pyro_observe(msg) + if msg["kwargs"]["name"] in self.data: + rv, obs = msg["args"] + event_dim = ( + len(rv.event_shape) + if hasattr(rv, "event_shape") + else msg["kwargs"].get("event_dim", 0) + ) + batch_obs = unbind_leftmost_dim(obs, self.name, event_dim=event_dim) + msg["args"] = (rv, batch_obs) + + +class PredictiveModel(Generic[P, T], torch.nn.Module): + """ + Given a Pyro model and guide, constructs a new model that behaves as if + the latent ``sample`` sites in the original model (i.e. the prior) + were replaced by their counterparts in the guide (i.e. the posterior). + + .. note:: Sites that only appear in the model are annotated in traces + produced by the predictive model with ``infer={"_model_predictive_site": True}`` . + + :param model: Pyro model. + :param guide: Pyro guide. + """ + + model: Callable[P, T] + guide: Callable[P, Any] + + def __init__( self, - dist: pyro.distributions.Distribution, - value: Optional[torch.Tensor], - device: torch.device = torch.device("cpu"), - name: Optional[str] = None, - ) -> torch.Tensor: - return torch.tensor(name is None or name in self.names, device=device) + model: Callable[P, T], + guide: Callable[P, Any], + ): + super().__init__() + self.model = model + self.guide = guide + + def forward(self, *args: P.args, **kwargs: P.kwargs) -> T: + with pyro.poutine.trace() as guide_tr: + self.guide(*args, **kwargs) + + block_guide_sample_sites = pyro.poutine.block( + hide=[ + name + for name, node in guide_tr.trace.nodes.items() + if node["type"] == "sample" + ] + ) + + with pyro.poutine.infer_config( + config_fn=lambda msg: {"_model_predictive_site": True} + ): + with block_guide_sample_sites: + with pyro.poutine.replay(trace=guide_tr.trace): + return self.model(*args, **kwargs) class PredictiveFunctional(Generic[P, T], torch.nn.Module): + """ + Functional that returns a batch of samples from the posterior predictive + distribution of a Pyro model given a guide. As with ``pyro.infer.Predictive`` , + the returned values are batched along their leftmost positional dimension. + + Similar to ``pyro.infer.Predictive(model, guide, num_samples, parallel=True)`` + but uses :class:`~PredictiveModel` to construct the predictive distribution + and infer the model ``sample`` sites whose values should be returned, + and uses :class:`~BatchedLatents` to parallelize over samples from the guide. + + .. warning:: ``PredictiveFunctional`` currently applies its own internal instance of + :class:`~chirho.indexed.handlers.IndexPlatesMessenger` , + so it may not behave as expected if used within another enclosing + :class:`~chirho.indexed.handlers.IndexPlatesMessenger` context. + + :param model: Pyro model. + :param guide: Pyro guide. + :param num_samples: Number of samples to return. + """ + model: Callable[P, Any] guide: Callable[P, Any] num_samples: int - max_plate_nesting: Optional[int] def __init__( self, @@ -49,59 +179,41 @@ def __init__( *, num_samples: int = 1, max_plate_nesting: Optional[int] = None, + name: str = "__particles_predictive", ): super().__init__() self.model = model self.guide = guide self.num_samples = num_samples - self.max_plate_nesting = max_plate_nesting - - def forward(self, *args: P.args, **kwargs: P.kwargs) -> Point[T]: - if self.max_plate_nesting is None: - self.max_plate_nesting = guess_max_plate_nesting( - self.model, self.guide, *args, **kwargs - ) - - particles_plate = ( - contextlib.nullcontext() - if self.num_samples == 1 - else pyro.plate( - "__predictive_particles", - self.num_samples, - dim=-self.max_plate_nesting - 1, - ) + self._predictive_model: PredictiveModel[P, Any] = PredictiveModel(model, guide) + self._first_available_dim = ( + -max_plate_nesting - 1 if max_plate_nesting is not None else None ) + self._mc_plate_name = name - with pyro.poutine.trace() as guide_tr, particles_plate: - self.guide(*args, **kwargs) + def forward(self, *args: P.args, **kwargs: P.kwargs) -> Point[T]: + with IndexPlatesMessenger(first_available_dim=self._first_available_dim): + with pyro.poutine.trace() as model_tr: + with BatchedLatents(self.num_samples, name=self._mc_plate_name): + self._predictive_model(*args, **kwargs) - block_guide_sample_sites = pyro.poutine.block( - hide=[ - name - for name, node in guide_tr.trace.nodes.items() + return { + name: bind_leftmost_dim( + node["value"], + self._mc_plate_name, + event_dim=len(node["fn"].event_shape), + ) + for name, node in model_tr.trace.nodes.items() if node["type"] == "sample" and not pyro.poutine.util.site_is_subsample(node) - ] - ) + and node["infer"].get("_model_predictive_site", False) + } - with pyro.poutine.trace() as model_tr: - with block_guide_sample_sites: - with pyro.poutine.replay(trace=guide_tr.trace), particles_plate: - self.model(*args, **kwargs) - return { - name: node["value"] - for name, node in model_tr.trace.nodes.items() - if node["type"] == "sample" - and not pyro.poutine.util.site_is_subsample(node) - } - - -class NMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module): +class BatchedNMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module): model: Callable[P, Any] guide: Callable[P, Any] num_samples: int - max_plate_nesting: Optional[int] def __init__( self, @@ -110,36 +222,71 @@ def __init__( *, num_samples: int = 1, max_plate_nesting: Optional[int] = None, + data_plate_name: str = "__particles_data", + mc_plate_name: str = "__particles_mc", ): super().__init__() self.model = model self.guide = guide self.num_samples = num_samples - self.max_plate_nesting = max_plate_nesting + self._first_available_dim = ( + -max_plate_nesting - 1 if max_plate_nesting is not None else None + ) + self._data_plate_name = data_plate_name + self._mc_plate_name = mc_plate_name def forward( self, data: Point[T], *args: P.args, **kwargs: P.kwargs ) -> torch.Tensor: - if self.max_plate_nesting is None: - self.max_plate_nesting = guess_max_plate_nesting( - self.model, self.guide, *args, **kwargs - ) - warnings.warn( - "Since max_plate_nesting is not specified, \ - the first call to NMCLogPredictiveLikelihood will not be seeded properly. \ - See https://github.com/BasisResearch/chirho/pull/408" - ) + get_nmc_traces = get_importance_traces(PredictiveModel(self.model, self.guide)) + + with IndexPlatesMessenger(first_available_dim=self._first_available_dim): + with BatchedLatents(self.num_samples, name=self._mc_plate_name): + with BatchedObservations(data, name=self._data_plate_name): + model_trace, guide_trace = get_nmc_traces(*args, **kwargs) + index_plates = get_index_plates() - masked_guide = pyro.poutine.mask(mask=False)(self.guide) - masked_model = _UnmaskNamedSites(names=set(data.keys()))( - condition(data=data)(self.model) + plate_name_to_dim = collections.OrderedDict( + (p, index_plates[p]) + for p in [self._mc_plate_name, self._data_plate_name] + if p in index_plates ) - log_weights = pyro.infer.importance.vectorized_importance_weights( - masked_model, - masked_guide, - *args, - num_samples=self.num_samples, - max_plate_nesting=self.max_plate_nesting, - **kwargs, - )[0] - return torch.logsumexp(log_weights, dim=0) - math.log(self.num_samples) + plate_frames = set(plate_name_to_dim.values()) + + log_weights = typing.cast(torch.Tensor, 0.0) + for site in model_trace.nodes.values(): + if site["type"] != "sample": + continue + site_log_prob = site["log_prob"] + for f in site["cond_indep_stack"]: + if f.dim is not None and f not in plate_frames: + site_log_prob = site_log_prob.sum(f.dim, keepdim=True) + log_weights = log_weights + site_log_prob + + for site in guide_trace.nodes.values(): + if site["type"] != "sample": + continue + site_log_prob = site["log_prob"] + for f in site["cond_indep_stack"]: + if f.dim is not None and f not in plate_frames: + site_log_prob = site_log_prob.sum(f.dim, keepdim=True) + log_weights = log_weights - site_log_prob + + # sum out particle dimension and discard + if self._mc_plate_name in index_plates: + log_weights = torch.logsumexp( + log_weights, + dim=plate_name_to_dim[self._mc_plate_name].dim, + keepdim=True, + ) - math.log(self.num_samples) + plate_name_to_dim.pop(self._mc_plate_name) + + # move data plate dimension to the left + for name in reversed(plate_name_to_dim.keys()): + log_weights = bind_leftmost_dim(log_weights, name) + + # pack log_weights by squeezing out rightmost dimensions + for _ in range(len(log_weights.shape) - len(plate_name_to_dim)): + log_weights = log_weights.squeeze(-1) + + return log_weights diff --git a/chirho/robust/internals/utils.py b/chirho/robust/internals/utils.py index 7af0af4e0..12094a1ef 100644 --- a/chirho/robust/internals/utils.py +++ b/chirho/robust/internals/utils.py @@ -1,11 +1,15 @@ import contextlib import functools -from typing import Any, Callable, Dict, Mapping, Tuple, TypeVar +import math +from typing import Any, Callable, Dict, Mapping, Optional, Tuple, TypeVar import pyro import torch from typing_extensions import Concatenate, ParamSpec +from chirho.indexed.handlers import add_indices +from chirho.indexed.ops import IndexSet, get_index_plates, indices_of + P = ParamSpec("P") Q = ParamSpec("Q") S = TypeVar("S") @@ -106,3 +110,126 @@ def reset_rng_state(rng_state: T): yield pyro.util.set_rng_state(rng_state) finally: pyro.util.set_rng_state(prev_rng_state) + + +@functools.singledispatch +def unbind_leftmost_dim(v, name: str, size: int = 1, **kwargs): + """ + Helper function to move the leftmost dimension of a ``torch.Tensor`` + or ``pyro.distributions.Distribution`` or other batched value + into a fresh named dimension using the machinery in ``chirho.indexed`` , + allocating a new dimension with the given name if necessary + via an enclosing :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . + + .. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . + + :param v: Batched value. + :param name: Name of the fresh dimension. + :param size: Size of the fresh dimension. If 1, the size is inferred from ``v`` . + """ + raise NotImplementedError + + +@unbind_leftmost_dim.register +def _unbind_leftmost_dim_tensor( + v: torch.Tensor, name: str, size: int = 1, *, event_dim: int = 0 +) -> torch.Tensor: + size = max(size, v.shape[0]) + v = v.expand((size,) + v.shape[1:]) + + if name not in get_index_plates(): + add_indices(IndexSet(**{name: set(range(size))})) + + new_dim: int = get_index_plates()[name].dim + orig_shape = v.shape + while new_dim - event_dim < -len(v.shape): + v = v[None] + if v.shape[0] == 1 and orig_shape[0] != 1: + v = torch.transpose(v, -len(orig_shape), new_dim - event_dim) + return v + + +@unbind_leftmost_dim.register +def _unbind_leftmost_dim_distribution( + v: pyro.distributions.Distribution, name: str, size: int = 1, **kwargs +) -> pyro.distributions.Distribution: + size = max(size, v.batch_shape[0]) + if v.batch_shape[0] != 1: + raise NotImplementedError("Cannot freely reshape distribution") + + if name not in get_index_plates(): + add_indices(IndexSet(**{name: set(range(size))})) + + new_dim: int = get_index_plates()[name].dim + orig_shape = v.batch_shape + + new_shape = (size,) + (1,) * (-new_dim - len(orig_shape)) + orig_shape[1:] + return v.expand(new_shape) + + +@functools.singledispatch +def bind_leftmost_dim(v, name: str, **kwargs): + """ + Helper function to move a named dimension managed by ``chirho.indexed`` + into a new unnamed dimension to the left of all named dimensions in the value. + + .. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . + """ + raise NotImplementedError + + +@bind_leftmost_dim.register +def _bind_leftmost_dim_tensor( + v: torch.Tensor, name: str, *, event_dim: int = 0, **kwargs +) -> torch.Tensor: + if name not in indices_of(v, event_dim=event_dim): + return v + return torch.transpose( + v[None], -len(v.shape) - 1, get_index_plates()[name].dim - event_dim + ) + + +def get_importance_traces( + model: Callable[P, Any], + guide: Optional[Callable[P, Any]] = None, +) -> Callable[P, Tuple[pyro.poutine.Trace, pyro.poutine.Trace]]: + """ + Thin functional wrapper around :func:`~pyro.infer.enum.get_importance_trace` + that cleans up the original interface to avoid unnecessary arguments + and efficiently supports using the prior in a model as a default guide. + + :param model: Model to run. + :param guide: Guide to run. If ``None``, use the prior in ``model`` as a guide. + :returns: A function that takes the same arguments as ``model`` and ``guide`` and returns + a tuple of importance traces ``(model_trace, guide_trace)``. + """ + + def _fn( + *args: P.args, **kwargs: P.kwargs + ) -> Tuple[pyro.poutine.Trace, pyro.poutine.Trace]: + if guide is not None: + model_trace, guide_trace = pyro.infer.enum.get_importance_trace( + "flat", math.inf, model, guide, args, kwargs + ) + return model_trace, guide_trace + else: # use prior as default guide, but don't run model twice + model_trace, _ = pyro.infer.enum.get_importance_trace( + "flat", math.inf, model, lambda *_, **__: None, args, kwargs + ) + + guide_trace = model_trace.copy() + for name, node in list(guide_trace.nodes.items()): + if node["type"] != "sample": + del model_trace.nodes[name] + elif pyro.poutine.util.site_is_factor(node) or node["is_observed"]: + del guide_trace.nodes[name] + return model_trace, guide_trace + + return _fn + + +def site_is_delta(msg: dict) -> bool: + d = msg["fn"] + while hasattr(d, "base_dist"): + d = d.base_dist + return isinstance(d, pyro.distributions.Delta) diff --git a/tests/robust/test_internals_compositions.py b/tests/robust/test_internals_compositions.py index 1e7aea907..b9924fab5 100644 --- a/tests/robust/test_internals_compositions.py +++ b/tests/robust/test_internals_compositions.py @@ -5,11 +5,17 @@ import pytest import torch +from chirho.indexed.handlers import IndexPlatesMessenger +from chirho.indexed.ops import indices_of from chirho.robust.internals.linearize import ( conjugate_gradient_solve, make_empirical_fisher_vp, ) -from chirho.robust.internals.predictive import NMCLogPredictiveLikelihood +from chirho.robust.internals.predictive import ( + BatchedLatents, + BatchedNMCLogPredictiveLikelihood, + BatchedObservations, +) from chirho.robust.internals.utils import make_functional_call, reset_rng_state from .robust_fixtures import SimpleGuide, SimpleModel @@ -21,7 +27,7 @@ def test_empirical_fisher_vp_nmclikelihood_cg_composition(): model = SimpleModel() guide = SimpleGuide() model(), guide() # initialize - log_prob = NMCLogPredictiveLikelihood(model, guide, num_samples=100) + log_prob = BatchedNMCLogPredictiveLikelihood(model, guide, num_samples=100) log_prob_params, func_log_prob = make_functional_call(log_prob) func_log_prob = reset_rng_state(pyro.util.get_rng_state())(func_log_prob) @@ -49,6 +55,7 @@ def test_empirical_fisher_vp_nmclikelihood_cg_composition(): # For this model, fvp for loc_a is zero. See # https://github.com/BasisResearch/chirho/issues/427 assert fvp(v)["guide.loc_a"].abs().max() == 0 + assert all(fvp_vk.shape == v[k].shape for k, fvp_vk in fvp(v).items()) solve_one = cg_solver(fvp, v) solve_two = cg_solver(fvp, v) @@ -88,7 +95,7 @@ def test_nmc_likelihood_seeded(link_fn): guide = SimpleGuide() model(), guide() # initialize - log_prob = NMCLogPredictiveLikelihood( + log_prob = BatchedNMCLogPredictiveLikelihood( model, guide, num_samples=3, max_plate_nesting=3 ) log_prob_params, func_log_prob = make_functional_call(log_prob) @@ -113,3 +120,96 @@ def test_nmc_likelihood_seeded(link_fn): # Check if fvp agrees across multiple calls of same `fvp` object assert torch.allclose(fvp(v)["guide.loc_a"], fvp(v)["guide.loc_a"]) assert torch.allclose(fvp(v)["guide.loc_b"], fvp(v)["guide.loc_b"]) + + +@pytest.mark.parametrize("pad_dim", [0, 1, 2]) +def test_batched_observations(pad_dim: int): + max_plate_nesting = 1 + pad_dim + obs_plate_name = "__dummy_plate__" + num_particles_obs = 3 + model = SimpleModel() + guide = SimpleGuide() + + model(), guide() # initialize + + predictive = pyro.infer.Predictive( + model, num_samples=num_particles_obs, return_sites=["y"] + ) + + test_data = predictive() + + with IndexPlatesMessenger(first_available_dim=-max_plate_nesting - 1): + with pyro.poutine.trace() as tr: + with BatchedObservations(test_data, name=obs_plate_name): + model() + + tr.trace.compute_log_prob() + + for name, node in tr.trace.nodes.items(): + if node["type"] == "sample" and not pyro.poutine.util.site_is_subsample( + node + ): + if name in test_data: + assert obs_plate_name in indices_of(node["log_prob"], event_dim=0) + assert obs_plate_name in indices_of( + node["value"], event_dim=len(node["fn"].event_shape) + ) + else: + assert obs_plate_name not in indices_of( + node["log_prob"], event_dim=0 + ) + assert obs_plate_name not in indices_of( + node["value"], event_dim=len(node["fn"].event_shape) + ) + + +@pytest.mark.parametrize("pad_dim", [0, 1, 2]) +def test_batched_latents_observations(pad_dim: int): + max_plate_nesting = 1 + pad_dim + num_particles_latent = 5 + num_particles_obs = 3 + obs_plate_name = "__dummy_plate__" + latent_plate_name = "__dummy_latents__" + model = SimpleModel() + guide = SimpleGuide() + + model(), guide() # initialize + + predictive = pyro.infer.Predictive( + model, num_samples=num_particles_obs, return_sites=["y"] + ) + + test_data = predictive() + + with IndexPlatesMessenger(first_available_dim=-max_plate_nesting - 1): + with pyro.poutine.trace() as tr: + with BatchedLatents( + num_particles=num_particles_latent, name=latent_plate_name + ): + with BatchedObservations(test_data, name=obs_plate_name): + model() + + tr.trace.compute_log_prob() + + for name, node in tr.trace.nodes.items(): + if node["type"] == "sample" and not pyro.poutine.util.site_is_subsample( + node + ): + if name in test_data: + assert obs_plate_name in indices_of(node["log_prob"], event_dim=0) + assert obs_plate_name in indices_of( + node["value"], event_dim=len(node["fn"].event_shape) + ) + assert latent_plate_name in indices_of( + node["log_prob"], event_dim=0 + ) + assert latent_plate_name not in indices_of( + node["value"], event_dim=len(node["fn"].event_shape) + ) + else: + assert latent_plate_name in indices_of( + node["log_prob"], event_dim=0 + ) + assert latent_plate_name in indices_of( + node["value"], event_dim=len(node["fn"].event_shape) + ) diff --git a/tests/robust/test_performance.py b/tests/robust/test_performance.py new file mode 100644 index 000000000..b1ec08f29 --- /dev/null +++ b/tests/robust/test_performance.py @@ -0,0 +1,178 @@ +import math +import time +import warnings +from functools import partial +from typing import Any, Callable, Container, Generic, Optional, TypeVar + +import pyro +import pytest +import torch +from typing_extensions import ParamSpec + +from chirho.indexed.handlers import DependentMaskMessenger +from chirho.observational.handlers import condition +from chirho.robust.internals.linearize import make_empirical_fisher_vp +from chirho.robust.internals.predictive import BatchedNMCLogPredictiveLikelihood +from chirho.robust.internals.utils import guess_max_plate_nesting, make_functional_call +from chirho.robust.ops import Point + +from .robust_fixtures import SimpleGuide, SimpleModel + +pyro.settings.set(module_local_params=True) + +P = ParamSpec("P") +Q = ParamSpec("Q") +S = TypeVar("S") +T = TypeVar("T") + + +class _UnmaskNamedSites(DependentMaskMessenger): + names: Container[str] + + def __init__(self, names: Container[str]): + self.names = names + + def get_mask( + self, + dist: pyro.distributions.Distribution, + value: Optional[torch.Tensor], + device: torch.device = torch.device("cpu"), + name: Optional[str] = None, + ) -> torch.Tensor: + return torch.tensor(name is None or name in self.names, device=device) + + +class OldNMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module): + model: Callable[P, Any] + guide: Callable[P, Any] + num_samples: int + max_plate_nesting: Optional[int] + + def __init__( + self, + model: torch.nn.Module, + guide: torch.nn.Module, + *, + num_samples: int = 1, + max_plate_nesting: Optional[int] = None, + ): + super().__init__() + self.model = model + self.guide = guide + self.num_samples = num_samples + self.max_plate_nesting = max_plate_nesting + + def forward( + self, data: Point[T], *args: P.args, **kwargs: P.kwargs + ) -> torch.Tensor: + if self.max_plate_nesting is None: + self.max_plate_nesting = guess_max_plate_nesting( + self.model, self.guide, *args, **kwargs + ) + warnings.warn( + "Since max_plate_nesting is not specified, \ + the first call to NMCLogPredictiveLikelihood will not be seeded properly. \ + See https://github.com/BasisResearch/chirho/pull/408" + ) + + masked_guide = pyro.poutine.mask(mask=False)(self.guide) + masked_model = _UnmaskNamedSites(names=set(data.keys()))( + condition(data=data)(self.model) + ) + log_weights = pyro.infer.importance.vectorized_importance_weights( + masked_model, + masked_guide, + *args, + num_samples=self.num_samples, + max_plate_nesting=self.max_plate_nesting, + **kwargs, + )[0] + return torch.logsumexp(log_weights, dim=0) - math.log(self.num_samples) + + +class SimpleMultivariateGaussianModel(pyro.nn.PyroModule): + def __init__(self, p): + super().__init__() + self.p = p + + def forward(self): + loc = pyro.sample( + "loc", pyro.distributions.Normal(torch.zeros(self.p), 1.0).to_event(1) + ) + cov_mat = torch.eye(self.p) + return pyro.sample("y", pyro.distributions.MultivariateNormal(loc, cov_mat)) + + +class SimpleMultivariateGuide(torch.nn.Module): + def __init__(self, p): + super().__init__() + self.loc_ = torch.nn.Parameter(torch.rand((p,))) + self.p = p + + def forward(self): + return pyro.sample("loc", pyro.distributions.Normal(self.loc_, 1).to_event(1)) + + +model_guide_types = [ + ( + partial(SimpleMultivariateGaussianModel, p=500), + partial(SimpleMultivariateGuide, p=500), + ), + (SimpleModel, SimpleGuide), +] + + +@pytest.mark.skip(reason="This test is too slow to run on CI") +@pytest.mark.parametrize("model_guide", model_guide_types) +def test_empirical_fisher_vp_performance_with_likelihood(model_guide): + num_monte_carlo = 10000 + model_family, guide_family = model_guide + + model = model_family() + guide = guide_family() + + model() + guide() + + start_time = time.time() + data = pyro.infer.Predictive( + model, guide=guide, num_samples=num_monte_carlo, return_sites=["y"] + )() + end_time = time.time() + print("Data generation time (s): ", end_time - start_time) + + log1_prob_params, func1_log_prob = make_functional_call( + OldNMCLogPredictiveLikelihood(model, guide, max_plate_nesting=1) + ) + batched_func1_log_prob = torch.func.vmap( + func1_log_prob, in_dims=(None, 0), randomness="different" + ) + + log2_prob_params, func2_log_prob = make_functional_call( + BatchedNMCLogPredictiveLikelihood(model, guide) + ) + + fisher_hessian_vmapped = make_empirical_fisher_vp( + batched_func1_log_prob, log1_prob_params, data + ) + + fisher_hessian_batched = make_empirical_fisher_vp( + func2_log_prob, log2_prob_params, data + ) + + v = { + k: torch.ones_like(v) if k != "guide.loc_a" else torch.zeros_like(v) + for k, v in log1_prob_params.items() + } + + func2_log_prob(log2_prob_params, data) + + start_time = time.time() + fisher_hessian_vmapped(v) + end_time = time.time() + print("Hessian vmapped time (s): ", end_time - start_time) + + start_time = time.time() + fisher_hessian_batched(v) + end_time = time.time() + print("Hessian manual batched time (s): ", end_time - start_time) From 3cfe31902c9db6ff7af3466a43d0cf1759eec2a5 Mon Sep 17 00:00:00 2001 From: Raj Agrawal Date: Tue, 2 Jan 2024 13:27:56 -0800 Subject: [PATCH 54/66] Added documentation for `chirho.robust` (#470) * documentation * documentation clean up w/ eli * fix lint issue --- chirho/robust/handlers/estimators.py | 20 +++ chirho/robust/internals/linearize.py | 227 ++++++++++++++++++++++++-- chirho/robust/internals/predictive.py | 41 +++++ chirho/robust/internals/utils.py | 37 +++++ chirho/robust/ops.py | 103 ++++++++++++ docs/source/robust.rst | 4 + 6 files changed, 420 insertions(+), 12 deletions(-) diff --git a/chirho/robust/handlers/estimators.py b/chirho/robust/handlers/estimators.py index 146282445..cab461ead 100644 --- a/chirho/robust/handlers/estimators.py +++ b/chirho/robust/handlers/estimators.py @@ -11,6 +11,26 @@ def one_step_correction( functional: Optional[Functional[P, S]] = None, **influence_kwargs, ) -> Callable[Concatenate[Point[T], P], S]: + """ + Returns a function that computes the one-step correction for the + functional at a specified set of test points as discussed in + [1]. + + :param model: Python callable containing Pyro primitives. + :type model: Callable[P, Any] + :param guide: Python callable containing Pyro primitives. + :type guide: Callable[P, Any] + :param functional: model summary of interest, which is a function of the + model and guide. If ``None``, defaults to :class:`PredictiveFunctional`. + :type functional: Optional[Functional[P, S]], optional + :return: function to compute the one-step correction + :rtype: Callable[Concatenate[Point[T], P], S] + + **References** + + [1] `Semiparametric doubly robust targeted double machine learning: a review`, + Edward H. Kennedy, 2022. + """ influence_kwargs_one_step = influence_kwargs.copy() influence_kwargs_one_step["pointwise_influence"] = False eif_fn = influence_fn(model, guide, functional, **influence_kwargs_one_step) diff --git a/chirho/robust/internals/linearize.py b/chirho/robust/internals/linearize.py index d02ef1207..e4fbdd115 100644 --- a/chirho/robust/internals/linearize.py +++ b/chirho/robust/internals/linearize.py @@ -27,20 +27,26 @@ def _flat_conjugate_gradient_solve( cg_iters: Optional[int] = None, residual_tol: float = 1e-3, ) -> torch.Tensor: - r"""Use Conjugate Gradient iteration to solve Ax = b. Demmel p 312. + """ + Use Conjugate Gradient iteration to solve Ax = b. Demmel p 312. + + :param f_Ax: a function to compute matrix vector products over a batch + of vectors ``x``. + :type f_Ax: Callable[[torch.Tensor], torch.Tensor] + :param b: batch of right hand sides of the equation to solve. + :type b: torch.Tensor + :param cg_iters: number of conjugate iterations to run, defaults to None + :type cg_iters: Optional[int], optional + :param residual_tol: tolerance for convergence, defaults to 1e-3 + :type residual_tol: float, optional + :return: batch of solutions ``x*`` for equation Ax = b. + :rtype: torch.Tensor - Args: - f_Ax (callable): A function to compute matrix vector product. - b (torch.Tensor): Right hand side of the equation to solve. - cg_iters (int): Number of iterations to run conjugate gradient - algorithm. - residual_tol (float): Tolerence for convergence. + .. note:: - Returns: - torch.Tensor: Solution x* for equation Ax = b. + Code is adapted from + https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py # noqa: E501 - Notes: This code is adapted from - https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py """ assert len(b.shape), "b must be a 2D matrix" @@ -81,6 +87,17 @@ def _batched_product(a: torch.Tensor, B: torch.Tensor) -> torch.Tensor: def conjugate_gradient_solve(f_Ax: Callable[[T], T], b: T, **kwargs) -> T: + """ + Use Conjugate Gradient iteration to solve Ax = b. + + :param f_Ax: a function to compute matrix vector products over a batch + of vectors ``x``. + :type f_Ax: Callable[[T], T] + :param b: batch of right hand sides of the equation to solve. + :type b: T + :return: batch of solutions ``x*`` for equation Ax = b. + :rtype: T + """ flatten, unflatten = make_flatten_unflatten(b) def f_Ax_flat(v: torch.Tensor) -> torch.Tensor: @@ -98,6 +115,90 @@ def make_empirical_fisher_vp( *args: P.args, **kwargs: P.kwargs, ) -> Callable[[ParamDict], ParamDict]: + r""" + Returns a function that computes the empirical Fisher vector product for an arbitrary + vector :math:`v` using only Hessian vector products via a batched version of + Perlmutter's trick [1]. + + .. math:: + + -\frac{1}{N} \sum_{n=1}^N \nabla_{\phi}^2 \log \tilde{p}_{\phi}(x_n) v, + + where :math:`\phi` corresponds to ``log_prob_params``, :math:`\tilde{p}_{\phi}` denotes the + predictive distribution ``log_prob``, and :math:`x_n` are the data points in ``data``. + + :param func_log_prob: computes the log probability of ``data`` given ``log_prob_params`` + :type func_log_prob: Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor] + :param log_prob_params: parameters of the predictive distribution + :type log_prob_params: ParamDict + :param data: data points + :type data: Point[T] + :param is_batched: if ``False``, ``func_log_prob`` is batched over ``data`` + using ``torch.func.vmap``. Otherwise, assumes ``func_log_prob`` is already batched + over multiple data points. ``Defaults to False``. + :type is_batched: bool, optional + :return: a function that computes the empirical Fisher vector product for an arbitrary + vector :math:`v` + :rtype: Callable[[ParamDict], ParamDict] + + **Example usage**: + + .. code-block:: python + + import pyro + import pyro.distributions as dist + import torch + + from chirho.robust.internals.linearize import make_empirical_fisher_vp + + pyro.settings.set(module_local_params=True) + + + class GaussianModel(pyro.nn.PyroModule): + def __init__(self, cov_mat: torch.Tensor): + super().__init__() + self.register_buffer("cov_mat", cov_mat) + + def forward(self, loc): + pyro.sample( + "x", dist.MultivariateNormal(loc=loc, covariance_matrix=self.cov_mat) + ) + + + def gaussian_log_prob(params, data_point, cov_mat): + with pyro.validation_enabled(False): + return dist.MultivariateNormal( + loc=params["loc"], covariance_matrix=cov_mat + ).log_prob(data_point["x"]) + + + v = torch.tensor([1.0, 0.0], requires_grad=False) + loc = torch.ones(2, requires_grad=True) + cov_mat = torch.ones(2, 2) + torch.eye(2) + + func_log_prob = gaussian_log_prob + log_prob_params = {"loc": loc} + N_monte_carlo = 10000 + data = pyro.infer.Predictive(GaussianModel(cov_mat), num_samples=N_monte_carlo)(loc) + empirical_fisher_vp_func = make_empirical_fisher_vp( + func_log_prob, log_prob_params, data, cov_mat=cov_mat + ) + + empirical_fisher_vp = empirical_fisher_vp_func({"loc": v})["loc"] + + # Closed form solution for the Fisher vector product + # See "Multivariate normal distribution" in https://en.wikipedia.org/wiki/Fisher_information + prec_matrix = torch.linalg.inv(cov_mat) + true_vp = prec_matrix.mv(v) + + assert torch.all(torch.isclose(empirical_fisher_vp, true_vp, atol=0.1)) + + + **References** + + [1] `Fast Exact Multiplication by the Hessian`, + Barak A. Pearlmutter, 1999. + """ N = data[next(iter(data))].shape[0] # type: ignore mean_vector = 1 / N * torch.ones(N) @@ -125,9 +226,111 @@ def linearize( num_samples_inner: Optional[int] = None, max_plate_nesting: Optional[int] = None, cg_iters: Optional[int] = None, - residual_tol: float = 1e-10, + residual_tol: float = 1e-4, pointwise_influence: bool = True, ) -> Callable[Concatenate[Point[T], P], ParamDict]: + r""" + Returns the influence function associated with the parameters + of ``guide`` and probabilistic program ``model``. This function + computes the following quantity at an arbitrary point :math:`x^{\prime}`: + + .. math:: + + \left[-\frac{1}{N} \sum_{n=1}^N \nabla_{\phi}^2 \log \tilde{p}_{\phi}(x_n) \right] + \nabla_{\phi} \log \tilde{p}_{\phi}(x^{\prime}), \quad + \tilde{p}_{\phi}(x) = \int p(x \mid \theta) q_{\phi}(\theta) d\theta, + + where :math:`\phi` corresponds to ``log_prob_params``, + :math:`p(x \mid \theta)` denotes the ``model``, :math:`q_{\phi}` denotes the ``guide``, + :math:`\tilde{p}_{\phi}` denotes the predictive distribution ``log_prob`` induced + from the ``model`` and ``guide``, and :math:`\{x_n\}_{n=1}^N` are the + data points drawn iid from the predictive distribution. + + :param model: Python callable containing Pyro primitives. + :type model: Callable[P, Any] + :param guide: Python callable containing Pyro primitives. + Must only contain continuous latent variables. + :type guide: Callable[P, Any] + :param num_samples_outer: number of Monte Carlo samples to + approximate Fisher information in :func:`make_empirical_fisher_vp` + :type num_samples_outer: int + :param num_samples_inner: number of Monte Carlo samples used in + :class:`BatchedNMCLogPredictiveLikelihood`. Defaults to ``num_samples_outer**2``. + :type num_samples_inner: Optional[int], optional + :param max_plate_nesting: bound on max number of nested :func:`pyro.plate` + contexts. Defaults to ``None``. + :type max_plate_nesting: Optional[int], optional + :param cg_iters: number of conjugate gradient steps used to + invert Fisher information matrix, defaults to None + :type cg_iters: Optional[int], optional + :param residual_tol: tolerance used to terminate conjugate gradients + early, defaults to 1e-4 + :type residual_tol: float, optional + :param pointwise_influence: if ``True``, computes the influence function at each + point in ``points``. If ``False``, computes the efficient influence averaged + over ``points``. Defaults to True. + :type pointwise_influence: bool, optional + :return: the influence function associated with the parameters + :rtype: Callable[Concatenate[Point[T], P], ParamDict] + + **Example usage**: + + .. code-block:: python + import pyro + import pyro.distributions as dist + import torch + + from chirho.robust.internals.linearize import linearize + + pyro.settings.set(module_local_params=True) + + + class SimpleModel(pyro.nn.PyroModule): + def forward(self): + a = pyro.sample("a", dist.Normal(0, 1)) + with pyro.plate("data", 3, dim=-1): + b = pyro.sample("b", dist.Normal(a, 1)) + return pyro.sample("y", dist.Normal(b, 1)) + + + class SimpleGuide(torch.nn.Module): + def __init__(self): + super().__init__() + self.loc_a = torch.nn.Parameter(torch.rand(())) + self.loc_b = torch.nn.Parameter(torch.rand((3,))) + + def forward(self): + a = pyro.sample("a", dist.Normal(self.loc_a, 1)) + with pyro.plate("data", 3, dim=-1): + b = pyro.sample("b", dist.Normal(self.loc_b, 1)) + return {"a": a, "b": b} + + model = SimpleModel() + guide = SimpleGuide() + predictive = pyro.infer.Predictive( + model, guide=guide, num_samples=10, return_sites=["y"] + ) + points = predictive() + influence = linearize( + model, + guide, + num_samples_outer=1000, + num_samples_inner=1000, + ) + + influence(points) + + .. note:: + + * Since the efficient influence function is approximated using Monte Carlo, the result + of this function is stochastic, i.e., evaluating this function on the same ``points`` + can result in different values. To reduce variance, increase ``num_samples_outer`` and + ``num_samples_inner`` in ``linearize_kwargs``. + + * Currently, ``model`` and ``guide`` cannot contain any ``pyro.param`` statements. + This issue will be addressed in a future release: + https://github.com/BasisResearch/chirho/issues/393. + """ assert isinstance(model, torch.nn.Module) assert isinstance(guide, torch.nn.Module) if num_samples_inner is None: diff --git a/chirho/robust/internals/predictive.py b/chirho/robust/internals/predictive.py index 19369924f..8b9721f44 100644 --- a/chirho/robust/internals/predictive.py +++ b/chirho/robust/internals/predictive.py @@ -128,6 +128,12 @@ def __init__( self.guide = guide def forward(self, *args: P.args, **kwargs: P.kwargs) -> T: + """ + Returns a sample from the posterior predictive distribution. + + :return: Sample from the posterior predictive distribution. + :rtype: T + """ with pyro.poutine.trace() as guide_tr: self.guide(*args, **kwargs) @@ -192,6 +198,12 @@ def __init__( self._mc_plate_name = name def forward(self, *args: P.args, **kwargs: P.kwargs) -> Point[T]: + """ + Returns a batch of samples from the posterior predictive distribution. + + :return: Dictionary of samples from the posterior predictive distribution. + :rtype: Point[T] + """ with IndexPlatesMessenger(first_available_dim=self._first_available_dim): with pyro.poutine.trace() as model_tr: with BatchedLatents(self.num_samples, name=self._mc_plate_name): @@ -211,6 +223,27 @@ def forward(self, *args: P.args, **kwargs: P.kwargs) -> Point[T]: class BatchedNMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module): + r""" + Approximates the log predictive likelihood induced by ``model`` and ``guide`` + using Monte Carlo sampling at an arbitrary batch of :math:`N` + points :math:`\{x_n\}_{n=1}^N`. + + .. math:: + \log \left(\frac{1}{M} \sum_{m=1}^M p(x_n \mid \theta_m)\right), + \quad \theta_m \sim q_{\phi}(\theta), + + where :math:`q_{\phi}(\theta)` is the guide and :math:`p(x_n \mid \theta_m)` + is the model conditioned on the latents from the guide. + + :param model: Python callable containing Pyro primitives. + :type model: torch.nn.Module + :param guide: Python callable containing Pyro primitives. + Must only contain continuous latent variables. + :type guide: torch.nn.Module + :param num_samples: Number of Monte Carlo draws :math:`M` + used to approximate predictive distribution, defaults to 1 + :type num_samples: int, optional + """ model: Callable[P, Any] guide: Callable[P, Any] num_samples: int @@ -238,6 +271,14 @@ def __init__( def forward( self, data: Point[T], *args: P.args, **kwargs: P.kwargs ) -> torch.Tensor: + """ + Computes the log predictive likelihood of ``data`` given ``model`` and ``guide``. + + :param data: Dictionary of observations. + :type data: Point[T] + :return: Log predictive likelihood at each datapoint. + :rtype: torch.Tensor + """ get_nmc_traces = get_importance_traces(PredictiveModel(self.model, self.guide)) with IndexPlatesMessenger(first_available_dim=self._first_available_dim): diff --git a/chirho/robust/internals/utils.py b/chirho/robust/internals/utils.py index 12094a1ef..9289027a7 100644 --- a/chirho/robust/internals/utils.py +++ b/chirho/robust/internals/utils.py @@ -22,11 +22,23 @@ def make_flatten_unflatten( v, ) -> Tuple[Callable[[T], torch.Tensor], Callable[[torch.Tensor], T]]: + """ + Returns functions to flatten and unflatten an object. Used as a helper + in :func:`chirho.robust.internals.linearize.conjugate_gradient_solve` + + :param v: some object + :raises NotImplementedError: + :return: flatten and unflatten functions + :rtype: Tuple[Callable[[T], torch.Tensor], Callable[[torch.Tensor], T]] + """ raise NotImplementedError @make_flatten_unflatten.register(torch.Tensor) def _make_flatten_unflatten_tensor(v: torch.Tensor): + """ + Returns functions to flatten and unflatten a `torch.Tensor`. + """ batch_size = v.shape[0] def flatten(v: torch.Tensor) -> torch.Tensor: @@ -46,6 +58,9 @@ def unflatten(x: torch.Tensor) -> torch.Tensor: @make_flatten_unflatten.register(dict) def _make_flatten_unflatten_dict(d: Dict[str, torch.Tensor]): + """ + Returns functions to flatten and unflatten a dictionary of `torch.Tensor`s. + """ batch_size = next(iter(d.values())).shape[0] def flatten(d: Dict[str, torch.Tensor]) -> torch.Tensor: @@ -81,6 +96,15 @@ def unflatten(x: torch.Tensor) -> Dict[str, torch.Tensor]: def make_functional_call( mod: Callable[P, T] ) -> Tuple[ParamDict, Callable[Concatenate[ParamDict, P], T]]: + """ + Converts a PyTorch module into a functional call for use with + functions in :class:`torch.func`. + + :param mod: PyTorch module + :type mod: Callable[P, T] + :return: parameter dictionary and functional call + :rtype: Tuple[ParamDict, Callable[Concatenate[ParamDict, P], T]] + """ assert isinstance(mod, torch.nn.Module) param_dict: ParamDict = dict(mod.named_parameters()) @@ -98,6 +122,16 @@ def mod_func(params: ParamDict, *args: P.args, **kwargs: P.kwargs) -> T: def guess_max_plate_nesting( model: Callable[P, Any], guide: Callable[P, Any], *args: P.args, **kwargs: P.kwargs ) -> int: + """ + Guesses the maximum plate nesting level by running `pyro.infer.Trace_ELBO` + + :param model: Python callable containing Pyro primitives. + :type model: Callable[P, Any] + :param guide: Python callable containing Pyro primitives. + :type guide: Callable[P, Any] + :return: maximum plate nesting level + :rtype: int + """ elbo = pyro.infer.Trace_ELBO() elbo._guess_max_plate_nesting(model, guide, args, kwargs) return elbo.max_plate_nesting @@ -105,6 +139,9 @@ def guess_max_plate_nesting( @contextlib.contextmanager def reset_rng_state(rng_state: T): + """ + Helper to temporarily reset the Pyro RNG state. + """ try: prev_rng_state: T = pyro.util.get_rng_state() yield pyro.util.set_rng_state(rng_state) diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py index 806f7b4d9..3745dc09e 100644 --- a/chirho/robust/ops.py +++ b/chirho/robust/ops.py @@ -21,6 +21,100 @@ def influence_fn( functional: Optional[Functional[P, S]] = None, **linearize_kwargs ) -> Callable[Concatenate[Point[T], P], S]: + """ + Returns the efficient influence function for ``functional`` + with respect to the parameters of ``guide`` and probabilistic + program ``model``. + + :param model: Python callable containing Pyro primitives. + :type model: Callable[P, Any] + :param guide: Python callable containing Pyro primitives. + Must only contain continuous latent variables. + :type guide: Callable[P, Any] + :param functional: model summary of interest, which is a function of the + model and guide. If ``None``, defaults to :class:`PredictiveFunctional`. + :type functional: Optional[Functional[P, S]], optional + :return: the efficient influence function for ``functional`` + :rtype: Callable[Concatenate[Point[T], P], S] + + **Example usage**: + + .. code-block:: python + + import pyro + import pyro.distributions as dist + import torch + + from chirho.robust.ops import influence_fn + + pyro.settings.set(module_local_params=True) + + + class SimpleModel(pyro.nn.PyroModule): + def forward(self): + a = pyro.sample("a", dist.Normal(0, 1)) + with pyro.plate("data", 3, dim=-1): + b = pyro.sample("b", dist.Normal(a, 1)) + return pyro.sample("y", dist.Normal(b, 1)) + + + class SimpleGuide(torch.nn.Module): + def __init__(self): + super().__init__() + self.loc_a = torch.nn.Parameter(torch.rand(())) + self.loc_b = torch.nn.Parameter(torch.rand((3,))) + + def forward(self): + a = pyro.sample("a", dist.Normal(self.loc_a, 1)) + with pyro.plate("data", 3, dim=-1): + b = pyro.sample("b", dist.Normal(self.loc_b, 1)) + return {"a": a, "b": b} + + + class SimpleFunctional(torch.nn.Module): + def __init__(self, model, guide, num_monte_carlo=1000): + super().__init__() + self.model = model + self.guide = guide + self.num_monte_carlo = num_monte_carlo + + def forward(self): + with pyro.plate("monte_carlo_functional", size=self.num_monte_carlo, dim=-2): + posterior_guide_samples = pyro.poutine.trace(self.guide).get_trace() + model_at_theta = pyro.poutine.replay(trace=posterior_guide_samples)( + self.model + ) + model_samples = pyro.poutine.trace(model_at_theta).get_trace() + return model_samples.nodes["b"]["value"].mean(axis=0) + + + model = SimpleModel() + guide = SimpleGuide() + predictive = pyro.infer.Predictive( + model, guide=guide, num_samples=10, return_sites=["y"] + ) + points = predictive() + influence = influence_fn( + model, + guide, + SimpleFunctional, + num_samples_outer=1000, + num_samples_inner=1000, + ) + + influence(points) + + .. note:: + + * ``functional`` must compose with ``torch.func.jvp`` + * Since the efficient influence function is approximated using Monte Carlo, the result + of this function is stochastic, i.e., evaluating this function on the same ``points`` + can result in different values. To reduce variance, increase ``num_samples_outer`` and + ``num_samples_inner`` in ``linearize_kwargs``. + * Currently, ``model`` and ``guide`` cannot contain any ``pyro.param`` statements. + This issue will be addressed in a future release: + https://github.com/BasisResearch/chirho/issues/393. + """ from chirho.robust.internals.linearize import linearize from chirho.robust.internals.predictive import PredictiveFunctional from chirho.robust.internals.utils import make_functional_call @@ -40,6 +134,15 @@ def influence_fn( @functools.wraps(target) def _fn(points: Point[T], *args: P.args, **kwargs: P.kwargs) -> S: + """ + Evaluates the efficient influence function for ``functional`` at each + point in ``points``. + + :param points: points at which to compute the efficient influence function + :type points: Point[T] + :return: efficient influence function evaluated at each point in ``points`` or averaged + :rtype: S + """ param_eif = linearized(points, *args, **kwargs) return torch.vmap( lambda d: torch.func.jvp( diff --git a/docs/source/robust.rst b/docs/source/robust.rst index 38ed8dc0e..a5a3a3493 100644 --- a/docs/source/robust.rst +++ b/docs/source/robust.rst @@ -19,6 +19,10 @@ Handlers :members: :undoc-members: +.. automodule:: chirho.robust.handlers.estimators + :members: + :undoc-members: + Internals --------- From 5d77fe0bf71c87e948bbb91ac4c087725a097597 Mon Sep 17 00:00:00 2001 From: eb8680 Date: Tue, 9 Jan 2024 18:36:03 -0500 Subject: [PATCH 55/66] Make functional argument to influence_fn required (#487) * Make functional argument required * estimator * docstring --- chirho/robust/handlers/estimators.py | 8 ++++---- chirho/robust/ops.py | 17 +++++------------ tests/robust/test_ops.py | 8 ++------ 3 files changed, 11 insertions(+), 22 deletions(-) diff --git a/chirho/robust/handlers/estimators.py b/chirho/robust/handlers/estimators.py index cab461ead..9d2d70f2d 100644 --- a/chirho/robust/handlers/estimators.py +++ b/chirho/robust/handlers/estimators.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Optional +from typing import Any, Callable from typing_extensions import Concatenate @@ -8,7 +8,7 @@ def one_step_correction( model: Callable[P, Any], guide: Callable[P, Any], - functional: Optional[Functional[P, S]] = None, + functional: Functional[P, S], **influence_kwargs, ) -> Callable[Concatenate[Point[T], P], S]: """ @@ -21,8 +21,8 @@ def one_step_correction( :param guide: Python callable containing Pyro primitives. :type guide: Callable[P, Any] :param functional: model summary of interest, which is a function of the - model and guide. If ``None``, defaults to :class:`PredictiveFunctional`. - :type functional: Optional[Functional[P, S]], optional + model and guide. + :type functional: Functional[P, S] :return: function to compute the one-step correction :rtype: Callable[Concatenate[Point[T], P], S] diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py index 3745dc09e..34b0dd0f5 100644 --- a/chirho/robust/ops.py +++ b/chirho/robust/ops.py @@ -1,5 +1,5 @@ import functools -from typing import Any, Callable, Mapping, Optional, TypeVar +from typing import Any, Callable, Mapping, TypeVar import torch from typing_extensions import Concatenate, ParamSpec @@ -18,7 +18,7 @@ def influence_fn( model: Callable[P, Any], guide: Callable[P, Any], - functional: Optional[Functional[P, S]] = None, + functional: Functional[P, S], **linearize_kwargs ) -> Callable[Concatenate[Point[T], P], S]: """ @@ -32,8 +32,8 @@ def influence_fn( Must only contain continuous latent variables. :type guide: Callable[P, Any] :param functional: model summary of interest, which is a function of the - model and guide. If ``None``, defaults to :class:`PredictiveFunctional`. - :type functional: Optional[Functional[P, S]], optional + model and guide. + :type functional: Functional[P, S] :return: the efficient influence function for ``functional`` :rtype: Callable[Concatenate[Point[T], P], S] @@ -116,17 +116,10 @@ def forward(self): https://github.com/BasisResearch/chirho/issues/393. """ from chirho.robust.internals.linearize import linearize - from chirho.robust.internals.predictive import PredictiveFunctional from chirho.robust.internals.utils import make_functional_call linearized = linearize(model, guide, **linearize_kwargs) - - if functional is None: - assert isinstance(model, torch.nn.Module) - assert isinstance(guide, torch.nn.Module) - target = PredictiveFunctional(model, guide) - else: - target = functional(model, guide) + target = functional(model, guide) # TODO check that target_params == model_params | guide_params assert isinstance(target, torch.nn.Module) diff --git a/tests/robust/test_ops.py b/tests/robust/test_ops.py index 881faa63c..3f91377c7 100644 --- a/tests/robust/test_ops.py +++ b/tests/robust/test_ops.py @@ -59,9 +59,7 @@ def test_nmc_predictive_influence_smoke( predictive_eif = influence_fn( model, guide, - functional=functools.partial( - PredictiveFunctional, num_samples=num_predictive_samples - ), + functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), max_plate_nesting=max_plate_nesting, num_samples_outer=num_samples_outer, num_samples_inner=num_samples_inner, @@ -106,9 +104,7 @@ def test_nmc_predictive_influence_vmap_smoke( predictive_eif = influence_fn( model, guide, - functional=functools.partial( - PredictiveFunctional, num_samples=num_predictive_samples - ), + functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), max_plate_nesting=max_plate_nesting, num_samples_outer=num_samples_outer, num_samples_inner=num_samples_inner, From 013d518f9b6091757fe232c891c3d69999da3bd2 Mon Sep 17 00:00:00 2001 From: eb8680 Date: Tue, 9 Jan 2024 18:56:35 -0500 Subject: [PATCH 56/66] Remove guide argument from `influence_fn` and `linearize` (#489) * Make functional argument required * estimator * docstring * Remove guide, make tests pass * rename internals.predictive to internals.nmc * expose handlers.predictive * expose handlers.predictive * docstrings * fix doc build * fix equation * docstring import --------- Co-authored-by: Sam Witty --- chirho/robust/handlers/estimators.py | 18 +-- chirho/robust/handlers/predictive.py | 140 +++++++++++++++++ chirho/robust/internals/linearize.py | 27 ++-- .../internals/{predictive.py => nmc.py} | 142 ++---------------- chirho/robust/ops.py | 25 ++- docs/source/robust.rst | 6 +- tests/robust/test_handlers.py | 9 +- tests/robust/test_internals_compositions.py | 45 +++--- tests/robust/test_internals_linearize.py | 21 ++- tests/robust/test_ops.py | 8 +- tests/robust/test_performance.py | 12 +- 11 files changed, 235 insertions(+), 218 deletions(-) create mode 100644 chirho/robust/handlers/predictive.py rename chirho/robust/internals/{predictive.py => nmc.py} (58%) diff --git a/chirho/robust/handlers/estimators.py b/chirho/robust/handlers/estimators.py index 9d2d70f2d..16e8ab227 100644 --- a/chirho/robust/handlers/estimators.py +++ b/chirho/robust/handlers/estimators.py @@ -1,13 +1,16 @@ -from typing import Any, Callable +from typing import Any, Callable, TypeVar -from typing_extensions import Concatenate +from typing_extensions import Concatenate, ParamSpec -from chirho.robust.ops import Functional, P, Point, S, T, influence_fn +from chirho.robust.ops import Functional, Point, influence_fn + +P = ParamSpec("P") +S = TypeVar("S") +T = TypeVar("T") def one_step_correction( model: Callable[P, Any], - guide: Callable[P, Any], functional: Functional[P, S], **influence_kwargs, ) -> Callable[Concatenate[Point[T], P], S]: @@ -18,10 +21,7 @@ def one_step_correction( :param model: Python callable containing Pyro primitives. :type model: Callable[P, Any] - :param guide: Python callable containing Pyro primitives. - :type guide: Callable[P, Any] - :param functional: model summary of interest, which is a function of the - model and guide. + :param functional: model summary of interest, which is a function of the model. :type functional: Functional[P, S] :return: function to compute the one-step correction :rtype: Callable[Concatenate[Point[T], P], S] @@ -33,7 +33,7 @@ def one_step_correction( """ influence_kwargs_one_step = influence_kwargs.copy() influence_kwargs_one_step["pointwise_influence"] = False - eif_fn = influence_fn(model, guide, functional, **influence_kwargs_one_step) + eif_fn = influence_fn(model, functional, **influence_kwargs_one_step) def _one_step(test_data: Point[T], *args, **kwargs) -> S: return eif_fn(test_data, *args, **kwargs) diff --git a/chirho/robust/handlers/predictive.py b/chirho/robust/handlers/predictive.py new file mode 100644 index 000000000..f73c38bfd --- /dev/null +++ b/chirho/robust/handlers/predictive.py @@ -0,0 +1,140 @@ +from typing import Any, Callable, Generic, Optional, TypeVar + +import pyro +import torch +from typing_extensions import ParamSpec + +from chirho.indexed.handlers import IndexPlatesMessenger +from chirho.robust.internals.nmc import BatchedLatents +from chirho.robust.internals.utils import bind_leftmost_dim +from chirho.robust.ops import Point + +P = ParamSpec("P") +S = TypeVar("S") +T = TypeVar("T") + + +class PredictiveModel(Generic[P, T], torch.nn.Module): + """ + Given a Pyro model and guide, constructs a new model that behaves as if + the latent ``sample`` sites in the original model (i.e. the prior) + were replaced by their counterparts in the guide (i.e. the posterior). + + .. note:: Sites that only appear in the model are annotated in traces + produced by the predictive model with ``infer={"_model_predictive_site": True}`` . + + :param model: Pyro model. + :param guide: Pyro guide. + """ + + model: Callable[P, T] + guide: Optional[Callable[P, Any]] + + def __init__( + self, + model: Callable[P, T], + guide: Optional[Callable[P, Any]] = None, + ): + super().__init__() + self.model = model + self.guide = guide + + def forward(self, *args: P.args, **kwargs: P.kwargs) -> T: + """ + Returns a sample from the posterior predictive distribution. + + :return: Sample from the posterior predictive distribution. + :rtype: T + """ + with pyro.poutine.infer_config( + config_fn=lambda msg: {"_model_predictive_site": False} + ): + with pyro.poutine.trace() as guide_tr: + if self.guide is not None: + self.guide(*args, **kwargs) + + block_guide_sample_sites = pyro.poutine.block( + hide=[ + name + for name, node in guide_tr.trace.nodes.items() + if node["type"] == "sample" + ] + ) + + with pyro.poutine.infer_config( + config_fn=lambda msg: {"_model_predictive_site": True} + ): + with block_guide_sample_sites: + with pyro.poutine.replay(trace=guide_tr.trace): + return self.model(*args, **kwargs) + + +class PredictiveFunctional(Generic[P, T], torch.nn.Module): + """ + Functional that returns a batch of samples from the predictive + distribution of a Pyro model. As with ``pyro.infer.Predictive`` , + the returned values are batched along their leftmost positional dimension. + + Similar to ``pyro.infer.Predictive(model, guide, num_samples, parallel=True)`` + when :class:`~chirho.robust.handlers.predictive.PredictiveModel` is used to construct + the ``model`` argument and infer the ``sample`` sites whose values should be returned, + and uses :class:`~BatchedLatents` to parallelize over samples from the model. + + .. warning:: ``PredictiveFunctional`` currently applies its own internal instance of + :class:`~chirho.indexed.handlers.IndexPlatesMessenger` , + so it may not behave as expected if used within another enclosing + :class:`~chirho.indexed.handlers.IndexPlatesMessenger` context. + + :param model: Pyro model. + :param num_samples: Number of samples to return. + """ + + model: Callable[P, Any] + num_samples: int + + def __init__( + self, + model: torch.nn.Module, + *, + num_samples: int = 1, + max_plate_nesting: Optional[int] = None, + name: str = "__particles_predictive", + ): + super().__init__() + self.model = model + self.num_samples = num_samples + self._first_available_dim = ( + -max_plate_nesting - 1 if max_plate_nesting is not None else None + ) + self._mc_plate_name = name + + def forward(self, *args: P.args, **kwargs: P.kwargs) -> Point[T]: + """ + Returns a batch of samples from the posterior predictive distribution. + + :return: Dictionary of samples from the posterior predictive distribution. + :rtype: Point[T] + """ + with IndexPlatesMessenger(first_available_dim=self._first_available_dim): + with pyro.poutine.trace() as model_tr: + with BatchedLatents(self.num_samples, name=self._mc_plate_name): + with pyro.poutine.infer_config( + config_fn=lambda msg: { + "_model_predictive_site": msg["infer"].get( + "_model_predictive_site", True + ) + } + ): + self.model(*args, **kwargs) + + return { + name: bind_leftmost_dim( + node["value"], + self._mc_plate_name, + event_dim=len(node["fn"].event_shape), + ) + for name, node in model_tr.trace.nodes.items() + if node["type"] == "sample" + and not pyro.poutine.util.site_is_subsample(node) + and node["infer"].get("_model_predictive_site", False) + } diff --git a/chirho/robust/internals/linearize.py b/chirho/robust/internals/linearize.py index e4fbdd115..27ce8da39 100644 --- a/chirho/robust/internals/linearize.py +++ b/chirho/robust/internals/linearize.py @@ -5,7 +5,7 @@ import torch from typing_extensions import Concatenate, ParamSpec -from chirho.robust.internals.predictive import BatchedNMCLogPredictiveLikelihood +from chirho.robust.internals.nmc import BatchedNMCLogMarginalLikelihood from chirho.robust.internals.utils import ( ParamDict, make_flatten_unflatten, @@ -220,7 +220,6 @@ def jvp_fn(log_prob_params: ParamDict) -> torch.Tensor: def linearize( model: Callable[P, Any], - guide: Callable[P, Any], *, num_samples_outer: int, num_samples_inner: Optional[int] = None, @@ -231,26 +230,23 @@ def linearize( ) -> Callable[Concatenate[Point[T], P], ParamDict]: r""" Returns the influence function associated with the parameters - of ``guide`` and probabilistic program ``model``. This function + of a normalized probabilistic program ``model``. This function computes the following quantity at an arbitrary point :math:`x^{\prime}`: .. math:: \left[-\frac{1}{N} \sum_{n=1}^N \nabla_{\phi}^2 \log \tilde{p}_{\phi}(x_n) \right] \nabla_{\phi} \log \tilde{p}_{\phi}(x^{\prime}), \quad - \tilde{p}_{\phi}(x) = \int p(x \mid \theta) q_{\phi}(\theta) d\theta, + \tilde{p}_{\phi}(x) = \int p_{\phi}(x, \theta) d\theta, where :math:`\phi` corresponds to ``log_prob_params``, - :math:`p(x \mid \theta)` denotes the ``model``, :math:`q_{\phi}` denotes the ``guide``, + :math:`p(x, \theta)` denotes the ``model``, :math:`\tilde{p}_{\phi}` denotes the predictive distribution ``log_prob`` induced - from the ``model`` and ``guide``, and :math:`\{x_n\}_{n=1}^N` are the + from the ``model``, and :math:`\{x_n\}_{n=1}^N` are the data points drawn iid from the predictive distribution. :param model: Python callable containing Pyro primitives. :type model: Callable[P, Any] - :param guide: Python callable containing Pyro primitives. - Must only contain continuous latent variables. - :type guide: Callable[P, Any] :param num_samples_outer: number of Monte Carlo samples to approximate Fisher information in :func:`make_empirical_fisher_vp` :type num_samples_outer: int @@ -276,10 +272,12 @@ def linearize( **Example usage**: .. code-block:: python + import pyro import pyro.distributions as dist import torch + from chirho.robust.handlers.predictive import PredictiveModel from chirho.robust.internals.linearize import linearize pyro.settings.set(module_local_params=True) @@ -312,8 +310,7 @@ def forward(self): ) points = predictive() influence = linearize( - model, - guide, + PredictiveModel(model, guide), num_samples_outer=1000, num_samples_inner=1000, ) @@ -327,24 +324,22 @@ def forward(self): can result in different values. To reduce variance, increase ``num_samples_outer`` and ``num_samples_inner`` in ``linearize_kwargs``. - * Currently, ``model`` and ``guide`` cannot contain any ``pyro.param`` statements. + * Currently, ``model`` cannot contain any ``pyro.param`` statements. This issue will be addressed in a future release: https://github.com/BasisResearch/chirho/issues/393. """ assert isinstance(model, torch.nn.Module) - assert isinstance(guide, torch.nn.Module) if num_samples_inner is None: num_samples_inner = num_samples_outer**2 predictive = pyro.infer.Predictive( model, - guide=guide, num_samples=num_samples_outer, parallel=True, ) - batched_log_prob = BatchedNMCLogPredictiveLikelihood( - model, guide, num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting + batched_log_prob = BatchedNMCLogMarginalLikelihood( + model, num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting ) log_prob_params, batched_func_log_prob = make_functional_call(batched_log_prob) log_prob_params_numel: int = sum(p.numel() for p in log_prob_params.values()) diff --git a/chirho/robust/internals/predictive.py b/chirho/robust/internals/nmc.py similarity index 58% rename from chirho/robust/internals/predictive.py rename to chirho/robust/internals/nmc.py index 8b9721f44..342abdcc0 100644 --- a/chirho/robust/internals/predictive.py +++ b/chirho/robust/internals/nmc.py @@ -102,138 +102,18 @@ def _pyro_observe(self, msg: dict) -> None: msg["args"] = (rv, batch_obs) -class PredictiveModel(Generic[P, T], torch.nn.Module): - """ - Given a Pyro model and guide, constructs a new model that behaves as if - the latent ``sample`` sites in the original model (i.e. the prior) - were replaced by their counterparts in the guide (i.e. the posterior). - - .. note:: Sites that only appear in the model are annotated in traces - produced by the predictive model with ``infer={"_model_predictive_site": True}`` . - - :param model: Pyro model. - :param guide: Pyro guide. - """ - - model: Callable[P, T] - guide: Callable[P, Any] - - def __init__( - self, - model: Callable[P, T], - guide: Callable[P, Any], - ): - super().__init__() - self.model = model - self.guide = guide - - def forward(self, *args: P.args, **kwargs: P.kwargs) -> T: - """ - Returns a sample from the posterior predictive distribution. - - :return: Sample from the posterior predictive distribution. - :rtype: T - """ - with pyro.poutine.trace() as guide_tr: - self.guide(*args, **kwargs) - - block_guide_sample_sites = pyro.poutine.block( - hide=[ - name - for name, node in guide_tr.trace.nodes.items() - if node["type"] == "sample" - ] - ) - - with pyro.poutine.infer_config( - config_fn=lambda msg: {"_model_predictive_site": True} - ): - with block_guide_sample_sites: - with pyro.poutine.replay(trace=guide_tr.trace): - return self.model(*args, **kwargs) - - -class PredictiveFunctional(Generic[P, T], torch.nn.Module): - """ - Functional that returns a batch of samples from the posterior predictive - distribution of a Pyro model given a guide. As with ``pyro.infer.Predictive`` , - the returned values are batched along their leftmost positional dimension. - - Similar to ``pyro.infer.Predictive(model, guide, num_samples, parallel=True)`` - but uses :class:`~PredictiveModel` to construct the predictive distribution - and infer the model ``sample`` sites whose values should be returned, - and uses :class:`~BatchedLatents` to parallelize over samples from the guide. - - .. warning:: ``PredictiveFunctional`` currently applies its own internal instance of - :class:`~chirho.indexed.handlers.IndexPlatesMessenger` , - so it may not behave as expected if used within another enclosing - :class:`~chirho.indexed.handlers.IndexPlatesMessenger` context. - - :param model: Pyro model. - :param guide: Pyro guide. - :param num_samples: Number of samples to return. - """ - - model: Callable[P, Any] - guide: Callable[P, Any] - num_samples: int - - def __init__( - self, - model: torch.nn.Module, - guide: torch.nn.Module, - *, - num_samples: int = 1, - max_plate_nesting: Optional[int] = None, - name: str = "__particles_predictive", - ): - super().__init__() - self.model = model - self.guide = guide - self.num_samples = num_samples - self._predictive_model: PredictiveModel[P, Any] = PredictiveModel(model, guide) - self._first_available_dim = ( - -max_plate_nesting - 1 if max_plate_nesting is not None else None - ) - self._mc_plate_name = name - - def forward(self, *args: P.args, **kwargs: P.kwargs) -> Point[T]: - """ - Returns a batch of samples from the posterior predictive distribution. - - :return: Dictionary of samples from the posterior predictive distribution. - :rtype: Point[T] - """ - with IndexPlatesMessenger(first_available_dim=self._first_available_dim): - with pyro.poutine.trace() as model_tr: - with BatchedLatents(self.num_samples, name=self._mc_plate_name): - self._predictive_model(*args, **kwargs) - - return { - name: bind_leftmost_dim( - node["value"], - self._mc_plate_name, - event_dim=len(node["fn"].event_shape), - ) - for name, node in model_tr.trace.nodes.items() - if node["type"] == "sample" - and not pyro.poutine.util.site_is_subsample(node) - and node["infer"].get("_model_predictive_site", False) - } - - -class BatchedNMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module): +class BatchedNMCLogMarginalLikelihood(Generic[P, T], torch.nn.Module): r""" - Approximates the log predictive likelihood induced by ``model`` and ``guide`` - using Monte Carlo sampling at an arbitrary batch of :math:`N` + Approximates the log marginal likelihood induced by ``model`` and ``guide`` + using importance sampling at an arbitrary batch of :math:`N` points :math:`\{x_n\}_{n=1}^N`. .. math:: - \log \left(\frac{1}{M} \sum_{m=1}^M p(x_n \mid \theta_m)\right), + \log \left(\frac{1}{M} \sum_{m=1}^M \frac{p(x_n \mid \theta_m) p(\theta_m) )}{q_{\phi}(\theta_m)} \right), \quad \theta_m \sim q_{\phi}(\theta), - where :math:`q_{\phi}(\theta)` is the guide and :math:`p(x_n \mid \theta_m)` - is the model conditioned on the latents from the guide. + where :math:`q_{\phi}(\theta)` is the guide, and :math:`p(x_n \mid \theta_m) p(\theta_m)` + is the model joint density of the data and the latents sampled from the guide. :param model: Python callable containing Pyro primitives. :type model: torch.nn.Module @@ -241,17 +121,17 @@ class BatchedNMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module): Must only contain continuous latent variables. :type guide: torch.nn.Module :param num_samples: Number of Monte Carlo draws :math:`M` - used to approximate predictive distribution, defaults to 1 + used to approximate marginal distribution, defaults to 1 :type num_samples: int, optional """ model: Callable[P, Any] - guide: Callable[P, Any] + guide: Optional[Callable[P, Any]] num_samples: int def __init__( self, model: torch.nn.Module, - guide: torch.nn.Module, + guide: Optional[torch.nn.Module] = None, *, num_samples: int = 1, max_plate_nesting: Optional[int] = None, @@ -276,10 +156,10 @@ def forward( :param data: Dictionary of observations. :type data: Point[T] - :return: Log predictive likelihood at each datapoint. + :return: Log marginal likelihood at each datapoint. :rtype: torch.Tensor """ - get_nmc_traces = get_importance_traces(PredictiveModel(self.model, self.guide)) + get_nmc_traces = get_importance_traces(self.model, self.guide) with IndexPlatesMessenger(first_available_dim=self._first_available_dim): with BatchedLatents(self.num_samples, name=self._mc_plate_name): diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py index 34b0dd0f5..b02bfa47e 100644 --- a/chirho/robust/ops.py +++ b/chirho/robust/ops.py @@ -12,27 +12,19 @@ T = TypeVar("T") Point = Mapping[str, Observation[T]] -Functional = Callable[[Callable[P, Any], Callable[P, Any]], Callable[P, S]] +Functional = Callable[[Callable[P, Any]], Callable[P, S]] def influence_fn( - model: Callable[P, Any], - guide: Callable[P, Any], - functional: Functional[P, S], - **linearize_kwargs + model: Callable[P, Any], functional: Functional[P, S], **linearize_kwargs ) -> Callable[Concatenate[Point[T], P], S]: """ Returns the efficient influence function for ``functional`` - with respect to the parameters of ``guide`` and probabilistic - program ``model``. + with respect to the parameters of probabilistic program ``model``. :param model: Python callable containing Pyro primitives. :type model: Callable[P, Any] - :param guide: Python callable containing Pyro primitives. - Must only contain continuous latent variables. - :type guide: Callable[P, Any] - :param functional: model summary of interest, which is a function of the - model and guide. + :param functional: model summary of interest, which is a function of ``model`` :type functional: Functional[P, S] :return: the efficient influence function for ``functional`` :rtype: Callable[Concatenate[Point[T], P], S] @@ -45,6 +37,7 @@ def influence_fn( import pyro.distributions as dist import torch + from chirho.robust.handlers.predictive import PredictiveModel from chirho.robust.ops import influence_fn pyro.settings.set(module_local_params=True) @@ -111,17 +104,17 @@ def forward(self): of this function is stochastic, i.e., evaluating this function on the same ``points`` can result in different values. To reduce variance, increase ``num_samples_outer`` and ``num_samples_inner`` in ``linearize_kwargs``. - * Currently, ``model`` and ``guide`` cannot contain any ``pyro.param`` statements. + * Currently, ``model`` cannot contain any ``pyro.param`` statements. This issue will be addressed in a future release: https://github.com/BasisResearch/chirho/issues/393. """ from chirho.robust.internals.linearize import linearize from chirho.robust.internals.utils import make_functional_call - linearized = linearize(model, guide, **linearize_kwargs) - target = functional(model, guide) + linearized = linearize(model, **linearize_kwargs) + target = functional(model) - # TODO check that target_params == model_params | guide_params + # TODO check that target_params == model_params assert isinstance(target, torch.nn.Module) target_params, func_target = make_functional_call(target) diff --git a/docs/source/robust.rst b/docs/source/robust.rst index a5a3a3493..172cfaf28 100644 --- a/docs/source/robust.rst +++ b/docs/source/robust.rst @@ -23,6 +23,10 @@ Handlers :members: :undoc-members: +.. automodule:: chirho.robust.handlers.predictive + :members: + :undoc-members: + Internals --------- @@ -34,7 +38,7 @@ Internals :members: :undoc-members: -.. automodule:: chirho.robust.internals.predictive +.. automodule:: chirho.robust.internals.nmc :members: :undoc-members: diff --git a/tests/robust/test_handlers.py b/tests/robust/test_handlers.py index fc36d0b49..f1849959a 100644 --- a/tests/robust/test_handlers.py +++ b/tests/robust/test_handlers.py @@ -7,7 +7,7 @@ from typing_extensions import ParamSpec from chirho.robust.handlers.estimators import one_step_correction -from chirho.robust.internals.predictive import PredictiveFunctional +from chirho.robust.handlers.predictive import PredictiveFunctional, PredictiveModel from .robust_fixtures import SimpleGuide, SimpleModel @@ -57,11 +57,8 @@ def test_one_step_correction_smoke( model(), guide() # initialize one_step = one_step_correction( - model, - guide, - functional=functools.partial( - PredictiveFunctional, num_samples=num_predictive_samples - ), + PredictiveModel(model, guide), + functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), max_plate_nesting=max_plate_nesting, num_samples_outer=num_samples_outer, num_samples_inner=num_samples_inner, diff --git a/tests/robust/test_internals_compositions.py b/tests/robust/test_internals_compositions.py index b9924fab5..6f9dfde8d 100644 --- a/tests/robust/test_internals_compositions.py +++ b/tests/robust/test_internals_compositions.py @@ -7,13 +7,14 @@ from chirho.indexed.handlers import IndexPlatesMessenger from chirho.indexed.ops import indices_of +from chirho.robust.handlers.predictive import PredictiveModel from chirho.robust.internals.linearize import ( conjugate_gradient_solve, make_empirical_fisher_vp, ) -from chirho.robust.internals.predictive import ( +from chirho.robust.internals.nmc import ( BatchedLatents, - BatchedNMCLogPredictiveLikelihood, + BatchedNMCLogMarginalLikelihood, BatchedObservations, ) from chirho.robust.internals.utils import make_functional_call, reset_rng_state @@ -27,7 +28,9 @@ def test_empirical_fisher_vp_nmclikelihood_cg_composition(): model = SimpleModel() guide = SimpleGuide() model(), guide() # initialize - log_prob = BatchedNMCLogPredictiveLikelihood(model, guide, num_samples=100) + log_prob = BatchedNMCLogMarginalLikelihood( + PredictiveModel(model, guide), num_samples=100 + ) log_prob_params, func_log_prob = make_functional_call(log_prob) func_log_prob = reset_rng_state(pyro.util.get_rng_state())(func_log_prob) @@ -47,38 +50,43 @@ def test_empirical_fisher_vp_nmclikelihood_cg_composition(): v = { k: torch.ones_like(v).unsqueeze(0) - if k != "guide.loc_a" + if k != "model.guide.loc_a" else torch.zeros_like(v).unsqueeze(0) for k, v in log_prob_params.items() } # For this model, fvp for loc_a is zero. See # https://github.com/BasisResearch/chirho/issues/427 - assert fvp(v)["guide.loc_a"].abs().max() == 0 + assert fvp(v)["model.guide.loc_a"].abs().max() == 0 assert all(fvp_vk.shape == v[k].shape for k, fvp_vk in fvp(v).items()) solve_one = cg_solver(fvp, v) solve_two = cg_solver(fvp, v) - if solve_one["guide.loc_a"].abs().max() > 1e6: + if solve_one["model.guide.loc_a"].abs().max() > 1e6: warnings.warn( "solve_one['guide.loc_a'] is large (max entry={}).".format( - solve_one["guide.loc_a"].abs().max() + solve_one["model.guide.loc_a"].abs().max() ) ) - if solve_one["guide.loc_b"].abs().max() > 1e6: + if solve_one["model.guide.loc_b"].abs().max() > 1e6: warnings.warn( "solve_one['guide.loc_b'] is large (max entry={}).".format( - solve_one["guide.loc_b"].abs().max() + solve_one["model.guide.loc_b"].abs().max() ) ) assert torch.allclose( - solve_one["guide.loc_a"], torch.zeros_like(log_prob_params["guide.loc_a"]) + solve_one["model.guide.loc_a"], + torch.zeros_like(log_prob_params["model.guide.loc_a"]), + ) + assert torch.allclose( + solve_one["model.guide.loc_a"], solve_two["model.guide.loc_a"] + ) + assert torch.allclose( + solve_one["model.guide.loc_b"], solve_two["model.guide.loc_b"] ) - assert torch.allclose(solve_one["guide.loc_a"], solve_two["guide.loc_a"]) - assert torch.allclose(solve_one["guide.loc_b"], solve_two["guide.loc_b"]) link_functions = [ @@ -95,8 +103,8 @@ def test_nmc_likelihood_seeded(link_fn): guide = SimpleGuide() model(), guide() # initialize - log_prob = BatchedNMCLogPredictiveLikelihood( - model, guide, num_samples=3, max_plate_nesting=3 + log_prob = BatchedNMCLogMarginalLikelihood( + PredictiveModel(model, guide), num_samples=3, max_plate_nesting=3 ) log_prob_params, func_log_prob = make_functional_call(log_prob) @@ -115,11 +123,14 @@ def test_nmc_likelihood_seeded(link_fn): v = {k: torch.ones_like(v) for k, v in log_prob_params.items()} - assert (fvp(v)["guide.loc_a"].abs().max() + fvp(v)["guide.loc_b"].abs().max()) > 0 + assert ( + fvp(v)["model.guide.loc_a"].abs().max() + + fvp(v)["model.guide.loc_b"].abs().max() + ) > 0 # Check if fvp agrees across multiple calls of same `fvp` object - assert torch.allclose(fvp(v)["guide.loc_a"], fvp(v)["guide.loc_a"]) - assert torch.allclose(fvp(v)["guide.loc_b"], fvp(v)["guide.loc_b"]) + assert torch.allclose(fvp(v)["model.guide.loc_a"], fvp(v)["model.guide.loc_a"]) + assert torch.allclose(fvp(v)["model.guide.loc_b"], fvp(v)["model.guide.loc_b"]) @pytest.mark.parametrize("pad_dim", [0, 1, 2]) diff --git a/tests/robust/test_internals_linearize.py b/tests/robust/test_internals_linearize.py index a8a80a536..435632789 100644 --- a/tests/robust/test_internals_linearize.py +++ b/tests/robust/test_internals_linearize.py @@ -8,6 +8,7 @@ from pyro.infer.predictive import Predictive from typing_extensions import ParamSpec +from chirho.robust.handlers.predictive import PredictiveModel from chirho.robust.internals.linearize import ( conjugate_gradient_solve, linearize, @@ -96,8 +97,7 @@ def test_nmc_param_influence_smoke( model(), guide() # initialize param_eif = linearize( - model, - guide, + PredictiveModel(model, guide), max_plate_nesting=max_plate_nesting, num_samples_outer=num_samples_outer, num_samples_inner=num_samples_inner, @@ -117,7 +117,7 @@ def test_nmc_param_influence_smoke( for k, v in test_datum_eif.items(): assert not torch.isnan(v).any(), f"eif for {k} had nans" assert not torch.isinf(v).any(), f"eif for {k} had infs" - if k != "guide.loc_a": + if not k.endswith("guide.loc_a"): assert not torch.isclose( v, torch.zeros_like(v) ).all(), f"eif for {k} was zero" @@ -145,8 +145,7 @@ def test_nmc_param_influence_vmap_smoke( model(), guide() # initialize param_eif = linearize( - model, - guide, + PredictiveModel(model, guide), max_plate_nesting=max_plate_nesting, num_samples_outer=num_samples_outer, num_samples_inner=num_samples_inner, @@ -163,7 +162,7 @@ def test_nmc_param_influence_vmap_smoke( for k, v in test_data_eif.items(): assert not torch.isnan(v).any(), f"eif for {k} had nans" assert not torch.isinf(v).any(), f"eif for {k} had infs" - if k != "guide.loc_a": + if not k.endswith("guide.loc_a"): assert not torch.isclose( v, torch.zeros_like(v) ).all(), f"eif for {k} was zero" @@ -326,8 +325,7 @@ def link(mu): mle_guide = MLEGuide(theta_hat) param_eif = linearize( - model, - mle_guide, + PredictiveModel(model, mle_guide), num_samples_outer=10000, num_samples_inner=1, cg_iters=4, # dimension of params = 4 @@ -336,7 +334,7 @@ def link(mu): test_data_eif = param_eif(D_test) median_abs_error = torch.abs( - test_data_eif["guide.treatment_weight_param"] - analytic_eif_at_test_pts + test_data_eif["model.guide.treatment_weight_param"] - analytic_eif_at_test_pts ).median() median_scale = torch.abs(analytic_eif_at_test_pts).median() if median_scale > 1: @@ -346,8 +344,7 @@ def link(mu): # Test w/ pointwise_influence=False param_eif = linearize( - model, - mle_guide, + PredictiveModel(model, mle_guide), num_samples_outer=10000, num_samples_inner=1, cg_iters=4, # dimension of params = 4 @@ -356,7 +353,7 @@ def link(mu): test_data_eif = param_eif(D_test) assert torch.allclose( - test_data_eif["guide.treatment_weight_param"][0], + test_data_eif["model.guide.treatment_weight_param"][0], analytic_eif_at_test_pts.mean(), atol=0.5, ) diff --git a/tests/robust/test_ops.py b/tests/robust/test_ops.py index 3f91377c7..1bdb2461b 100644 --- a/tests/robust/test_ops.py +++ b/tests/robust/test_ops.py @@ -6,7 +6,7 @@ import torch from typing_extensions import ParamSpec -from chirho.robust.internals.predictive import PredictiveFunctional +from chirho.robust.handlers.predictive import PredictiveFunctional, PredictiveModel from chirho.robust.ops import influence_fn from .robust_fixtures import SimpleGuide, SimpleModel @@ -57,8 +57,7 @@ def test_nmc_predictive_influence_smoke( model(), guide() # initialize predictive_eif = influence_fn( - model, - guide, + PredictiveModel(model, guide), functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), max_plate_nesting=max_plate_nesting, num_samples_outer=num_samples_outer, @@ -102,8 +101,7 @@ def test_nmc_predictive_influence_vmap_smoke( model(), guide() # initialize predictive_eif = influence_fn( - model, - guide, + PredictiveModel(model, guide), functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), max_plate_nesting=max_plate_nesting, num_samples_outer=num_samples_outer, diff --git a/tests/robust/test_performance.py b/tests/robust/test_performance.py index b1ec08f29..34d5e4d02 100644 --- a/tests/robust/test_performance.py +++ b/tests/robust/test_performance.py @@ -11,8 +11,9 @@ from chirho.indexed.handlers import DependentMaskMessenger from chirho.observational.handlers import condition +from chirho.robust.handlers.predictive import PredictiveModel from chirho.robust.internals.linearize import make_empirical_fisher_vp -from chirho.robust.internals.predictive import BatchedNMCLogPredictiveLikelihood +from chirho.robust.internals.nmc import BatchedNMCLogMarginalLikelihood from chirho.robust.internals.utils import guess_max_plate_nesting, make_functional_call from chirho.robust.ops import Point @@ -149,7 +150,7 @@ def test_empirical_fisher_vp_performance_with_likelihood(model_guide): ) log2_prob_params, func2_log_prob = make_functional_call( - BatchedNMCLogPredictiveLikelihood(model, guide) + BatchedNMCLogMarginalLikelihood(PredictiveModel(model, guide)) ) fisher_hessian_vmapped = make_empirical_fisher_vp( @@ -160,19 +161,20 @@ def test_empirical_fisher_vp_performance_with_likelihood(model_guide): func2_log_prob, log2_prob_params, data ) - v = { + v1 = { k: torch.ones_like(v) if k != "guide.loc_a" else torch.zeros_like(v) for k, v in log1_prob_params.items() } + v2 = {f"model.{k}": v for k, v in v1.items()} func2_log_prob(log2_prob_params, data) start_time = time.time() - fisher_hessian_vmapped(v) + fisher_hessian_vmapped(v1) end_time = time.time() print("Hessian vmapped time (s): ", end_time - start_time) start_time = time.time() - fisher_hessian_batched(v) + fisher_hessian_batched(v2) end_time = time.time() print("Hessian manual batched time (s): ", end_time - start_time) From c4346c800400ad064744e3cf1035bcc07c01f980 Mon Sep 17 00:00:00 2001 From: eb8680 Date: Thu, 11 Jan 2024 14:19:05 -0500 Subject: [PATCH 57/66] Make influence_fn a higher-order Functional (#492) * make influence a functional * fix test * multiple arguments * doc * docstring * docstring --- chirho/robust/handlers/estimators.py | 29 +++----- chirho/robust/internals/linearize.py | 8 ++- chirho/robust/ops.py | 98 ++++++++++++++++------------ tests/robust/test_handlers.py | 20 +++--- tests/robust/test_ops.py | 36 +++++----- 5 files changed, 100 insertions(+), 91 deletions(-) diff --git a/chirho/robust/handlers/estimators.py b/chirho/robust/handlers/estimators.py index 16e8ab227..4f60ddcd6 100644 --- a/chirho/robust/handlers/estimators.py +++ b/chirho/robust/handlers/estimators.py @@ -1,6 +1,6 @@ -from typing import Any, Callable, TypeVar +from typing import TypeVar -from typing_extensions import Concatenate, ParamSpec +from typing_extensions import ParamSpec from chirho.robust.ops import Functional, Point, influence_fn @@ -10,21 +10,17 @@ def one_step_correction( - model: Callable[P, Any], functional: Functional[P, S], + *test_points: Point[T], **influence_kwargs, -) -> Callable[Concatenate[Point[T], P], S]: +) -> Functional[P, S]: """ - Returns a function that computes the one-step correction for the - functional at a specified set of test points as discussed in - [1]. + Returns a functional that computes the one-step correction for the + functional at a specified set of test points as discussed in [1]. - :param model: Python callable containing Pyro primitives. - :type model: Callable[P, Any] - :param functional: model summary of interest, which is a function of the model. - :type functional: Functional[P, S] - :return: function to compute the one-step correction - :rtype: Callable[Concatenate[Point[T], P], S] + :param functional: model summary functional of interest + :param test_points: points at which to compute the one-step correction + :return: functional to compute the one-step correction **References** @@ -33,9 +29,4 @@ def one_step_correction( """ influence_kwargs_one_step = influence_kwargs.copy() influence_kwargs_one_step["pointwise_influence"] = False - eif_fn = influence_fn(model, functional, **influence_kwargs_one_step) - - def _one_step(test_data: Point[T], *args, **kwargs) -> S: - return eif_fn(test_data, *args, **kwargs) - - return _one_step + return influence_fn(functional, *test_points, **influence_kwargs_one_step) diff --git a/chirho/robust/internals/linearize.py b/chirho/robust/internals/linearize.py index 27ce8da39..29447c736 100644 --- a/chirho/robust/internals/linearize.py +++ b/chirho/robust/internals/linearize.py @@ -219,8 +219,7 @@ def jvp_fn(log_prob_params: ParamDict) -> torch.Tensor: def linearize( - model: Callable[P, Any], - *, + *models: Callable[P, Any], num_samples_outer: int, num_samples_inner: Optional[int] = None, max_plate_nesting: Optional[int] = None, @@ -328,6 +327,11 @@ def forward(self): This issue will be addressed in a future release: https://github.com/BasisResearch/chirho/issues/393. """ + if len(models) > 1: + raise NotImplementedError("Only unary version of linearize is implemented.") + else: + (model,) = models + assert isinstance(model, torch.nn.Module) if num_samples_inner is None: num_samples_inner = num_samples_outer**2 diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py index b02bfa47e..86ed8f89d 100644 --- a/chirho/robust/ops.py +++ b/chirho/robust/ops.py @@ -1,33 +1,35 @@ -import functools -from typing import Any, Callable, Mapping, TypeVar +from typing import Any, Callable, Mapping, Protocol, TypeVar import torch -from typing_extensions import Concatenate, ParamSpec +from typing_extensions import ParamSpec from chirho.observational.ops import Observation P = ParamSpec("P") Q = ParamSpec("Q") -S = TypeVar("S") +S = TypeVar("S", covariant=True) T = TypeVar("T") Point = Mapping[str, Observation[T]] -Functional = Callable[[Callable[P, Any]], Callable[P, S]] + + +class Functional(Protocol[P, S]): + def __call__( + self, __model: Callable[P, Any], *models: Callable[P, Any] + ) -> Callable[P, S]: + ... def influence_fn( - model: Callable[P, Any], functional: Functional[P, S], **linearize_kwargs -) -> Callable[Concatenate[Point[T], P], S]: + functional: Functional[P, S], *points: Point[T], **linearize_kwargs +) -> Functional[P, S]: """ - Returns the efficient influence function for ``functional`` - with respect to the parameters of probabilistic program ``model``. + Returns a new functional that computes the efficient influence function for ``functional`` + at the given ``points`` with respect to the parameters of its probabilistic program arguments. - :param model: Python callable containing Pyro primitives. - :type model: Callable[P, Any] :param functional: model summary of interest, which is a function of ``model`` - :type functional: Functional[P, S] - :return: the efficient influence function for ``functional`` - :rtype: Callable[Concatenate[Point[T], P], S] + :param points: points for each input to ``functional`` at which to compute the efficient influence function + :return: functional that computes the efficient influence function for ``functional`` at ``points`` **Example usage**: @@ -88,14 +90,13 @@ def forward(self): ) points = predictive() influence = influence_fn( - model, - guide, SimpleFunctional, + points, num_samples_outer=1000, num_samples_inner=1000, - ) + )(PredictiveModel(model, guide)) - influence(points) + influence() .. note:: @@ -111,31 +112,44 @@ def forward(self): from chirho.robust.internals.linearize import linearize from chirho.robust.internals.utils import make_functional_call - linearized = linearize(model, **linearize_kwargs) - target = functional(model) - - # TODO check that target_params == model_params - assert isinstance(target, torch.nn.Module) - target_params, func_target = make_functional_call(target) + if len(points) != 1: + raise NotImplementedError( + "influence_fn currently only supports unary functionals" + ) - @functools.wraps(target) - def _fn(points: Point[T], *args: P.args, **kwargs: P.kwargs) -> S: + def _influence_functional(*models: Callable[P, Any]) -> Callable[P, S]: """ - Evaluates the efficient influence function for ``functional`` at each - point in ``points``. + Functional representing the efficient influence function of ``functional`` at ``points`` . - :param points: points at which to compute the efficient influence function - :type points: Point[T] - :return: efficient influence function evaluated at each point in ``points`` or averaged - :rtype: S + :param models: Python callables containing Pyro primitives. + :return: efficient influence function for ``functional`` evaluated at ``model`` and ``points`` """ - param_eif = linearized(points, *args, **kwargs) - return torch.vmap( - lambda d: torch.func.jvp( - lambda p: func_target(p, *args, **kwargs), (target_params,), (d,) - )[1], - in_dims=0, - randomness="different", - )(param_eif) - - return _fn + if len(models) != len(points): + raise ValueError("mismatch between number of models and points") + + linearized = linearize(*models, **linearize_kwargs) + target = functional(*models) + + # TODO check that target_params == model_params + assert isinstance(target, torch.nn.Module) + target_params, func_target = make_functional_call(target) + + def _fn(*args: P.args, **kwargs: P.kwargs) -> S: + """ + Evaluates the efficient influence function for ``functional`` at each + point in ``points``. + + :return: efficient influence function evaluated at each point in ``points`` or averaged + """ + param_eif = linearized(*points, *args, **kwargs) + return torch.vmap( + lambda d: torch.func.jvp( + lambda p: func_target(p, *args, **kwargs), (target_params,), (d,) + )[1], + in_dims=0, + randomness="different", + )(param_eif) + + return _fn + + return _influence_functional diff --git a/tests/robust/test_handlers.py b/tests/robust/test_handlers.py index f1849959a..6168d563a 100644 --- a/tests/robust/test_handlers.py +++ b/tests/robust/test_handlers.py @@ -56,15 +56,6 @@ def test_one_step_correction_smoke( guide = guide(model) model(), guide() # initialize - one_step = one_step_correction( - PredictiveModel(model, guide), - functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), - max_plate_nesting=max_plate_nesting, - num_samples_outer=num_samples_outer, - num_samples_inner=num_samples_inner, - cg_iters=cg_iters, - ) - with torch.no_grad(): test_datum = { k: v[0] @@ -73,7 +64,16 @@ def test_one_step_correction_smoke( )().items() } - one_step_on_test: Mapping[str, torch.Tensor] = one_step(test_datum) + one_step = one_step_correction( + functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), + test_datum, + max_plate_nesting=max_plate_nesting, + num_samples_outer=num_samples_outer, + num_samples_inner=num_samples_inner, + cg_iters=cg_iters, + )(PredictiveModel(model, guide)) + + one_step_on_test: Mapping[str, torch.Tensor] = one_step() assert len(one_step_on_test) > 0 for k, v in one_step_on_test.items(): assert not torch.isnan(v).any(), f"one_step for {k} had nans" diff --git a/tests/robust/test_ops.py b/tests/robust/test_ops.py index 1bdb2461b..e3d5e5290 100644 --- a/tests/robust/test_ops.py +++ b/tests/robust/test_ops.py @@ -56,15 +56,6 @@ def test_nmc_predictive_influence_smoke( guide = guide(model) model(), guide() # initialize - predictive_eif = influence_fn( - PredictiveModel(model, guide), - functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), - max_plate_nesting=max_plate_nesting, - num_samples_outer=num_samples_outer, - num_samples_inner=num_samples_inner, - cg_iters=cg_iters, - ) - with torch.no_grad(): test_datum = { k: v[0] @@ -73,7 +64,16 @@ def test_nmc_predictive_influence_smoke( )().items() } - test_datum_eif: Mapping[str, torch.Tensor] = predictive_eif(test_datum) + predictive_eif = influence_fn( + functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), + test_datum, + max_plate_nesting=max_plate_nesting, + num_samples_outer=num_samples_outer, + num_samples_inner=num_samples_inner, + cg_iters=cg_iters, + )(PredictiveModel(model, guide)) + + test_datum_eif: Mapping[str, torch.Tensor] = predictive_eif() assert len(test_datum_eif) > 0 for k, v in test_datum_eif.items(): assert not torch.isnan(v).any(), f"eif for {k} had nans" @@ -100,21 +100,21 @@ def test_nmc_predictive_influence_vmap_smoke( model(), guide() # initialize + with torch.no_grad(): + test_data = pyro.infer.Predictive( + model, num_samples=4, return_sites=obs_names, parallel=True + )() + predictive_eif = influence_fn( - PredictiveModel(model, guide), functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), + test_data, max_plate_nesting=max_plate_nesting, num_samples_outer=num_samples_outer, num_samples_inner=num_samples_inner, cg_iters=cg_iters, - ) - - with torch.no_grad(): - test_data = pyro.infer.Predictive( - model, num_samples=4, return_sites=obs_names, parallel=True - )() + )(PredictiveModel(model, guide)) - test_data_eif: Mapping[str, torch.Tensor] = predictive_eif(test_data) + test_data_eif: Mapping[str, torch.Tensor] = predictive_eif() assert len(test_data_eif) > 0 for k, v in test_data_eif.items(): assert not torch.isnan(v).any(), f"eif for {k} had nans" From 9207e3eb3589d6d52e9b916c2920fcdf8ab32bb6 Mon Sep 17 00:00:00 2001 From: Sam Witty Date: Fri, 12 Jan 2024 09:56:04 -0500 Subject: [PATCH 58/66] Add full corrected one step estimator (#476) * added scaffolding to one step estimator * kept signature the same as one_step_correction * lint * refactored test to include multiple estimators * typo * revise error * added dict handling * remove assert * more informative error message * replace dispatch with pytree flatten and unflatten * revert arg for influence_function_estimator * docs and lint * lingering influence_fn * fixed missing return * rename * lint * add *model to appease the linter --- chirho/robust/handlers/estimators.py | 29 +++++++++++++++++++++++++--- tests/robust/test_handlers.py | 20 ++++++++++--------- 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/chirho/robust/handlers/estimators.py b/chirho/robust/handlers/estimators.py index 4f60ddcd6..eb6e8d6ee 100644 --- a/chirho/robust/handlers/estimators.py +++ b/chirho/robust/handlers/estimators.py @@ -1,5 +1,6 @@ -from typing import TypeVar +from typing import Any, Callable, TypeVar +import torch from typing_extensions import ParamSpec from chirho.robust.ops import Functional, Point, influence_fn @@ -9,7 +10,7 @@ T = TypeVar("T") -def one_step_correction( +def one_step_corrected_estimator( functional: Functional[P, S], *test_points: Point[T], **influence_kwargs, @@ -29,4 +30,26 @@ def one_step_correction( """ influence_kwargs_one_step = influence_kwargs.copy() influence_kwargs_one_step["pointwise_influence"] = False - return influence_fn(functional, *test_points, **influence_kwargs_one_step) + eif_fn = influence_fn(functional, *test_points, **influence_kwargs_one_step) + + def _corrected_functional(*model: Callable[P, Any]) -> Callable[P, S]: + plug_in_estimator = functional(*model) + correction_estimator = eif_fn(*model) + + def _estimator(*args, **kwargs) -> S: + plug_in_estimate = plug_in_estimator(*args, **kwargs) + correction = correction_estimator(*args, **kwargs) + + flat_plug_in_estimate, treespec = torch.utils._pytree.tree_flatten( + plug_in_estimate + ) + flat_correction, _ = torch.utils._pytree.tree_flatten(correction) + + return torch.utils._pytree.tree_unflatten( + [a + b for a, b in zip(flat_plug_in_estimate, flat_correction)], + treespec, + ) + + return _estimator + + return _corrected_functional diff --git a/tests/robust/test_handlers.py b/tests/robust/test_handlers.py index 6168d563a..e43015282 100644 --- a/tests/robust/test_handlers.py +++ b/tests/robust/test_handlers.py @@ -6,7 +6,7 @@ import torch from typing_extensions import ParamSpec -from chirho.robust.handlers.estimators import one_step_correction +from chirho.robust.handlers.estimators import one_step_corrected_estimator from chirho.robust.handlers.predictive import PredictiveFunctional, PredictiveModel from .robust_fixtures import SimpleGuide, SimpleModel @@ -42,7 +42,8 @@ @pytest.mark.parametrize("num_samples_outer,num_samples_inner", [(10, None), (10, 100)]) @pytest.mark.parametrize("cg_iters", [None, 1, 10]) @pytest.mark.parametrize("num_predictive_samples", [1, 5]) -def test_one_step_correction_smoke( +@pytest.mark.parametrize("estimation_method", [one_step_corrected_estimator]) +def test_estimator_smoke( model, guide, obs_names, @@ -51,6 +52,7 @@ def test_one_step_correction_smoke( num_samples_inner, cg_iters, num_predictive_samples, + estimation_method, ): model = model() guide = guide(model) @@ -64,7 +66,7 @@ def test_one_step_correction_smoke( )().items() } - one_step = one_step_correction( + estimator = estimation_method( functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), test_datum, max_plate_nesting=max_plate_nesting, @@ -73,11 +75,11 @@ def test_one_step_correction_smoke( cg_iters=cg_iters, )(PredictiveModel(model, guide)) - one_step_on_test: Mapping[str, torch.Tensor] = one_step() - assert len(one_step_on_test) > 0 - for k, v in one_step_on_test.items(): - assert not torch.isnan(v).any(), f"one_step for {k} had nans" - assert not torch.isinf(v).any(), f"one_step for {k} had infs" + estimate_on_test: Mapping[str, torch.Tensor] = estimator() + assert len(estimate_on_test) > 0 + for k, v in estimate_on_test.items(): + assert not torch.isnan(v).any(), f"{estimation_method} for {k} had nans" + assert not torch.isinf(v).any(), f"{estimation_method} for {k} had infs" assert not torch.isclose( v, torch.zeros_like(v) - ).all(), f"one_step for {k} was zero" + ).all(), f"{estimation_method} estimator for {k} was zero" From a7875c6ae049c37948c17b1d5059add492c1892d Mon Sep 17 00:00:00 2001 From: Andy Zane Date: Fri, 12 Jan 2024 15:09:29 -0500 Subject: [PATCH 59/66] add abstractions and simple temp scratch to test with squared unit normal functional with perturbation. --- chirho/robust/handlers/fd_model.py | 49 +++ docs/source/fd_scratch.ipynb | 384 ++++++++++++++++++ docs/source/robust_fd/__init__.py | 0 .../robust_fd/squared_normal_density.py | 52 +++ docs/source/robust_fd_scratch.py | 29 ++ 5 files changed, 514 insertions(+) create mode 100644 chirho/robust/handlers/fd_model.py create mode 100644 docs/source/fd_scratch.ipynb create mode 100644 docs/source/robust_fd/__init__.py create mode 100644 docs/source/robust_fd/squared_normal_density.py create mode 100644 docs/source/robust_fd_scratch.py diff --git a/chirho/robust/handlers/fd_model.py b/chirho/robust/handlers/fd_model.py new file mode 100644 index 000000000..df067f857 --- /dev/null +++ b/chirho/robust/handlers/fd_model.py @@ -0,0 +1,49 @@ +import torch +import pyro +import pyro.distributions as dist +from typing import Dict + + +class ModelWithMarginalDensity(torch.nn.Module): + def density(self, *args, **kwargs): + raise NotImplementedError() + + def forward(self, *args, **kwargs): + raise NotImplementedError() + + +class FDModel(ModelWithMarginalDensity): + + model: ModelWithMarginalDensity + kernel: ModelWithMarginalDensity + + def __init__(self, eps=0.): + super().__init__() + self.eps = eps + self.weights = torch.tensor([1. - eps, eps]) + + def density(self, model_kwargs: Dict, kernel_kwargs: Dict): + mpart = self.weights[0] * self.model.density(**model_kwargs) + kpart = self.weights[1] * self.kernel.density(**kernel_kwargs) + return mpart + kpart + + def forward(self, model_kwargs: Dict, kernel_kwargs: Dict): + _from_kernel = pyro.sample('_mixture_assignment', dist.Categorical(self.weights)) + + if _from_kernel: + return self.kernel_forward(**kernel_kwargs) + else: + return self.model_forward(**model_kwargs) + + def functional(self, *args, **kwargs): + """ + The functional target for this model. This is tightly coupled to a particular + pyro model because finite differencing operates in the space of densities, and + automatically exploit any structure of the pyro model the functional + is being evaluated with respect to. As such, the functional must be implemented + with the specific structure of coupled pyro model in mind. + :param args: + :param kwargs: + :return: An estimate of the functional for ths model. + """ + raise NotImplementedError() diff --git a/docs/source/fd_scratch.ipynb b/docs/source/fd_scratch.ipynb new file mode 100644 index 000000000..70e160e77 --- /dev/null +++ b/docs/source/fd_scratch.ipynb @@ -0,0 +1,384 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 6, + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2024-01-11T20:04:15.169353Z", + "start_time": "2024-01-11T20:04:15.160243Z" + } + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "from scipy.integrate import nquad\n", + "from scipy.stats import multivariate_normal\n", + "\n", + "# Thanks chatgpt for this cell.\n", + "\n", + "def squared_multivariate_normal_density(x, mean, cov):\n", + " \"\"\"\n", + " Squared density function of a multivariate normal distribution.\n", + "\n", + " :param x: A point in n-dimensional space.\n", + " :param mean: Mean vector of the multivariate normal distribution.\n", + " :param cov: Covariance matrix of the multivariate normal distribution.\n", + " :return: Squared density at the point x.\n", + " \"\"\"\n", + " density = multivariate_normal.pdf(x, mean, cov)\n", + " return density ** 2\n", + "\n", + "def integrate_squared_density_multivariate(mean, cov, dims):\n", + " \"\"\"\n", + " Numerically approximate the integrated squared density of a multivariate normal distribution.\n", + "\n", + " :param mean: Mean vector of the multivariate normal distribution.\n", + " :param cov: Covariance matrix of the multivariate normal distribution.\n", + " :param dims: Number of dimensions.\n", + " :return: Numerical approximation of the integrated squared density.\n", + " \"\"\"\n", + " # Integration limits for each dimension\n", + " limits = [(-6, 6)] * dims\n", + "\n", + " # Wrapper function to adapt the multivariate function for nquad\n", + " def integrand(*args):\n", + " return squared_multivariate_normal_density(np.array(args), mean, cov)\n", + "\n", + " return nquad(integrand, limits)[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "outputs": [], + "source": [ + "# quad_results = []\n", + "# # Integrated squared densities\n", + "# for d in [1, 2, 3]:\n", + "# # Example usage\n", + "# mean = np.zeros(d)\n", + "# cov = np.eye(d)\n", + "# \n", + "# quad_result = integrate_squared_density_multivariate(mean, cov, d)\n", + "# quad_results.append(quad_result)\n", + "# print(f\"with ndim={d}; \\int p(x)^2 dx =\", quad_result)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-11T20:04:15.169506Z", + "start_time": "2024-01-11T20:04:15.166068Z" + } + }, + "id": "646a660b67fa99fb" + }, + { + "cell_type": "code", + "execution_count": 8, + "outputs": [], + "source": [ + "from chirho.robust.ops import fd_influence_fn\n", + "import pyro\n", + "import pyro.distributions as dist\n", + "import torch" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-11T20:04:15.176080Z", + "start_time": "2024-01-11T20:04:15.169911Z" + } + }, + "id": "36f8603d2044d036" + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [], + "source": [ + "def diagnormal(mean=0, std=1):\n", + " return dict(x=pyro.sample('x', dist.Normal(mean, std)))" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-11T20:04:15.176196Z", + "start_time": "2024-01-11T20:04:15.172819Z" + } + }, + "id": "82c73c5bbf49581d" + }, + { + "cell_type": "code", + "execution_count": 10, + "outputs": [], + "source": [ + "# Generate training data.\n", + "datas = []\n", + "N = 100\n", + "\n", + "for d in [1, 2, 3]:\n", + " with pyro.plate('N', N, dim=-2):\n", + " with pyro.plate('d', d, dim=-1):\n", + " datas.append(diagnormal(0, 1))" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-11T20:04:15.182621Z", + "start_time": "2024-01-11T20:04:15.176516Z" + } + }, + "id": "2fff3ef4c1360b7a" + }, + { + "cell_type": "code", + "execution_count": 15, + "outputs": [], + "source": [ + "# Train models on the first half of the data.\n", + "inferred_models = []\n", + "for data in datas:\n", + " # A model fit to the first half of the data.\n", + " mean = torch.mean(data['x'][:N//2])\n", + " std = torch.std(data['x'][:N//2])\n", + " d = data['x'].shape[-1]\n", + " def _model():\n", + " with pyro.plate('d', d, dim=-1):\n", + " return diagnormal(mean=mean, std=std)\n", + " inferred_models.append(_model)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-11T20:05:43.095256Z", + "start_time": "2024-01-11T20:05:43.090240Z" + } + }, + "id": "2ba8b23ee37e52d4" + }, + { + "cell_type": "code", + "execution_count": 16, + "outputs": [], + "source": [ + "def squared_density_functional(model):\n", + " def target(d, nmc=100):\n", + " res = 0\n", + " for _ in range(nmc):\n", + " with pyro.poutine.trace() as tr:\n", + " model()\n", + " res += tr.trace.log_prob_sum().exp() / N\n", + " return res\n", + " return target" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-11T20:05:43.711955Z", + "start_time": "2024-01-11T20:05:43.705190Z" + } + }, + "id": "4d7e81d2cad0bfff" + }, + { + "cell_type": "code", + "execution_count": 18, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "plugin with ndim=1; \\int p(x)^2 dx = tensor(0.2571)\n", + "tensor([[-0.2172],\n", + " [-0.8808],\n", + " [-0.8434],\n", + " [-0.8053],\n", + " [-0.0255],\n", + " [ 1.9628],\n", + " [-0.6141],\n", + " [ 1.5846],\n", + " [ 0.5951],\n", + " [-2.1053],\n", + " [-0.1268],\n", + " [-2.2127],\n", + " [ 0.5829],\n", + " [-0.0794],\n", + " [-1.4608],\n", + " [-2.0613],\n", + " [-0.8474],\n", + " [-1.6414],\n", + " [ 0.2076],\n", + " [ 0.8621],\n", + " [ 0.3099],\n", + " [ 0.4285],\n", + " [-0.3761],\n", + " [-0.4787],\n", + " [ 0.5645],\n", + " [ 0.8255],\n", + " [ 0.2588],\n", + " [-2.1734],\n", + " [ 0.7421],\n", + " [-0.5980],\n", + " [-1.7096],\n", + " [ 0.1787],\n", + " [ 0.5939],\n", + " [-0.6935],\n", + " [-0.4807],\n", + " [ 2.2234],\n", + " [ 0.0458],\n", + " [-0.6098],\n", + " [-0.3696],\n", + " [ 1.9040],\n", + " [-1.0408],\n", + " [-0.7124],\n", + " [ 0.3746],\n", + " [-0.5014],\n", + " [ 0.3685],\n", + " [-0.7191],\n", + " [ 1.1040],\n", + " [-0.1124],\n", + " [-1.0923],\n", + " [ 1.0834]])\n", + "correction with ndim=1; \\int p(x)^2 dx = tensor(-7.8416)\n", + "corrected with ndim=1; \\int p(x)^2 dx = tensor(-7.5844)\n", + "plugin with ndim=2; \\int p(x)^2 dx = tensor(0.0710)\n", + "tensor([[ 1.0300, 0.5128],\n", + " [-2.1107, 1.0861],\n", + " [-0.2092, 0.5601],\n", + " [ 1.8190, 0.5242],\n", + " [-0.9258, -1.1970],\n", + " [ 0.1553, -1.4582],\n", + " [-0.9955, 0.8934],\n", + " [ 0.0463, 0.1219],\n", + " [-0.1988, 0.0465],\n", + " [-0.1220, -0.5855],\n", + " [-1.5542, -1.4126],\n", + " [ 0.0987, -0.4780],\n", + " [ 1.5085, 1.2975],\n", + " [ 0.9800, -0.5606],\n", + " [-0.1584, -1.5427],\n", + " [-1.1436, -2.4300],\n", + " [-0.4134, -1.5894],\n", + " [-0.0390, -0.2948],\n", + " [ 0.5064, 0.0847],\n", + " [ 2.4972, -0.4473],\n", + " [ 2.3271, -0.1996],\n", + " [-1.4356, 1.2007],\n", + " [-0.0953, -0.9698],\n", + " [-0.1491, 0.4237],\n", + " [-0.1791, 0.2849],\n", + " [ 0.4424, -0.5289],\n", + " [ 0.5607, -0.4383],\n", + " [-0.8822, -1.1802],\n", + " [-0.3455, -1.5077],\n", + " [ 0.4472, -2.1917],\n", + " [-1.0550, -1.7937],\n", + " [ 1.1272, 0.7234],\n", + " [-0.3967, -0.5824],\n", + " [-1.1693, 0.2371],\n", + " [-0.2727, -2.2177],\n", + " [-1.5342, 1.0350],\n", + " [-0.0880, -0.1815],\n", + " [-1.0534, 1.4517],\n", + " [-1.5974, -0.8562],\n", + " [ 0.9563, -1.3292],\n", + " [-0.4638, -0.0427],\n", + " [-1.6517, 1.7289],\n", + " [-0.2355, 1.1272],\n", + " [-0.5346, -0.1717],\n", + " [-0.2318, 1.0640],\n", + " [ 0.9861, 1.7652],\n", + " [-0.0561, 1.0491],\n", + " [-0.3232, -0.4149],\n", + " [ 1.0538, -0.3474],\n", + " [ 0.7251, -0.5657]])\n" + ] + }, + { + "ename": "AttributeError", + "evalue": "'Tensor' object has no attribute 'items'", + "output_type": "error", + "traceback": [ + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mAttributeError\u001B[0m Traceback (most recent call last)", + "Cell \u001B[0;32mIn[18], line 14\u001B[0m\n\u001B[1;32m 12\u001B[0m \u001B[38;5;28mprint\u001B[39m(data)\n\u001B[1;32m 13\u001B[0m eif \u001B[38;5;241m=\u001B[39m fd_influence_fn(inferred_model, squared_density_functional, eps\u001B[38;5;241m=\u001B[39mtorch\u001B[38;5;241m.\u001B[39mtensor(\u001B[38;5;241m1e-3\u001B[39m))\n\u001B[0;32m---> 14\u001B[0m correction_result \u001B[38;5;241m=\u001B[39m \u001B[43meif\u001B[49m\u001B[43m(\u001B[49m\u001B[43mdata\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43md\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43md\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 15\u001B[0m correction_results\u001B[38;5;241m.\u001B[39mappend(correction_result)\n\u001B[1;32m 16\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mcorrection with ndim=\u001B[39m\u001B[38;5;132;01m{\u001B[39;00md\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m; \u001B[39m\u001B[38;5;124m\\\u001B[39m\u001B[38;5;124mint p(x)^2 dx =\u001B[39m\u001B[38;5;124m\"\u001B[39m, correction_result)\n", + "File \u001B[0;32m~/GitRepo/causal_pyro/chirho/robust/ops.py:187\u001B[0m, in \u001B[0;36mfd_influence_fn.._fn\u001B[0;34m(points, *args, **kwargs)\u001B[0m\n\u001B[1;32m 184\u001B[0m target_perturbed \u001B[38;5;241m=\u001B[39m functional(perturbed_model)\n\u001B[1;32m 186\u001B[0m \u001B[38;5;66;03m# FIXME bdbjdis vmap with func_target etc. here, this is basically just pseudo code right now.\u001B[39;00m\n\u001B[0;32m--> 187\u001B[0m t_p_eps \u001B[38;5;241m=\u001B[39m \u001B[43mtarget_perturbed\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 188\u001B[0m t_p_hat \u001B[38;5;241m=\u001B[39m target(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[1;32m 189\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m (t_p_eps \u001B[38;5;241m-\u001B[39m t_p_hat) \u001B[38;5;241m/\u001B[39m eps\n", + "Cell \u001B[0;32mIn[16], line 6\u001B[0m, in \u001B[0;36msquared_density_functional..target\u001B[0;34m(d, nmc)\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m _ \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(nmc):\n\u001B[1;32m 5\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m pyro\u001B[38;5;241m.\u001B[39mpoutine\u001B[38;5;241m.\u001B[39mtrace() \u001B[38;5;28;01mas\u001B[39;00m tr:\n\u001B[0;32m----> 6\u001B[0m \u001B[43mmodel\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 7\u001B[0m res \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m tr\u001B[38;5;241m.\u001B[39mtrace\u001B[38;5;241m.\u001B[39mlog_prob_sum()\u001B[38;5;241m.\u001B[39mexp() \u001B[38;5;241m/\u001B[39m N\n\u001B[1;32m 8\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m res\n", + "File \u001B[0;32m~/miniconda3/envs/basis_general/lib/python3.11/site-packages/torch/nn/modules/module.py:1501\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[0;34m(self, *args, **kwargs)\u001B[0m\n\u001B[1;32m 1496\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[1;32m 1497\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[1;32m 1498\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[1;32m 1499\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[1;32m 1500\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[0;32m-> 1501\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 1502\u001B[0m \u001B[38;5;66;03m# Do not call functions when jit is used\u001B[39;00m\n\u001B[1;32m 1503\u001B[0m full_backward_hooks, non_full_backward_hooks \u001B[38;5;241m=\u001B[39m [], []\n", + "File \u001B[0;32m~/GitRepo/causal_pyro/chirho/robust/handlers/predictive.py:75\u001B[0m, in \u001B[0;36mKernelPerturbedModel.forward\u001B[0;34m(self, *args, **kwargs)\u001B[0m\n\u001B[1;32m 73\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m _from_kernel:\n\u001B[1;32m 74\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m pyro\u001B[38;5;241m.\u001B[39mpoutine\u001B[38;5;241m.\u001B[39mtrace() \u001B[38;5;28;01mas\u001B[39;00m kernel_tr:\n\u001B[0;32m---> 75\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mkernel\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mpoints\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 76\u001B[0m replay_from_kernel \u001B[38;5;241m=\u001B[39m pyro\u001B[38;5;241m.\u001B[39mpoutine\u001B[38;5;241m.\u001B[39mreplay(trace\u001B[38;5;241m=\u001B[39mkernel_tr\u001B[38;5;241m.\u001B[39mtrace)\n\u001B[1;32m 78\u001B[0m \u001B[38;5;66;03m# This prevents any outer traces from seeing the kernel sites twice, once in the kernel and again\u001B[39;00m\n\u001B[1;32m 79\u001B[0m \u001B[38;5;66;03m# in the model.\u001B[39;00m\n", + "File \u001B[0;32m~/GitRepo/causal_pyro/chirho/robust/handlers/predictive.py:43\u001B[0m, in \u001B[0;36mKernelPerturbedModel.__init__..kernel\u001B[0;34m(_kernel_loc)\u001B[0m\n\u001B[1;32m 41\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mkernel\u001B[39m(_kernel_loc: Point[T]):\n\u001B[1;32m 42\u001B[0m ret \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mdict\u001B[39m()\n\u001B[0;32m---> 43\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m name, value \u001B[38;5;129;01min\u001B[39;00m \u001B[43m_kernel_loc\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mitems\u001B[49m():\n\u001B[1;32m 44\u001B[0m ret[name] \u001B[38;5;241m=\u001B[39m pyro\u001B[38;5;241m.\u001B[39msample(name, dist\u001B[38;5;241m.\u001B[39mDelta(value))\n\u001B[1;32m 45\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m ret\n", + "\u001B[0;31mAttributeError\u001B[0m: 'Tensor' object has no attribute 'items'" + ] + } + ], + "source": [ + "plugin_results = []\n", + "correction_results = []\n", + "corrected_results = []\n", + "for d, inferred_model, alldata in zip([1, 2, 3], inferred_models, datas):\n", + " # Compute plugin.\n", + " plugin_result = squared_density_functional(inferred_model)(d)\n", + " plugin_results.append(plugin_result)\n", + " print(f\"plugin with ndim={d}; \\int p(x)^2 dx =\", plugin_result)\n", + " \n", + " # Estimate the expected influence function on the second half of the data.\n", + " data = alldata['x'][N//2:]\n", + " print(data)\n", + " eif = fd_influence_fn(inferred_model, squared_density_functional, eps=torch.tensor(1e-3))\n", + " correction_result = eif(data, d=d)\n", + " correction_results.append(correction_result)\n", + " print(f\"correction with ndim={d}; \\int p(x)^2 dx =\", correction_result)\n", + " \n", + " corrected_result = plugin_result + correction_result\n", + " corrected_results.append(corrected_result)\n", + " print(f\"corrected with ndim={d}; \\int p(x)^2 dx =\", corrected_result)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-11T20:06:23.022754Z", + "start_time": "2024-01-11T20:06:22.921330Z" + } + }, + "id": "66e7f29d984cb010" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2024-01-11T20:04:15.299976Z" + } + }, + "id": "46379e54f229c5b" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/robust_fd/__init__.py b/docs/source/robust_fd/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/docs/source/robust_fd/squared_normal_density.py b/docs/source/robust_fd/squared_normal_density.py new file mode 100644 index 000000000..07c5bd497 --- /dev/null +++ b/docs/source/robust_fd/squared_normal_density.py @@ -0,0 +1,52 @@ +from chirho.robust.handlers.fd_model import FDModel, ModelWithMarginalDensity +import pyro +import pyro.distributions as dist +import torch +from scipy.stats import multivariate_normal +from scipy.integrate import nquad +import numpy as np + + +class FDMultivariateNormal(ModelWithMarginalDensity): + + def __init__(self, mean, cov): + super().__init__() + + self.mean = mean + self.cov = cov + + def density(self, x): + return multivariate_normal.pdf(x, mean=self.mean, cov=self.cov) + + def forward(self): + return pyro.sample("x", dist.MultivariateNormal(self.mean, self.cov)) + + +class _SquaredNormalDensity(FDModel): + + def __init__(self, *args, mean, cov, lambda_: float, **kwargs): + ndims = mean.shape[-1] + super().__init__(*args, **kwargs) + + self.model = FDMultivariateNormal(mean, cov) + self.kernel = FDMultivariateNormal(torch.zeros(ndims), torch.eye(ndims) * lambda_) + self.lambda_ = lambda_ + + self.mean = mean + self.cov = cov + + +class SquaredNormalDensityQuad(_SquaredNormalDensity): + """ + Compute the squared normal density using quadrature. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def functional(self): + def integrand(*args): + model_kwargs = kernel_kwargs = dict(x=np.array(args)) + return self.density(model_kwargs, kernel_kwargs) ** 2 + + return nquad(integrand, [[-np.inf, np.inf]] * self.mean.shape[-1])[0] \ No newline at end of file diff --git a/docs/source/robust_fd_scratch.py b/docs/source/robust_fd_scratch.py new file mode 100644 index 000000000..dabfd8008 --- /dev/null +++ b/docs/source/robust_fd_scratch.py @@ -0,0 +1,29 @@ +from robust_fd.squared_normal_density import SquaredNormalDensityQuad +import numpy as np +import torch +import matplotlib.pyplot as plt + +ndim = 1 +eps = 0.01 +mean = torch.tensor([0.,] * ndim) +cov = torch.eye(ndim) +lambda_ = 0.01 + +sndq = SquaredNormalDensityQuad( + mean=mean, + cov=cov, + eps=eps, + lambda_=lambda_, +) + +print(sndq.functional()) + +xx = np.linspace(-5, 5, 1000) +yy = [sndq.density( + {'x': torch.tensor([x])}, + {'x': torch.tensor([x])}) + for x in xx +] + +plt.plot(xx, yy) +plt.show() From ad519bec0cb6d2291cdd53ccab6b0adbfe5d7c7a Mon Sep 17 00:00:00 2001 From: Andy Zane Date: Fri, 12 Jan 2024 15:13:24 -0500 Subject: [PATCH 60/66] removes old scratch notebook --- docs/source/fd_scratch.ipynb | 384 ----------------------------------- 1 file changed, 384 deletions(-) delete mode 100644 docs/source/fd_scratch.ipynb diff --git a/docs/source/fd_scratch.ipynb b/docs/source/fd_scratch.ipynb deleted file mode 100644 index 70e160e77..000000000 --- a/docs/source/fd_scratch.ipynb +++ /dev/null @@ -1,384 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 6, - "id": "initial_id", - "metadata": { - "collapsed": true, - "ExecuteTime": { - "end_time": "2024-01-11T20:04:15.169353Z", - "start_time": "2024-01-11T20:04:15.160243Z" - } - }, - "outputs": [], - "source": [ - "import numpy as np\n", - "from scipy.integrate import nquad\n", - "from scipy.stats import multivariate_normal\n", - "\n", - "# Thanks chatgpt for this cell.\n", - "\n", - "def squared_multivariate_normal_density(x, mean, cov):\n", - " \"\"\"\n", - " Squared density function of a multivariate normal distribution.\n", - "\n", - " :param x: A point in n-dimensional space.\n", - " :param mean: Mean vector of the multivariate normal distribution.\n", - " :param cov: Covariance matrix of the multivariate normal distribution.\n", - " :return: Squared density at the point x.\n", - " \"\"\"\n", - " density = multivariate_normal.pdf(x, mean, cov)\n", - " return density ** 2\n", - "\n", - "def integrate_squared_density_multivariate(mean, cov, dims):\n", - " \"\"\"\n", - " Numerically approximate the integrated squared density of a multivariate normal distribution.\n", - "\n", - " :param mean: Mean vector of the multivariate normal distribution.\n", - " :param cov: Covariance matrix of the multivariate normal distribution.\n", - " :param dims: Number of dimensions.\n", - " :return: Numerical approximation of the integrated squared density.\n", - " \"\"\"\n", - " # Integration limits for each dimension\n", - " limits = [(-6, 6)] * dims\n", - "\n", - " # Wrapper function to adapt the multivariate function for nquad\n", - " def integrand(*args):\n", - " return squared_multivariate_normal_density(np.array(args), mean, cov)\n", - "\n", - " return nquad(integrand, limits)[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "outputs": [], - "source": [ - "# quad_results = []\n", - "# # Integrated squared densities\n", - "# for d in [1, 2, 3]:\n", - "# # Example usage\n", - "# mean = np.zeros(d)\n", - "# cov = np.eye(d)\n", - "# \n", - "# quad_result = integrate_squared_density_multivariate(mean, cov, d)\n", - "# quad_results.append(quad_result)\n", - "# print(f\"with ndim={d}; \\int p(x)^2 dx =\", quad_result)" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-01-11T20:04:15.169506Z", - "start_time": "2024-01-11T20:04:15.166068Z" - } - }, - "id": "646a660b67fa99fb" - }, - { - "cell_type": "code", - "execution_count": 8, - "outputs": [], - "source": [ - "from chirho.robust.ops import fd_influence_fn\n", - "import pyro\n", - "import pyro.distributions as dist\n", - "import torch" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-01-11T20:04:15.176080Z", - "start_time": "2024-01-11T20:04:15.169911Z" - } - }, - "id": "36f8603d2044d036" - }, - { - "cell_type": "code", - "execution_count": 9, - "outputs": [], - "source": [ - "def diagnormal(mean=0, std=1):\n", - " return dict(x=pyro.sample('x', dist.Normal(mean, std)))" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-01-11T20:04:15.176196Z", - "start_time": "2024-01-11T20:04:15.172819Z" - } - }, - "id": "82c73c5bbf49581d" - }, - { - "cell_type": "code", - "execution_count": 10, - "outputs": [], - "source": [ - "# Generate training data.\n", - "datas = []\n", - "N = 100\n", - "\n", - "for d in [1, 2, 3]:\n", - " with pyro.plate('N', N, dim=-2):\n", - " with pyro.plate('d', d, dim=-1):\n", - " datas.append(diagnormal(0, 1))" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-01-11T20:04:15.182621Z", - "start_time": "2024-01-11T20:04:15.176516Z" - } - }, - "id": "2fff3ef4c1360b7a" - }, - { - "cell_type": "code", - "execution_count": 15, - "outputs": [], - "source": [ - "# Train models on the first half of the data.\n", - "inferred_models = []\n", - "for data in datas:\n", - " # A model fit to the first half of the data.\n", - " mean = torch.mean(data['x'][:N//2])\n", - " std = torch.std(data['x'][:N//2])\n", - " d = data['x'].shape[-1]\n", - " def _model():\n", - " with pyro.plate('d', d, dim=-1):\n", - " return diagnormal(mean=mean, std=std)\n", - " inferred_models.append(_model)" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-01-11T20:05:43.095256Z", - "start_time": "2024-01-11T20:05:43.090240Z" - } - }, - "id": "2ba8b23ee37e52d4" - }, - { - "cell_type": "code", - "execution_count": 16, - "outputs": [], - "source": [ - "def squared_density_functional(model):\n", - " def target(d, nmc=100):\n", - " res = 0\n", - " for _ in range(nmc):\n", - " with pyro.poutine.trace() as tr:\n", - " model()\n", - " res += tr.trace.log_prob_sum().exp() / N\n", - " return res\n", - " return target" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-01-11T20:05:43.711955Z", - "start_time": "2024-01-11T20:05:43.705190Z" - } - }, - "id": "4d7e81d2cad0bfff" - }, - { - "cell_type": "code", - "execution_count": 18, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "plugin with ndim=1; \\int p(x)^2 dx = tensor(0.2571)\n", - "tensor([[-0.2172],\n", - " [-0.8808],\n", - " [-0.8434],\n", - " [-0.8053],\n", - " [-0.0255],\n", - " [ 1.9628],\n", - " [-0.6141],\n", - " [ 1.5846],\n", - " [ 0.5951],\n", - " [-2.1053],\n", - " [-0.1268],\n", - " [-2.2127],\n", - " [ 0.5829],\n", - " [-0.0794],\n", - " [-1.4608],\n", - " [-2.0613],\n", - " [-0.8474],\n", - " [-1.6414],\n", - " [ 0.2076],\n", - " [ 0.8621],\n", - " [ 0.3099],\n", - " [ 0.4285],\n", - " [-0.3761],\n", - " [-0.4787],\n", - " [ 0.5645],\n", - " [ 0.8255],\n", - " [ 0.2588],\n", - " [-2.1734],\n", - " [ 0.7421],\n", - " [-0.5980],\n", - " [-1.7096],\n", - " [ 0.1787],\n", - " [ 0.5939],\n", - " [-0.6935],\n", - " [-0.4807],\n", - " [ 2.2234],\n", - " [ 0.0458],\n", - " [-0.6098],\n", - " [-0.3696],\n", - " [ 1.9040],\n", - " [-1.0408],\n", - " [-0.7124],\n", - " [ 0.3746],\n", - " [-0.5014],\n", - " [ 0.3685],\n", - " [-0.7191],\n", - " [ 1.1040],\n", - " [-0.1124],\n", - " [-1.0923],\n", - " [ 1.0834]])\n", - "correction with ndim=1; \\int p(x)^2 dx = tensor(-7.8416)\n", - "corrected with ndim=1; \\int p(x)^2 dx = tensor(-7.5844)\n", - "plugin with ndim=2; \\int p(x)^2 dx = tensor(0.0710)\n", - "tensor([[ 1.0300, 0.5128],\n", - " [-2.1107, 1.0861],\n", - " [-0.2092, 0.5601],\n", - " [ 1.8190, 0.5242],\n", - " [-0.9258, -1.1970],\n", - " [ 0.1553, -1.4582],\n", - " [-0.9955, 0.8934],\n", - " [ 0.0463, 0.1219],\n", - " [-0.1988, 0.0465],\n", - " [-0.1220, -0.5855],\n", - " [-1.5542, -1.4126],\n", - " [ 0.0987, -0.4780],\n", - " [ 1.5085, 1.2975],\n", - " [ 0.9800, -0.5606],\n", - " [-0.1584, -1.5427],\n", - " [-1.1436, -2.4300],\n", - " [-0.4134, -1.5894],\n", - " [-0.0390, -0.2948],\n", - " [ 0.5064, 0.0847],\n", - " [ 2.4972, -0.4473],\n", - " [ 2.3271, -0.1996],\n", - " [-1.4356, 1.2007],\n", - " [-0.0953, -0.9698],\n", - " [-0.1491, 0.4237],\n", - " [-0.1791, 0.2849],\n", - " [ 0.4424, -0.5289],\n", - " [ 0.5607, -0.4383],\n", - " [-0.8822, -1.1802],\n", - " [-0.3455, -1.5077],\n", - " [ 0.4472, -2.1917],\n", - " [-1.0550, -1.7937],\n", - " [ 1.1272, 0.7234],\n", - " [-0.3967, -0.5824],\n", - " [-1.1693, 0.2371],\n", - " [-0.2727, -2.2177],\n", - " [-1.5342, 1.0350],\n", - " [-0.0880, -0.1815],\n", - " [-1.0534, 1.4517],\n", - " [-1.5974, -0.8562],\n", - " [ 0.9563, -1.3292],\n", - " [-0.4638, -0.0427],\n", - " [-1.6517, 1.7289],\n", - " [-0.2355, 1.1272],\n", - " [-0.5346, -0.1717],\n", - " [-0.2318, 1.0640],\n", - " [ 0.9861, 1.7652],\n", - " [-0.0561, 1.0491],\n", - " [-0.3232, -0.4149],\n", - " [ 1.0538, -0.3474],\n", - " [ 0.7251, -0.5657]])\n" - ] - }, - { - "ename": "AttributeError", - "evalue": "'Tensor' object has no attribute 'items'", - "output_type": "error", - "traceback": [ - "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[0;31mAttributeError\u001B[0m Traceback (most recent call last)", - "Cell \u001B[0;32mIn[18], line 14\u001B[0m\n\u001B[1;32m 12\u001B[0m \u001B[38;5;28mprint\u001B[39m(data)\n\u001B[1;32m 13\u001B[0m eif \u001B[38;5;241m=\u001B[39m fd_influence_fn(inferred_model, squared_density_functional, eps\u001B[38;5;241m=\u001B[39mtorch\u001B[38;5;241m.\u001B[39mtensor(\u001B[38;5;241m1e-3\u001B[39m))\n\u001B[0;32m---> 14\u001B[0m correction_result \u001B[38;5;241m=\u001B[39m \u001B[43meif\u001B[49m\u001B[43m(\u001B[49m\u001B[43mdata\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43md\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43md\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 15\u001B[0m correction_results\u001B[38;5;241m.\u001B[39mappend(correction_result)\n\u001B[1;32m 16\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mcorrection with ndim=\u001B[39m\u001B[38;5;132;01m{\u001B[39;00md\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m; \u001B[39m\u001B[38;5;124m\\\u001B[39m\u001B[38;5;124mint p(x)^2 dx =\u001B[39m\u001B[38;5;124m\"\u001B[39m, correction_result)\n", - "File \u001B[0;32m~/GitRepo/causal_pyro/chirho/robust/ops.py:187\u001B[0m, in \u001B[0;36mfd_influence_fn.._fn\u001B[0;34m(points, *args, **kwargs)\u001B[0m\n\u001B[1;32m 184\u001B[0m target_perturbed \u001B[38;5;241m=\u001B[39m functional(perturbed_model)\n\u001B[1;32m 186\u001B[0m \u001B[38;5;66;03m# FIXME bdbjdis vmap with func_target etc. here, this is basically just pseudo code right now.\u001B[39;00m\n\u001B[0;32m--> 187\u001B[0m t_p_eps \u001B[38;5;241m=\u001B[39m \u001B[43mtarget_perturbed\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 188\u001B[0m t_p_hat \u001B[38;5;241m=\u001B[39m target(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[1;32m 189\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m (t_p_eps \u001B[38;5;241m-\u001B[39m t_p_hat) \u001B[38;5;241m/\u001B[39m eps\n", - "Cell \u001B[0;32mIn[16], line 6\u001B[0m, in \u001B[0;36msquared_density_functional..target\u001B[0;34m(d, nmc)\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m _ \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(nmc):\n\u001B[1;32m 5\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m pyro\u001B[38;5;241m.\u001B[39mpoutine\u001B[38;5;241m.\u001B[39mtrace() \u001B[38;5;28;01mas\u001B[39;00m tr:\n\u001B[0;32m----> 6\u001B[0m \u001B[43mmodel\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 7\u001B[0m res \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m tr\u001B[38;5;241m.\u001B[39mtrace\u001B[38;5;241m.\u001B[39mlog_prob_sum()\u001B[38;5;241m.\u001B[39mexp() \u001B[38;5;241m/\u001B[39m N\n\u001B[1;32m 8\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m res\n", - "File \u001B[0;32m~/miniconda3/envs/basis_general/lib/python3.11/site-packages/torch/nn/modules/module.py:1501\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[0;34m(self, *args, **kwargs)\u001B[0m\n\u001B[1;32m 1496\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[1;32m 1497\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[1;32m 1498\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[1;32m 1499\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[1;32m 1500\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[0;32m-> 1501\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 1502\u001B[0m \u001B[38;5;66;03m# Do not call functions when jit is used\u001B[39;00m\n\u001B[1;32m 1503\u001B[0m full_backward_hooks, non_full_backward_hooks \u001B[38;5;241m=\u001B[39m [], []\n", - "File \u001B[0;32m~/GitRepo/causal_pyro/chirho/robust/handlers/predictive.py:75\u001B[0m, in \u001B[0;36mKernelPerturbedModel.forward\u001B[0;34m(self, *args, **kwargs)\u001B[0m\n\u001B[1;32m 73\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m _from_kernel:\n\u001B[1;32m 74\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m pyro\u001B[38;5;241m.\u001B[39mpoutine\u001B[38;5;241m.\u001B[39mtrace() \u001B[38;5;28;01mas\u001B[39;00m kernel_tr:\n\u001B[0;32m---> 75\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mkernel\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mpoints\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 76\u001B[0m replay_from_kernel \u001B[38;5;241m=\u001B[39m pyro\u001B[38;5;241m.\u001B[39mpoutine\u001B[38;5;241m.\u001B[39mreplay(trace\u001B[38;5;241m=\u001B[39mkernel_tr\u001B[38;5;241m.\u001B[39mtrace)\n\u001B[1;32m 78\u001B[0m \u001B[38;5;66;03m# This prevents any outer traces from seeing the kernel sites twice, once in the kernel and again\u001B[39;00m\n\u001B[1;32m 79\u001B[0m \u001B[38;5;66;03m# in the model.\u001B[39;00m\n", - "File \u001B[0;32m~/GitRepo/causal_pyro/chirho/robust/handlers/predictive.py:43\u001B[0m, in \u001B[0;36mKernelPerturbedModel.__init__..kernel\u001B[0;34m(_kernel_loc)\u001B[0m\n\u001B[1;32m 41\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mkernel\u001B[39m(_kernel_loc: Point[T]):\n\u001B[1;32m 42\u001B[0m ret \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mdict\u001B[39m()\n\u001B[0;32m---> 43\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m name, value \u001B[38;5;129;01min\u001B[39;00m \u001B[43m_kernel_loc\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mitems\u001B[49m():\n\u001B[1;32m 44\u001B[0m ret[name] \u001B[38;5;241m=\u001B[39m pyro\u001B[38;5;241m.\u001B[39msample(name, dist\u001B[38;5;241m.\u001B[39mDelta(value))\n\u001B[1;32m 45\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m ret\n", - "\u001B[0;31mAttributeError\u001B[0m: 'Tensor' object has no attribute 'items'" - ] - } - ], - "source": [ - "plugin_results = []\n", - "correction_results = []\n", - "corrected_results = []\n", - "for d, inferred_model, alldata in zip([1, 2, 3], inferred_models, datas):\n", - " # Compute plugin.\n", - " plugin_result = squared_density_functional(inferred_model)(d)\n", - " plugin_results.append(plugin_result)\n", - " print(f\"plugin with ndim={d}; \\int p(x)^2 dx =\", plugin_result)\n", - " \n", - " # Estimate the expected influence function on the second half of the data.\n", - " data = alldata['x'][N//2:]\n", - " print(data)\n", - " eif = fd_influence_fn(inferred_model, squared_density_functional, eps=torch.tensor(1e-3))\n", - " correction_result = eif(data, d=d)\n", - " correction_results.append(correction_result)\n", - " print(f\"correction with ndim={d}; \\int p(x)^2 dx =\", correction_result)\n", - " \n", - " corrected_result = plugin_result + correction_result\n", - " corrected_results.append(corrected_result)\n", - " print(f\"corrected with ndim={d}; \\int p(x)^2 dx =\", corrected_result)" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-01-11T20:06:23.022754Z", - "start_time": "2024-01-11T20:06:22.921330Z" - } - }, - "id": "66e7f29d984cb010" - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "start_time": "2024-01-11T20:04:15.299976Z" - } - }, - "id": "46379e54f229c5b" - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From 1efe6ea221da6f47df8ff9e71f4e86aa401e4bc7 Mon Sep 17 00:00:00 2001 From: Andy Zane Date: Fri, 12 Jan 2024 17:30:41 -0500 Subject: [PATCH 61/66] gets squared density running under abstraction that couples functionals and models --- chirho/robust/handlers/fd_model.py | 104 ++++++++++++++++-- .../robust_fd/squared_normal_density.py | 43 ++++++-- docs/source/robust_fd_scratch.py | 60 ++++++++-- 3 files changed, 177 insertions(+), 30 deletions(-) diff --git a/chirho/robust/handlers/fd_model.py b/chirho/robust/handlers/fd_model.py index df067f857..ba586528d 100644 --- a/chirho/robust/handlers/fd_model.py +++ b/chirho/robust/handlers/fd_model.py @@ -1,41 +1,97 @@ import torch import pyro import pyro.distributions as dist -from typing import Dict +from typing import Dict, Optional +from contextlib import contextmanager +from chirho.robust.ops import Functional, Point, T class ModelWithMarginalDensity(torch.nn.Module): def density(self, *args, **kwargs): + # TODO this can probably default to using BatchedNMCLogMarginalLikelihood applied to self, + # but providing here to avail of analytic densities. Or have a constructor that takes a + # regular model and puts the marginal density here. raise NotImplementedError() def forward(self, *args, **kwargs): raise NotImplementedError() -class FDModel(ModelWithMarginalDensity): +class FDModelFunctionalDensity(ModelWithMarginalDensity): + """ + This class serves to couple the forward sampling model, density, and functional. Finite differencing + operates in the space of densities, and therefore requires of its functionals that they "know about" + the causal structure of the generative model. Thus, the three components are coupled together here. + + """ model: ModelWithMarginalDensity - kernel: ModelWithMarginalDensity - def __init__(self, eps=0.): + # TODO These managers are weird but lets you define a valid model at init time and then temporarily + # modify the perturbation later, eg. in the influence function approximatoin. + # TODO pull out boilerplate + @contextmanager + def set_eps(self, eps): + original_eps = self._eps + self._eps = eps + try: + yield + finally: + self._eps = original_eps + + @contextmanager + def set_lambda(self, lambda_): + original_lambda = self._lambda + self._lambda = lambda_ + try: + yield + finally: + self._lambda = original_lambda + + @contextmanager + def set_kernel_point(self, kernel_point: Dict): + original_kernel_point = self._kernel_point + self._kernel_point = kernel_point + try: + yield + finally: + self._kernel_point = original_kernel_point + + @property + def kernel(self) -> ModelWithMarginalDensity: + # TODO implementation of a kernel could be brought up to this level. User would need to pass a kernel type + # that's parameterized by the kernel point and lambda. + """ + Inheritors should construct the kernel here as a function of self._kernel_point and self._lambda. + :return: + """ + raise NotImplementedError() + + def __init__(self, default_kernel_point: Dict, default_eps=0., default_lambda=0.1): super().__init__() - self.eps = eps - self.weights = torch.tensor([1. - eps, eps]) + self._eps = default_eps + self._lambda = default_lambda + self._kernel_point = default_kernel_point + + @property + def mixture_weights(self): + return torch.tensor([1. - self._eps, self._eps]) def density(self, model_kwargs: Dict, kernel_kwargs: Dict): - mpart = self.weights[0] * self.model.density(**model_kwargs) - kpart = self.weights[1] * self.kernel.density(**kernel_kwargs) + mpart = self.mixture_weights[0] * self.model.density(**model_kwargs) + kpart = self.mixture_weights[1] * self.kernel.density(**kernel_kwargs) return mpart + kpart - def forward(self, model_kwargs: Dict, kernel_kwargs: Dict): - _from_kernel = pyro.sample('_mixture_assignment', dist.Categorical(self.weights)) + def forward(self, model_kwargs: Optional[Dict] = None, kernel_kwargs: Optional[Dict] = None): + _from_kernel = pyro.sample('_mixture_assignment', dist.Categorical(self.mixture_weights)) if _from_kernel: - return self.kernel_forward(**kernel_kwargs) + return self.kernel_forward(**(kernel_kwargs or dict())) else: - return self.model_forward(**model_kwargs) + return self.model_forward(**(model_kwargs or dict())) def functional(self, *args, **kwargs): + # TODO update docstring to this being build_functional instead of just functional """ The functional target for this model. This is tightly coupled to a particular pyro model because finite differencing operates in the space of densities, and @@ -47,3 +103,27 @@ def functional(self, *args, **kwargs): :return: An estimate of the functional for ths model. """ raise NotImplementedError() + + +# TODO move this to chirho/robust/ops.py and resolve signature mismatches? Maybe. The problem is that the ops +# signature (rightly) decouples models and functionals, whereas for finite differencing they must be coupled +# because the functional (in many cases) must know about the causal structure of the model. +def fd_influence_fn(model: FDModelFunctionalDensity, points: Point[T], eps: float, lambda_: float): + + def _influence_fn(*args, **kwargs): + + # Length of first value in points mappping. + len_points = len(list(points.values())[0]) + for i in range(len_points): + kernel_point = {k: v[i] for k, v in points.items()} + + psi_p = model.functional(*args, **kwargs) + + with model.set_eps(eps), model.set_lambda(lambda_), model.set_kernel_point(kernel_point): + psi_p_eps = model.functional(*args, **kwargs) + + return (psi_p_eps - psi_p) / eps + + return _influence_fn + + diff --git a/docs/source/robust_fd/squared_normal_density.py b/docs/source/robust_fd/squared_normal_density.py index 07c5bd497..a6363c436 100644 --- a/docs/source/robust_fd/squared_normal_density.py +++ b/docs/source/robust_fd/squared_normal_density.py @@ -1,4 +1,4 @@ -from chirho.robust.handlers.fd_model import FDModel, ModelWithMarginalDensity +from chirho.robust.handlers.fd_model import FDModelFunctionalDensity, ModelWithMarginalDensity import pyro import pyro.distributions as dist import torch @@ -6,6 +6,11 @@ from scipy.integrate import nquad import numpy as np +# TODO after putting this together, a mixin model would be more appropriate, as we still +# want explicit coupling between models and functionals but it can be M:M. I.e. mixin the +# functional that could apply to a number of models, and/or mixin the model that could work +# with a number of functionals. + class FDMultivariateNormal(ModelWithMarginalDensity): @@ -22,21 +27,27 @@ def forward(self): return pyro.sample("x", dist.MultivariateNormal(self.mean, self.cov)) -class _SquaredNormalDensity(FDModel): +class _ExpectedNormalDensity(FDModelFunctionalDensity): + + @property + def kernel(self): + try: + mean = self._kernel_point['x'] + except TypeError as e: + raise + return FDMultivariateNormal(mean, torch.eye(self.ndims) * self._lambda) - def __init__(self, *args, mean, cov, lambda_: float, **kwargs): - ndims = mean.shape[-1] + def __init__(self, *args, mean, cov, **kwargs): super().__init__(*args, **kwargs) + self.ndims = mean.shape[-1] self.model = FDMultivariateNormal(mean, cov) - self.kernel = FDMultivariateNormal(torch.zeros(ndims), torch.eye(ndims) * lambda_) - self.lambda_ = lambda_ self.mean = mean self.cov = cov -class SquaredNormalDensityQuad(_SquaredNormalDensity): +class ExpectedNormalDensityQuad(_ExpectedNormalDensity): """ Compute the squared normal density using quadrature. """ @@ -49,4 +60,20 @@ def integrand(*args): model_kwargs = kernel_kwargs = dict(x=np.array(args)) return self.density(model_kwargs, kernel_kwargs) ** 2 - return nquad(integrand, [[-np.inf, np.inf]] * self.mean.shape[-1])[0] \ No newline at end of file + return nquad(integrand, [[-np.inf, np.inf]] * self.mean.shape[-1])[0] + + +class ExpectedNormalDensityMC(_ExpectedNormalDensity): + """ + Compute the squared normal density using Monte Carlo. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def functional(self, nmc=1000): + with pyro.plate('samples', nmc): + with pyro.poutine.trace() as tr: + self() + points = tr.trace.nodes['x']['value'] + return torch.mean(self.density(model_kwargs=dict(x=points), kernel_kwargs=dict(x=points))) diff --git a/docs/source/robust_fd_scratch.py b/docs/source/robust_fd_scratch.py index dabfd8008..aaef3b2af 100644 --- a/docs/source/robust_fd_scratch.py +++ b/docs/source/robust_fd_scratch.py @@ -1,4 +1,5 @@ -from robust_fd.squared_normal_density import SquaredNormalDensityQuad +from robust_fd.squared_normal_density import ExpectedNormalDensityQuad, ExpectedNormalDensityMC +from chirho.robust.handlers.fd_model import fd_influence_fn import numpy as np import torch import matplotlib.pyplot as plt @@ -9,21 +10,60 @@ cov = torch.eye(ndim) lambda_ = 0.01 -sndq = SquaredNormalDensityQuad( +end_quad = ExpectedNormalDensityQuad( mean=mean, cov=cov, - eps=eps, - lambda_=lambda_, + default_kernel_point=dict(x=torch.tensor([0.,] * ndim)), + default_eps=eps, + default_lambda=lambda_, ) -print(sndq.functional()) +print(end_quad.functional()) xx = np.linspace(-5, 5, 1000) -yy = [sndq.density( - {'x': torch.tensor([x])}, - {'x': torch.tensor([x])}) - for x in xx -] + +with end_quad.set_kernel_point(dict(x=torch.tensor([1., ] * ndim))), end_quad.set_lambda(.01), end_quad.set_eps(0.1): + yy = [end_quad.density( + {'x': torch.tensor([x])}, + {'x': torch.tensor([x])}) + for x in xx + ] + + end_quad.model + plt.plot(xx, yy) plt.show() + +# Sample points from a slightly more entropoic model. +points = dict(x=torch.linspace(-3, 3, 100)) + +target_quad = fd_influence_fn( + model=end_quad, + points=points, + eps=0.1, + lambda_=0.1, +) + +correction_quad = target_quad() + +print(correction_quad) + +end_mc = ExpectedNormalDensityMC( + mean=mean, + cov=cov, + default_kernel_point=dict(x=torch.tensor([0.,] * ndim)), + default_eps=eps, + default_lambda=lambda_, +) + +target_mc = fd_influence_fn( + model=end_mc, + points=points, + eps=0.1, + lambda_=0.1, +) + +correction_mc = target_mc(nmc=10000000).item() + +print(correction_mc) From 44785d89db70573f42797bd5c01125b3e4a27430 Mon Sep 17 00:00:00 2001 From: Andy Zane Date: Fri, 12 Jan 2024 17:54:12 -0500 Subject: [PATCH 62/66] gets quad and mc approximations to match, vectorization hacky. --- chirho/robust/handlers/fd_model.py | 40 +++++++++++++++++-- .../robust_fd/squared_normal_density.py | 4 +- docs/source/robust_fd_scratch.py | 5 +-- 3 files changed, 39 insertions(+), 10 deletions(-) diff --git a/chirho/robust/handlers/fd_model.py b/chirho/robust/handlers/fd_model.py index ba586528d..c2de2ed47 100644 --- a/chirho/robust/handlers/fd_model.py +++ b/chirho/robust/handlers/fd_model.py @@ -17,6 +17,15 @@ def forward(self, *args, **kwargs): raise NotImplementedError() +class PrefixMessenger(pyro.poutine.messenger.Messenger): + + def __init__(self, prefix: str): + self.prefix = prefix + + def _pyro_sample(self, msg) -> None: + msg["name"] = f"{self.prefix}{msg['name']}" + + class FDModelFunctionalDensity(ModelWithMarginalDensity): """ This class serves to couple the forward sampling model, density, and functional. Finite differencing @@ -83,12 +92,35 @@ def density(self, model_kwargs: Dict, kernel_kwargs: Dict): return mpart + kpart def forward(self, model_kwargs: Optional[Dict] = None, kernel_kwargs: Optional[Dict] = None): + # _from_kernel = pyro.sample('_mixture_assignment', dist.Categorical(self.mixture_weights)) + # + # if _from_kernel: + # return self.kernel(**(kernel_kwargs or dict())) + # else: + # return self.model(**(model_kwargs or dict())) + _from_kernel = pyro.sample('_mixture_assignment', dist.Categorical(self.mixture_weights)) - if _from_kernel: - return self.kernel_forward(**(kernel_kwargs or dict())) - else: - return self.model_forward(**(model_kwargs or dict())) + kernel_mask = _from_kernel.bool() # Convert to boolean mask + + # Apply the respective functions using the masks + with PrefixMessenger('kernel_'), pyro.poutine.trace() as kernel_tr: + kernel_result = self.kernel(**(kernel_kwargs or dict())) + with PrefixMessenger('model_'), pyro.poutine.trace() as model_tr: + model_result = self.model(**(model_kwargs or dict())) + + # FIXME to make log likelihoods work properly, the log likelihoods need to be masked/not added + # for particular elements. See e.g. MaskedMixture for a non-general example of how to do this (it + # uses torch distributions instead of arbitrary probabilistic programs. + # https://docs.pyro.ai/en/stable/distributions.html?highlight=MaskedMixture#maskedmixture + # FIXME ideally the trace would have elements of the same name as well here. + + # FIXME where isn't shape agnostic. + + # Use masks to select the appropriate result for each sample + result = torch.where(kernel_mask[:, None], kernel_result, model_result) + + return result def functional(self, *args, **kwargs): # TODO update docstring to this being build_functional instead of just functional diff --git a/docs/source/robust_fd/squared_normal_density.py b/docs/source/robust_fd/squared_normal_density.py index a6363c436..fdf606503 100644 --- a/docs/source/robust_fd/squared_normal_density.py +++ b/docs/source/robust_fd/squared_normal_density.py @@ -73,7 +73,5 @@ def __init__(self, *args, **kwargs): def functional(self, nmc=1000): with pyro.plate('samples', nmc): - with pyro.poutine.trace() as tr: - self() - points = tr.trace.nodes['x']['value'] + points = self() return torch.mean(self.density(model_kwargs=dict(x=points), kernel_kwargs=dict(x=points))) diff --git a/docs/source/robust_fd_scratch.py b/docs/source/robust_fd_scratch.py index aaef3b2af..4de9beca7 100644 --- a/docs/source/robust_fd_scratch.py +++ b/docs/source/robust_fd_scratch.py @@ -29,14 +29,13 @@ for x in xx ] - end_quad.model - plt.plot(xx, yy) plt.show() # Sample points from a slightly more entropoic model. -points = dict(x=torch.linspace(-3, 3, 100)) +# FIXME not generalized for ndim > 1 +points = dict(x=torch.linspace(-3, 3, 100)[:, None]) target_quad = fd_influence_fn( model=end_quad, From 31cc9acf0bc9a1118d6f69f68c3fca479d440856 Mon Sep 17 00:00:00 2001 From: Andy Zane Date: Tue, 16 Jan 2024 15:32:14 -0800 Subject: [PATCH 63/66] adds plotting and comparative to analytic. --- chirho/robust/handlers/fd_model.py | 4 +- docs/source/robust_fd_scratch.py | 98 +++++++++++++++++++++++------- 2 files changed, 80 insertions(+), 22 deletions(-) diff --git a/chirho/robust/handlers/fd_model.py b/chirho/robust/handlers/fd_model.py index c2de2ed47..150758487 100644 --- a/chirho/robust/handlers/fd_model.py +++ b/chirho/robust/handlers/fd_model.py @@ -146,6 +146,7 @@ def _influence_fn(*args, **kwargs): # Length of first value in points mappping. len_points = len(list(points.values())[0]) + eif_vals = [] for i in range(len_points): kernel_point = {k: v[i] for k, v in points.items()} @@ -154,7 +155,8 @@ def _influence_fn(*args, **kwargs): with model.set_eps(eps), model.set_lambda(lambda_), model.set_kernel_point(kernel_point): psi_p_eps = model.functional(*args, **kwargs) - return (psi_p_eps - psi_p) / eps + eif_vals.append(-(psi_p_eps - psi_p) / eps) + return eif_vals return _influence_fn diff --git a/docs/source/robust_fd_scratch.py b/docs/source/robust_fd_scratch.py index 4de9beca7..b674bc389 100644 --- a/docs/source/robust_fd_scratch.py +++ b/docs/source/robust_fd_scratch.py @@ -1,68 +1,124 @@ -from robust_fd.squared_normal_density import ExpectedNormalDensityQuad, ExpectedNormalDensityMC +from robust_fd.squared_normal_density import ExpectedNormalDensityQuad, ExpectedNormalDensityMC, _ExpectedNormalDensity from chirho.robust.handlers.fd_model import fd_influence_fn import numpy as np import torch import matplotlib.pyplot as plt +from scipy.stats import multivariate_normal, norm ndim = 1 eps = 0.01 mean = torch.tensor([0.,] * ndim) cov = torch.eye(ndim) -lambda_ = 0.01 +lambda_ = 0.001 end_quad = ExpectedNormalDensityQuad( mean=mean, cov=cov, default_kernel_point=dict(x=torch.tensor([0.,] * ndim)), - default_eps=eps, - default_lambda=lambda_, ) -print(end_quad.functional()) +guess = end_quad.functional() +print(f"Guess: {guess}") -xx = np.linspace(-5, 5, 1000) +if ndim == 1: + xx = np.linspace(-5, 5, 1000) + with (end_quad.set_kernel_point(dict(x=torch.tensor([1., ] * ndim))), + end_quad.set_lambda(lambda_), + end_quad.set_eps(eps)): + yy = [end_quad.density( + {'x': torch.tensor([x])}, + {'x': torch.tensor([x])}) + for x in xx + ] + + plt.plot(xx, yy) -with end_quad.set_kernel_point(dict(x=torch.tensor([1., ] * ndim))), end_quad.set_lambda(.01), end_quad.set_eps(0.1): yy = [end_quad.density( {'x': torch.tensor([x])}, {'x': torch.tensor([x])}) for x in xx ] - -plt.plot(xx, yy) -plt.show() + plt.plot(xx, yy) # Sample points from a slightly more entropoic model. # FIXME not generalized for ndim > 1 -points = dict(x=torch.linspace(-3, 3, 100)[:, None]) +points = dict(x=torch.linspace(-3, 3, 50)[:, None]) + +print(f"Analytic: {((1./(3. - -3.))**2) * (3. - -3.)}") target_quad = fd_influence_fn( model=end_quad, points=points, - eps=0.1, - lambda_=0.1, + eps=eps, + lambda_=lambda_, ) -correction_quad = target_quad() +correction_quad_eif = np.array(target_quad()) -print(correction_quad) +if ndim == 1: + plt.figure() + plt.plot(points['x'].numpy(), correction_quad_eif, label='quad eif') + +correction_quad = np.mean(correction_quad_eif) + +print(f"Correction (Quad): {correction_quad}") end_mc = ExpectedNormalDensityMC( mean=mean, cov=cov, default_kernel_point=dict(x=torch.tensor([0.,] * ndim)), - default_eps=eps, - default_lambda=lambda_, ) target_mc = fd_influence_fn( model=end_mc, points=points, - eps=0.1, - lambda_=0.1, + eps=eps, + lambda_=lambda_, ) -correction_mc = target_mc(nmc=10000000).item() +correction_mc_eif = np.array(target_mc(nmc=4000)) + +if ndim == 1: + plt.plot(points['x'].numpy(), correction_mc_eif, linewidth=0.3, alpha=0.8) + +correction_mc = np.mean(correction_mc_eif) + +print(f"Correction (MC): {correction_mc}") + + +def compute_analytic_eif(model: _ExpectedNormalDensity, points): + funcval = model.functional() + density = model.density(points, points) + + return 2. * (funcval - density) + + +analytic_eif = compute_analytic_eif(end_quad, points).numpy() + +analytic = np.mean(analytic_eif) + +print(f"Analytic: {analytic}") + +print(f"Analytic Corrected: {guess - analytic}") + + +if ndim == 1: + + plt.suptitle(f"ndim={ndim}, eps={eps}, lambda={lambda_}") + + pxsamps = points['x'].numpy().squeeze() + + plt.plot(pxsamps, analytic_eif, label="analytic") + + # Plot the corresponding uniform and normal densities. + plt.plot(points['x'].numpy(), [1./(3. - -3.)] * len(points['x']), color='black', label='uniform') + + # plt.plot(xx, norm.pdf(xx, loc=0, scale=1), color='green', label='normal') + plt.plot(pxsamps, norm.pdf(pxsamps, loc=0, scale=1), color='green', label='normal') + # Plot the correction, just quad. + plt.plot(pxsamps, norm.pdf(pxsamps, loc=0, scale=1) - 0.1 * np.array(correction_quad_eif), + linestyle='--', color='green', label='normal (corrected)') -print(correction_mc) + plt.legend() + plt.show() From f867f2a1759d56cbd23e9c54d9e5bbf1f9e6648f Mon Sep 17 00:00:00 2001 From: Andy Zane Date: Tue, 16 Jan 2024 21:59:00 -0800 Subject: [PATCH 64/66] adds scratch experiment comparing squared density analytic vs fd approx across various epsilon lambdas --- chirho/robust/handlers/fd_model.py | 12 +- .../robust_fd/squared_normal_density.py | 39 +-- docs/source/robust_fd_scratch.py | 4 +- docs/source/robust_fd_scratch_sqd_epslam.py | 277 ++++++++++++++++++ 4 files changed, 310 insertions(+), 22 deletions(-) create mode 100644 docs/source/robust_fd_scratch_sqd_epslam.py diff --git a/chirho/robust/handlers/fd_model.py b/chirho/robust/handlers/fd_model.py index 150758487..2ba9e7343 100644 --- a/chirho/robust/handlers/fd_model.py +++ b/chirho/robust/handlers/fd_model.py @@ -4,9 +4,13 @@ from typing import Dict, Optional from contextlib import contextmanager from chirho.robust.ops import Functional, Point, T +import numpy as np class ModelWithMarginalDensity(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + def density(self, *args, **kwargs): # TODO this can probably default to using BatchedNMCLogMarginalLikelihood applied to self, # but providing here to avail of analytic densities. Or have a constructor that takes a @@ -76,11 +80,13 @@ def kernel(self) -> ModelWithMarginalDensity: """ raise NotImplementedError() - def __init__(self, default_kernel_point: Dict, default_eps=0., default_lambda=0.1): - super().__init__() + def __init__(self, default_kernel_point: Dict, *args, default_eps=0., default_lambda=0.1, **kwargs): + super().__init__(*args, **kwargs) self._eps = default_eps self._lambda = default_lambda self._kernel_point = default_kernel_point + # TODO don't assume .shape[-1] + self.ndims = np.sum([v.shape[-1] for v in self._kernel_point.values()]) @property def mixture_weights(self): @@ -155,7 +161,7 @@ def _influence_fn(*args, **kwargs): with model.set_eps(eps), model.set_lambda(lambda_), model.set_kernel_point(kernel_point): psi_p_eps = model.functional(*args, **kwargs) - eif_vals.append(-(psi_p_eps - psi_p) / eps) + eif_vals.append((psi_p_eps - psi_p) / eps) return eif_vals return _influence_fn diff --git a/docs/source/robust_fd/squared_normal_density.py b/docs/source/robust_fd/squared_normal_density.py index fdf606503..675355fe0 100644 --- a/docs/source/robust_fd/squared_normal_density.py +++ b/docs/source/robust_fd/squared_normal_density.py @@ -6,16 +6,11 @@ from scipy.integrate import nquad import numpy as np -# TODO after putting this together, a mixin model would be more appropriate, as we still -# want explicit coupling between models and functionals but it can be M:M. I.e. mixin the -# functional that could apply to a number of models, and/or mixin the model that could work -# with a number of functionals. +class MultivariateNormalwDensity(ModelWithMarginalDensity): -class FDMultivariateNormal(ModelWithMarginalDensity): - - def __init__(self, mean, cov): - super().__init__() + def __init__(self, mean, cov, *args, **kwargs): + super().__init__(*args, **kwargs) self.mean = mean self.cov = cov @@ -27,27 +22,31 @@ def forward(self): return pyro.sample("x", dist.MultivariateNormal(self.mean, self.cov)) -class _ExpectedNormalDensity(FDModelFunctionalDensity): +class NormalKernel(FDModelFunctionalDensity): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) @property def kernel(self): - try: - mean = self._kernel_point['x'] - except TypeError as e: - raise - return FDMultivariateNormal(mean, torch.eye(self.ndims) * self._lambda) + # TODO agnostic to names. + mean = self._kernel_point['x'] + return MultivariateNormalwDensity(mean, torch.eye(self.ndims) * self._lambda) + + +class PerturbableNormal(FDModelFunctionalDensity): def __init__(self, *args, mean, cov, **kwargs): super().__init__(*args, **kwargs) self.ndims = mean.shape[-1] - self.model = FDMultivariateNormal(mean, cov) + self.model = MultivariateNormalwDensity(mean, cov) self.mean = mean self.cov = cov -class ExpectedNormalDensityQuad(_ExpectedNormalDensity): +class ExpectedDensityQuadFunctional(FDModelFunctionalDensity): """ Compute the squared normal density using quadrature. """ @@ -57,13 +56,16 @@ def __init__(self, *args, **kwargs): def functional(self): def integrand(*args): + # TODO agnostic to kwarg names. model_kwargs = kernel_kwargs = dict(x=np.array(args)) return self.density(model_kwargs, kernel_kwargs) ** 2 - return nquad(integrand, [[-np.inf, np.inf]] * self.mean.shape[-1])[0] + ndim = self._kernel_point['x'].shape[-1] + + return nquad(integrand, [[-np.inf, np.inf]] * ndim)[0] -class ExpectedNormalDensityMC(_ExpectedNormalDensity): +class ExpectedDensityMCFunctional(FDModelFunctionalDensity): """ Compute the squared normal density using Monte Carlo. """ @@ -72,6 +74,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def functional(self, nmc=1000): + # TODO agnostic to kwarg names with pyro.plate('samples', nmc): points = self() return torch.mean(self.density(model_kwargs=dict(x=points), kernel_kwargs=dict(x=points))) diff --git a/docs/source/robust_fd_scratch.py b/docs/source/robust_fd_scratch.py index b674bc389..8bf257581 100644 --- a/docs/source/robust_fd_scratch.py +++ b/docs/source/robust_fd_scratch.py @@ -1,3 +1,5 @@ +raise NotImplementedError() + from robust_fd.squared_normal_density import ExpectedNormalDensityQuad, ExpectedNormalDensityMC, _ExpectedNormalDensity from chirho.robust.handlers.fd_model import fd_influence_fn import numpy as np @@ -91,7 +93,7 @@ def compute_analytic_eif(model: _ExpectedNormalDensity, points): funcval = model.functional() density = model.density(points, points) - return 2. * (funcval - density) + return 2. * (density - funcval) analytic_eif = compute_analytic_eif(end_quad, points).numpy() diff --git a/docs/source/robust_fd_scratch_sqd_epslam.py b/docs/source/robust_fd_scratch_sqd_epslam.py new file mode 100644 index 000000000..d81b55c80 --- /dev/null +++ b/docs/source/robust_fd_scratch_sqd_epslam.py @@ -0,0 +1,277 @@ +from robust_fd.squared_normal_density import ( + NormalKernel, + PerturbableNormal, + ExpectedDensityQuadFunctional, + ExpectedDensityMCFunctional +) +from chirho.robust.handlers.fd_model import ( + fd_influence_fn, + ModelWithMarginalDensity, + FDModelFunctionalDensity +) +import numpy as np +import torch +import matplotlib.pyplot as plt +from scipy.stats import multivariate_normal as mvn, norm +from itertools import product +from typing import List, Dict, Tuple, Optional +from scipy.stats._multivariate import _squeeze_output + + +EPS = [0.1, 0.01, 0.001] +LAMBDA = [0.1, 0.01, 0.001] +NDIM = [1, 2] +NDATASETS = 15 +NGUESS = 100 +NEIF = 20 + + +def analytic_eif(model: FDModelFunctionalDensity, points, funcval=None): + if funcval is None: + funcval = model.functional() + density = model.density(points, points) + + return 2. * (density - funcval) + + +class MultivariateSkewnormFDModel(ModelWithMarginalDensity): + + # From https://gregorygundersen.com/blog/2020/12/29/multivariate-skew-normal/: + + def __init__(self, shape, cov, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dim = len(shape) + self.shape = np.asarray(shape) + self.mean = np.zeros(self.dim) + self.cov = np.eye(self.dim) if cov is None else np.asarray(cov) + + def pdf(self, x): + return np.exp(self.logpdf(x)) + + def density(self, x): + return self.pdf(x) + + def logpdf(self, x): + x = mvn._process_quantiles(x, self.dim) + pdf = mvn(self.mean, self.cov).logpdf(x) + cdf = norm(0, 1).logcdf(np.dot(x, self.shape)) + return _squeeze_output(np.log(2) + pdf + cdf) + + def rvs_fast(self, size=1): + aCa = self.shape @ self.cov @ self.shape + delta = (1 / np.sqrt(1 + aCa)) * self.cov @ self.shape + cov_star = np.block([[np.ones(1), delta], + [delta[:, None], self.cov]]) + x = mvn(np.zeros(self.dim + 1), cov_star).rvs(size) + x0, x1 = x[:, 0], x[:, 1:] + inds = x0 <= 0 + x1[inds] = -1 * x1[inds] + return x1 + + def forward(self, *args, **kwargs): + # TODO whatever the pyro version of this is? If there is one just get rid of this class. + raise NotImplementedError() + + +class PerturbableSkewNormal(FDModelFunctionalDensity): + def __init__(self, shape, cov, *args, **kwargs): + default_kernel_point = dict(x=np.zeros(len(shape))) + super().__init__(*args, default_kernel_point=default_kernel_point, **kwargs) + + self.model = MultivariateSkewnormFDModel(shape, cov) + + self.shape = shape + self.cov = cov + + +class ExpectedSkewNormalDensityQuadFunctional( + NormalKernel, + PerturbableSkewNormal, + ExpectedDensityQuadFunctional, +): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class ExpectedNormalDensityQuadFunctional( + NormalKernel, + PerturbableNormal, + ExpectedDensityQuadFunctional, +): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class ExpectedNormalDensityMCFunctional( + NormalKernel, + PerturbableNormal, + ExpectedDensityMCFunctional, +): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +def skew_unit_norm_dataset_generator(ndim: int, num_datasets: int, nguess: int, neif: int): + + for _ in range(num_datasets): + cov = np.eye(ndim) + shape = np.random.normal(size=ndim, scale=3.) + + datadist = ExpectedSkewNormalDensityQuadFunctional(shape, cov) + yield datadist.model.rvs_fast(nguess + neif), datadist + + +def main2d(): + for dataset, oracle in skew_unit_norm_dataset_generator(2, NDATASETS, NGUESS): + + print(f"Oracle: {oracle}") + + plt.figure() + plt.scatter(dataset[:, 0], dataset[:, 1], alpha=0.5, s=0.2) + plt.show() + + +def main(ndim=1, plot_densities=True, plot_corrections=True, plot_fd_ana_diag=True): + + fd_cors_quad = dict() # type: Dict[Tuple[float, float], List[float]] + fd_cors_mc = dict() # type: Dict[Tuple[float, float], List[float]] + ana_cors_quad = list() + ana_cors_mc = list() + + for dataset, datadist in skew_unit_norm_dataset_generator(ndim, NDATASETS, NGUESS, NEIF): + + oracle_fval = datadist.functional() + print(f"Oracle: {oracle_fval}") + + if ndim == 1 and plot_densities: + # Plot a density estimate. + f1 = plt.figure() + plt.hist(dataset[NGUESS:], bins=100, density=True, alpha=0.5) + xx = np.linspace(dataset.min(), dataset.max(), 100).reshape(-1, ndim) + yy = datadist.density(dict(x=xx), dict(x=xx)) + plt.plot(xx, yy) + + # Train a model on the first nguess points in the dataset. This is a misspecified + # model in that it is a normal with no skew. + mean = torch.tensor(np.atleast_1d(np.mean(dataset[:NGUESS], axis=0))).float() + cov = torch.tensor(np.atleast_2d(np.cov(dataset[:NGUESS], rowvar=False))).float() + + # Compute the functionals of the unperturbed models. + fd_quad = ExpectedNormalDensityQuadFunctional( + default_kernel_point=dict(x=torch.zeros(len(mean))), + mean=mean, + cov=cov + ) + quad_guess = fd_quad.functional() + print(f"Quad: {quad_guess}") + fd_mc = ExpectedNormalDensityMCFunctional( + default_kernel_point=dict(x=torch.zeros(len(mean))), + mean=mean, + cov=cov + ) + mc_guess = fd_mc.functional() + print(f"MC: {mc_guess}") + + # Compute the analytic eif on the samples after the first nguess. + correction_points = dict(x=torch.tensor(dataset[NGUESS:]).float()) + + if ndim == 1 and plot_densities: + # Plot the influence function across the linspace in the same figure. + plt.plot(xx, analytic_eif(fd_quad, points=dict(x=xx), funcval=quad_guess), color='blue') + plt.plot(xx, analytic_eif(fd_mc, points=dict(x=xx), funcval=mc_guess), color='red') + + # Quick check that the two have the same density. + assert np.allclose( + fd_quad.density(correction_points, correction_points), + fd_mc.density(correction_points, correction_points) + ) + + # And compute the analytic corrections. + ana_eif_quad = analytic_eif(fd_quad, correction_points, funcval=quad_guess) + ana_cor_quad = ana_eif_quad.mean() + ana_eif_mc = analytic_eif(fd_mc, correction_points, funcval=mc_guess) + ana_cor_mc = ana_eif_mc.mean() + + print(f"Quad (Ana Correction): {ana_cor_quad}") + print(f"MC (Ana Correction): {ana_cor_mc}") + + print(f"Oracle: {oracle_fval}") + print(f"Quad (Ana Corrected): {quad_guess + ana_cor_quad}") + print(f"MC (Ana Corrected): {mc_guess + ana_cor_mc}") + + if plot_corrections: + f2 = plt.figure() + # Plot the oracle value. + plt.axhline(oracle_fval, color='black', label='Oracle', linestyle='--') + # Plot lines from guesses to corrected values. + plt.plot([0, 1], [quad_guess, quad_guess + ana_cor_quad], color='blue', label='Quad') + plt.plot([0, 1], [mc_guess, mc_guess + ana_cor_mc], color='red', label='MC') + plt.legend() + + for eps, lambda_ in product(EPS, LAMBDA): + fd_eif_quad = fd_influence_fn( + model=fd_quad, + points=correction_points, + eps=eps, + lambda_=lambda_)() + fd_cor_quad = np.mean(fd_eif_quad) + fd_eif_mc = fd_influence_fn( + model=fd_mc, + points=correction_points, + eps=eps, + # Scale the nmc with epsilon so that the kernel gets seen. + lambda_=lambda_)(nmc=(1. / eps) * 10) + fd_cor_mc = np.mean(fd_eif_mc) + + fd_cors_quad[(eps, lambda_)] = fd_cors_quad.get((eps, lambda_), []) + [fd_cor_quad] + fd_cors_mc[(eps, lambda_)] = fd_cors_mc.get((eps, lambda_), []) + [fd_cor_mc] + + ana_cors_quad.append(ana_cor_quad) + ana_cors_mc.append(ana_cor_mc) + + print() + + plt.show() + if ndim == 1 and plot_densities: + plt.close(f1) + if plot_corrections: + plt.close(f2) + + def plot_diag(ax, x1, x2s, x1lab, x2labs, title): + ax.set_title(title) + for x2, lab in zip(x2s, x2labs): + ax.scatter(x1, x2, alpha=0.5, label=lab) + ax.set_xlabel(x1lab) + ax.set_ylabel("FD Correction") + # Draw the diagonal in figure coordinates. + xmin, xmax = ax.get_xlim() + ymin, ymax = ax.get_ylim() + # Make the axes the same (min/max of each xy) + xymin = min(xmin, ymin) + xymax = max(xmax, ymax) + ax.set_xlim(xymin, xymax) + ax.set_ylim(xymin, xymax) + ax.plot([xymin, xymax], [xymin, xymax], color='black', linestyle='--') + + # Plot the finite difference diagonals. + if plot_fd_ana_diag: + # Prep gridplot over eps and lambda. + f, axes = plt.subplots(len(EPS), len(LAMBDA), figsize=(30, 30)) + for (eps, lambda_), ax in zip(product(EPS, LAMBDA), axes.flatten()): + plot_diag( + ax=ax, + x1=ana_cors_quad, + x2s=[fd_cors_quad[(eps, lambda_)], fd_cors_mc[(eps, lambda_)]], + x1lab='Analytic Correction', + x2labs=['Quad', 'MC'], + title=f"Quad (ndim={ndim}, eps={eps}, lambda={lambda_})" + ) + plt.tight_layout() + plt.show() + plt.close(f) + + return + + +if __name__ == '__main__': + main(plot_densities=False, plot_corrections=False, plot_fd_ana_diag=True) From 7f106675177c24f09cdaf67127998633fdb91d75 Mon Sep 17 00:00:00 2001 From: Andy Zane Date: Wed, 17 Jan 2024 10:27:21 -0800 Subject: [PATCH 65/66] fixes dataset splitting, breaks analytic eif --- docs/source/robust_fd_scratch_sqd_epslam.py | 61 ++++++++++++--------- 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/docs/source/robust_fd_scratch_sqd_epslam.py b/docs/source/robust_fd_scratch_sqd_epslam.py index d81b55c80..977022859 100644 --- a/docs/source/robust_fd_scratch_sqd_epslam.py +++ b/docs/source/robust_fd_scratch_sqd_epslam.py @@ -18,12 +18,14 @@ from scipy.stats._multivariate import _squeeze_output -EPS = [0.1, 0.01, 0.001] -LAMBDA = [0.1, 0.01, 0.001] +# EPS = [0.01, 0.001] +# LAMBDA = [0.01, 0.001] +EPS = [0.01] +LAMBDA = [0.01] NDIM = [1, 2] -NDATASETS = 15 -NGUESS = 100 -NEIF = 20 +NDATASETS = 3 +NGUESS = 50 +NEIF = 50 def analytic_eif(model: FDModelFunctionalDensity, points, funcval=None): @@ -118,17 +120,21 @@ def skew_unit_norm_dataset_generator(ndim: int, num_datasets: int, nguess: int, shape = np.random.normal(size=ndim, scale=3.) datadist = ExpectedSkewNormalDensityQuadFunctional(shape, cov) - yield datadist.model.rvs_fast(nguess + neif), datadist + dataset = datadist.model.rvs_fast(nguess + neif) + guess_dataset, correction_dataset = dataset[:nguess], dataset[nguess:] + assert len(guess_dataset) == nguess + assert len(correction_dataset) == neif + yield guess_dataset, correction_dataset, datadist -def main2d(): - for dataset, oracle in skew_unit_norm_dataset_generator(2, NDATASETS, NGUESS): - - print(f"Oracle: {oracle}") - - plt.figure() - plt.scatter(dataset[:, 0], dataset[:, 1], alpha=0.5, s=0.2) - plt.show() +# def main2d(): +# for guess_dataset, correction_dataset, oracle in skew_unit_norm_dataset_generator(2, NDATASETS, NGUESS): +# +# print(f"Oracle: {oracle}") +# +# plt.figure() +# plt.scatter(dataset[:, 0], dataset[:, 1], alpha=0.5, s=0.2) +# plt.show() def main(ndim=1, plot_densities=True, plot_corrections=True, plot_fd_ana_diag=True): @@ -138,7 +144,7 @@ def main(ndim=1, plot_densities=True, plot_corrections=True, plot_fd_ana_diag=Tr ana_cors_quad = list() ana_cors_mc = list() - for dataset, datadist in skew_unit_norm_dataset_generator(ndim, NDATASETS, NGUESS, NEIF): + for guess_dataset, correction_dataset, datadist in skew_unit_norm_dataset_generator(ndim, NDATASETS, NGUESS, NEIF): oracle_fval = datadist.functional() print(f"Oracle: {oracle_fval}") @@ -146,15 +152,16 @@ def main(ndim=1, plot_densities=True, plot_corrections=True, plot_fd_ana_diag=Tr if ndim == 1 and plot_densities: # Plot a density estimate. f1 = plt.figure() - plt.hist(dataset[NGUESS:], bins=100, density=True, alpha=0.5) - xx = np.linspace(dataset.min(), dataset.max(), 100).reshape(-1, ndim) + # Using correction dataset to see where things fall on the influence function curve. + plt.hist(correction_dataset, bins=100, density=True, alpha=0.5) + xx = np.linspace(correction_dataset.min(), correction_dataset.max(), 100).reshape(-1, ndim) yy = datadist.density(dict(x=xx), dict(x=xx)) plt.plot(xx, yy) # Train a model on the first nguess points in the dataset. This is a misspecified # model in that it is a normal with no skew. - mean = torch.tensor(np.atleast_1d(np.mean(dataset[:NGUESS], axis=0))).float() - cov = torch.tensor(np.atleast_2d(np.cov(dataset[:NGUESS], rowvar=False))).float() + mean = torch.tensor(np.atleast_1d(np.mean(guess_dataset, axis=0))).float() + cov = torch.tensor(np.atleast_2d(np.cov(guess_dataset, rowvar=False))).float() # Compute the functionals of the unperturbed models. fd_quad = ExpectedNormalDensityQuadFunctional( @@ -173,12 +180,12 @@ def main(ndim=1, plot_densities=True, plot_corrections=True, plot_fd_ana_diag=Tr print(f"MC: {mc_guess}") # Compute the analytic eif on the samples after the first nguess. - correction_points = dict(x=torch.tensor(dataset[NGUESS:]).float()) + correction_points = dict(x=torch.tensor(correction_dataset).float()) if ndim == 1 and plot_densities: # Plot the influence function across the linspace in the same figure. - plt.plot(xx, analytic_eif(fd_quad, points=dict(x=xx), funcval=quad_guess), color='blue') - plt.plot(xx, analytic_eif(fd_mc, points=dict(x=xx), funcval=mc_guess), color='red') + plt.plot(xx, analytic_eif(fd_quad, points=dict(x=xx), funcval=oracle_fval), color='blue') + plt.plot(xx, analytic_eif(fd_mc, points=dict(x=xx), funcval=oracle_fval), color='red') # Quick check that the two have the same density. assert np.allclose( @@ -187,9 +194,9 @@ def main(ndim=1, plot_densities=True, plot_corrections=True, plot_fd_ana_diag=Tr ) # And compute the analytic corrections. - ana_eif_quad = analytic_eif(fd_quad, correction_points, funcval=quad_guess) + ana_eif_quad = analytic_eif(fd_quad, correction_points, funcval=oracle_fval) ana_cor_quad = ana_eif_quad.mean() - ana_eif_mc = analytic_eif(fd_mc, correction_points, funcval=mc_guess) + ana_eif_mc = analytic_eif(fd_mc, correction_points, funcval=oracle_fval) ana_cor_mc = ana_eif_mc.mean() print(f"Quad (Ana Correction): {ana_cor_quad}") @@ -220,7 +227,7 @@ def main(ndim=1, plot_densities=True, plot_corrections=True, plot_fd_ana_diag=Tr points=correction_points, eps=eps, # Scale the nmc with epsilon so that the kernel gets seen. - lambda_=lambda_)(nmc=(1. / eps) * 10) + lambda_=lambda_)(nmc=(1. / eps) * 100) fd_cor_mc = np.mean(fd_eif_mc) fd_cors_quad[(eps, lambda_)] = fd_cors_quad.get((eps, lambda_), []) + [fd_cor_quad] @@ -256,7 +263,7 @@ def plot_diag(ax, x1, x2s, x1lab, x2labs, title): # Plot the finite difference diagonals. if plot_fd_ana_diag: # Prep gridplot over eps and lambda. - f, axes = plt.subplots(len(EPS), len(LAMBDA), figsize=(30, 30)) + f, axes = plt.subplots(len(EPS), len(LAMBDA), figsize=(30, 20)) for (eps, lambda_), ax in zip(product(EPS, LAMBDA), axes.flatten()): plot_diag( ax=ax, @@ -274,4 +281,4 @@ def plot_diag(ax, x1, x2s, x1lab, x2labs, title): if __name__ == '__main__': - main(plot_densities=False, plot_corrections=False, plot_fd_ana_diag=True) + main(plot_densities=True, plot_corrections=True, plot_fd_ana_diag=True) From 094562ac3b99cdd0700a9604b8122ec98a2a1892 Mon Sep 17 00:00:00 2001 From: Andy Zane Date: Wed, 17 Jan 2024 10:49:23 -0800 Subject: [PATCH 66/66] unfixes an incorrect fix, working now. --- docs/source/robust_fd_scratch_sqd_epslam.py | 22 ++++++++++----------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/docs/source/robust_fd_scratch_sqd_epslam.py b/docs/source/robust_fd_scratch_sqd_epslam.py index 977022859..8aace8eb1 100644 --- a/docs/source/robust_fd_scratch_sqd_epslam.py +++ b/docs/source/robust_fd_scratch_sqd_epslam.py @@ -18,14 +18,12 @@ from scipy.stats._multivariate import _squeeze_output -# EPS = [0.01, 0.001] -# LAMBDA = [0.01, 0.001] -EPS = [0.01] -LAMBDA = [0.01] +EPS = [0.01, 0.001] +LAMBDA = [0.01, 0.001] NDIM = [1, 2] -NDATASETS = 3 +NDATASETS = 30 NGUESS = 50 -NEIF = 50 +NEIF = 25 def analytic_eif(model: FDModelFunctionalDensity, points, funcval=None): @@ -184,8 +182,8 @@ def main(ndim=1, plot_densities=True, plot_corrections=True, plot_fd_ana_diag=Tr if ndim == 1 and plot_densities: # Plot the influence function across the linspace in the same figure. - plt.plot(xx, analytic_eif(fd_quad, points=dict(x=xx), funcval=oracle_fval), color='blue') - plt.plot(xx, analytic_eif(fd_mc, points=dict(x=xx), funcval=oracle_fval), color='red') + plt.plot(xx, analytic_eif(fd_quad, points=dict(x=xx), funcval=quad_guess), color='blue') + plt.plot(xx, analytic_eif(fd_mc, points=dict(x=xx), funcval=mc_guess), color='red') # Quick check that the two have the same density. assert np.allclose( @@ -194,9 +192,9 @@ def main(ndim=1, plot_densities=True, plot_corrections=True, plot_fd_ana_diag=Tr ) # And compute the analytic corrections. - ana_eif_quad = analytic_eif(fd_quad, correction_points, funcval=oracle_fval) + ana_eif_quad = analytic_eif(fd_quad, correction_points, funcval=quad_guess) ana_cor_quad = ana_eif_quad.mean() - ana_eif_mc = analytic_eif(fd_mc, correction_points, funcval=oracle_fval) + ana_eif_mc = analytic_eif(fd_mc, correction_points, funcval=mc_guess) ana_cor_mc = ana_eif_mc.mean() print(f"Quad (Ana Correction): {ana_cor_quad}") @@ -227,7 +225,7 @@ def main(ndim=1, plot_densities=True, plot_corrections=True, plot_fd_ana_diag=Tr points=correction_points, eps=eps, # Scale the nmc with epsilon so that the kernel gets seen. - lambda_=lambda_)(nmc=(1. / eps) * 100) + lambda_=lambda_)(nmc=(1. / eps) * 1000) fd_cor_mc = np.mean(fd_eif_mc) fd_cors_quad[(eps, lambda_)] = fd_cors_quad.get((eps, lambda_), []) + [fd_cor_quad] @@ -281,4 +279,4 @@ def plot_diag(ax, x1, x2s, x1lab, x2labs, title): if __name__ == '__main__': - main(plot_densities=True, plot_corrections=True, plot_fd_ana_diag=True) + main(plot_densities=False, plot_corrections=False, plot_fd_ana_diag=True)