-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsonar.py
83 lines (60 loc) · 2.33 KB
/
sonar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from typing import List
import chex
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
import pickle
import numpyro
import numpyro.distributions as pydist
from jax._src.flatten_util import ravel_pytree
from targets.base_target import Target
from utils.path_utils import project_path
def pad_with_const(X):
extra = np.ones((X.shape[0], 1))
return np.hstack([extra, X])
def standardize_and_pad(X):
mean = np.mean(X, axis=0)
std = np.std(X, axis=0)
std[std == 0] = 1.
X = (X - mean) / std
return pad_with_const(X)
def load_model_sonar():
def model(Y):
w = numpyro.sample("weights", pydist.Normal(jnp.zeros(dim), jnp.ones(dim)))
logits = jnp.dot(X, w)
with numpyro.plate('J', n_data):
y = numpyro.sample("y", pydist.BernoulliLogits(logits), obs=Y)
with open(project_path('targets/data/sonar_full.pkl'), 'rb') as f:
X, Y = pickle.load(f)
Y = (Y + 1) // 2
X = standardize_and_pad(X)
dim = X.shape[1]
n_data = X.shape[0]
model_args = (Y,)
return model, model_args
class Sonar(Target):
def __init__(self, dim=61, log_Z=None, can_sample=False, sample_bounds=None) -> None:
super().__init__(dim=dim, log_Z=log_Z, can_sample=can_sample)
self.data_ndim = dim
rng_key = jax.random.PRNGKey(1)
model, model_args = load_model_sonar()
model_param_info, potential_fn, constrain_fn, _ = numpyro.infer.util.initialize_model(rng_key, model,
model_args=model_args)
params_flat, unflattener = ravel_pytree(model_param_info[0])
self.log_prob_model = lambda z: -1. * potential_fn(unflattener(z))
def get_dim(self):
return self.dim
def log_prob(self, x: chex.Array):
batched = x.ndim == 2
if not batched:
x = x[None,]
# log prob model can only handle unbatched input
log_probs = jax.vmap(self.log_prob_model)(x)
if not batched:
log_probs = jnp.squeeze(log_probs, axis=0)
return log_probs
def visualise(self, samples: chex.Array = None, axes=None, show=False, prefix='') -> dict:
return {}
def sample(self, seed: chex.PRNGKey, sample_shape: chex.Shape) -> chex.Array:
return None