diff --git a/pyro/params/param_store.py b/pyro/params/param_store.py index 62e10fdb08..b4771e369d 100644 --- a/pyro/params/param_store.py +++ b/pyro/params/param_store.py @@ -229,6 +229,7 @@ def get_param( init_tensor: Optional[torch.Tensor] = None, constraint: constraints.Constraint = constraints.real, event_dim: Optional[int] = None, + parametrization: Optional[str] = None, ) -> torch.Tensor: """ Get parameter from its name. If it does not yet exist in the @@ -246,9 +247,19 @@ def get_param( :rtype: torch.Tensor """ if init_tensor is None: - return self[name] + param = self[name] else: - return self.setdefault(name, init_tensor, constraint) + param = self.setdefault(name, init_tensor, constraint) + # Apply parametrization if requested + if parametrization == "orthogonal": + import torch.nn.utils.parametrizations as parametrizations + + if ( + not hasattr(param, "parametrizations") + or "orthogonal" not in param.parametrizations + ): + param = parametrizations.orthogonal(param) + return param def match(self, name: str) -> Dict[str, torch.Tensor]: """ diff --git a/pyro/primitives.py b/pyro/primitives.py index 6ed6862c3c..6e514418d5 100644 --- a/pyro/primitives.py +++ b/pyro/primitives.py @@ -59,6 +59,7 @@ def param( init_tensor: Union[torch.Tensor, Callable[[], torch.Tensor], None] = None, constraint: constraints.Constraint = constraints.real, event_dim: Optional[int] = None, + parametrization: Optional[str] = None, ) -> torch.Tensor: """ Saves the variable as a parameter in the param store. @@ -86,7 +87,13 @@ def param( """ # Note effectful(-) requires the double passing of name below. args = (name,) if init_tensor is None else (name, init_tensor) - value = _param(*args, constraint=constraint, event_dim=event_dim, name=name) + value = _param( + *args, + constraint=constraint, + event_dim=event_dim, + name=name, + parametrization=parametrization, + ) assert value is not None # type narrowing guaranteed by _param return value diff --git a/tests/params/test_param.py b/tests/params/test_param.py index a5bfc6c494..bbdf2c02c9 100644 --- a/tests/params/test_param.py +++ b/tests/params/test_param.py @@ -243,3 +243,45 @@ def check_constraint(name): assert_equal(pyro.param("z"), z0) check_constraint("x0") check_constraint("z0") + + +def test_get_param_behaviour(): + """Tests for ParamStoreDict.get_param: missing/no-init raises, init creates param, parametrization accepted. + + This covers the following behaviors: + - calling get_param without init on a missing name raises KeyError + - calling get_param with an init tensor registers the parameter and returns the constrained value + - requesting parametrization 'orthogonal' returns a tensor of the requested shape and grad enabled + """ + param_store = pyro.get_param_store() + param_store.clear() + + # missing without init should raise + raised = False + try: + param_store.get_param("missing_without_init") + except KeyError: + raised = True + assert raised + + # with init and a positive constraint: param is created and returns the constrained value + init = 2.0 * torch.ones(2, 3) + p = param_store.get_param( + "p_with_init", init_tensor=init, constraint=constraints.positive + ) + assert "p_with_init" in param_store + # returned constrained value should equal the init + assert_equal(p, init) + # the stored unconstrained value should equal log(init) for positive constraint + stored_unconstrained = param_store._params["p_with_init"] + expected_unconstrained = torch.log(init) + assert torch.allclose(stored_unconstrained, expected_unconstrained) + + # requesting an orthogonal parametrization should not error and should return a tensor + param_store.clear() + p2_init = torch.randn(4, 4) + p2 = param_store.get_param("p2", init_tensor=p2_init, parametrization="orthogonal") + assert isinstance(p2, torch.Tensor) + assert p2.shape == p2_init.shape + # parametrized tensors should still require grad + assert p2.requires_grad