|
1 | 1 | import emcee
|
2 | 2 | import numpy as np
|
3 | 3 | from marcia import Likelihood as lk
|
4 |
| -from marcia import temporarily_false |
5 | 4 | from getdist import plots, MCSamples
|
6 | 5 | import scipy.optimize as op
|
7 | 6 | from chainconsumer import ChainConsumer
|
8 | 7 | import logging
|
9 | 8 | import os
|
10 | 9 |
|
11 |
| -logging.basicConfig(filename="sampler.log",format='%(asctime)s %(message)s',filemode='w') |
12 |
| -logger = logging.getLogger() |
13 |
| -logger.setLevel(logging.INFO) |
14 |
| - |
15 |
| - |
16 |
| - |
17 | 10 | class Sampler:
|
18 | 11 |
|
19 |
| - def __init__(self,model,parameters,data,initial_guess,prior_file=None, |
20 |
| - max_n=100000,nwalkers=100,sampler_file='sampler.h5', |
21 |
| - resume=False,converge=False): |
22 |
| - |
23 |
| - self.likelihood = lk(model,parameters,data,prior_file) |
| 12 | + def __init__(self, model, parameters, data, initial_guess, prior_file=None, |
| 13 | + max_n=100000, nwalkers=100, sampler_file='sampler.h5', converge=False,): |
| 14 | + |
| 15 | + self.likelihood = lk(model, parameters, data, prior_file) |
24 | 16 | self.ndim = len(self.likelihood.priors)
|
25 | 17 | self.nwalkers = nwalkers
|
26 | 18 | self.initial_guess = initial_guess
|
27 | 19 | self.max_n = max_n
|
28 |
| - self.sampler_file = sampler_file |
29 |
| - self.resume = resume |
| 20 | + self.HDFBackend = emcee.backends.HDFBackend(sampler_file) |
30 | 21 | self.converge = converge
|
31 | 22 | self.mle = {}
|
32 | 23 |
|
33 |
| - |
34 |
| - def MLE(self,verbose=True): |
35 |
| - if 'result' not in self.mle.keys(): |
36 |
| - nll = lambda x: -1*self.likelihood.logProb(x) |
37 |
| - result = op.minimize(nll, x0=self.initial_guess, method = 'Nelder-Mead', options={'maxfev': None}) |
| 24 | + def MLE(self, verbose=True): |
| 25 | + if 'result' not in self.mle: |
| 26 | + nll = lambda x: -1 * self.likelihood.logProb(x) |
| 27 | + result = op.minimize(nll, x0=self.initial_guess, method='Nelder-Mead', options={'maxfev': None}) |
38 | 28 | if verbose:
|
39 | 29 | print(f'Best-fit values: {result.x}')
|
40 |
| - print(f'Max-Likelihood value (including prior likelihood):{self.likelihood.logProb(result.x)}') |
| 30 | + print(f'Max-Likelihood value (including prior likelihood): {self.likelihood.logProb(result.x)}') |
41 | 31 | self.mle['result'] = result.x
|
42 |
| - |
| 32 | + |
43 | 33 | return self.mle['result']
|
44 |
| - |
45 | 34 |
|
46 | 35 | def sampler_pos(self):
|
47 | 36 | mle = self.MLE()
|
48 |
| - pos = [mle+ 1e-4*np.random.randn(self.ndim) for i in range(self.nwalkers)] |
| 37 | + pos = [mle + 1e-4 * np.random.randn(self.ndim) for _ in range(self.nwalkers)] |
49 | 38 | return pos
|
50 |
| - |
51 |
| - def sampler_w_covergence(self): |
52 |
| - backend = emcee.backends.HDFBackend(self.sampler_file) |
53 |
| - if os.path.isfile(self.sampler_file) and (not self.resume): |
54 |
| - return backend |
55 |
| - else: |
56 |
| - if self.resume: |
57 |
| - print('Resuming from previous run') |
58 |
| - else: |
59 |
| - backend.reset(self.nwalkers, self.ndim) |
60 |
| - sampler = emcee.EnsembleSampler(self.nwalkers, self.ndim, self.likelihood.logProb, backend=backend) |
61 |
| - index = 0 |
62 |
| - autocorr = np.empty(self.max_n) |
63 |
| - old_tau = np.inf |
64 |
| - for sample in sampler.sample(self.sampler_pos(), iterations=self.max_n, progress=True): |
65 | 39 |
|
66 |
| - if sampler.iteration % 100: |
67 |
| - continue |
68 |
| - tau = sampler.get_autocorr_time(tol=0) |
69 |
| - autocorr[index] = np.mean(tau) |
70 |
| - index += 1 |
71 |
| - converged = np.all(tau * 100 < sampler.iteration) |
72 |
| - converged &= np.all(np.abs(old_tau - tau) / tau < 0.01) |
73 |
| - logger.info(f'I:{sampler.iteration}, A:{(tau*100)-sampler.iteration}, T:{np.abs(old_tau - tau) / tau}') |
74 |
| - if converged: |
75 |
| - print(f'Converged at iteration {sampler.iteration}') |
76 |
| - break |
77 |
| - old_tau = tau |
78 |
| - |
79 |
| - return sampler |
80 |
| - def sampler_wo_covergence(self): |
81 |
| - backend = emcee.backends.HDFBackend(self.sampler_file) |
82 |
| - if os.path.isfile(self.sampler_file) and (not self.resume): |
83 |
| - return backend |
84 |
| - else: |
85 |
| - if self.resume: |
86 |
| - print('Resuming from previous run') |
| 40 | + def sampler(self,reset=False): |
| 41 | + try: |
| 42 | + self.HDFBackend.iteration |
| 43 | + except OSError: |
| 44 | + self.HDFBackend.reset(self.nwalkers, self.ndim) |
| 45 | + |
| 46 | + last_iteration = self.HDFBackend.iteration if self.HDFBackend.iteration is not None else 0 |
| 47 | + |
| 48 | + if last_iteration < self.max_n: |
| 49 | + if last_iteration == 0: |
| 50 | + print('Sampling begins') |
87 | 51 | else:
|
88 |
| - backend.reset(self.nwalkers, self.ndim) |
89 |
| - sampler = emcee.EnsembleSampler(self.nwalkers, self.ndim, self.likelihood.logProb, backend=backend) |
90 |
| - sampler.run_mcmc(self.sampler_pos(), self.max_n, progress=True) |
91 |
| - return sampler |
92 |
| - |
93 |
| - def sampler(self): |
94 |
| - if self.converge: |
95 |
| - return self.sampler_w_covergence() |
| 52 | + if reset: |
| 53 | + print(f'Reseting sampling from iteration: {last_iteration}') |
| 54 | + self.HDFBackend.reset(self.nwalkers, self.ndim) |
| 55 | + else: |
| 56 | + print(f'Sampling resuming from iteration: {last_iteration}') |
| 57 | + sampler = emcee.EnsembleSampler(self.nwalkers, self.ndim, self.likelihood.logProb, backend=self.HDFBackend) |
| 58 | + if self.converge: |
| 59 | + index = 0 |
| 60 | + autocorr = np.empty(self.max_n) |
| 61 | + old_tau = np.inf |
| 62 | + for sample in sampler.sample(self.sampler_pos(), iterations=self.max_n, progress=True): |
| 63 | + if sampler.iteration % 100: |
| 64 | + continue |
| 65 | + tau = sampler.get_autocorr_time(tol=0) |
| 66 | + autocorr[index] = np.mean(tau) |
| 67 | + index += 1 |
| 68 | + converged = np.all(tau * 100 < sampler.iteration) |
| 69 | + converged &= np.all(np.abs(old_tau - tau) / tau < 0.01) |
| 70 | + print(f'I:{sampler.iteration}, A:{(tau*100)-sampler.iteration}, T:{np.abs(old_tau - tau) / tau}') |
| 71 | + if converged: |
| 72 | + print(f'Converged at iteration {sampler.iteration}') |
| 73 | + break |
| 74 | + old_tau = tau |
| 75 | + return sampler |
| 76 | + else: |
| 77 | + sampler.run_mcmc(self.sampler_pos(), self.max_n, progress=True) |
| 78 | + return sampler |
96 | 79 | else:
|
97 |
| - return self.sampler_wo_covergence() |
| 80 | + if reset: |
| 81 | + print(f'Reseting sampling from iteration: {last_iteration}') |
| 82 | + self.HDFBackend.reset(self.nwalkers, self.ndim) |
| 83 | + return self.sampler() |
| 84 | + print(f'Already completed {last_iteration} iterations') |
| 85 | + return self.HDFBackend |
98 | 86 |
|
99 |
| - #@temporarily_false('resume') |
100 | 87 | def get_burnin(self):
|
101 |
| - try: |
| 88 | + if self.converge: |
102 | 89 | tau = self.sampler().get_autocorr_time()
|
103 | 90 | burnin = int(2 * np.max(tau))
|
104 | 91 | thin = int(0.5 * np.min(tau))
|
105 |
| - except: |
| 92 | + else: |
106 | 93 | burnin = 50
|
107 | 94 | thin = 1
|
108 | 95 | return burnin, thin
|
109 | 96 |
|
110 |
| - @temporarily_false('resume') |
111 |
| - def get_chain(self,getdist=False): |
112 |
| - sampler = self.sampler() |
| 97 | + def get_chain(self, getdist=False,reset=False): |
| 98 | + sampler = self.sampler(reset=reset) |
113 | 99 | burnin, thin = self.get_burnin()
|
114 | 100 | samples = sampler.get_chain(discard=burnin, thin=thin, flat=True)
|
115 | 101 | if getdist:
|
116 | 102 | lnprob = sampler.get_log_prob(discard=burnin, thin=thin, flat=True)
|
117 | 103 | lnprior = sampler.get_blobs(discard=burnin, thin=thin, flat=True)
|
118 | 104 | if lnprior is None:
|
119 | 105 | lnprior = np.zeros_like(lnprob)
|
120 |
| - samples = np.concatenate((lnprob[:,None],lnprior[:,None],samples),axis=1) |
| 106 | + samples = np.concatenate((lnprob[:, None], lnprior[:, None], samples), axis=1) |
121 | 107 | return samples
|
122 |
| - |
123 |
| - @temporarily_false('resume') |
124 |
| - def corner_plot(self,getdist=False): |
| 108 | + |
| 109 | + def corner_plot(self, getdist=False): |
125 | 110 | chains = self.get_chain()
|
126 | 111 | names = self.likelihood.theory.param.parameters
|
127 |
| - labels = [p.replace('$','') for p in self.likelihood.theory.labels] |
| 112 | + labels = [p.replace('$', '') for p in self.likelihood.theory.labels] |
128 | 113 | if getdist:
|
129 |
| - samples = MCSamples(samples=chains,names=names,labels=labels) |
| 114 | + samples = MCSamples(samples=chains, names=names, labels=labels) |
130 | 115 | g = plots.get_subplot_plotter()
|
131 | 116 | g.triangle_plot([samples], filled=True)
|
132 | 117 | else:
|
133 | 118 | c = ChainConsumer().add_chain(chains, parameters=self.likelihood.theory.labels)
|
134 | 119 | fig = c.plotter.plot(truth=list(self.MLE(False)))
|
135 |
| - fig.set_size_inches(3 + fig.get_size_inches()) |
136 |
| - |
137 |
| - |
138 |
| - |
139 |
| - |
| 120 | + fig.set_size_inches(3 + fig.get_size_inches()) |
140 | 121 |
|
| 122 | + def get_simple_stat(self): |
| 123 | + samples = self.get_chain() |
| 124 | + names = self.likelihood.theory.param.parameters |
| 125 | + data = {} |
| 126 | + for i,name in enumerate(names): |
| 127 | + data[name] = {} |
| 128 | + data[name]['mean'] = np.mean(samples[:,i]) |
| 129 | + data[name]['std'] = np.std(samples[:,i]) |
| 130 | + return data |
0 commit comments