-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add vanilla HMC method #75
base: main
Are you sure you want to change the base?
Conversation
momentum, _ = jax.flatten_util.ravel_pytree(momentum) | ||
kinetic = 0.5 * jnp.dot(momentum, momentum) | ||
hamiltonian = kinetic + state.log_prob | ||
accept_prob = jnp.minimum(1.0, jnp.exp(hamiltonian - state.hamiltonian)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed, you can avoid the minimum and the exponential here. You can define
log_accept_ratio = hamiltonian - state.hamiltonian
See later for the accept/reject part.
return revert_updates, state.params, state.hamiltonian | ||
|
||
updates, new_params, new_hamiltonian = jax.lax.cond( | ||
jax.random.uniform(uniform_key) < accept_prob, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following the comment above, this line should become
jnp.log(jax.random.uniform(uniform_key)) < log_accept_ratio
.
This is equivalent to what you have written but with one operation less. Alternatively, notice that -log(U) ~ Exponential(1))
if U~Uniform(0, 1)
. This means that you can also write
-jax.random.exponential(uniform_key)) < log_accept_ratio
.
All of these should be equivalent. Please check that the lines I wrote are correct :-)
""" | ||
|
||
encoded_name: jnp.ndarray = convert_string_to_jnp_array("HMCState") | ||
_encoded_which_params: Optional[Dict[str, List[Array]]] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was expecting to see the stored _hamiltonian
here too?
**kwargs, | ||
) | ||
state = state.replace( | ||
opt_state=state.opt_state._replace(log_prob=aux["loss"]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should opt_state
be added to the parameters of HMCState
?
Add full-batch Hamiltonian Monte Carlo implementation.
Pull request type
Please check the type of change your PR introduces: