-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathsbtm_sim.py
148 lines (118 loc) · 4.12 KB
/
sbtm_sim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""
Base class for score-based transport modeling.
Nicholas M. Boffi
10/27/22
"""
from dataclasses import dataclass
from typing import Callable, Tuple, Union
import haiku as hk
import jax
import numpy as onp
from jaxlib.xla_extension import Device
import optax
import networks
import rollouts
State = onp.ndarray
Time = float
@dataclass
class SBTMSim:
"""
Base class for all SBTM simulations.
Contains simulation parameters common to all SBTM approaches.
"""
# initial condition fitting
n_max_init_opt_steps: int
init_learning_rate: float
init_ltol: float
sig0: float
mu0: onp.ndarray
# system parameters
drift: Callable[[State, Time], State]
force_args: Tuple
amp: Callable[[Time], float]
freq: float
dt: float
D: onp.ndarray
D_sqrt: onp.ndarray
n: int
d: int
N: int
# timestepping
ltol: float
gtol: float
n_opt_steps: int
learning_rate: float
# network parameters
n_hidden: int
n_neurons: int
act: Callable[[State], State]
residual_blocks: bool
interacting_particle_system: bool
# general simulation parameters
key: onp.ndarray
params_list: list
all_samples: dict
# output information
output_folder: str
output_name: str
def __init__(self, data_dict: dict) -> None:
self.__dict__ = data_dict.copy()
def initialize_forcing(self) -> None:
self.forcing = lambda x, t: self.drift(x, t, *self.force_args)
def initialize_network_and_optimizer(self) -> None:
"""Initialize the network parameters and optimizer."""
if self.interacting_particle_system:
self.score_network, self.potential_network = \
networks.construct_interacting_particle_system_network(
self.n_hidden,
self.n_neurons,
self.N,
self.d,
self.act,
self.residual_blocks
)
example_x = onp.zeros(self.N*self.d)
else:
self.score_network, self.potential_network = \
networks.construct_score_network(
self.d,
self.n_hidden,
self.n_neurons,
self.act,
is_gradient=True
)
example_x = onp.zeros(self.d)
self.key, sk = jax.random.split(self.key)
init_params = self.score_network.init(self.key, example_x)
self.params_list = [init_params]
network_size = jax.flatten_util.ravel_pytree(init_params)[0].size
print(f'Number of parameters: {network_size}')
print(f'Number of parameters needed for overparameterization: ' \
+ f'{self.n*example_x.size}')
# set up the optimizer
self.opt = optax.radam(self.learning_rate)
self.opt_state = self.opt.init(init_params)
# set up batching for the score
self.batch_score = jax.vmap(self.score_network.apply, in_axes=(None, 0))
def fit_init(self, cpu: Device, gpu: Device) -> None:
"""Fit the initial condition."""
# draw samples
samples_shape = (self.n, self.N*self.d)
init_samples = self.sig0*onp.random.randn(*samples_shape) + self.mu0[None, :]
# set up optimizer
init_params = jax.device_put(self.params_list[0], gpu)
opt = optax.adabelief(self.init_learning_rate)
opt_state = opt.init(init_params)
init_params = rollouts.fit_initial_condition(
self.n_max_init_opt_steps,
self.init_ltol,
init_params,
self.sig0,
self.mu0,
self.score_network,
opt,
opt_state,
init_samples
)
self.params_list = [jax.device_put(init_params, device=cpu)]
self.all_samples = {'SDE': [init_samples], 'learned': [init_samples]}