-
Notifications
You must be signed in to change notification settings - Fork 1
/
ODESolver.py
185 lines (154 loc) · 5.82 KB
/
ODESolver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import numpy as np
class ODESolver:
"""
Superclass for numerical methods solving scalar and vector ODEs
du/dt = f(u, t)
Attributes:
t: array of time values
u: array of solution values (at time points t)
k: step number of the most recently computed solution
f: callable object implementing f(u, t)
"""
def __init__(self, f):
if not callable(f):
raise TypeError('f is %s, not a function' % type(f))
# For ODE systems, f will often return a list, but
# arithmetic operations with f in numerical methods
# require that f is an array. Let self.f be a function
# that first calls f(u,t) and then ensures that the
# result is an array of floats.
self.f = lambda u, t: np.asarray(f(u, t), float)
def advance(self):
"""Advance solution one time step."""
raise NotImplementedError
def set_initial_condition(self, U0):
if isinstance(U0, (float,int)): # scalar ODE
self.neq = 1
U0 = float(U0)
else: # system of ODEs
U0 = np.asarray(U0) # (assume U0 is sequence)
self.neq = U0.size
self.U0 = U0
# Check that f returns correct length:
try:
f0 = self.f(self.U0, 0)
except IndexError:
raise IndexError('Index of u out of bounds in f(u,t) func. Legal indices are %s' % (str(range(self.neq))))
if f0.size != self.neq:
raise ValueError('f(u,t) returns %d components, while u has %d components' % (f0.size, self.neq))
def solve(self, time_points, terminate=None):
"""
Compute solution u for t values in the list/array
time_points, as long as terminate(u,t,step_no) is False.
terminate(u,t,step_no) is a user-given function
returning True or False. By default, a terminate
function which always returns False is used.
"""
if terminate is None:
terminate = lambda u, t, step_no: False
if isinstance(time_points, (float,int)):
raise TypeError('solve: time_points is not a sequence')
self.t = np.asarray(time_points)
if self.t.size <= 1:
raise ValueError('ODESolver.solve requires time_points array with at least 2 time points')
n = self.t.size
if self.neq == 1: # scalar ODEs
self.u = np.zeros(n)
else: # systems of ODEs
self.u = np.zeros((n,self.neq))
# Assume that self.t[0] corresponds to self.U0
self.u[0] = self.U0
# Time loop
for k in range(n-1):
self.k = k
self.u[k+1] = self.advance()
if terminate(self.u, self.t, self.k+1):
break # terminate loop over k
return self.u[:k+2], self.t[:k+2]
class ForwardEuler(ODESolver):
def advance(self):
u, f, k, t = self.u, self.f, self.k, self.t
dt = t[k+1] - t[k]
u_new = u[k] + dt*f(u[k], t[k])
return u_new
class RungeKutta4(ODESolver):
def advance(self):
u, f, k, t = self.u, self.f, self.k, self.t
dt = t[k+1] - t[k]
dt2 = dt/2.0
K1 = dt*f(u[k], t[k])
K2 = dt*f(u[k] + 0.5*K1, t[k] + dt2)
K3 = dt*f(u[k] + 0.5*K2, t[k] + dt2)
K4 = dt*f(u[k] + K3, t[k] + dt)
u_new = u[k] + (1/6.0)*(K1 + 2*K2 + 2*K3 + K4)
return u_new
import sys, os
class BackwardEuler(ODESolver):
"""Backward Euler solver for scalar ODEs."""
def __init__(self, f):
ODESolver.__init__(self, f)
# Make a sample call to check that f is a scalar function:
try:
u = np.array([1]); t = 1
value = f(u, t)
except IndexError: # index out of bounds for u
raise ValueError('f(u,t) must return float/int')
# BackwardEuler needs to import function Newton from Newton.py:
try:
from Newton import Newton
self.Newton = Newton
except ImportError:
raise ImportError('''
Could not import module "Newton". Place Newton.py in this directory
(%s)
''' % (os.path.dirname(os.path.abspath(__file__))))
# Alternative implementation of F:
#def F(self, w):
# return w - self.dt*self.f(w, self.t[-1]) - self.u[self.k]
def advance(self):
u, f, k, t = self.u, self.f, self.k, self.t
dt = t[k+1] - t[k]
def F(w):
return w - dt*f(w, t[k+1]) - u[k]
dFdw = Derivative(F)
w_start = u[k] + dt*f(u[k], t[k]) # Forward Euler step
u_new, n, F_value = self.Newton(F, w_start, dFdw, max_n=30)
if k == 0:
self.Newton_iter = []
self.Newton_iter.append(n)
if n >= 30:
print("Newton's failed to converge at t=%g "\
"(%d iterations)" % (t[k+1], n))
return u_new
class Derivative:
def __init__(self, f, h=1E-9):
self.f = f
self.h = float(h)
def __call__(self, x):
f, h = self.f, self.h # make short forms
return (f(x+h) - f(x-h))/(2*h)
registered_solver_classes = [
ForwardEuler, RungeKutta4, BackwardEuler]
def test_exact_numerical_solution():
a = 0.2; b = 3
def f(u, t):
return a + (u - u_exact(t))**5
def u_exact(t):
"""Exact u(t) corresponding to f above."""
return a*t + b
U0 = u_exact(0)
T = 8
n = 10
tol = 1E-15
t_points = np.linspace(0, T, n)
for solver_class in registered_solver_classes:
solver = solver_class(f)
solver.set_initial_condition(U0)
u, t = solver.solve(t_points)
u_e = u_exact(t)
max_error = (u_e - u).max()
msg = '%s failed with max_error=%g' % \
(solver.__class__.__name__, max_error)
assert max_error < tol, msg
if __name__ == '__main__':
test_exact_numerical_solution()