Skip to content

Commit

Permalink
Merge pull request #31 from antolonappan/dynesty
Browse files Browse the repository at this point in the history
dynesty
  • Loading branch information
antolonappan authored Oct 7, 2023
2 parents dae9cdc + 6d01c71 commit adebb97
Show file tree
Hide file tree
Showing 3 changed files with 291 additions and 4 deletions.
212 changes: 212 additions & 0 deletions Notebooks/dynesty.ipynb

Large diffs are not rendered by default.

13 changes: 10 additions & 3 deletions marcia/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,19 @@ def chisq(self,theta):
return chi2

def logPrior(self, theta):
# if self.priors[0][0] < theta[0] < self.priors[0][1] and self.priors[1][0] < theta[1] < self.priors[1][1]:
# logPrior is independent of data for the most of it, unless otherwise some strange functions are defined
#if all((np.array(theta)-self.priors[:, 0])>0) and all((self.priors[:, 1]-np.array(theta))>0):
if all(self.priors[i][0] < theta[i] < self.priors[i][1] for i in range(len(theta))):
return 0.0
return -np.inf

def prior_transform(self,utheta):
# Assuming self.priors is a list of tuples [(lower1, upper1), (lower2, upper2), ...]
theta = []
for i, (lower, upper) in enumerate(self.priors):
# Transform the uniform random variable utheta[i] to the prior range [lower, upper]
theta_i = lower + (upper - lower) * utheta[i]
theta.append(theta_i)

return theta


def logLike(self,theta):
Expand Down
70 changes: 69 additions & 1 deletion marcia/sampler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any
import emcee
import numpy as np
from marcia import Likelihood as lk
Expand All @@ -7,8 +8,75 @@
from chainconsumer import ChainConsumer
import logging
import os
import dynesty
import pickle as pl
import matplotlib.pyplot as plt
from dynesty import plotting as dyplot

class Sampler:

def Sampler(which,*args,**kwargs):
if which == 'mcmc':
return SamplerMCMC(*args,**kwargs)
elif which == 'nest':
return SamplerNEST(*args,**kwargs)
else:
raise ValueError(f'Unknown sampler: {which}')




class SamplerNEST:

def __init__(self,model,parameters,data,bound='multi',sample='rwalk',sampler_file='sampler.pkl'):
self.data = data
self.model = model
self.likelihood = lk(model,parameters,data)
self.ndim = len(self.likelihood.priors)
self.bound = bound
self.sample = sample
self.sampler_file = sampler_file


def dynesty_sampler(self):
dsampler = dynesty.DynamicNestedSampler(self.likelihood.logLike,
self.likelihood.prior_transform, ndim=self.ndim,
bound=self.bound, sample=self.sample)
dsampler.run_nested()
return dsampler

def Sampler(self):
if os.path.isfile(self.sampler_file):
return pl.load(open(self.sampler_file,'rb'))
else:
dsampler = self.dynesty_sampler()
res = dsampler.results
pl.dump(res,open(self.sampler_file,'wb'))
return res

def get_chain(self):
res = self.Sampler()
samples = res.samples
return samples

def trace_plot(self):
res = self.Sampler()
labels = self.likelihood.theory.labels
fig, axes = dyplot.traceplot(res,labels=labels,
fig=plt.subplots(self.ndim, self.ndim, figsize=(8, 6)))
fig.tight_layout()

def corner_plot(self):
res = self.Sampler()
labels = self.likelihood.theory.labels
fig, axes = dyplot.cornerplot(res, show_titles=True,
title_kwargs={'y': 1.04}, labels=labels,
fig=plt.subplots(self.ndim,self.ndim, figsize=(8, 8)))
fig.tight_layout()




class SamplerMCMC:

def __init__(self, model, parameters, data, initial_guess, prior_file=None,
max_n=100000, nwalkers=100, sampler_file='sampler.h5', converge=False, prior_dist = 'uniform', burnin = 0.3, thin = 1):
Expand Down

0 comments on commit adebb97

Please sign in to comment.