11# | # Minimal example of GW parameter estimation
2- # |
2+ # |
33# | This script performs Bayesian inference on LIGO data (from GW150914) using
4- # | a nested sampling algorithm implemented with BlackJAX. It loads the
5- # | detector data and sets up a gravitational-wave waveform model, defines a
6- # | prior and likelihood for the model parameters, then runs nested sampling
7- # | to sample from the posterior. Finally, it processes the samples with
4+ # | a nested sampling algorithm implemented with BlackJAX. It loads the
5+ # | detector data and sets up a gravitational-wave waveform model, defines a
6+ # | prior and likelihood for the model parameters, then runs nested sampling
7+ # | to sample from the posterior. Finally, it processes the samples with
88# | anesthetic and writes them to a CSV file.
99# |
1010# | ## Installation
2626
2727from astropy .time import Time
2828from jimgw .single_event .detector import H1 , L1
29- from jimgw .single_event .likelihood import original_likelihood as likelihood_function
29+ from jimgw .single_event .likelihood import (
30+ original_likelihood as likelihood_function ,
31+ )
3032from jimgw .single_event .waveform import RippleIMRPhenomD
3133
32- jax .config .update (' jax_enable_x64' , True )
34+ jax .config .update (" jax_enable_x64" , True )
3335
3436# | Define LIGO event data
3537
4244waveform = RippleIMRPhenomD (f_ref = 20 )
4345detectors = [H1 , L1 ]
4446frequencies = H1 .frequencies
45- duration = 4
46- post_trigger_duration = 2
47+ duration = 4
48+ post_trigger_duration = 2
4749epoch = duration - post_trigger_duration
4850gmst = Time (gps , format = "gps" ).sidereal_time ("apparent" , "greenwich" ).rad
4951
50- columns = ["M_c" , "q" , "s1_z" , "s2_z" , "iota" , "d_L" , "t_c" , "phase_c" , "psi" , "ra" , "dec" ]
51- labels = [r"$M_c$" , r"$q$" , r"$s_{1z}$" , r"$s_{2z}$" , r"$\iota$" , r"$d_L$" , r"$t_c$" , r"$\phi_c$" , r"$\psi$" , r"$\alpha$" , r"$\delta$" ]
52+ columns = [
53+ "M_c" ,
54+ "q" ,
55+ "s1_z" ,
56+ "s2_z" ,
57+ "iota" ,
58+ "d_L" ,
59+ "t_c" ,
60+ "phase_c" ,
61+ "psi" ,
62+ "ra" ,
63+ "dec" ,
64+ ]
65+ labels = [
66+ r"$M_c$" ,
67+ r"$q$" ,
68+ r"$s_{1z}$" ,
69+ r"$s_{2z}$" ,
70+ r"$\iota$" ,
71+ r"$d_L$" ,
72+ r"$t_c$" ,
73+ r"$\phi_c$" ,
74+ r"$\psi$" ,
75+ r"$\alpha$" ,
76+ r"$\delta$" ,
77+ ]
78+
5279
5380@jax .jit
5481def loglikelihood_fn (params ):
5582 p = params .copy ()
56- p ["eta" ] = p ["q" ] / (1 + p ["q" ])** 2
83+ p ["eta" ] = p ["q" ] / (1 + p ["q" ]) ** 2
5784 p ["gmst" ] = gmst
5885 waveform_sky = waveform (frequencies , p )
5986 align_time = jnp .exp (- 1j * 2 * jnp .pi * frequencies * (epoch + p ["t_c" ]))
60- return likelihood_function (p , waveform_sky , detectors , frequencies , align_time )
87+ return likelihood_function (
88+ p , waveform_sky , detectors , frequencies , align_time
89+ )
90+
6191
6292# | Define the prior function
6393
@@ -66,10 +96,17 @@ def loglikelihood_fn(params):
6696d_L_min , d_L_max = 1.0 , 2000.0
6797t_c_min , t_c_max = - 0.05 , 0.05
6898
69- cosine_logprob = lambda x : jnp .where (jnp .abs (x ) < jnp .pi / 2 , jnp .log (jnp .cos (x )/ 2.0 ), - jnp .inf )
70- sine_logprob = lambda x : jnp .where ((x >= 0.0 ) & (x <= jnp .pi ), jnp .log (jnp .sin (x )/ 2.0 ), - jnp .inf )
71- uniform_logprob = lambda x , a , b : jax .scipy .stats .uniform .logpdf (x , a , b - a )
72- power_logprob = lambda x , n , a , b : jax .scipy .stats .beta .logpdf (x , n , n , loc = a , scale = b - a )
99+ cosine_logprob = lambda x : jnp .where (
100+ jnp .abs (x ) < jnp .pi / 2 , jnp .log (jnp .cos (x ) / 2.0 ), - jnp .inf
101+ )
102+ sine_logprob = lambda x : jnp .where (
103+ (x >= 0.0 ) & (x <= jnp .pi ), jnp .log (jnp .sin (x ) / 2.0 ), - jnp .inf
104+ )
105+ uniform_logprob = lambda x , a , b : jax .scipy .stats .uniform .logpdf (x , a , b - a )
106+ power_logprob = lambda x , n , a , b : jax .scipy .stats .beta .logpdf (
107+ x , n , n , loc = a , scale = b - a
108+ )
109+
73110
74111def logprior_fn (p ):
75112 logprob = 0.0
@@ -86,6 +123,7 @@ def logprior_fn(p):
86123 logprob += cosine_logprob (p ["dec" ])
87124 return logprob
88125
126+
89127# | Define the Nested Sampling algorithm
90128n_dims = len (columns )
91129n_live = 1000
@@ -97,20 +135,41 @@ def logprior_fn(p):
97135rng_key , init_key = jax .random .split (rng_key , 2 )
98136init_keys = jax .random .split (init_key , n_dims )
99137particles = {
100- "M_c" : jax .random .uniform (init_keys [0 ], (n_live ,), minval = M_c_min , maxval = M_c_max ),
101- "q" : jax .random .uniform (init_keys [1 ], (n_live ,), minval = q_min , maxval = q_max ),
102- "s1_z" : jax .random .uniform (init_keys [2 ], (n_live ,), minval = - 1.0 , maxval = 1.0 ),
103- "s2_z" : jax .random .uniform (init_keys [3 ], (n_live ,), minval = - 1.0 , maxval = 1.0 ),
104- "iota" : 2 * jnp .arcsin (jax .random .uniform (init_keys [4 ], (n_live ,))** 0.5 ),
105- "d_L" : jax .random .beta (init_keys [5 ], 2.0 , 2.0 , shape = (n_live ,)) * (d_L_max - d_L_min ) + d_L_min ,
106- "t_c" : jax .random .uniform (init_keys [6 ], (n_live ,), minval = t_c_min , maxval = t_c_max ),
107- "phase_c" : jax .random .uniform (init_keys [7 ], (n_live ,), minval = 0.0 , maxval = 2 * jnp .pi ),
108- "psi" : jax .random .uniform (init_keys [8 ], (n_live ,), minval = 0.0 , maxval = 2 * jnp .pi ),
109- "ra" : jax .random .uniform (init_keys [9 ], (n_live ,), minval = 0.0 , maxval = 2 * jnp .pi ),
110- "dec" : 2 * jnp .arcsin (jax .random .uniform (init_keys [10 ], (n_live ,))** 0.5 ) - jnp .pi / 2.0 ,
138+ "M_c" : jax .random .uniform (
139+ init_keys [0 ], (n_live ,), minval = M_c_min , maxval = M_c_max
140+ ),
141+ "q" : jax .random .uniform (
142+ init_keys [1 ], (n_live ,), minval = q_min , maxval = q_max
143+ ),
144+ "s1_z" : jax .random .uniform (
145+ init_keys [2 ], (n_live ,), minval = - 1.0 , maxval = 1.0
146+ ),
147+ "s2_z" : jax .random .uniform (
148+ init_keys [3 ], (n_live ,), minval = - 1.0 , maxval = 1.0
149+ ),
150+ "iota" : 2 * jnp .arcsin (jax .random .uniform (init_keys [4 ], (n_live ,)) ** 0.5 ),
151+ "d_L" : jax .random .beta (init_keys [5 ], 2.0 , 2.0 , shape = (n_live ,))
152+ * (d_L_max - d_L_min )
153+ + d_L_min ,
154+ "t_c" : jax .random .uniform (
155+ init_keys [6 ], (n_live ,), minval = t_c_min , maxval = t_c_max
156+ ),
157+ "phase_c" : jax .random .uniform (
158+ init_keys [7 ], (n_live ,), minval = 0.0 , maxval = 2 * jnp .pi
159+ ),
160+ "psi" : jax .random .uniform (
161+ init_keys [8 ], (n_live ,), minval = 0.0 , maxval = 2 * jnp .pi
162+ ),
163+ "ra" : jax .random .uniform (
164+ init_keys [9 ], (n_live ,), minval = 0.0 , maxval = 2 * jnp .pi
165+ ),
166+ "dec" : 2 * jnp .arcsin (jax .random .uniform (init_keys [10 ], (n_live ,)) ** 0.5 )
167+ - jnp .pi / 2.0 ,
111168}
112169
113- _ , ravel_fn = jax .flatten_util .ravel_pytree ({k : v [0 ] for k , v in particles .items ()})
170+ _ , ravel_fn = jax .flatten_util .ravel_pytree (
171+ {k : v [0 ] for k , v in particles .items ()}
172+ )
114173
115174# | Initialize the Nested Sampling algorithm
116175nested_sampler = blackjax .ns .adaptive .nss (
@@ -123,13 +182,15 @@ def logprior_fn(p):
123182
124183state = nested_sampler .init (particles , loglikelihood_fn )
125184
185+
126186@jax .jit
127187def one_step (carry , xs ):
128188 state , k = carry
129189 k , subk = jax .random .split (k , 2 )
130190 state , dead_point = nested_sampler .step (subk , state )
131191 return (state , k ), dead_point
132192
193+
133194# | Run Nested Sampling
134195dead = []
135196with tqdm .tqdm (desc = "Dead points" , unit = " dead points" ) as pbar :
@@ -141,18 +202,26 @@ def one_step(carry, xs):
141202# | anesthetic post-processing
142203from anesthetic import NestedSamples
143204import numpy as np
205+
144206dead = jax .tree .map (
145- lambda * args : jnp .reshape (jnp .stack (args , axis = 0 ),
146- (- 1 ,) + args [0 ].shape [1 :]),
147- * dead )
207+ lambda * args : jnp .reshape (
208+ jnp .stack (args , axis = 0 ), (- 1 ,) + args [0 ].shape [1 :]
209+ ),
210+ * dead ,
211+ )
148212live = state .sampler_state
149213
150214logL = np .concatenate ((dead .logL , live .logL ), dtype = float )
151215logL_birth = np .concatenate ((dead .logL_birth , live .logL_birth ), dtype = float )
152- data = np .concatenate ([
153- np .column_stack ([v for v in dead .particles .values ()]),
154- np .column_stack ([v for v in live .particles .values ()])
155- ], axis = 0 )
216+ data = np .concatenate (
217+ [
218+ np .column_stack ([v for v in dead .particles .values ()]),
219+ np .column_stack ([v for v in live .particles .values ()]),
220+ ],
221+ axis = 0 ,
222+ )
156223
157- samples = NestedSamples (data , logL = logL , logL_birth = logL_birth , columns = columns , labels = labels )
158- samples .to_csv ('GW.csv' )
224+ samples = NestedSamples (
225+ data , logL = logL , logL_birth = logL_birth , columns = columns , labels = labels
226+ )
227+ samples .to_csv ("GW.csv" )
0 commit comments