Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement (Smoothed) Finite Difference Approximation of Influence Function #501

Closed
wants to merge 72 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
38b6158
added robust folder
agrawalraj Nov 8, 2023
c09dcab
uncommited scratch work for log prob
agrawalraj Nov 9, 2023
21e31bf
untested variational log prob
agrawalraj Nov 9, 2023
faed235
uncomitted changes
agrawalraj Nov 13, 2023
fac98cd
uncomitted changes
agrawalraj Nov 16, 2023
4edcb5e
pair coding w/ eli
agrawalraj Nov 16, 2023
fe17403
added tests w/ Eli
agrawalraj Nov 17, 2023
b159687
eif
eb8680 Nov 17, 2023
33f4811
linting
agrawalraj Nov 18, 2023
8e171f4
moving test autograd to internals and deleted old utils file
agrawalraj Nov 20, 2023
93cc014
sketch influence implementation
eb8680 Nov 21, 2023
9bc704c
fix more args
eb8680 Nov 21, 2023
cedb818
ops file
eb8680 Nov 21, 2023
418f792
file
eb8680 Nov 21, 2023
f792ddf
format
eb8680 Nov 21, 2023
88a100b
lint
eb8680 Nov 21, 2023
94c2fc6
clean up influence and tests
eb8680 Nov 21, 2023
da0bc5c
make tests more generic
eb8680 Nov 22, 2023
4d027e4
guess max plate nesting
eb8680 Nov 22, 2023
e85e33f
linearize
eb8680 Nov 22, 2023
1734191
rename file
eb8680 Nov 22, 2023
f46556b
tensor flatten
eb8680 Nov 22, 2023
1abc5e0
predictive eif
eb8680 Nov 22, 2023
9c80b60
jvp type
eb8680 Nov 22, 2023
931da4f
reorganize files
eb8680 Nov 22, 2023
dc63f31
shrink test case
eb8680 Nov 22, 2023
be3bc8d
move guess_max_plate_nesting
eb8680 Nov 22, 2023
9ce164a
move cg solver to linearze
eb8680 Nov 22, 2023
81196d4
type alias
eb8680 Nov 22, 2023
30cb2e7
test_ops
eb8680 Nov 22, 2023
21cf2d7
basic cg tests
eb8680 Nov 22, 2023
720661f
remove failing test case
eb8680 Nov 22, 2023
91833da
format
eb8680 Nov 22, 2023
548069a
move paramdict up
eb8680 Nov 22, 2023
12b22c0
remove obsolete test files
eb8680 Nov 22, 2023
d2bbf9d
Merge branch 'master' into staging-robust
eb8680 Nov 22, 2023
3b72bb0
add empty handlers
eb8680 Nov 22, 2023
89d9f6b
add chirho.robust to docs
eb8680 Nov 22, 2023
7582c22
fix memory leak in tests
eb8680 Nov 27, 2023
82c23e8
make typing compatible with python 3.8
eb8680 Nov 27, 2023
e08d9d6
typing_extensions
eb8680 Nov 27, 2023
22eae09
add branch to ci
eb8680 Nov 27, 2023
d0014db
predictive
eb8680 Nov 27, 2023
e5342dc
remove imprecise annotation
eb8680 Nov 27, 2023
be13ac5
Merge branch 'master' into staging-robust
SamWitty Nov 28, 2023
c5fe64b
Added more tests for `linearize` and `make_empirical_fisher_vp` (#405)
agrawalraj Dec 6, 2023
117d645
Add upper bound on number of CG steps (#404)
eb8680 Dec 7, 2023
8fe1b25
fixed test for non-symmetric matrix (#437)
agrawalraj Dec 7, 2023
3f0c83d
Make `NMCLogPredictiveLikelihood` seeded (#408)
agrawalraj Dec 8, 2023
4d41807
Use Hessian formulation of Fisher information in `make_empirical_fish…
agrawalraj Dec 8, 2023
2e01b7b
Add new `SimpleModel` and `SimpleGuide` (#440)
agrawalraj Dec 8, 2023
538cef8
Batching in `linearize` and `influence` (#465)
agrawalraj Dec 22, 2023
6bba70b
batched cg (#466)
agrawalraj Dec 22, 2023
f143d3a
One step correction implemented (#467)
agrawalraj Dec 22, 2023
878eb0d
Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLog…
eb8680 Jan 2, 2024
3cfe319
Added documentation for `chirho.robust` (#470)
agrawalraj Jan 2, 2024
5d77fe0
Make functional argument to influence_fn required (#487)
eb8680 Jan 9, 2024
013d518
Remove guide argument from `influence_fn` and `linearize` (#489)
eb8680 Jan 9, 2024
c4346c8
Make influence_fn a higher-order Functional (#492)
eb8680 Jan 11, 2024
9207e3e
Add full corrected one step estimator (#476)
SamWitty Jan 12, 2024
ca916cd
Merge branch 'master' into staging-robust
eb8680 Jan 12, 2024
a7875c6
add abstractions and simple temp scratch to test with squared unit no…
azane Jan 12, 2024
ad519be
removes old scratch notebook
azane Jan 12, 2024
127a4a4
Merge branch 'staging-robust' into az-influence-finite-difference-2
azane Jan 12, 2024
1efe6ea
gets squared density running under abstraction that couples functiona…
azane Jan 12, 2024
44785d8
gets quad and mc approximations to match, vectorization hacky.
azane Jan 12, 2024
5a11a7a
Merge branch 'staging-robust-icml' into az-influence-finite-difference-2
azane Jan 16, 2024
31cc9ac
adds plotting and comparative to analytic.
azane Jan 16, 2024
f867f2a
adds scratch experiment comparing squared density analytic vs fd appr…
azane Jan 17, 2024
7f10667
fixes dataset splitting, breaks analytic eif
azane Jan 17, 2024
094562a
unfixes an incorrect fix, working now.
azane Jan 17, 2024
0556543
Merge branch 'staging-robust-icml' into az-influence-finite-difference-2
azane Jan 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions chirho/robust/handlers/fd_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import torch
import pyro
import pyro.distributions as dist
from typing import Dict, Optional
from contextlib import contextmanager
from chirho.robust.ops import Functional, Point, T
import numpy as np


class ModelWithMarginalDensity(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def density(self, *args, **kwargs):
# TODO this can probably default to using BatchedNMCLogMarginalLikelihood applied to self,
# but providing here to avail of analytic densities. Or have a constructor that takes a
# regular model and puts the marginal density here.
raise NotImplementedError()

def forward(self, *args, **kwargs):
raise NotImplementedError()


class PrefixMessenger(pyro.poutine.messenger.Messenger):

def __init__(self, prefix: str):
self.prefix = prefix

def _pyro_sample(self, msg) -> None:
msg["name"] = f"{self.prefix}{msg['name']}"


class FDModelFunctionalDensity(ModelWithMarginalDensity):
"""
This class serves to couple the forward sampling model, density, and functional. Finite differencing
operates in the space of densities, and therefore requires of its functionals that they "know about"
the causal structure of the generative model. Thus, the three components are coupled together here.

"""

model: ModelWithMarginalDensity

# TODO These managers are weird but lets you define a valid model at init time and then temporarily
# modify the perturbation later, eg. in the influence function approximatoin.
# TODO pull out boilerplate
@contextmanager
def set_eps(self, eps):
original_eps = self._eps
self._eps = eps
try:
yield
finally:
self._eps = original_eps

@contextmanager
def set_lambda(self, lambda_):
original_lambda = self._lambda
self._lambda = lambda_
try:
yield
finally:
self._lambda = original_lambda

@contextmanager
def set_kernel_point(self, kernel_point: Dict):
original_kernel_point = self._kernel_point
self._kernel_point = kernel_point
try:
yield
finally:
self._kernel_point = original_kernel_point

@property
def kernel(self) -> ModelWithMarginalDensity:
# TODO implementation of a kernel could be brought up to this level. User would need to pass a kernel type
# that's parameterized by the kernel point and lambda.
"""
Inheritors should construct the kernel here as a function of self._kernel_point and self._lambda.
:return:
"""
raise NotImplementedError()

def __init__(self, default_kernel_point: Dict, *args, default_eps=0., default_lambda=0.1, **kwargs):
super().__init__(*args, **kwargs)
self._eps = default_eps
self._lambda = default_lambda
self._kernel_point = default_kernel_point
# TODO don't assume .shape[-1]
self.ndims = np.sum([v.shape[-1] for v in self._kernel_point.values()])

@property
def mixture_weights(self):
return torch.tensor([1. - self._eps, self._eps])

def density(self, model_kwargs: Dict, kernel_kwargs: Dict):
mpart = self.mixture_weights[0] * self.model.density(**model_kwargs)
kpart = self.mixture_weights[1] * self.kernel.density(**kernel_kwargs)
return mpart + kpart

def forward(self, model_kwargs: Optional[Dict] = None, kernel_kwargs: Optional[Dict] = None):
# _from_kernel = pyro.sample('_mixture_assignment', dist.Categorical(self.mixture_weights))
#
# if _from_kernel:
# return self.kernel(**(kernel_kwargs or dict()))
# else:
# return self.model(**(model_kwargs or dict()))

_from_kernel = pyro.sample('_mixture_assignment', dist.Categorical(self.mixture_weights))

kernel_mask = _from_kernel.bool() # Convert to boolean mask

# Apply the respective functions using the masks
with PrefixMessenger('kernel_'), pyro.poutine.trace() as kernel_tr:
kernel_result = self.kernel(**(kernel_kwargs or dict()))
with PrefixMessenger('model_'), pyro.poutine.trace() as model_tr:
model_result = self.model(**(model_kwargs or dict()))

# FIXME to make log likelihoods work properly, the log likelihoods need to be masked/not added
# for particular elements. See e.g. MaskedMixture for a non-general example of how to do this (it
# uses torch distributions instead of arbitrary probabilistic programs.
# https://docs.pyro.ai/en/stable/distributions.html?highlight=MaskedMixture#maskedmixture
# FIXME ideally the trace would have elements of the same name as well here.

# FIXME where isn't shape agnostic.

# Use masks to select the appropriate result for each sample
result = torch.where(kernel_mask[:, None], kernel_result, model_result)

return result

def functional(self, *args, **kwargs):
# TODO update docstring to this being build_functional instead of just functional
"""
The functional target for this model. This is tightly coupled to a particular
pyro model because finite differencing operates in the space of densities, and
automatically exploit any structure of the pyro model the functional
is being evaluated with respect to. As such, the functional must be implemented
with the specific structure of coupled pyro model in mind.
:param args:
:param kwargs:
:return: An estimate of the functional for ths model.
"""
raise NotImplementedError()


# TODO move this to chirho/robust/ops.py and resolve signature mismatches? Maybe. The problem is that the ops
# signature (rightly) decouples models and functionals, whereas for finite differencing they must be coupled
# because the functional (in many cases) must know about the causal structure of the model.
def fd_influence_fn(model: FDModelFunctionalDensity, points: Point[T], eps: float, lambda_: float):

def _influence_fn(*args, **kwargs):

# Length of first value in points mappping.
len_points = len(list(points.values())[0])
eif_vals = []
for i in range(len_points):
kernel_point = {k: v[i] for k, v in points.items()}

psi_p = model.functional(*args, **kwargs)

with model.set_eps(eps), model.set_lambda(lambda_), model.set_kernel_point(kernel_point):
psi_p_eps = model.functional(*args, **kwargs)

eif_vals.append((psi_p_eps - psi_p) / eps)
return eif_vals

return _influence_fn


Empty file.
80 changes: 80 additions & 0 deletions docs/source/robust_fd/squared_normal_density.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from chirho.robust.handlers.fd_model import FDModelFunctionalDensity, ModelWithMarginalDensity
import pyro
import pyro.distributions as dist
import torch
from scipy.stats import multivariate_normal
from scipy.integrate import nquad
import numpy as np


class MultivariateNormalwDensity(ModelWithMarginalDensity):

def __init__(self, mean, cov, *args, **kwargs):
super().__init__(*args, **kwargs)

self.mean = mean
self.cov = cov

def density(self, x):
return multivariate_normal.pdf(x, mean=self.mean, cov=self.cov)

def forward(self):
return pyro.sample("x", dist.MultivariateNormal(self.mean, self.cov))


class NormalKernel(FDModelFunctionalDensity):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@property
def kernel(self):
# TODO agnostic to names.
mean = self._kernel_point['x']
return MultivariateNormalwDensity(mean, torch.eye(self.ndims) * self._lambda)


class PerturbableNormal(FDModelFunctionalDensity):

def __init__(self, *args, mean, cov, **kwargs):
super().__init__(*args, **kwargs)

self.ndims = mean.shape[-1]
self.model = MultivariateNormalwDensity(mean, cov)

self.mean = mean
self.cov = cov


class ExpectedDensityQuadFunctional(FDModelFunctionalDensity):
"""
Compute the squared normal density using quadrature.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def functional(self):
def integrand(*args):
# TODO agnostic to kwarg names.
model_kwargs = kernel_kwargs = dict(x=np.array(args))
return self.density(model_kwargs, kernel_kwargs) ** 2

ndim = self._kernel_point['x'].shape[-1]

return nquad(integrand, [[-np.inf, np.inf]] * ndim)[0]


class ExpectedDensityMCFunctional(FDModelFunctionalDensity):
"""
Compute the squared normal density using Monte Carlo.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def functional(self, nmc=1000):
# TODO agnostic to kwarg names
with pyro.plate('samples', nmc):
points = self()
return torch.mean(self.density(model_kwargs=dict(x=points), kernel_kwargs=dict(x=points)))
126 changes: 126 additions & 0 deletions docs/source/robust_fd_scratch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
raise NotImplementedError()

from robust_fd.squared_normal_density import ExpectedNormalDensityQuad, ExpectedNormalDensityMC, _ExpectedNormalDensity
from chirho.robust.handlers.fd_model import fd_influence_fn
import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal, norm

ndim = 1
eps = 0.01
mean = torch.tensor([0.,] * ndim)
cov = torch.eye(ndim)
lambda_ = 0.001

end_quad = ExpectedNormalDensityQuad(
mean=mean,
cov=cov,
default_kernel_point=dict(x=torch.tensor([0.,] * ndim)),
)

guess = end_quad.functional()
print(f"Guess: {guess}")

if ndim == 1:
xx = np.linspace(-5, 5, 1000)
with (end_quad.set_kernel_point(dict(x=torch.tensor([1., ] * ndim))),
end_quad.set_lambda(lambda_),
end_quad.set_eps(eps)):
yy = [end_quad.density(
{'x': torch.tensor([x])},
{'x': torch.tensor([x])})
for x in xx
]

plt.plot(xx, yy)

yy = [end_quad.density(
{'x': torch.tensor([x])},
{'x': torch.tensor([x])})
for x in xx
]

plt.plot(xx, yy)

# Sample points from a slightly more entropoic model.
# FIXME not generalized for ndim > 1
points = dict(x=torch.linspace(-3, 3, 50)[:, None])

print(f"Analytic: {((1./(3. - -3.))**2) * (3. - -3.)}")

target_quad = fd_influence_fn(
model=end_quad,
points=points,
eps=eps,
lambda_=lambda_,
)

correction_quad_eif = np.array(target_quad())

if ndim == 1:
plt.figure()
plt.plot(points['x'].numpy(), correction_quad_eif, label='quad eif')

correction_quad = np.mean(correction_quad_eif)

print(f"Correction (Quad): {correction_quad}")

end_mc = ExpectedNormalDensityMC(
mean=mean,
cov=cov,
default_kernel_point=dict(x=torch.tensor([0.,] * ndim)),
)

target_mc = fd_influence_fn(
model=end_mc,
points=points,
eps=eps,
lambda_=lambda_,
)

correction_mc_eif = np.array(target_mc(nmc=4000))

if ndim == 1:
plt.plot(points['x'].numpy(), correction_mc_eif, linewidth=0.3, alpha=0.8)

correction_mc = np.mean(correction_mc_eif)

print(f"Correction (MC): {correction_mc}")


def compute_analytic_eif(model: _ExpectedNormalDensity, points):
funcval = model.functional()
density = model.density(points, points)

return 2. * (density - funcval)


analytic_eif = compute_analytic_eif(end_quad, points).numpy()

analytic = np.mean(analytic_eif)

print(f"Analytic: {analytic}")

print(f"Analytic Corrected: {guess - analytic}")


if ndim == 1:

plt.suptitle(f"ndim={ndim}, eps={eps}, lambda={lambda_}")

pxsamps = points['x'].numpy().squeeze()

plt.plot(pxsamps, analytic_eif, label="analytic")

# Plot the corresponding uniform and normal densities.
plt.plot(points['x'].numpy(), [1./(3. - -3.)] * len(points['x']), color='black', label='uniform')

# plt.plot(xx, norm.pdf(xx, loc=0, scale=1), color='green', label='normal')
plt.plot(pxsamps, norm.pdf(pxsamps, loc=0, scale=1), color='green', label='normal')
# Plot the correction, just quad.
plt.plot(pxsamps, norm.pdf(pxsamps, loc=0, scale=1) - 0.1 * np.array(correction_quad_eif),
linestyle='--', color='green', label='normal (corrected)')

plt.legend()
plt.show()
Loading
Loading