Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mixture-of-gammas global prior, updated by EM #349

Merged
merged 4 commits into from
Jan 6, 2024

Conversation

nspope
Copy link
Contributor

@nspope nspope commented Dec 11, 2023

  1. To use the coalescent prior to initialize a mixture with K components, pass prior_mixture_dim=K
  2. To disable the EM updates of the mixture components, pass em_iterations=0
  3. The "original" global prior for the variational gamma method can be used via prior_mixture_dim=1, em_iterations=0
  4. The default is to optimize a single component prior (prior_mixture_dim=1, em_iterations=10)

@nspope
Copy link
Contributor Author

nspope commented Dec 11, 2023

With and without EM-updated prior for equillibrium demography:

@hyanwong
Copy link
Member

Wow, that's great. What other issues is this linked to? We still have to figure out if there are better ways to make the prior bespoke for each node, right (i.e. #292)?

Do we need to decide on an API yet? Or should we keep this undocumented? It would be good to release 0.1.6 before Xmas if the changes in #346 are all OK.

@hyanwong
Copy link
Member

hyanwong commented Dec 11, 2023

Currently getting this:

---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
Cell In[5], line 1
----> 1 import tsdate
      2 i_ts = tsinfer.infer(tsinfer.SampleData.from_tree_sequence(sim_ts))
      3 s_ts = i_ts.simplify()

File ~/mambaforge/lib/python3.10/site-packages/tsdate/__init__.py:23
      1 # MIT License
      2 #
      3 # Copyright (c) 2020 University of Oxford
   (...)
     20 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
     21 # SOFTWARE.
     22 from .cache import *  # NOQA: F401,F403
---> 23 from .core import date  # NOQA: F401
     24 from .core import get_dates  # NOQA: F401
     25 from .prior import build_grid as build_prior_grid  # NOQA: F401

File ~/mambaforge/lib/python3.10/site-packages/tsdate/core.py:924
    917             maximized_node_times[child] = np.argmax(
    918                 self.lik.combine(result[: youngest_par_index + 1], inside_val)
    919             )
    921         return self.lik.timepoints[np.array(maximized_node_times).astype("int")]
--> 924 class ExpectationPropagation(InOutAlgorithms):
    925     r"""
    926     Expectation propagation (EP) algorithm to infer approximate marginal
    927     distributions for node ages.
   (...)
    954     Bayesian Inference"
    955     """
    957     def __init__(self, *args, global_prior, **kwargs):

File ~/mambaforge/lib/python3.10/site-packages/tsdate/core.py:1119, in ExpectationPropagation()
   1113         scale[c] *= child_eta
   1115     return 0.0  # TODO, placeholder
   1117 @staticmethod
   1118 @numba.njit("f8(i4[:], f8[:, :], f8[:, :], f8[:, :], f8[:], f8, i4, f8)")
-> 1119 def propagate_prior(
   1120     nodes, global_prior, posterior, messages, scale, max_shape, em_maxitt, em_reltol
   1121 ):
   1122     """TODO
   1123 
   1124     :param ndarray nodes: ids of nodes that should be updated
   (...)
   1137         log-likelihood
   1138     """
   1140     if global_prior.shape[0] == 0:

File ~/mambaforge/lib/python3.10/site-packages/numba/core/decorators.py:219, in _jit.<locals>.wrapper(func)
    217     with typeinfer.register_dispatcher(disp):
    218         for sig in sigs:
--> 219             disp.compile(sig)
    220         disp.disable_compile()
    221 return disp

File ~/mambaforge/lib/python3.10/site-packages/numba/core/dispatcher.py:965, in Dispatcher.compile(self, sig)
    963 with ev.trigger_event("numba:compile", data=ev_details):
    964     try:
--> 965         cres = self._compiler.compile(args, return_type)
    966     except errors.ForceLiteralArg as e:
    967         def folded(args, kws):

