Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Option to adjust the uniform distribution bounds for jitter/initialisation #7555

Open
aphc14 opened this issue Nov 3, 2024 · 3 comments

Comments

@aphc14
Copy link

aphc14 commented Nov 3, 2024

Before

def make_initial_point_fn(
    *,
    free_rvs: Sequence[TensorVariable],
    rvs_to_transforms: dict[TensorVariable, Transform],
    initval_strategies: dict[TensorVariable, np.ndarray | Variable | str | None],
    jitter_rvs: set[TensorVariable] | None = None,
    default_strategy: str = "support_point",
    return_transformed: bool = False,
) -> list[TensorVariable]:
    
    # ... existing code ...

        if variable in jitter_rvs:
            jitter = pt.random.uniform(-1, 1, size=value.shape)
            jitter.name = f"{variable.name}_jitter"
            value = value + jitter

After

def make_initial_point_fn(
    *,
    model,
    overrides: StartDict | None = None,
    jitter_rvs: set[TensorVariable] | None = None,
    jitter_bounds: tuple[float, float] = (-1, 1), # <---- new
    default_strategy: str = "support_point",
    return_transformed: bool = True,
) -> Callable:



    # ... existing code ...
    initial_values = make_initial_point_expression(
        free_rvs=model.free_RVs,
        rvs_to_transforms=model.rvs_to_transforms,
        initval_strategies=initval_strats,
        jitter_rvs=jitter_rvs,
        jitter_bounds=jitter_bounds, # <---- new
        default_strategy=default_strategy,
        return_transformed=return_transformed,
    )
    # ... rest of existing code ...


def make_initial_point_fns_per_chain(
    *,
    model,
    overrides: StartDict | Sequence[StartDict | None] | None,
    jitter_rvs: set[TensorVariable] | None = None,
    jitter_bounds: tuple[float, float] = (-1, 1),    # <---- new
    chains: int,
) -> list[Callable]:


    if isinstance(overrides, dict) or overrides is None:
        ipfns = [
            make_initial_point_fn(
                model=model,
                overrides=overrides,
                jitter_rvs=jitter_rvs,
                jitter_bounds=jitter_bounds,    # <---- new
                return_transformed=True,
            )
        ] * chains
    elif len(overrides) == chains:
        ipfns = [
            make_initial_point_fn(
                model=model,
                jitter_rvs=jitter_rvs,
                jitter_bounds=jitter_bounds,     # <---- new
                overrides=chain_overrides,
                return_transformed=True,
            )
            for chain_overrides in overrides
        ]



def make_initial_point_expression(
    *,
    free_rvs: Sequence[TensorVariable],
    rvs_to_transforms: dict[TensorVariable, Transform],
    initval_strategies: dict[TensorVariable, np.ndarray | Variable | str | None],
    jitter_rvs: set[TensorVariable] | None = None,
    jitter_bounds: tuple[float, float] = (-1, 1),  # <---- new
    default_strategy: str = "support_point",
    return_transformed: bool = False,
) -> list[TensorVariable]:
    
    # ... existing code ...

        if variable in jitter_rvs:
            jitter = pt.random.uniform(         
                jitter_bounds[0],             # <---- new
                jitter_bounds[1], 
                size=value.shape
            )
            jitter.name = f"{variable.name}_jitter"
            value = value + jitter

    # ... existing code ...

Context for the issue:

To assist multi-path Pathfinder in exploring complicated posteriors (i.e., multimodal, flat or saddle point regions, or posteriors with several local modes that get stuck during optimisation), each single Pathfinder needs to be initialised over a broader region. This would require random initialisation points wider than Uniform(-1, 1). Attached is an image comparing wide and broad random initialisations from Figure 11 of the paper.
Screenshot 2024-11-03 195921

Allowing the uniform distribution bounds to be input parameters would allow users of the multi-path Pathfinder algorithm to adjust the initialisations to their scenario better.

I'm happy to work on this feature :) Any suggestions on how you'd like the changes to be made?

References:
Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1–49.

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 3, 2024

If you pass the default strategy == ”prior" you'll get much more variation. Does that suffice?

Otherwise would a single jitter_scale argument suffice? Do you need to be able to specify both bounds? And will you need to vary this per variable?

@aphc14
Copy link
Author

aphc14 commented Nov 4, 2024

Depending on the scenario, setting the default strategy == "prior" gives too much variation. Having jitter ~ U(-2, -2) may be too much variation, and for other problems, jitter ~ U(-20, 20) would be better. The option to set the jitter would be useful.

What's demonstrated in the Pathfinder paper is for jitter to have a single bound and its the same bound for all variables. For now single bound and for all variables should be okay.

I have implemented a workaround in this commit pymc-devs/pymc-extras#386 which has the initialisation needed for pathfinder:

def make_initial_pathfinder_point(
    model,
    jitter: float = 2.0,
    random_seed: RandomSeed | None = None,
) -> DictToArrayBijection:
    """
    create jittered initial point for pathfinder

    Parameters
    ----------
    model : Model
        pymc model
    jitter : float
        initial values in the unconstrained space are jittered by the uniform distribution, U(-jitter, jitter). Set jitter to 0 for no jitter.
    random_seed : RandomSeed | None
        random seed for reproducibility

    Returns
    -------
    DictToArrayBijection
        bijection containing jittered initial point
    """
    ipfn = make_initial_point_fn(
        model=model,
    )
    ip = Point(ipfn(random_seed), model=model)
    ip_map = DictToArrayBijection.map(ip)

    rng = np.random.default_rng(random_seed)
    jitter_value = rng.uniform(-jitter, jitter, size=ip_map.data.shape)
    ip_map = ip_map._replace(data=ip_map.data + jitter_value)
    return ip_map

Happy for this request to now be closed or remain open if the jitter~U(-n, n) option should be available in the pymc library as well.

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 4, 2024

Single jitter for all variables sounds fine, feel free to open a PR. Perhaps call it jitter_scale?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants