Skip to content

Commit

Permalink
Integrates FD for Squared Density Into Experiment (#510)
Browse files Browse the repository at this point in the history
* added robust folder

* uncommited scratch work for log prob

* untested variational log prob

* uncomitted changes

* uncomitted changes

* pair coding w/ eli

* added tests w/ Eli

* eif

* linting

* moving test autograd to internals and deleted old utils file

* sketch influence implementation

* fix more args

* ops file

* file

* format

* lint

* clean up influence and tests

* make tests more generic

* guess max plate nesting

* linearize

* rename file

* tensor flatten

* predictive eif

* jvp type

* reorganize files

* shrink test case

* move guess_max_plate_nesting

* move cg solver to linearze

* type alias

* test_ops

* basic cg tests

* remove failing test case

* format

* move paramdict up

* remove obsolete test files

* add empty handlers

* add chirho.robust to docs

* fix memory leak in tests

* make typing compatible with python 3.8

* typing_extensions

* add branch to ci

* predictive

* remove imprecise annotation

* Added more tests for `linearize` and `make_empirical_fisher_vp` (#405)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* removed missing import

* fixed failing test with seeding

* addressing Eli's comments

* Add upper bound on number of CG steps (#404)

* upper bound on cg_iters

* address comment

* fixed test for non-symmetric matrix (#437)

* Make `NMCLogPredictiveLikelihood` seeded (#408)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* switched back to different

* Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430)

* hessian vector product formulation for fisher

* ignoring small type error

* fixed linting error

* Add new `SimpleModel` and `SimpleGuide` (#440)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* uncomitted change before branch switch

* switched back to different

* added revised simple model and guide

* added multiple link functions in test

* linting

* Batching in `linearize` and `influence` (#465)

* batching in linearize and influence

* addressing eli's review

* added optimization for pointwise false case

* fixing lint error

* batched cg (#466)

* One step correction implemented (#467)

* one step correction

* increased tolerance

* fixing lint issue

* Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473)

* sketch batched nmc lpd

* nits

* fix type

* format

* comment

* comment

* comment

* typo

* typo

* add condition to help guarantee idempotence

* simplify edge case

* simplify plate_name

* simplify batchedobservation logic

* factorize

* simplify batched

* reorder

* comment

* remove plate_names

* types

* formatting and type

* move unbind to utils

* remove max_plate_nesting arg from get_traces

* comment

* nit

* move get_importance_traces to utils

* fix types

* generic obs type

* lint

* format

* handle observe in batchedobservations

* event dim

* move batching handlers to utils

* replace 2/3 vmaps, tests pass

* remove dead code

* format

* name args

* lint

* shuffle code

* try an extra optimization in batchedlatents

* add another optimization

* undo changes to test

* remove inplace adds

* add performance test showing speedup

* document internal helpers

* batch latents test

* move batch handlers to predictive

* add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel

* use bind_leftmost_dim in log prob

* Added documentation for `chirho.robust` (#470)

* documentation

* documentation clean up w/ eli

* fix lint issue

* Make functional argument to influence_fn required (#487)

* Make functional argument required

* estimator

* docstring

* Remove guide argument from `influence_fn` and `linearize` (#489)

* Make functional argument required

* estimator

* docstring

* Remove guide, make tests pass

* rename internals.predictive to internals.nmc

* expose handlers.predictive

* expose handlers.predictive

* docstrings

* fix doc build

* fix equation

* docstring import

---------

Co-authored-by: Sam Witty <[email protected]>

* Make influence_fn a higher-order Functional (#492)

* make influence a functional

* fix test

* multiple arguments

* doc

* docstring

* docstring

* Add full corrected one step estimator (#476)

* added scaffolding to one step estimator

* kept signature the same as one_step_correction

* lint

* refactored test to include multiple estimators

* typo

* revise error

* added dict handling

* remove assert

* more informative error message

* replace dispatch with pytree flatten and unflatten

* revert arg for influence_function_estimator

* docs and lint

* lingering influence_fn

* fixed missing return

* rename

* lint

* add *model to appease the linter

* add abstractions and simple temp scratch to test with squared unit normal functional with perturbation.

* removes old scratch notebook

* gets squared density running under abstraction that couples functionals and models

* gets quad and mc approximations to match, vectorization hacky.

* adds plotting and comparative to analytic.

* adds scratch experiment comparing squared density analytic vs fd approx across various epsilon lambdas

* fixes dataset splitting, breaks analytic eif

* unfixes an incorrect fix, working now.

* refactors finite difference machinery to fit experimental specs.

* switches to existing rng seed context manager.

* reverts back to what turns out to be a slightly different seeding context.

* gets fd integrated into experiment exec and running.

* adds perturbable normal model to statics listing

* switches back to mean not mu

* lines up mean mu loc naming correctly.

---------

Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: Eli <[email protected]>
Co-authored-by: Sam Witty <[email protected]>
Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: eb8680 <[email protected]>
  • Loading branch information
6 people authored Jan 24, 2024
1 parent 72452d1 commit bd63526
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ def influence_approx_experiment_expected_density():
"cg_iters": None,
"residual_tol": 1e-4,
},
"fd_influence_estimator_kwargs": {
"lambdas": [0.1, 0.01, 0.001],
"epss": [0.1, 0.01, 0.001, 0.0001],
"num_samples_scaling": 100,
"seed": 0,
},
"data_config": data_config,
}
save_experiment_config(experiment_config)
Expand Down
12 changes: 6 additions & 6 deletions docs/examples/robust_paper/scripts/fd_influence_approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def __init__(self, *args, **kwargs):


def compute_fd_correction_sqd_mvn_quad(*, theta_hat: Point[T], **kwargs) -> List[Dict]:
mean = theta_hat['mean']
scale_tril = theta_hat['scale_tril']
mean = theta_hat['mu'].detach()
scale_tril = theta_hat['scale_tril'].detach()

fd_coupling = ExpectedNormalDensityQuadFunctional(
# TODO agnostic to names
Expand All @@ -57,8 +57,8 @@ def compute_fd_correction_sqd_mvn_quad(*, theta_hat: Point[T], **kwargs) -> List


def compute_fd_correction_sqd_mvn_mc(*, theta_hat: Point[T], **kwargs) -> List[Dict]:
mean = theta_hat['mean']
scale_tril = theta_hat['scale_tril']
mean = theta_hat['mu'].detach()
scale_tril = theta_hat['scale_tril'].detach()

fd_coupling = ExpectedNormalDensityMCFunctional(
# TODO agnostic to names
Expand Down Expand Up @@ -136,12 +136,12 @@ def smoke_test():
# Runtime
for ndim in [1, 2]:
theta_hat = th = dict(
mean=torch.zeros(ndim),
mu=torch.zeros(ndim),
scale_tril=torch.linalg.cholesky(torch.eye(ndim))
)

test_data = dict(
x=dist.MultivariateNormal(loc=th['mean'], scale_tril=th['scale_tril']).sample((20,))
x=dist.MultivariateNormal(loc=th['mu'], scale_tril=th['scale_tril']).sample((20,))
)

mc_correction = compute_fd_correction_sqd_mvn_mc(
Expand Down
25 changes: 23 additions & 2 deletions docs/examples/robust_paper/scripts/influence_approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
analytic_eif_expected_density,
analytic_eif_ate_causal_glm,
)
from fd_influence_approx import (
compute_fd_correction_sqd_mvn_mc,
compute_fd_correction_sqd_mvn_quad
)


def run_experiment(exp_config):
Expand Down Expand Up @@ -55,7 +59,8 @@ def run_experiment(exp_config):
conditioned_model = MODELS[model_str]["conditioned_model"](D_train, **model_kwargs)

# Load in functional
functional_class = FUNCTIONALS_DICT[exp_config["functional_str"]]
functional_str = exp_config["functional_str"]
functional_class = FUNCTIONALS_DICT[functional_str]
functional = functools.partial(functional_class, **exp_config["functional_kwargs"])

# Fit MLE
Expand Down Expand Up @@ -139,7 +144,23 @@ def run_experiment(exp_config):
results["all_monte_carlo_eif_results"] = all_monte_carlo_eif_results

### Finite Difference EIF ###
# TODO: Andy to add here
if model_str == "MultivariateNormalModel" and functional_str == "expected_density":
fd_kwargs = exp_config["fd_influence_estimator_kwargs"]
fd_mc_eif_results = compute_fd_correction_sqd_mvn_mc(
theta_hat=theta_hat,
test_data=D_test,
**fd_kwargs
)
results["fd_mc_eif_results"] = fd_mc_eif_results

# Maybe run this for 2d if you're having too good of a day, and need it to get worse.
if theta_hat["mu"].shape[-1] == 1:
fd_quad_eif_results = compute_fd_correction_sqd_mvn_quad(
theta_hat=theta_hat,
test_data=D_test,
**fd_kwargs
)
results["fd_quad_eif_results"] = fd_quad_eif_results

### Analytic EIF ###
if model_str == "CausalGLM":
Expand Down
2 changes: 2 additions & 0 deletions docs/examples/robust_paper/scripts/statics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pyro.distributions as dist
from docs.examples.robust_paper.models import *
from docs.examples.robust_paper.functionals import *
from docs.examples.robust_paper.finite_difference_eif.distributions import PerturbableNormal


MODELS = {
Expand All @@ -15,6 +16,7 @@
"data_generator": DataGeneratorMultivariateNormalModel,
"model": MultivariateNormalModel,
"conditioned_model": ConditionedMultivariateNormalModel,
"fd_perturbable_model": PerturbableNormal
},
}

Expand Down

0 comments on commit bd63526

Please sign in to comment.