From 9a9264260931fbfdf521c974ad909c2df3e4a51d Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 10 Jan 2024 14:48:35 -0500 Subject: [PATCH 1/6] make influence a functional --- chirho/robust/handlers/estimators.py | 21 +++----- chirho/robust/ops.py | 74 +++++++++++++++------------- tests/robust/test_handlers.py | 18 +++---- tests/robust/test_ops.py | 36 +++++++------- 4 files changed, 72 insertions(+), 77 deletions(-) diff --git a/chirho/robust/handlers/estimators.py b/chirho/robust/handlers/estimators.py index 16e8ab227..ef8e944d2 100644 --- a/chirho/robust/handlers/estimators.py +++ b/chirho/robust/handlers/estimators.py @@ -1,6 +1,6 @@ -from typing import Any, Callable, TypeVar +from typing import TypeVar -from typing_extensions import Concatenate, ParamSpec +from typing_extensions import ParamSpec from chirho.robust.ops import Functional, Point, influence_fn @@ -10,21 +10,17 @@ def one_step_correction( - model: Callable[P, Any], functional: Functional[P, S], + test_data: Point[T], **influence_kwargs, -) -> Callable[Concatenate[Point[T], P], S]: +) -> Functional[P, S]: """ Returns a function that computes the one-step correction for the functional at a specified set of test points as discussed in [1]. - :param model: Python callable containing Pyro primitives. - :type model: Callable[P, Any] :param functional: model summary of interest, which is a function of the model. - :type functional: Functional[P, S] - :return: function to compute the one-step correction - :rtype: Callable[Concatenate[Point[T], P], S] + :return: functional to compute the one-step correction **References** @@ -33,9 +29,4 @@ def one_step_correction( """ influence_kwargs_one_step = influence_kwargs.copy() influence_kwargs_one_step["pointwise_influence"] = False - eif_fn = influence_fn(model, functional, **influence_kwargs_one_step) - - def _one_step(test_data: Point[T], *args, **kwargs) -> S: - return eif_fn(test_data, *args, **kwargs) - - return _one_step + return influence_fn(functional, test_data, **influence_kwargs_one_step) diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py index b02bfa47e..f0e8c1422 100644 --- a/chirho/robust/ops.py +++ b/chirho/robust/ops.py @@ -1,8 +1,7 @@ -import functools from typing import Any, Callable, Mapping, TypeVar import torch -from typing_extensions import Concatenate, ParamSpec +from typing_extensions import ParamSpec from chirho.observational.ops import Observation @@ -16,16 +15,16 @@ def influence_fn( - model: Callable[P, Any], functional: Functional[P, S], **linearize_kwargs -) -> Callable[Concatenate[Point[T], P], S]: + functional: Functional[P, S], points: Point[T], **linearize_kwargs +) -> Functional[P, S]: """ Returns the efficient influence function for ``functional`` with respect to the parameters of probabilistic program ``model``. - :param model: Python callable containing Pyro primitives. - :type model: Callable[P, Any] :param functional: model summary of interest, which is a function of ``model`` :type functional: Functional[P, S] + :param points: points at which to compute the efficient influence function + :type points: Point[T] :return: the efficient influence function for ``functional`` :rtype: Callable[Concatenate[Point[T], P], S] @@ -88,14 +87,13 @@ def forward(self): ) points = predictive() influence = influence_fn( - model, - guide, SimpleFunctional, + points, num_samples_outer=1000, num_samples_inner=1000, - ) + )(PredictiveModel(model, guide)) - influence(points) + influence() .. note:: @@ -111,31 +109,37 @@ def forward(self): from chirho.robust.internals.linearize import linearize from chirho.robust.internals.utils import make_functional_call - linearized = linearize(model, **linearize_kwargs) - target = functional(model) - - # TODO check that target_params == model_params - assert isinstance(target, torch.nn.Module) - target_params, func_target = make_functional_call(target) - - @functools.wraps(target) - def _fn(points: Point[T], *args: P.args, **kwargs: P.kwargs) -> S: + def _influence_functional(model: Callable[P, Any]) -> Callable[P, S]: """ - Evaluates the efficient influence function for ``functional`` at each - point in ``points``. + Functional representing the efficient influence function of ``functional`` at ``points`` . - :param points: points at which to compute the efficient influence function - :type points: Point[T] - :return: efficient influence function evaluated at each point in ``points`` or averaged - :rtype: S + :param model: Python callable containing Pyro primitives. + :return: efficient influence function for ``functional`` evaluated at ``model`` and ``points`` """ - param_eif = linearized(points, *args, **kwargs) - return torch.vmap( - lambda d: torch.func.jvp( - lambda p: func_target(p, *args, **kwargs), (target_params,), (d,) - )[1], - in_dims=0, - randomness="different", - )(param_eif) - - return _fn + linearized = linearize(model, **linearize_kwargs) + target = functional(model) + + # TODO check that target_params == model_params + assert isinstance(target, torch.nn.Module) + target_params, func_target = make_functional_call(target) + + def _fn(*args: P.args, **kwargs: P.kwargs) -> S: + """ + Evaluates the efficient influence function for ``functional`` at each + point in ``points``. + + :return: efficient influence function evaluated at each point in ``points`` or averaged + :rtype: S + """ + param_eif = linearized(points, *args, **kwargs) + return torch.vmap( + lambda d: torch.func.jvp( + lambda p: func_target(p, *args, **kwargs), (target_params,), (d,) + )[1], + in_dims=0, + randomness="different", + )(param_eif) + + return _fn + + return _influence_functional diff --git a/tests/robust/test_handlers.py b/tests/robust/test_handlers.py index f1849959a..a947ceb42 100644 --- a/tests/robust/test_handlers.py +++ b/tests/robust/test_handlers.py @@ -56,15 +56,6 @@ def test_one_step_correction_smoke( guide = guide(model) model(), guide() # initialize - one_step = one_step_correction( - PredictiveModel(model, guide), - functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), - max_plate_nesting=max_plate_nesting, - num_samples_outer=num_samples_outer, - num_samples_inner=num_samples_inner, - cg_iters=cg_iters, - ) - with torch.no_grad(): test_datum = { k: v[0] @@ -73,6 +64,15 @@ def test_one_step_correction_smoke( )().items() } + one_step = one_step_correction( + functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), + test_datum, + max_plate_nesting=max_plate_nesting, + num_samples_outer=num_samples_outer, + num_samples_inner=num_samples_inner, + cg_iters=cg_iters, + )(PredictiveModel(model, guide)) + one_step_on_test: Mapping[str, torch.Tensor] = one_step(test_datum) assert len(one_step_on_test) > 0 for k, v in one_step_on_test.items(): diff --git a/tests/robust/test_ops.py b/tests/robust/test_ops.py index 1bdb2461b..e3d5e5290 100644 --- a/tests/robust/test_ops.py +++ b/tests/robust/test_ops.py @@ -56,15 +56,6 @@ def test_nmc_predictive_influence_smoke( guide = guide(model) model(), guide() # initialize - predictive_eif = influence_fn( - PredictiveModel(model, guide), - functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), - max_plate_nesting=max_plate_nesting, - num_samples_outer=num_samples_outer, - num_samples_inner=num_samples_inner, - cg_iters=cg_iters, - ) - with torch.no_grad(): test_datum = { k: v[0] @@ -73,7 +64,16 @@ def test_nmc_predictive_influence_smoke( )().items() } - test_datum_eif: Mapping[str, torch.Tensor] = predictive_eif(test_datum) + predictive_eif = influence_fn( + functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), + test_datum, + max_plate_nesting=max_plate_nesting, + num_samples_outer=num_samples_outer, + num_samples_inner=num_samples_inner, + cg_iters=cg_iters, + )(PredictiveModel(model, guide)) + + test_datum_eif: Mapping[str, torch.Tensor] = predictive_eif() assert len(test_datum_eif) > 0 for k, v in test_datum_eif.items(): assert not torch.isnan(v).any(), f"eif for {k} had nans" @@ -100,21 +100,21 @@ def test_nmc_predictive_influence_vmap_smoke( model(), guide() # initialize + with torch.no_grad(): + test_data = pyro.infer.Predictive( + model, num_samples=4, return_sites=obs_names, parallel=True + )() + predictive_eif = influence_fn( - PredictiveModel(model, guide), functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), + test_data, max_plate_nesting=max_plate_nesting, num_samples_outer=num_samples_outer, num_samples_inner=num_samples_inner, cg_iters=cg_iters, - ) - - with torch.no_grad(): - test_data = pyro.infer.Predictive( - model, num_samples=4, return_sites=obs_names, parallel=True - )() + )(PredictiveModel(model, guide)) - test_data_eif: Mapping[str, torch.Tensor] = predictive_eif(test_data) + test_data_eif: Mapping[str, torch.Tensor] = predictive_eif() assert len(test_data_eif) > 0 for k, v in test_data_eif.items(): assert not torch.isnan(v).any(), f"eif for {k} had nans" From e57d06a382354e36c6cf1b0eef92ab8d979aa8c5 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 10 Jan 2024 14:51:31 -0500 Subject: [PATCH 2/6] fix test --- tests/robust/test_handlers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/robust/test_handlers.py b/tests/robust/test_handlers.py index a947ceb42..6168d563a 100644 --- a/tests/robust/test_handlers.py +++ b/tests/robust/test_handlers.py @@ -73,7 +73,7 @@ def test_one_step_correction_smoke( cg_iters=cg_iters, )(PredictiveModel(model, guide)) - one_step_on_test: Mapping[str, torch.Tensor] = one_step(test_datum) + one_step_on_test: Mapping[str, torch.Tensor] = one_step() assert len(one_step_on_test) > 0 for k, v in one_step_on_test.items(): assert not torch.isnan(v).any(), f"one_step for {k} had nans" From 1673d6854bf4052de18a9d1df3afa1b039136524 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 10 Jan 2024 15:00:06 -0500 Subject: [PATCH 3/6] multiple arguments --- chirho/robust/internals/linearize.py | 8 ++++++-- chirho/robust/ops.py | 28 ++++++++++++++++++++-------- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/chirho/robust/internals/linearize.py b/chirho/robust/internals/linearize.py index 27ce8da39..cdfa04989 100644 --- a/chirho/robust/internals/linearize.py +++ b/chirho/robust/internals/linearize.py @@ -219,8 +219,7 @@ def jvp_fn(log_prob_params: ParamDict) -> torch.Tensor: def linearize( - model: Callable[P, Any], - *, + *models: Callable[P, Any], num_samples_outer: int, num_samples_inner: Optional[int] = None, max_plate_nesting: Optional[int] = None, @@ -328,6 +327,11 @@ def forward(self): This issue will be addressed in a future release: https://github.com/BasisResearch/chirho/issues/393. """ + if len(models) > 1: + raise NotImplementedError("Batched linearization not yet implemented") + else: + (model,) = models + assert isinstance(model, torch.nn.Module) if num_samples_inner is None: num_samples_inner = num_samples_outer**2 diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py index f0e8c1422..0340664c8 100644 --- a/chirho/robust/ops.py +++ b/chirho/robust/ops.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Mapping, TypeVar +from typing import Any, Callable, Mapping, Protocol, TypeVar import torch from typing_extensions import ParamSpec @@ -7,15 +7,19 @@ P = ParamSpec("P") Q = ParamSpec("Q") -S = TypeVar("S") +S = TypeVar("S", covariant=True) T = TypeVar("T") Point = Mapping[str, Observation[T]] -Functional = Callable[[Callable[P, Any]], Callable[P, S]] + + +class Functional(Protocol[P, S]): + def __call__(self, *models: Callable[P, Any]) -> Callable[P, S]: + ... def influence_fn( - functional: Functional[P, S], points: Point[T], **linearize_kwargs + functional: Functional[P, S], *points: Point[T], **linearize_kwargs ) -> Functional[P, S]: """ Returns the efficient influence function for ``functional`` @@ -109,15 +113,23 @@ def forward(self): from chirho.robust.internals.linearize import linearize from chirho.robust.internals.utils import make_functional_call - def _influence_functional(model: Callable[P, Any]) -> Callable[P, S]: + if len(points) != 1: + raise NotImplementedError( + "influence_fn currently only supports unary functionals" + ) + + def _influence_functional(*models: Callable[P, Any]) -> Callable[P, S]: """ Functional representing the efficient influence function of ``functional`` at ``points`` . :param model: Python callable containing Pyro primitives. :return: efficient influence function for ``functional`` evaluated at ``model`` and ``points`` """ - linearized = linearize(model, **linearize_kwargs) - target = functional(model) + if len(models) != len(points): + raise ValueError("mismatch between number of models and points") + + linearized = linearize(*models, **linearize_kwargs) + target = functional(*models) # TODO check that target_params == model_params assert isinstance(target, torch.nn.Module) @@ -131,7 +143,7 @@ def _fn(*args: P.args, **kwargs: P.kwargs) -> S: :return: efficient influence function evaluated at each point in ``points`` or averaged :rtype: S """ - param_eif = linearized(points, *args, **kwargs) + param_eif = linearized(*points, *args, **kwargs) return torch.vmap( lambda d: torch.func.jvp( lambda p: func_target(p, *args, **kwargs), (target_params,), (d,) From 2aeef065e365ce3a0d3a168529050caae3517a2b Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 10 Jan 2024 15:00:57 -0500 Subject: [PATCH 4/6] doc --- chirho/robust/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py index 0340664c8..03ebabfe5 100644 --- a/chirho/robust/ops.py +++ b/chirho/robust/ops.py @@ -122,7 +122,7 @@ def _influence_functional(*models: Callable[P, Any]) -> Callable[P, S]: """ Functional representing the efficient influence function of ``functional`` at ``points`` . - :param model: Python callable containing Pyro primitives. + :param models: Python callables containing Pyro primitives. :return: efficient influence function for ``functional`` evaluated at ``model`` and ``points`` """ if len(models) != len(points): From b62d828bd87a7dab3c53ed6abb7b3a077e3715ab Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 11 Jan 2024 09:51:01 -0500 Subject: [PATCH 5/6] docstring --- chirho/robust/internals/linearize.py | 2 +- chirho/robust/ops.py | 16 +++++++--------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/chirho/robust/internals/linearize.py b/chirho/robust/internals/linearize.py index cdfa04989..29447c736 100644 --- a/chirho/robust/internals/linearize.py +++ b/chirho/robust/internals/linearize.py @@ -328,7 +328,7 @@ def forward(self): https://github.com/BasisResearch/chirho/issues/393. """ if len(models) > 1: - raise NotImplementedError("Batched linearization not yet implemented") + raise NotImplementedError("Only unary version of linearize is implemented.") else: (model,) = models diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py index 03ebabfe5..86ed8f89d 100644 --- a/chirho/robust/ops.py +++ b/chirho/robust/ops.py @@ -14,7 +14,9 @@ class Functional(Protocol[P, S]): - def __call__(self, *models: Callable[P, Any]) -> Callable[P, S]: + def __call__( + self, __model: Callable[P, Any], *models: Callable[P, Any] + ) -> Callable[P, S]: ... @@ -22,15 +24,12 @@ def influence_fn( functional: Functional[P, S], *points: Point[T], **linearize_kwargs ) -> Functional[P, S]: """ - Returns the efficient influence function for ``functional`` - with respect to the parameters of probabilistic program ``model``. + Returns a new functional that computes the efficient influence function for ``functional`` + at the given ``points`` with respect to the parameters of its probabilistic program arguments. :param functional: model summary of interest, which is a function of ``model`` - :type functional: Functional[P, S] - :param points: points at which to compute the efficient influence function - :type points: Point[T] - :return: the efficient influence function for ``functional`` - :rtype: Callable[Concatenate[Point[T], P], S] + :param points: points for each input to ``functional`` at which to compute the efficient influence function + :return: functional that computes the efficient influence function for ``functional`` at ``points`` **Example usage**: @@ -141,7 +140,6 @@ def _fn(*args: P.args, **kwargs: P.kwargs) -> S: point in ``points``. :return: efficient influence function evaluated at each point in ``points`` or averaged - :rtype: S """ param_eif = linearized(*points, *args, **kwargs) return torch.vmap( From 3ea42748a29c550b5d8d332ec8047de1be3f66ed Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 11 Jan 2024 09:53:22 -0500 Subject: [PATCH 6/6] docstring --- chirho/robust/handlers/estimators.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/chirho/robust/handlers/estimators.py b/chirho/robust/handlers/estimators.py index ef8e944d2..4f60ddcd6 100644 --- a/chirho/robust/handlers/estimators.py +++ b/chirho/robust/handlers/estimators.py @@ -11,15 +11,15 @@ def one_step_correction( functional: Functional[P, S], - test_data: Point[T], + *test_points: Point[T], **influence_kwargs, ) -> Functional[P, S]: """ - Returns a function that computes the one-step correction for the - functional at a specified set of test points as discussed in - [1]. + Returns a functional that computes the one-step correction for the + functional at a specified set of test points as discussed in [1]. - :param functional: model summary of interest, which is a function of the model. + :param functional: model summary functional of interest + :param test_points: points at which to compute the one-step correction :return: functional to compute the one-step correction **References** @@ -29,4 +29,4 @@ def one_step_correction( """ influence_kwargs_one_step = influence_kwargs.copy() influence_kwargs_one_step["pointwise_influence"] = False - return influence_fn(functional, test_data, **influence_kwargs_one_step) + return influence_fn(functional, *test_points, **influence_kwargs_one_step)