1+ #!/usr/bin/env python3
2+ # -*- coding: utf-8 -*-
3+ """
4+ Created on Mon Nov 28 16:12:17 2022
5+
6+ @author: phillips
7+ """
8+
9+ import os
10+ import tensorflow as tf
11+ import numpy as np
12+ import tensorflow_probability as tfp
13+ os .environ ['TF_XLA_FLAGS' ] = '--tf_xla_enable_xla_devices'
14+
15+ class MSSModel :
16+ """
17+ This is the super class of Model1, Model2 and Model3. It contains methods
18+ for parameter optimisation and parameter sampling with MCMC that are used
19+ by all models. Not to be instantiated.
20+ """
21+
22+ def find_MAP (self ,
23+ loss_and_grads ,
24+ N_starts = 500 ,
25+ seed = 123 ,
26+ max_iterations = 1000 ,
27+ max_line_search_iterations = 500 ,
28+ initial_inverse_hessian_scale = 1e-4 ,
29+ print_on = True ):
30+ """ Function to find the MAP parameter point estimate of the model
31+ using the BFGS algorithm.
32+ Uses multiple restarts using samples from the prior distribution as
33+ initial starting points.
34+
35+ :param loss_and_grads: callable that accepts a point as a real Tensor
36+ and returns a tuple of Tensors containing the value
37+ of the MAP function and its gradient at that point
38+ :param N_starts: the number of random initialisations for optimisation
39+ :param seed: the seed of the optimiser
40+ :param max_iterations: The maximum number of iterations for BFGS updates.
41+ :param max_line_search_iterations: The maximum number of iterations
42+ for the line search algorithm.
43+ :param initial_inverse_hessian_scale: the starting estimate for the
44+ inverse of the Hessian at the initial point
45+ :param print_on: prints the current iteration number in terms of the
46+ number of initialisations from prior samples
47+
48+ """
49+ tf .random .set_seed (seed )
50+ MAP_list , param_list = [[],[]]
51+
52+ for i in range (N_starts ):
53+ if print_on :
54+ print (i )
55+ start = self .sample_prior ()
56+ optim_results = tfp .optimizer .bfgs_minimize (
57+ loss_and_grads ,
58+ initial_position = start ,
59+ max_iterations = max_iterations ,
60+ max_line_search_iterations = max_line_search_iterations ,
61+ initial_inverse_hessian_estimate = initial_inverse_hessian_scale * tf .eye (len (start ))
62+ )
63+ MAP_list .append (optim_results .objective_value )
64+ param_list .append (optim_results .position )
65+
66+ self .MAP_list = MAP_list
67+ self .param_list = param_list
68+ self .MAP_opted = MAP_list [np .nanargmin (MAP_list )]
69+ self .params_opted_map = param_list [np .nanargmin (MAP_list )]
70+
71+ @tf .function (jit_compile = True )
72+ def run_chain (self ,
73+ initial_state ,
74+ num_results ,
75+ num_burnin_steps ,
76+ kernel ,
77+ seed = [1 ,1 ]):
78+ """ Implements Markov chain Monte Carlo (MCMC), which is used to
79+ sample the posterior parameter distribution.
80+
81+ :param initial_state: the current position of the Markov chain
82+ :param num_results: number of Markov chain samples
83+ :param num_burnin_steps: number of chain steps to take before starting
84+ to collect results
85+ :param kernel: An instance of tfp.mcmc.TransitionKernel that
86+ implements one step of the Markov chain
87+ :param seed: seed for sampler
88+ """
89+ return tfp .mcmc .sample_chain (
90+ num_results = num_results ,
91+ num_burnin_steps = num_burnin_steps ,
92+ current_state = initial_state ,
93+ kernel = kernel ,
94+ seed = seed ,
95+ trace_fn = lambda current_state , kernel_results : kernel_results )
96+
97+ def first_MCMC_run (self ,
98+ initial_state ,
99+ step_size = 3e-3 ,
100+ num_burnin_steps = 10000 ,
101+ num_results = 10000 ):
102+ """ Inital an initial MCMC run to estimate the standard deviation of
103+ the posterior parameter distribution. The standard deviation is
104+ then used to tune the step size of the full MCMC run.
105+
106+ :param initial_state: the current position of the Markov chain
107+ :param step_size: the step size for the leapfrog integrator
108+ :param num_burnin_steps: number of chain steps to take before starting
109+ to collect results
110+ :param num_results: number of Markov chain samples
111+ """
112+ kernel = tfp .mcmc .HamiltonianMonteCarlo (
113+ target_log_prob_fn = self .unnormalized_posterior ,
114+ num_leapfrog_steps = 5 ,
115+ step_size = step_size )
116+
117+ kernel = tfp .mcmc .SimpleStepSizeAdaptation (
118+ inner_kernel = kernel , num_adaptation_steps = int (num_burnin_steps * 0.8 ))
119+
120+ samples , kernel_results = self .run_chain (initial_state = initial_state ,
121+ num_results = num_results ,
122+ num_burnin_steps = num_burnin_steps ,
123+ kernel = kernel )
124+
125+ approx_posterior_std = np .std (samples ,axis = 0 )
126+ approx_posterior_std = approx_posterior_std / np .max (approx_posterior_std )
127+ self .step_size_optimised = approx_posterior_std * 1e-1
128+
129+ def sample_posterior (self ,
130+ initial_state ,
131+ step_size ,
132+ num_chains = 4 ,
133+ num_burnin_steps = 10000 ,
134+ num_results = 10000 ):
135+ """ Full HMC to sample from the posterior parameter distribution.
136+ The MCMC samples are saved as a list, which is then used for
137+ the main results. MCMC diagnostics are also saved, including the
138+ acceptance_rate, the potential scale reduction (rhat)
139+ and the log probability
140+
141+ :param initial_state: the current position of the Markov chain
142+ :param step_size: the step size for the leapfrog integrator
143+ :param num_chains: the number of independent MCMC chains
144+ :param num_burnin_steps: number of chain steps to take before starting
145+ to collect results
146+ :param num_results: number of Markov chain samples
147+
148+ """
149+ kernel = tfp .mcmc .HamiltonianMonteCarlo (
150+ target_log_prob_fn = self .unnormalized_posterior ,
151+ num_leapfrog_steps = 5 ,
152+ step_size = step_size )
153+
154+ kernel = tfp .mcmc .SimpleStepSizeAdaptation (
155+ inner_kernel = kernel , num_adaptation_steps = int (num_burnin_steps * 0.8 ))
156+
157+ samples_list = []
158+ kernel_results_list = []
159+
160+ for i in range (num_chains ):
161+ print (i )
162+ samples , kernel_results = self .run_chain (initial_state = initial_state ,
163+ num_results = num_results ,
164+ num_burnin_steps = num_burnin_steps ,
165+ kernel = kernel ,
166+ seed = [i ,i ])
167+ samples_list .append (samples )
168+ kernel_results_list .append (kernel_results )
169+
170+ acceptance_rate = [kernel_results_list [chain ].inner_results .is_accepted .numpy ().mean () for chain in range (len (kernel_results_list ))]
171+
172+ chain_states = tf .stack (samples_list ,1 )
173+ rhat = tfp .mcmc .potential_scale_reduction (chain_states ,
174+ independent_chain_ndims = 1 )
175+
176+ target_log_prob = [kernel_results_list [chain ].inner_results .accepted_results .target_log_prob for chain in range (len (kernel_results_list ))]
177+
178+ mcmc_diagnostics = {'acceptance_rate' : acceptance_rate ,
179+ 'rhat' : rhat ,
180+ 'target_log_prob' : target_log_prob }
181+
182+ self .samples_list = samples_list
183+ self .mcmc_diagnostics = mcmc_diagnostics
184+
185+
0 commit comments