Skip to content

Commit

Permalink
Add Exponential distribution to model/transforms/autoreparam.py (#365)
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine authored Jul 19, 2024
1 parent 99170df commit d50742d
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 17 deletions.
38 changes: 38 additions & 0 deletions pymc_experimental/model/transforms/autoreparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,44 @@ def _(
return vip_rep


@_vip_reparam_node.register
def _(
op: pm.Exponential,
node: Apply,
name: str,
dims: List[Variable],
transform: Optional[Transform],
lam: pt.TensorVariable,
) -> ModelDeterministic:
rng, size, scale = node.inputs
scale_centered = scale**lam
scale_noncentered = scale ** (1 - lam)
vip_rv_ = pm.Exponential.dist(
scale=scale_centered,
size=size,
rng=rng,
)
vip_rv_value_ = vip_rv_.clone()
vip_rv_.name = f"{name}::tau_"
if transform is not None:
vip_rv_value_.name = f"{vip_rv_.name}_{transform.name}__"
else:
vip_rv_value_.name = vip_rv_.name
vip_rv = model_free_rv(
vip_rv_,
vip_rv_value_,
transform,
*dims,
)

vip_rep_ = scale_noncentered * vip_rv

vip_rep_.name = name

vip_rep = model_deterministic(vip_rep_, *dims)
return vip_rep


def vip_reparametrize(
model: pm.Model,
var_names: Sequence[str],
Expand Down
40 changes: 23 additions & 17 deletions tests/model/transforms/test_autoreparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def model_c():
m = pm.Normal("m")
s = pm.LogNormal("s")
pm.Normal("g", m, s, shape=5)
pm.Exponential("e", scale=s, shape=7)
return mod


Expand All @@ -20,31 +21,34 @@ def model_nc():
m = pm.Normal("m")
s = pm.LogNormal("s")
pm.Deterministic("g", pm.Normal("z", shape=5) * s + m)
pm.Deterministic("e", pm.Exponential("z_e", 1, shape=7) * s)
return mod


def test_reparametrize_created(model_c: pm.Model):
model_reparam, vip = vip_reparametrize(model_c, ["g"])
assert "g" in vip.get_lambda()
assert "g::lam_logit__" in model_reparam.named_vars
assert "g::tau_" in model_reparam.named_vars
@pytest.mark.parametrize("var", ["g", "e"])
def test_reparametrize_created(model_c: pm.Model, var):
model_reparam, vip = vip_reparametrize(model_c, [var])
assert f"{var}" in vip.get_lambda()
assert f"{var}::lam_logit__" in model_reparam.named_vars
assert f"{var}::tau_" in model_reparam.named_vars
vip.set_all_lambda(1)
assert ~np.isfinite(model_reparam["g::lam_logit__"].get_value()).any()
assert ~np.isfinite(model_reparam[f"{var}::lam_logit__"].get_value()).any()


def test_random_draw(model_c: pm.Model, model_nc):
@pytest.mark.parametrize("var", ["g", "e"])
def test_random_draw(model_c: pm.Model, model_nc, var):
model_c = pm.do(model_c, {"m": 3, "s": 2})
model_nc = pm.do(model_nc, {"m": 3, "s": 2})
model_v, vip = vip_reparametrize(model_c, ["g"])
assert "g" in [v.name for v in model_v.deterministics]
c = pm.draw(model_c["g"], random_seed=42, draws=1000)
nc = pm.draw(model_nc["g"], random_seed=42, draws=1000)
model_v, vip = vip_reparametrize(model_c, [var])
assert var in [v.name for v in model_v.deterministics]
c = pm.draw(model_c[var], random_seed=42, draws=1000)
nc = pm.draw(model_nc[var], random_seed=42, draws=1000)
vip.set_all_lambda(1)
v_1 = pm.draw(model_v["g"], random_seed=42, draws=1000)
v_1 = pm.draw(model_v[var], random_seed=42, draws=1000)
vip.set_all_lambda(0)
v_0 = pm.draw(model_v["g"], random_seed=42, draws=1000)
v_0 = pm.draw(model_v[var], random_seed=42, draws=1000)
vip.set_all_lambda(0.5)
v_05 = pm.draw(model_v["g"], random_seed=42, draws=1000)
v_05 = pm.draw(model_v[var], random_seed=42, draws=1000)
np.testing.assert_allclose(c.mean(), nc.mean())
np.testing.assert_allclose(c.mean(), v_0.mean())
np.testing.assert_allclose(v_05.mean(), v_1.mean())
Expand All @@ -57,10 +61,12 @@ def test_random_draw(model_c: pm.Model, model_nc):


def test_reparam_fit(model_c):
model_v, vip = vip_reparametrize(model_c, ["g"])
vars = ["g", "e"]
model_v, vip = vip_reparametrize(model_c, ["g", "e"])
with model_v:
vip.fit(random_seed=42)
np.testing.assert_allclose(vip.get_lambda()["g"], 0, atol=0.01)
vip.fit(50000, random_seed=42)
for var in vars:
np.testing.assert_allclose(vip.get_lambda()[var], 0, atol=0.01)


def test_multilevel():
Expand Down

0 comments on commit d50742d

Please sign in to comment.