Skip to content

Commit

Permalink
Fix for Robust Estimation Memory Issue (#548)
Browse files Browse the repository at this point in the history
* messy experiments and profiling, with fix drafted and working using basic reverse diff and manual mat vec product.

* first passing shape test of generalized pytree jacobian vector product.

* passes shape and equivalence test with vmapped forward jvp.

* passes robust module tests, still needs cleanup.

* refactors tests a bit to include separate smoke

* cleans up and comments.

* cleans up scratch files.

* lints and adds comment in mceif handler

* removes unused imports

* adds explicit test asserting mangeable memory usage.

* updates comment.
  • Loading branch information
azane authored Jul 6, 2024
1 parent c9a8734 commit be862d1
Show file tree
Hide file tree
Showing 3 changed files with 312 additions and 10 deletions.
21 changes: 12 additions & 9 deletions chirho/robust/handlers/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pyro
import torch

from chirho.robust.internals.utils import pytree_generalized_manual_revjvp
from chirho.robust.ops import Functional, P, Point, S, T, influence_fn


Expand Down Expand Up @@ -73,15 +74,17 @@ def _pyro_influence(self, msg) -> None:
"Please use torch.no_grad() to avoid this issue. See example in the docstring."
)
param_eif = linearized(*points, *args, **kwargs)
msg["value"] = torch.vmap(
lambda d: torch.func.jvp(
lambda p: func_target(p, *args, **kwargs),
(target_params,),
(d,),
)[1],
in_dims=0,
randomness="different",
)(param_eif)

# Compute the jacobian vector product of the functionals jacobian wrt the parameters with the fisher matrix X
# log prob of data product. This implementation uses reverse mode auto diff for the jacobian and then
# manually right multiplies with param_eif. Unfortunately, torch.func.jvp uses very large amounts of memory,
# which is exacerbated when using vmap to broadcast, as opposed to the manual impplementation used herein.
msg["value"] = pytree_generalized_manual_revjvp(
fn=lambda p: func_target(p, *args, **kwargs),
params=target_params,
batched_vector=param_eif,
)

msg["done"] = True


Expand Down
150 changes: 149 additions & 1 deletion chirho/robust/internals/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import contextlib
import functools
import math
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, TypeVar
from math import prod
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, TypeVar

import pyro
import torch
from torch.utils._pytree import (
SUPPORTED_NODES,
PyTree,
TreeSpec,
_get_node_type,
tree_flatten,
tree_unflatten,
)
from typing_extensions import Concatenate, ParamSpec

from chirho.indexed.handlers import add_indices
Expand All @@ -14,6 +23,7 @@
Q = ParamSpec("Q")
S = TypeVar("S")
T = TypeVar("T")
U = TypeVar("U")

ParamDict = Mapping[str, torch.Tensor]

Expand Down Expand Up @@ -93,6 +103,144 @@ def unflatten(x: torch.Tensor) -> Dict[str, torch.Tensor]:
return flatten, unflatten


SPyTree = TypeVar("SPyTree", bound=PyTree)
TPyTree = TypeVar("TPyTree", bound=PyTree)
UPyTree = TypeVar("UPyTree", bound=PyTree)


def pytree_generalized_manual_revjvp(
fn: Callable[[TPyTree], SPyTree], params: TPyTree, batched_vector: UPyTree
) -> SPyTree:
"""
Computes the jacobian-vector product using backward differentiation for the jacobian, and then manually
right multiplying the batched vector. This supports pytree structured inputs, outputs, and params.
:param fn: function to compute the jacobian of
:param params: parameters to compute the jacobian at
:param batched_vector: batched vector to right multiply the jacobian by
:raises ValueError: if params and batched_vector do not have the same tree structure
:return: jacobian-vector product
"""

# Assumptions (in terms of elements of the referenced pytrees):
# 1. params is not batched, and represents just the inputs to the fn that we'll take the jac wrt.
# - params.shape == (*param_shape)
# 2. batched_vector is the batched vector component of the jv product. It's rightside shape matches params.
# - batched_vector.shape == (*batch_shape, *param_shape)
# 3. The output of the function will have some output_shape, which will cause the jacobian to have shape.
# - jac.shape == (*output_shape, *param_shape)
# So the task is to infer these shapes and line everything up correctly. As a general approach, we'll flatten
# the inputs and output shapes in order to apply a standard batched matrix multiplication operation.
# The output will have shape (*batch_shape, *output_shape).

# The shaping is complicated by fact that we aren't working with tensors, but PyTrees instead, and we want to
# perform the same inner product wrt to the tree structure. This mainly shows up in that the jacobian will
# return a pytree with a "root" structure matching that of SPyTree (the return of the fn), but at each leaf
# of that tree, we have a pytree matching the structure of TPyTree (the params). This is the tree-structured
# equivalent the jac shape matching output on the left, and params on the right.

jac_fn = torch.func.jacrev(fn)
jac = jac_fn(params)

flat_params, param_tspec = tree_flatten(params)