File ~/mambaforge/lib/python3.10/site-packages/numba/core/dispatcher.py:129, in _FunctionCompiler.compile(self, args, return_type)
    127     return retval
    128 else:
--> 129     raise retval

File ~/mambaforge/lib/python3.10/site-packages/numba/core/dispatcher.py:139, in _FunctionCompiler._compile_cached(self, args, return_type)
    136     pass
    138 try:
--> 139     retval = self._compile_core(args, return_type)
    140 except errors.TypingError as e:
    141     self._failed_cache[key] = e

File ~/mambaforge/lib/python3.10/site-packages/numba/core/dispatcher.py:152, in _FunctionCompiler._compile_core(self, args, return_type)
    149 flags = self._customize_flags(flags)
    151 impl = self._get_implementation(args, {})
--> 152 cres = compiler.compile_extra(self.targetdescr.typing_context,
    153                               self.targetdescr.target_context,
    154                               impl,
    155                               args=args, return_type=return_type,
    156                               flags=flags, locals=self.locals,
    157                               pipeline_class=self.pipeline_class)
    158 # Check typing error if object mode is used
    159 if cres.typing_error is not None and not flags.enable_pyobject:

File ~/mambaforge/lib/python3.10/site-packages/numba/core/compiler.py:716, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    692 """Compiler entry point
    693 
    694 Parameter
   (...)
    712     compiler pipeline
    713 """
    714 pipeline = pipeline_class(typingctx, targetctx, library,
    715                           args, return_type, flags, locals)
--> 716 return pipeline.compile_extra(func)

File ~/mambaforge/lib/python3.10/site-packages/numba/core/compiler.py:452, in CompilerBase.compile_extra(self, func)
    450 self.state.lifted = ()
    451 self.state.lifted_from = None
--> 452 return self._compile_bytecode()

File ~/mambaforge/lib/python3.10/site-packages/numba/core/compiler.py:520, in CompilerBase._compile_bytecode(self)
    516 """
    517 Populate and run pipeline for bytecode input
    518 """
    519 assert self.state.func_ir is None
--> 520 return self._compile_core()

File ~/mambaforge/lib/python3.10/site-packages/numba/core/compiler.py:499, in CompilerBase._compile_core(self)
    497         self.state.status.fail_reason = e
    498         if is_final_pipeline:
--> 499             raise e
    500 else:
    501     raise CompilerError("All available pipelines exhausted")

File ~/mambaforge/lib/python3.10/site-packages/numba/core/compiler.py:486, in CompilerBase._compile_core(self)
    484 res = None
    485 try:
--> 486     pm.run(self.state)
    487     if self.state.cr is not None:
    488         break

File ~/mambaforge/lib/python3.10/site-packages/numba/core/compiler_machinery.py:368, in PassManager.run(self, state)
    365 msg = "Failed in %s mode pipeline (step: %s)" % \
    366     (self.pipeline_name, pass_desc)
    367 patched_exception = self._patch_error(msg, e)
--> 368 raise patched_exception

File ~/mambaforge/lib/python3.10/site-packages/numba/core/compiler_machinery.py:356, in PassManager.run(self, state)
    354 pass_inst = _pass_registry.get(pss).pass_inst
    355 if isinstance(pass_inst, CompilerPass):
--> 356     self._runPass(idx, pass_inst, state)
    357 else:
    358     raise BaseException("Legacy pass in use")

File ~/mambaforge/lib/python3.10/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File ~/mambaforge/lib/python3.10/site-packages/numba/core/compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
    309     mutated |= check(pss.run_initialization, internal_state)
    310 with SimpleTimer() as pass_time:
--> 311     mutated |= check(pss.run_pass, internal_state)
    312 with SimpleTimer() as finalize_time:
    313     mutated |= check(pss.run_finalizer, internal_state)

File ~/mambaforge/lib/python3.10/site-packages/numba/core/compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
    272 def check(func, compiler_state):
--> 273     mangled = func(compiler_state)
    274     if mangled not in (True, False):
    275         msg = ("CompilerPass implementations should return True/False. "
    276                "CompilerPass with name '%s' did not.")

File ~/mambaforge/lib/python3.10/site-packages/numba/core/typed_passes.py:105, in BaseTypeInference.run_pass(self, state)
     99 """
    100 Type inference and legalization
    101 """
    102 with fallback_context(state, 'Function "%s" failed type inference'
    103                       % (state.func_id.func_name,)):
    104     # Type inference
--> 105     typemap, return_type, calltypes, errs = type_inference_stage(
    106         state.typingctx,
    107         state.targetctx,
    108         state.func_ir,
    109         state.args,
    110         state.return_type,
    111         state.locals,
    112         raise_errors=self._raise_errors)
    113     state.typemap = typemap
    114     # save errors in case of partial typing

File ~/mambaforge/lib/python3.10/site-packages/numba/core/typed_passes.py:83, in type_inference_stage(typingctx, targetctx, interp, args, return_type, locals, raise_errors)
     81     infer.build_constraint()
     82     # return errors in case of partial typing
---> 83     errs = infer.propagate(raise_errors=raise_errors)
     84     typemap, restype, calltypes = infer.unify(raise_errors=raise_errors)
     86 # Output all Numba warnings

File ~/mambaforge/lib/python3.10/site-packages/numba/core/typeinfer.py:1086, in TypeInferer.propagate(self, raise_errors)
   1083 force_lit_args = [e for e in errors
   1084                   if isinstance(e, ForceLiteralArg)]
   1085 if not force_lit_args:
-> 1086     raise errors[0]
   1087 else:
   1088     raise reduce(operator.or_, force_lit_args)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function getitem>) found for signature:
 
 >>> getitem(array(float64, 1d, A), Tuple(array(int32, 1d, A), none))
 
There are 22 candidate implementations:
      - Of which 20 did not match due to:
      Overload of function 'getitem': File: <numerous>: Line N/A.
        With argument(s): '(array(float64, 1d, A), Tuple(array(int32, 1d, A), none))':
       No match.
      - Of which 2 did not match due to:
      Overload in function 'GetItemBuffer.generic': File: numba/core/typing/arraydecl.py: Line 166.
        With argument(s): '(array(float64, 1d, A), Tuple(array(int32, 1d, A), none))':
       Rejected as the implementation raised a specific error:
         NumbaTypeError: unsupported array index type none in Tuple(array(int32, 1d, A), none)
  raised from /Users/yan/mambaforge/lib/python3.10/site-packages/numba/core/typing/arraydecl.py:72

During: typing of intrinsic-call at /Users/yan/mambaforge/lib/python3.10/site-packages/tsdate/core.py (1159)

File "../../../../mambaforge/lib/python3.10/site-packages/tsdate/core.py", line 1159:
        def posterior_damping(x):
            <source elided>
        cavity = np.zeros(posterior.shape)
        cavity[nodes] = posterior[nodes] - messages[nodes] * scale[nodes, np.newaxis]
        ^

Maybe I need a more recent numba version (in which case we should specify a minimum version)

Edit - confirmed that upgrading to numba 0.58.1 fixes it.

@hyanwong
Copy link
Member

hyanwong commented Dec 11, 2023

To use the coalescent prior to initialize a mixture with K components, pass global_prior=K

This doesn't work, because of assuming the form of global prior when storing provenance. Using e.g. global_prior=2, I get "TypeError: 'int' object is not subscriptable" -> 1458 "weight": list(global_prior[:, 0]),

I presume we either store the value 2 in the metadata, or calculate the global prior array and store that?

@nspope
Copy link
Contributor Author

nspope commented Dec 11, 2023

Fixed -- so numba 0.58.1 as a minimum version?

@hyanwong
Copy link
Member

so numba 0.58.1 as a minimum version?

I guess so. Can't hurt, right?

@nspope
Copy link
Contributor Author

nspope commented Dec 11, 2023

I just did this in the mom-fixup branch. Yea, there's enough code movement in numba that pinning the version is a good idea regardless

@jeromekelleher
Copy link
Member

Pinning to recent numba is fine - I think in practise you end up with needing the latest version pretty quickly anyway, what with numpy updates.

@nspope nspope force-pushed the mix-prior branch 2 times, most recently from 2aaf053 to ced5d8b Compare December 29, 2023 19:40
@nspope nspope marked this pull request as ready for review December 29, 2023 19:49
@nspope
Copy link
Contributor Author

nspope commented Dec 29, 2023

This is ready ... I'll add support for user-specified priors in a later PR.

I've checked that it works with a large (2k sample) mosquito tree sequence. It'd be great to make sure this runs through with global_prior set to 1 and then set to 5 on GEL or UKB, when you have the time, @hyanwong.

@nspope nspope mentioned this pull request Jan 4, 2024
Copy link
Member

@hyanwong hyanwong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems good to me. I have made a version that works with the latest main branch (new API version) in main...hyanwong:tsdate:mix-prior-merge - feel free to steal stuff from that branch. I have set most of the defaults to None, then overridden then with a sensible default, which makes it easier to change the defaults later without causing an API change.

tsdate/core.py Outdated Show resolved Hide resolved
tsdate/core.py Outdated Show resolved Hide resolved
@hyanwong
Copy link
Member

hyanwong commented Jan 5, 2024

It'd be great to make sure this runs through with global_prior set to 1 and then set to 5 on GEL or UKB, when you have the time, @hyanwong.

I just ran it on the GEL dataset and it went through without problems with global_prior set to 1 and then 5. Can't vouch for the accuracy, of course!

nspope added 3 commits January 5, 2024 13:19
More streamlined numerical checks

Initialize gamma mixture from conditional coalescent prior

Add pdf

Update mixture.py to use natural parameterization

WIP

Moved fully into numba

Cleanup

Cleanup

More debugging

WIP

Working

wording

Add missing constant to loglikelihood

Skip prior update completely instead of components

Skip prior update completely instead of components

Remove verbose; use logweights in conditional posterior

Move mixture initialization to function

Docstrings and CLI

Remove some debugging inserts

Remove preemptive reference

Fix tests
tsdate/core.py Outdated
@@ -1682,7 +1790,8 @@ def variational_gamma(
max_iterations=None,
max_shape=None,
match_central_moments=None,
global_prior=True,
prior_mixture_dim=1,
em_iterations=10,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think prior_mixture_dim and em_iteration defaults need to be None (which then get overridden below)

@@ -954,32 +955,37 @@ class ExpectationPropagation(InOutAlgorithms):
Bayesian Inference"
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, global_prior, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I quite understand how the global prior is used differently from self.prior here. Is it worth adding a docstring to explain?

Copy link
Member

@hyanwong hyanwong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, module two minor comments. I think I'm a bit confused about how the the global_prior differs from the "normal" self.prior, so maybe worth a minor bit of documentation?

self.priors.grid_data[:, 0], self.priors.grid_data[:, 1]
)

self.global_prior = mixture.initialize_mixture(
Copy link
Member

@hyanwong hyanwong Jan 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a comment here about the difference between self.global_prior and self.priors would be helpful? They sound like the same sort of thing.

@nspope
Copy link
Contributor Author

nspope commented Jan 6, 2024

done -- I think this is good to merge, but it'd be nice to check a few examples before putting a release out. I can do that tomorrow.

@hyanwong hyanwong merged commit a388028 into tskit-dev:main Jan 6, 2024
3 checks passed
@hyanwong
Copy link
Member

hyanwong commented Jan 6, 2024

Great, thanks. Merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants