From bd63526f9d8efbbab38a787c0336a3f97c9ebf37 Mon Sep 17 00:00:00 2001 From: Andy Zane Date: Wed, 24 Jan 2024 12:48:12 -0800 Subject: [PATCH] Integrates FD for Squared Density Into Experiment (#510) * added robust folder * uncommited scratch work for log prob * untested variational log prob * uncomitted changes * uncomitted changes * pair coding w/ eli * added tests w/ Eli * eif * linting * moving test autograd to internals and deleted old utils file * sketch influence implementation * fix more args * ops file * file * format * lint * clean up influence and tests * make tests more generic * guess max plate nesting * linearize * rename file * tensor flatten * predictive eif * jvp type * reorganize files * shrink test case * move guess_max_plate_nesting * move cg solver to linearze * type alias * test_ops * basic cg tests * remove failing test case * format * move paramdict up * remove obsolete test files * add empty handlers * add chirho.robust to docs * fix memory leak in tests * make typing compatible with python 3.8 * typing_extensions * add branch to ci * predictive * remove imprecise annotation * Added more tests for `linearize` and `make_empirical_fisher_vp` (#405) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * removed missing import * fixed failing test with seeding * addressing Eli's comments * Add upper bound on number of CG steps (#404) * upper bound on cg_iters * address comment * fixed test for non-symmetric matrix (#437) * Make `NMCLogPredictiveLikelihood` seeded (#408) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * switched back to different * Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430) * hessian vector product formulation for fisher * ignoring small type error * fixed linting error * Add new `SimpleModel` and `SimpleGuide` (#440) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * uncomitted change before branch switch * switched back to different * added revised simple model and guide * added multiple link functions in test * linting * Batching in `linearize` and `influence` (#465) * batching in linearize and influence * addressing eli's review * added optimization for pointwise false case * fixing lint error * batched cg (#466) * One step correction implemented (#467) * one step correction * increased tolerance * fixing lint issue * Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473) * sketch batched nmc lpd * nits * fix type * format * comment * comment * comment * typo * typo * add condition to help guarantee idempotence * simplify edge case * simplify plate_name * simplify batchedobservation logic * factorize * simplify batched * reorder * comment * remove plate_names * types * formatting and type * move unbind to utils * remove max_plate_nesting arg from get_traces * comment * nit * move get_importance_traces to utils * fix types * generic obs type * lint * format * handle observe in batchedobservations * event dim * move batching handlers to utils * replace 2/3 vmaps, tests pass * remove dead code * format * name args * lint * shuffle code * try an extra optimization in batchedlatents * add another optimization * undo changes to test * remove inplace adds * add performance test showing speedup * document internal helpers * batch latents test * move batch handlers to predictive * add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel * use bind_leftmost_dim in log prob * Added documentation for `chirho.robust` (#470) * documentation * documentation clean up w/ eli * fix lint issue * Make functional argument to influence_fn required (#487) * Make functional argument required * estimator * docstring * Remove guide argument from `influence_fn` and `linearize` (#489) * Make functional argument required * estimator * docstring * Remove guide, make tests pass * rename internals.predictive to internals.nmc * expose handlers.predictive * expose handlers.predictive * docstrings * fix doc build * fix equation * docstring import --------- Co-authored-by: Sam Witty * Make influence_fn a higher-order Functional (#492) * make influence a functional * fix test * multiple arguments * doc * docstring * docstring * Add full corrected one step estimator (#476) * added scaffolding to one step estimator * kept signature the same as one_step_correction * lint * refactored test to include multiple estimators * typo * revise error * added dict handling * remove assert * more informative error message * replace dispatch with pytree flatten and unflatten * revert arg for influence_function_estimator * docs and lint * lingering influence_fn * fixed missing return * rename * lint * add *model to appease the linter * add abstractions and simple temp scratch to test with squared unit normal functional with perturbation. * removes old scratch notebook * gets squared density running under abstraction that couples functionals and models * gets quad and mc approximations to match, vectorization hacky. * adds plotting and comparative to analytic. * adds scratch experiment comparing squared density analytic vs fd approx across various epsilon lambdas * fixes dataset splitting, breaks analytic eif * unfixes an incorrect fix, working now. * refactors finite difference machinery to fit experimental specs. * switches to existing rng seed context manager. * reverts back to what turns out to be a slightly different seeding context. * gets fd integrated into experiment exec and running. * adds perturbable normal model to statics listing * switches back to mean not mu * lines up mean mu loc naming correctly. --------- Co-authored-by: Raj Agrawal Co-authored-by: Eli Co-authored-by: Sam Witty Co-authored-by: Raj Agrawal Co-authored-by: eb8680 --- .../scripts/create_experiment_configs.py | 6 +++++ .../scripts/fd_influence_approx.py | 12 ++++----- .../robust_paper/scripts/influence_approx.py | 25 +++++++++++++++++-- docs/examples/robust_paper/scripts/statics.py | 2 ++ 4 files changed, 37 insertions(+), 8 deletions(-) diff --git a/docs/examples/robust_paper/scripts/create_experiment_configs.py b/docs/examples/robust_paper/scripts/create_experiment_configs.py index 67410f3cc..965819947 100644 --- a/docs/examples/robust_paper/scripts/create_experiment_configs.py +++ b/docs/examples/robust_paper/scripts/create_experiment_configs.py @@ -91,6 +91,12 @@ def influence_approx_experiment_expected_density(): "cg_iters": None, "residual_tol": 1e-4, }, + "fd_influence_estimator_kwargs": { + "lambdas": [0.1, 0.01, 0.001], + "epss": [0.1, 0.01, 0.001, 0.0001], + "num_samples_scaling": 100, + "seed": 0, + }, "data_config": data_config, } save_experiment_config(experiment_config) diff --git a/docs/examples/robust_paper/scripts/fd_influence_approx.py b/docs/examples/robust_paper/scripts/fd_influence_approx.py index 2fd3580d3..d499389b9 100644 --- a/docs/examples/robust_paper/scripts/fd_influence_approx.py +++ b/docs/examples/robust_paper/scripts/fd_influence_approx.py @@ -43,8 +43,8 @@ def __init__(self, *args, **kwargs): def compute_fd_correction_sqd_mvn_quad(*, theta_hat: Point[T], **kwargs) -> List[Dict]: - mean = theta_hat['mean'] - scale_tril = theta_hat['scale_tril'] + mean = theta_hat['mu'].detach() + scale_tril = theta_hat['scale_tril'].detach() fd_coupling = ExpectedNormalDensityQuadFunctional( # TODO agnostic to names @@ -57,8 +57,8 @@ def compute_fd_correction_sqd_mvn_quad(*, theta_hat: Point[T], **kwargs) -> List def compute_fd_correction_sqd_mvn_mc(*, theta_hat: Point[T], **kwargs) -> List[Dict]: - mean = theta_hat['mean'] - scale_tril = theta_hat['scale_tril'] + mean = theta_hat['mu'].detach() + scale_tril = theta_hat['scale_tril'].detach() fd_coupling = ExpectedNormalDensityMCFunctional( # TODO agnostic to names @@ -136,12 +136,12 @@ def smoke_test(): # Runtime for ndim in [1, 2]: theta_hat = th = dict( - mean=torch.zeros(ndim), + mu=torch.zeros(ndim), scale_tril=torch.linalg.cholesky(torch.eye(ndim)) ) test_data = dict( - x=dist.MultivariateNormal(loc=th['mean'], scale_tril=th['scale_tril']).sample((20,)) + x=dist.MultivariateNormal(loc=th['mu'], scale_tril=th['scale_tril']).sample((20,)) ) mc_correction = compute_fd_correction_sqd_mvn_mc( diff --git a/docs/examples/robust_paper/scripts/influence_approx.py b/docs/examples/robust_paper/scripts/influence_approx.py index b59d2560b..7881438a7 100644 --- a/docs/examples/robust_paper/scripts/influence_approx.py +++ b/docs/examples/robust_paper/scripts/influence_approx.py @@ -16,6 +16,10 @@ analytic_eif_expected_density, analytic_eif_ate_causal_glm, ) +from fd_influence_approx import ( + compute_fd_correction_sqd_mvn_mc, + compute_fd_correction_sqd_mvn_quad +) def run_experiment(exp_config): @@ -55,7 +59,8 @@ def run_experiment(exp_config): conditioned_model = MODELS[model_str]["conditioned_model"](D_train, **model_kwargs) # Load in functional - functional_class = FUNCTIONALS_DICT[exp_config["functional_str"]] + functional_str = exp_config["functional_str"] + functional_class = FUNCTIONALS_DICT[functional_str] functional = functools.partial(functional_class, **exp_config["functional_kwargs"]) # Fit MLE @@ -139,7 +144,23 @@ def run_experiment(exp_config): results["all_monte_carlo_eif_results"] = all_monte_carlo_eif_results ### Finite Difference EIF ### - # TODO: Andy to add here + if model_str == "MultivariateNormalModel" and functional_str == "expected_density": + fd_kwargs = exp_config["fd_influence_estimator_kwargs"] + fd_mc_eif_results = compute_fd_correction_sqd_mvn_mc( + theta_hat=theta_hat, + test_data=D_test, + **fd_kwargs + ) + results["fd_mc_eif_results"] = fd_mc_eif_results + + # Maybe run this for 2d if you're having too good of a day, and need it to get worse. + if theta_hat["mu"].shape[-1] == 1: + fd_quad_eif_results = compute_fd_correction_sqd_mvn_quad( + theta_hat=theta_hat, + test_data=D_test, + **fd_kwargs + ) + results["fd_quad_eif_results"] = fd_quad_eif_results ### Analytic EIF ### if model_str == "CausalGLM": diff --git a/docs/examples/robust_paper/scripts/statics.py b/docs/examples/robust_paper/scripts/statics.py index a08388200..32aae1b5d 100644 --- a/docs/examples/robust_paper/scripts/statics.py +++ b/docs/examples/robust_paper/scripts/statics.py @@ -3,6 +3,7 @@ import pyro.distributions as dist from docs.examples.robust_paper.models import * from docs.examples.robust_paper.functionals import * +from docs.examples.robust_paper.finite_difference_eif.distributions import PerturbableNormal MODELS = { @@ -15,6 +16,7 @@ "data_generator": DataGeneratorMultivariateNormalModel, "model": MultivariateNormalModel, "conditioned_model": ConditionedMultivariateNormalModel, + "fd_perturbable_model": PerturbableNormal }, }