flat_batched_vector, batched_vector_tspec = tree_flatten(batched_vector)

if param_tspec != batched_vector_tspec:
# This is also required by pytorch's jvp implementation.
raise ValueError(
"params and batched_vector must have the same tree structure. This requirement generalizes"
" the notion that the batched_vector must be the correct shape to right multiply the "
"jacobian."
)

# In order to map the param shapes together, we need to iterate through the output tree structure and map each
# subtree (corresponding to params) onto the params and batched_vector tree structures, which are both structured
# according to the parameters.
def recurse_to_flattened_sub_tspec(
pytree: PyTree, sub_tspec: TreeSpec, tspec: Optional[TreeSpec] = None
):
# Default to passed treespec, otherwise compute here.
_, tspec = tree_flatten(pytree) if tspec is None else (None, tspec)

# If fn returns a tensor straight away, then the subtree will match at the root node. Check for that here.
if tspec == sub_tspec:
flattened, _ = tree_flatten(pytree)
yield flattened
return

# Extract child trees in a node-type agnostic way.
node_type = _get_node_type(pytree)
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
children_pytrees, _ = flatten_fn(pytree)
children_tspecs = tspec.children_specs

# Iterate through children and their specs.
for child_pytree, child_tspec in zip(children_pytrees, children_tspecs):
# If we've landed on the target subtree...
if child_tspec == sub_tspec:
child_flattened, _ = tree_flatten(child_pytree)
yield child_flattened # ...yield the flat child for that subtree.
else: # otherwise, recurse to the next level.
yield from recurse_to_flattened_sub_tspec(
child_pytree, sub_tspec, tspec=child_tspec
)

flat_out: List[PyTree] = []

# Recurse into the jacobian tree to find the subtree corresponding to the sub-jacobian for each
# individual output tensor in that tree.
for flat_jac_output_subtree in recurse_to_flattened_sub_tspec(
pytree=jac, sub_tspec=param_tspec
):

flat_sub_out: List[torch.Tensor] = []

# Then map that subtree (with tree structure matching that of params) onto the params and batched_vector.
for i, (p, j, v) in enumerate(
zip(flat_params, flat_jac_output_subtree, flat_batched_vector)
):
# Infer the parameter shapes directly from passed parameters.
og_param_shape = p.shape
param_shape = og_param_shape if len(og_param_shape) else (1,)
param_numel = prod(param_shape)
og_param_ndim = len(og_param_shape)

# Infer the batch shape by subtracting off the param shape on the right.
og_batch_shape = v.shape[:-og_param_ndim] if og_param_ndim else v.shape
batch_shape = og_batch_shape if len(og_batch_shape) else (1,)
batch_ndim = len(batch_shape)

# Infer the output shape by subtracting off the param shape from the jacobian.
og_output_shape = j.shape[:-og_param_ndim] if og_param_ndim else j.shape
output_shape = og_output_shape if len(og_output_shape) else (1,)
output_numel = prod(output_shape)

# Reshape for matmul and s.t. that the jacobian can be broadcast over the batch dims.
j_bm = j.reshape(*(1,) * batch_ndim, output_numel, param_numel)
v_bm = v.reshape(*batch_shape, param_numel, 1)
jv = j_bm @ v_bm

# Reshape result back to the original output shape, with support for empty scalar shapes.
og_res_shape = (*og_batch_shape, *og_output_shape)
jv = jv.reshape(*og_res_shape) if len(og_res_shape) else jv.squeeze()

flat_sub_out.append(jv)

# The inner product is operating over parameters and the parameter subtree that we just iterated over.
# So stack these and sum.
flat_out.append(torch.stack(flat_sub_out, dim=0).sum(0))

# flat_out is now the flattened version of the tree returned by fn, with each contained tensor having the same
# batch dimensions (matching the batching of the batched vector).
# TODO get out_treespec from the jacobian treespec instead, and don't have an extra forward eval of fn.
# Jacobian tree has this structure but its leaves have params treespec.
out = fn(params)
_, out_treespec = tree_flatten(out)

return tree_unflatten(flat_out, out_treespec)


def make_functional_call(
mod: Callable[P, T]
) -> Tuple[ParamDict, Callable[Concatenate[ParamDict, P], T]]:
Expand Down
151 changes: 151 additions & 0 deletions tests/robust/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from math import prod

import pytest
import torch

from chirho.robust.internals.utils import pytree_generalized_manual_revjvp

_shapes = [tuple(), (1,), (1, 1), (2,), (2, 3)]


def _exec_pytree_generalized_manual_revjvp(
batch_shape, output_shape1, output_shape2, param_shape1, param_shape2
):

# TODO add tests of subdicts and sublists to really exercise the pytree structure.
# TODO add permutations for single tensors params/batch_vector/outputs (i.e. not in an explicit tree structure.

params = dict(
params1=torch.randn(param_shape1),
params2=torch.randn(param_shape2),
)

batch_vector = dict( # this tree is mapped onto the params struture in the right multiplication w/ the jacobian.
params1=torch.randn(batch_shape + param_shape1),
params2=torch.randn(batch_shape + param_shape2),
)

