-
Notifications
You must be signed in to change notification settings - Fork 0
/
odesolvers.py
84 lines (64 loc) · 2.36 KB
/
odesolvers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
###################################################################################################
# Code as from https://github.com/amirgholami/anode to get these solvers #
###################################################################################################
class Time_Stepper(object):
def __init__(self, func, y0, Nt = 2):
self.func = func
self.Nt = Nt
self.dt_next = []
def step(self, func, t, dt, y):
pass
def integrate(self, y0):
y1 = y0
dt = 1. / float(self.Nt)
for n in range(self.Nt):
t0 = 0 + n * dt
self.dt_next.append(dt)
y1 = self.step(self.func, t0, dt, y1)
return y1
#############################################################
class Euler(Time_Stepper):
def step(self, func, t, dt, y):
out = y + dt * func(t, y)
return out
class RK2(Time_Stepper):
def step(self, func, t, dt, y):
k1 = dt * func(t, y)
k2 = dt * func(t + dt / 2.0, y + 1.0 / 2.0 * k1)
out = y + k2
return out
class RK4(Time_Stepper):
def step(self, func, t, dt, y):
k1 = dt * func(t, y)
k2 = dt * func(t + dt / 2.0, y + 1.0 / 2.0 * k1)
k3 = dt * func(t + dt / 2.0, y + 1.0 / 2.0 * k2)
k4 = dt * func(t + dt, y + k3)
out = y + 1.0 / 6.0 * k1 + 1.0 / 3.0 * k2 + 1.0 / 3.0 * k3 + 1.0 / 6.0 * k4
return out
#############################################################
def odesolver(func, z0, options={'method' : 'Euler', 'Nt' : 2}):
Nt = options['Nt']
if options['method'] == 'Euler':
solver = Euler(func, z0, Nt=Nt)
elif options['method'] == 'RK2':
solver = RK2(func, z0, Nt=Nt)
elif options['method'] == 'RK4':
solver = RK4(func, z0, Nt=Nt)
else:
print('error unsupported method passed')
return
z1 = solver.integrate(z0)
if hasattr(func, 'base_func'):
if hasattr(func.base_func, 'dt'):
func.base_func.dt.append(solver.dt_next)
elif hasattr(func, 'dt'):
func.dt.append(solver.dt_next)
return z1
#############################################################
if __name__ == '__main__':
#y_deriv = f(x)
def f(t, x):
return t * x
#using the ode solver
result = odesolver(f, 3, {'method' : 'Euler', 'Nt' : 2})
print(result)