Skip to content

Commit b02613c

Browse files
committed
Code for pre-print
1 parent a855604 commit b02613c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+20470
-2
lines changed

.DS_Store

10 KB
Binary file not shown.

.gitignore.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.egg-info

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@
186186
same "printed page" as the copyright notice for easier
187187
identification within third-party archives.
188188

189-
Copyright [yyyy] [name of copyright owner]
189+
Copyright [2021] Nicholas Phillips
190190

191191
Licensed under the Apache License, Version 2.0 (the "License");
192192
you may not use this file except in compliance with the License.

MSS/.DS_Store

6 KB
Binary file not shown.

MSS/__init__.py

Whitespace-only changes.

MSS/abstractmodel.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Created on Mon Nov 28 16:12:17 2022
5+
6+
@author: phillips
7+
"""
8+
9+
import os
10+
import tensorflow as tf
11+
import numpy as np
12+
import tensorflow_probability as tfp
13+
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
14+
15+
class MSSModel:
16+
"""
17+
This is the super class of Model1, Model2 and Model3. It contains methods
18+
for parameter optimisation and parameter sampling with MCMC that are used
19+
by all models. Not to be instantiated.
20+
"""
21+
22+
def find_MAP(self,
23+
loss_and_grads,
24+
N_starts=500,
25+
seed=123,
26+
max_iterations=1000,
27+
max_line_search_iterations=500,
28+
initial_inverse_hessian_scale = 1e-4,
29+
print_on=True):
30+
""" Function to find the MAP parameter point estimate of the model
31+
using the BFGS algorithm.
32+
Uses multiple restarts using samples from the prior distribution as
33+
initial starting points.
34+
35+
:param loss_and_grads: callable that accepts a point as a real Tensor
36+
and returns a tuple of Tensors containing the value
37+
of the MAP function and its gradient at that point
38+
:param N_starts: the number of random initialisations for optimisation
39+
:param seed: the seed of the optimiser
40+
:param max_iterations: The maximum number of iterations for BFGS updates.
41+
:param max_line_search_iterations: The maximum number of iterations
42+
for the line search algorithm.
43+
:param initial_inverse_hessian_scale: the starting estimate for the
44+
inverse of the Hessian at the initial point
45+
:param print_on: prints the current iteration number in terms of the
46+
number of initialisations from prior samples
47+
48+
"""
49+
tf.random.set_seed(seed)
50+
MAP_list, param_list = [[],[]]
51+
52+
for i in range(N_starts):
53+
if print_on:
54+
print(i)
55+
start = self.sample_prior()
56+
optim_results = tfp.optimizer.bfgs_minimize(
57+
loss_and_grads,
58+
initial_position=start,
59+
max_iterations=max_iterations,
60+
max_line_search_iterations=max_line_search_iterations,
61+
initial_inverse_hessian_estimate = initial_inverse_hessian_scale*tf.eye(len(start))
62+
)
63+
MAP_list.append(optim_results.objective_value)
64+
param_list.append(optim_results.position)
65+
66+
self.MAP_list = MAP_list
67+
self.param_list = param_list
68+
self.MAP_opted = MAP_list[np.nanargmin(MAP_list)]
69+
self.params_opted_map = param_list[np.nanargmin(MAP_list)]
70+
71+
@tf.function(jit_compile=True)
72+
def run_chain(self,
73+
initial_state,
74+
num_results,
75+
num_burnin_steps,
76+
kernel,
77+
seed = [1,1]):
78+
""" Implements Markov chain Monte Carlo (MCMC), which is used to
79+
sample the posterior parameter distribution.
80+
81+
:param initial_state: the current position of the Markov chain
82+
:param num_results: number of Markov chain samples
83+
:param num_burnin_steps: number of chain steps to take before starting
84+
to collect results
85+
:param kernel: An instance of tfp.mcmc.TransitionKernel that
86+
implements one step of the Markov chain
87+
:param seed: seed for sampler
88+
"""
89+
return tfp.mcmc.sample_chain(
90+
num_results=num_results,
91+
num_burnin_steps=num_burnin_steps,
92+
current_state=initial_state,
93+
kernel=kernel,
94+
seed = seed,
95+
trace_fn=lambda current_state, kernel_results: kernel_results)
96+
97+
def first_MCMC_run(self,
98+
initial_state,
99+
step_size = 3e-3,
100+
num_burnin_steps = 10000,
101+
num_results = 10000):
102+
""" Inital an initial MCMC run to estimate the standard deviation of
103+
the posterior parameter distribution. The standard deviation is
104+
then used to tune the step size of the full MCMC run.
105+
106+
:param initial_state: the current position of the Markov chain
107+
:param step_size: the step size for the leapfrog integrator
108+
:param num_burnin_steps: number of chain steps to take before starting
109+
to collect results
110+
:param num_results: number of Markov chain samples
111+
"""
112+
kernel = tfp.mcmc.HamiltonianMonteCarlo(
113+
target_log_prob_fn=self.unnormalized_posterior,
114+
num_leapfrog_steps=5,
115+
step_size=step_size)
116+
117+
kernel = tfp.mcmc.SimpleStepSizeAdaptation(
118+
inner_kernel=kernel, num_adaptation_steps=int(num_burnin_steps * 0.8))
119+
120+
samples, kernel_results = self.run_chain(initial_state=initial_state,
121+
num_results=num_results,
122+
num_burnin_steps=num_burnin_steps,
123+
kernel=kernel)
124+
125+
approx_posterior_std = np.std(samples,axis=0)
126+
approx_posterior_std = approx_posterior_std/np.max(approx_posterior_std)
127+
self.step_size_optimised = approx_posterior_std*1e-1
128+
129+
def sample_posterior(self,
130+
initial_state,
131+
step_size,
132+
num_chains = 4,
133+
num_burnin_steps = 10000,
134+
num_results = 10000):
135+
""" Full HMC to sample from the posterior parameter distribution.
136+
The MCMC samples are saved as a list, which is then used for
137+
the main results. MCMC diagnostics are also saved, including the
138+
acceptance_rate, the potential scale reduction (rhat)
139+
and the log probability
140+
141+
:param initial_state: the current position of the Markov chain
142+
:param step_size: the step size for the leapfrog integrator
143+
:param num_chains: the number of independent MCMC chains
144+
:param num_burnin_steps: number of chain steps to take before starting
145+
to collect results
146+
:param num_results: number of Markov chain samples
147+
148+
"""
149+
kernel = tfp.mcmc.HamiltonianMonteCarlo(
150+
target_log_prob_fn=self.unnormalized_posterior,
151+
num_leapfrog_steps=5,
152+
step_size=step_size)
153+
154+
kernel = tfp.mcmc.SimpleStepSizeAdaptation(
155+
inner_kernel=kernel, num_adaptation_steps=int(num_burnin_steps * 0.8))
156+
157+
samples_list = []
158+
kernel_results_list = []
159+
160+
for i in range(num_chains):
161+
print(i)
162+
samples, kernel_results = self.run_chain(initial_state=initial_state,
163+
num_results=num_results,
164+
num_burnin_steps=num_burnin_steps,
165+
kernel=kernel,
166+
seed = [i,i])
167+
samples_list.append(samples)
168+
kernel_results_list.append(kernel_results)
169+
170+
acceptance_rate = [kernel_results_list[chain].inner_results.is_accepted.numpy().mean() for chain in range(len(kernel_results_list))]
171+
172+
chain_states = tf.stack(samples_list,1)
173+
rhat = tfp.mcmc.potential_scale_reduction(chain_states,
174+
independent_chain_ndims=1)
175+
176+
target_log_prob = [kernel_results_list[chain].inner_results.accepted_results.target_log_prob for chain in range(len(kernel_results_list))]
177+
178+
mcmc_diagnostics = {'acceptance_rate': acceptance_rate,
179+
'rhat' : rhat,
180+
'target_log_prob' : target_log_prob}
181+
182+
self.samples_list = samples_list
183+
self.mcmc_diagnostics = mcmc_diagnostics
184+
185+

MSS/config.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Created on Tue Nov 29 10:43:49 2022
5+
6+
@author: phillips
7+
"""
8+
import matplotlib.pyplot as plt
9+
10+
COLOR_GLUC = plt.cm.tab20(0)
11+
COLOR_HR = plt.cm.tab20(6)
12+
COLOR_HRV = plt.cm.tab20(8)
13+
COLOR_ACTI = plt.cm.tab20(4)
14+
COLOR_CIRC = 'k'
15+
16+
COLOR_PRED1 = plt.cm.tab20c(4)
17+
COLOR_PRED2 = plt.cm.tab20c(5)
18+
COLOR_PRED3 = plt.cm.tab20c(6)
19+

0 commit comments

Comments
 (0)