Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement generic O(2) population density support #81

Open
maedoc opened this issue Jan 30, 2025 · 0 comments
Open

Implement generic O(2) population density support #81

maedoc opened this issue Jan 30, 2025 · 0 comments

Comments

@maedoc
Copy link
Member

maedoc commented Jan 30, 2025

From

it seems straightforward to use the auto-diff in Jax to implement 2nd order ODEs for some arbitrary Langevin w/ dX_i = f_i(X)dt + σ_i dW_i, D_i = σ_i²,

  • dμ_i/dt = f_i(μ) + (1/2)∑_j,k (∂²f_i/∂x_j∂x_k)|_μ C_jk
  • dC_ij/dt = ∑_k [(∂f_i/∂x_k)|_μ C_kj + (∂f_j/∂x_k)|_μ C_ik] + D_iδ_ij

with a quick guess at initial implementation,

import jax
import jax.numpy as jnp
from functools import partial

def make_moment_odes(f, D):  
    def moment_odes(state):

        # TODO handle (u,C) directly as pytree instead of flattened vector
        n = len(D)
        μ = state[:n]
        C = state[n:].reshape(n, n)
        
        # Compute Jacobian and Hessian at current means
        J = jax.jacfwd(f)(μ)
        H = jax.hessian(f)(μ)
        
        # Compute mean derivatives
        # dμ_i/dt = f_i(μ) + (1/2)∑_j,k (∂²f_i/∂x_j∂x_k)|_μ C_jk
         = f(μ) + 0.5 * jnp.einsum('ijk,jk->i', H, C)
        
        # Compute covariance derivatives
        # dC_ij/dt = ∑_k [(∂f_i/∂x_k)|_μ C_kj + (∂f_j/∂x_k)|_μ C_ik] + D_iδ_ij
        dC = jnp.einsum('ik,kj->ij', J, C) + jnp.einsum('jk,ki->ij', J, C) + jnp.diag(D)
        
        # Return concatenated derivatives
        return jnp.concatenate([, dC.flatten()])
    
    return moment_odes

NB both math & code were drafted by Claude and require manual & numerical verification

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant