Skip to content

Commit 9149568

Browse files
committed
sort out chapters
1 parent 1d3bf4e commit 9149568

File tree

5 files changed

+181
-120
lines changed

5 files changed

+181
-120
lines changed

_toc.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@ parts:
1414
- caption: Physics examples
1515
chapters:
1616
- file: physics/supernovae.ipynb
17-
# - caption: Supernovae Light Curves
18-
# chapters:
19-
# - file: scripts/supernovae.ipynb
2017
- caption: Appendix
2118
chapters:
2219
- file: contributors.md

advanced/GP.ipynb

Lines changed: 71 additions & 76 deletions
Large diffs are not rendered by default.

contributors.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Contributors and Citations
22

3-
# Contributors
3+
## Contributors
44

55
- Basic
66
- Quickstart - https://github.com/yallup
@@ -11,7 +11,7 @@
1111
- Supernovae Light Curve Fittting - https://github.com/samleeney
1212

1313

14-
# Citations
14+
## Citations
1515

1616
```{bibliography}
1717
```

physics/supernovae.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"\n",
99
"In this example we will explore a physical example of fitting a supernova light curve model to data\n",
1010
"\n",
11-
"# Nested Sampling with JAX-bandflux\n",
11+
"## Nested Sampling with JAX-bandflux\n",
1212
"\n",
1313
"This notebook demonstrates how to run the nested sampling procedure for supernovae SALT model fitting using the JAX-bandflux package (as implemented in `ns.py`). We will install the package, load the data, set up and run the nested sampling algorithm, and finally produce a corner plot of the posterior samples.\n",
1414
"\n",

scripts/GW.py

Lines changed: 107 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# | # Minimal example of GW parameter estimation
2-
# |
2+
# |
33
# | This script performs Bayesian inference on LIGO data (from GW150914) using
4-
# | a nested sampling algorithm implemented with BlackJAX. It loads the
5-
# | detector data and sets up a gravitational-wave waveform model, defines a
6-
# | prior and likelihood for the model parameters, then runs nested sampling
7-
# | to sample from the posterior. Finally, it processes the samples with
4+
# | a nested sampling algorithm implemented with BlackJAX. It loads the
5+
# | detector data and sets up a gravitational-wave waveform model, defines a
6+
# | prior and likelihood for the model parameters, then runs nested sampling
7+
# | to sample from the posterior. Finally, it processes the samples with
88
# | anesthetic and writes them to a CSV file.
99
# |
1010
# | ## Installation
@@ -26,10 +26,12 @@
2626

2727
from astropy.time import Time
2828
from jimgw.single_event.detector import H1, L1
29-
from jimgw.single_event.likelihood import original_likelihood as likelihood_function
29+
from jimgw.single_event.likelihood import (
30+
original_likelihood as likelihood_function,
31+
)
3032
from jimgw.single_event.waveform import RippleIMRPhenomD
3133

32-
jax.config.update('jax_enable_x64', True)
34+
jax.config.update("jax_enable_x64", True)
3335

3436
# | Define LIGO event data
3537

@@ -42,22 +44,50 @@
4244
waveform = RippleIMRPhenomD(f_ref=20)
4345
detectors = [H1, L1]
4446
frequencies = H1.frequencies
45-
duration=4
46-
post_trigger_duration=2
47+
duration = 4
48+
post_trigger_duration = 2
4749
epoch = duration - post_trigger_duration
4850
gmst = Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad
4951

50-
columns = ["M_c", "q", "s1_z", "s2_z", "iota", "d_L", "t_c", "phase_c", "psi", "ra", "dec"]
51-
labels = [r"$M_c$", r"$q$", r"$s_{1z}$", r"$s_{2z}$", r"$\iota$", r"$d_L$", r"$t_c$", r"$\phi_c$", r"$\psi$", r"$\alpha$", r"$\delta$"]
52+
columns = [
53+
"M_c",
54+
"q",
55+
"s1_z",
56+
"s2_z",
57+
"iota",
58+
"d_L",
59+
"t_c",
60+
"phase_c",
61+
"psi",
62+
"ra",
63+
"dec",
64+
]
65+
labels = [
66+
r"$M_c$",
67+
r"$q$",
68+
r"$s_{1z}$",
69+
r"$s_{2z}$",
70+
r"$\iota$",
71+
r"$d_L$",
72+
r"$t_c$",
73+
r"$\phi_c$",
74+
r"$\psi$",
75+
r"$\alpha$",
76+
r"$\delta$",
77+
]
78+
5279

5380
@jax.jit
5481
def loglikelihood_fn(params):
5582
p = params.copy()
56-
p["eta"] = p["q"] / (1 + p["q"])**2
83+
p["eta"] = p["q"] / (1 + p["q"]) ** 2
5784
p["gmst"] = gmst
5885
waveform_sky = waveform(frequencies, p)
5986
align_time = jnp.exp(-1j * 2 * jnp.pi * frequencies * (epoch + p["t_c"]))
60-
return likelihood_function(p, waveform_sky, detectors, frequencies, align_time)
87+
return likelihood_function(
88+
p, waveform_sky, detectors, frequencies, align_time
89+
)
90+
6191

6292
# | Define the prior function
6393

@@ -66,10 +96,17 @@ def loglikelihood_fn(params):
6696
d_L_min, d_L_max = 1.0, 2000.0
6797
t_c_min, t_c_max = -0.05, 0.05
6898

69-
cosine_logprob = lambda x: jnp.where(jnp.abs(x) < jnp.pi/2, jnp.log(jnp.cos(x)/2.0), -jnp.inf)
70-
sine_logprob = lambda x: jnp.where((x >= 0.0) & (x <= jnp.pi), jnp.log(jnp.sin(x)/2.0), -jnp.inf)
71-
uniform_logprob = lambda x, a, b: jax.scipy.stats.uniform.logpdf(x, a, b-a)
72-
power_logprob = lambda x, n, a, b: jax.scipy.stats.beta.logpdf(x, n, n, loc=a, scale=b-a)
99+
cosine_logprob = lambda x: jnp.where(
100+
jnp.abs(x) < jnp.pi / 2, jnp.log(jnp.cos(x) / 2.0), -jnp.inf
101+
)
102+
sine_logprob = lambda x: jnp.where(
103+
(x >= 0.0) & (x <= jnp.pi), jnp.log(jnp.sin(x) / 2.0), -jnp.inf
104+
)
105+
uniform_logprob = lambda x, a, b: jax.scipy.stats.uniform.logpdf(x, a, b - a)
106+
power_logprob = lambda x, n, a, b: jax.scipy.stats.beta.logpdf(
107+
x, n, n, loc=a, scale=b - a
108+
)
109+
73110

74111
def logprior_fn(p):
75112
logprob = 0.0
@@ -86,6 +123,7 @@ def logprior_fn(p):
86123
logprob += cosine_logprob(p["dec"])
87124
return logprob
88125

126+
89127
# | Define the Nested Sampling algorithm
90128
n_dims = len(columns)
91129
n_live = 1000
@@ -97,20 +135,41 @@ def logprior_fn(p):
97135
rng_key, init_key = jax.random.split(rng_key, 2)
98136
init_keys = jax.random.split(init_key, n_dims)
99137
particles = {
100-
"M_c": jax.random.uniform(init_keys[0], (n_live,), minval=M_c_min, maxval=M_c_max),
101-
"q": jax.random.uniform(init_keys[1], (n_live,), minval=q_min, maxval=q_max),
102-
"s1_z": jax.random.uniform(init_keys[2], (n_live,), minval=-1.0, maxval=1.0),
103-
"s2_z": jax.random.uniform(init_keys[3], (n_live,), minval=-1.0, maxval=1.0),
104-
"iota": 2 * jnp.arcsin(jax.random.uniform(init_keys[4], (n_live,))**0.5),
105-
"d_L": jax.random.beta(init_keys[5], 2.0, 2.0, shape=(n_live,)) * (d_L_max - d_L_min) + d_L_min,
106-
"t_c": jax.random.uniform(init_keys[6], (n_live,), minval=t_c_min, maxval=t_c_max),
107-
"phase_c": jax.random.uniform(init_keys[7], (n_live,), minval=0.0, maxval=2 * jnp.pi),
108-
"psi": jax.random.uniform(init_keys[8], (n_live,), minval=0.0, maxval=2 * jnp.pi),
109-
"ra": jax.random.uniform(init_keys[9], (n_live,), minval=0.0, maxval=2 * jnp.pi),
110-
"dec": 2 * jnp.arcsin(jax.random.uniform(init_keys[10], (n_live,))**0.5) - jnp.pi/2.0,
138+
"M_c": jax.random.uniform(
139+
init_keys[0], (n_live,), minval=M_c_min, maxval=M_c_max
140+
),
141+
"q": jax.random.uniform(
142+
init_keys[1], (n_live,), minval=q_min, maxval=q_max
143+
),
144+
"s1_z": jax.random.uniform(
145+
init_keys[2], (n_live,), minval=-1.0, maxval=1.0
146+
),
147+
"s2_z": jax.random.uniform(
148+
init_keys[3], (n_live,), minval=-1.0, maxval=1.0
149+
),
150+
"iota": 2 * jnp.arcsin(jax.random.uniform(init_keys[4], (n_live,)) ** 0.5),
151+
"d_L": jax.random.beta(init_keys[5], 2.0, 2.0, shape=(n_live,))
152+
* (d_L_max - d_L_min)
153+
+ d_L_min,
154+
"t_c": jax.random.uniform(
155+
init_keys[6], (n_live,), minval=t_c_min, maxval=t_c_max
156+
),
157+
"phase_c": jax.random.uniform(
158+
init_keys[7], (n_live,), minval=0.0, maxval=2 * jnp.pi
159+
),
160+
"psi": jax.random.uniform(
161+
init_keys[8], (n_live,), minval=0.0, maxval=2 * jnp.pi
162+
),
163+
"ra": jax.random.uniform(
164+
init_keys[9], (n_live,), minval=0.0, maxval=2 * jnp.pi
165+
),
166+
"dec": 2 * jnp.arcsin(jax.random.uniform(init_keys[10], (n_live,)) ** 0.5)
167+
- jnp.pi / 2.0,
111168
}
112169

113-
_, ravel_fn = jax.flatten_util.ravel_pytree({k: v[0] for k, v in particles.items()})
170+
_, ravel_fn = jax.flatten_util.ravel_pytree(
171+
{k: v[0] for k, v in particles.items()}
172+
)
114173

115174
# | Initialize the Nested Sampling algorithm
116175
nested_sampler = blackjax.ns.adaptive.nss(
@@ -123,13 +182,15 @@ def logprior_fn(p):
123182

124183
state = nested_sampler.init(particles, loglikelihood_fn)
125184

185+
126186
@jax.jit
127187
def one_step(carry, xs):
128188
state, k = carry
129189
k, subk = jax.random.split(k, 2)
130190
state, dead_point = nested_sampler.step(subk, state)
131191
return (state, k), dead_point
132192

193+
133194
# | Run Nested Sampling
134195
dead = []
135196
with tqdm.tqdm(desc="Dead points", unit=" dead points") as pbar:
@@ -141,18 +202,26 @@ def one_step(carry, xs):
141202
# | anesthetic post-processing
142203
from anesthetic import NestedSamples
143204
import numpy as np
205+
144206
dead = jax.tree.map(
145-
lambda *args: jnp.reshape(jnp.stack(args, axis=0),
146-
(-1,) + args[0].shape[1:]),
147-
*dead)
207+
lambda *args: jnp.reshape(
208+
jnp.stack(args, axis=0), (-1,) + args[0].shape[1:]
209+
),
210+
*dead,
211+
)
148212
live = state.sampler_state
149213

150214
logL = np.concatenate((dead.logL, live.logL), dtype=float)
151215
logL_birth = np.concatenate((dead.logL_birth, live.logL_birth), dtype=float)
152-
data = np.concatenate([
153-
np.column_stack([v for v in dead.particles.values()]),
154-
np.column_stack([v for v in live.particles.values()])
155-
], axis=0)
216+
data = np.concatenate(
217+
[
218+
np.column_stack([v for v in dead.particles.values()]),
219+
np.column_stack([v for v in live.particles.values()]),
220+
],
221+
axis=0,
222+
)
156223

157-
samples = NestedSamples(data, logL=logL, logL_birth=logL_birth, columns=columns, labels=labels)
158-
samples.to_csv('GW.csv')
224+
samples = NestedSamples(
225+
data, logL=logL, logL_birth=logL_birth, columns=columns, labels=labels
226+
)
227+
samples.to_csv("GW.csv")

0 commit comments

Comments
 (0)