weights1 = torch.randn(prod(output_shape1), prod(param_shape1))
weights2 = torch.randn(prod(output_shape2), prod(param_shape2))

def fn_inner(p: torch.Tensor, weights: torch.Tensor):
# Arbitrary functino that maps param shape to output shape implicit in weights.
p = p.flatten()
out = weights @ p
return out

def fn(p):
return dict(
out1=fn_inner(p["params1"], weights1).reshape(output_shape1),
out2=fn_inner(p["params2"], weights2).reshape(output_shape2),
)

for (k, v), output_shape in zip(fn(params).items(), (output_shape1, output_shape2)):
assert v.shape == output_shape

broadcasted_reverse_jvp_result = pytree_generalized_manual_revjvp(
fn, params, batch_vector
)

return broadcasted_reverse_jvp_result, (fn, params, batch_vector)


@pytest.mark.parametrize("batch_shape", _shapes)
@pytest.mark.parametrize("output_shape1", _shapes)
@pytest.mark.parametrize("output_shape2", _shapes)
@pytest.mark.parametrize("param_shape1", _shapes)
@pytest.mark.parametrize("param_shape2", _shapes)
def test_smoke_pytree_generalized_manual_revjvp(
batch_shape, output_shape1, output_shape2, param_shape1, param_shape2
):

broadcasted_reverse_jvp_result, _ = _exec_pytree_generalized_manual_revjvp(
batch_shape, output_shape1, output_shape2, param_shape1, param_shape2
)

assert broadcasted_reverse_jvp_result["out1"].shape == batch_shape + output_shape1
assert broadcasted_reverse_jvp_result["out2"].shape == batch_shape + output_shape2

assert not torch.isnan(broadcasted_reverse_jvp_result["out1"]).any()
assert not torch.isnan(broadcasted_reverse_jvp_result["out2"]).any()


# Standard vmap and jvp application doesn't support multiple batch dims or scalar shapes. So manually spec
# single batch dims and remove the tuple() scalar shape via _shapes[1:]
@pytest.mark.parametrize("batch_shape", [(1,), (3,)])
@pytest.mark.parametrize("output_shape1", _shapes[1:])
@pytest.mark.parametrize("output_shape2", _shapes[1:])
@pytest.mark.parametrize("param_shape1", _shapes[1:])
@pytest.mark.parametrize("param_shape2", _shapes[1:])
def test_pytree_generalized_manual_revjvp(
batch_shape, output_shape1, output_shape2, param_shape1, param_shape2
):

broadcasted_reverse_jvp_result, (fn, params, batch_vector) = (
_exec_pytree_generalized_manual_revjvp(
batch_shape, output_shape1, output_shape2, param_shape1, param_shape2
)
)

vmapped_forward_jvp_result = torch.vmap(
lambda d: torch.func.jvp(
fn,
(params,),
(d,),
)[1],
in_dims=0,
randomness="different",
)(batch_vector)

# When using standard precision, this test has some stochastic failures (around 1/3000) that pass on rerun.
# This is probably due to floating point mismatch induced by lower precision of separate jacobian computation
# and manual matmul?
assert torch.allclose(
broadcasted_reverse_jvp_result["out1"],
vmapped_forward_jvp_result["out1"],
atol=1e-5,
)
assert torch.allclose(
broadcasted_reverse_jvp_result["out2"],
vmapped_forward_jvp_result["out2"],
atol=1e-5,
)


def test_memory_pytree_generalized_manual_revjvp():
# vmap over jvp can not handle 1000 batch x 1000 params (10s of gigabytes used).
batch_shape = (10000,)
output_shape1 = (2,)
output_shape2 = (2,)
params_shape1 = (10000,)
params_shape2 = (10000,)
# Also works with these, but runtime is too long for CI. Runs locally at a little over 7GB.
# params_shape1 = (100000,)
# params_shape2 = (100000,)

with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU],
profile_memory=True,
with_stack=False,
) as prof:

broadcasted_reverse_jvp_result, _ = _exec_pytree_generalized_manual_revjvp(
batch_shape, output_shape1, output_shape2, params_shape1, params_shape2
)

assert broadcasted_reverse_jvp_result["out1"].shape == batch_shape + output_shape1
assert broadcasted_reverse_jvp_result["out2"].shape == batch_shape + output_shape2

assert not torch.isnan(broadcasted_reverse_jvp_result["out1"]).any()
assert not torch.isnan(broadcasted_reverse_jvp_result["out2"]).any()

# Summing up the self CPU memory usage
total_memory_allocated = sum(
[item.self_cpu_memory_usage for item in prof.key_averages()]
)
total_gb_allocated = total_memory_allocated / (1024**3)

# Locally, this runs at slightly over 1.0 GB.
assert (
total_gb_allocated < 3.0
), f"Memory usage was {total_gb_allocated} GB, which is too high."

0 comments on commit be862d1

Please sign in to comment.