-
Notifications
You must be signed in to change notification settings - Fork 0
/
seir_func.py
64 lines (50 loc) · 1.47 KB
/
seir_func.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
import matplotlib.pyplot as plt
def SEIR(u, t):
beta = 0.5
r_ia = 0.1
r_e2 = 1.25
lmbda_1 = 0.33
lmbda_2 = 0.5
p_a = 0.4
mu = 0.2
S, E1, E2, I, Ia, R = u
N = sum(u)
dS = -beta * S * I / N - r_ia * beta * S * Ia / N - r_e2 * beta * S * E2 / N
dE1 = beta * S * I / N + r_ia * beta * S * Ia / N + r_e2 * beta * S * E2 / N - lmbda_1 * E1
dE2 = lmbda_1 * (1 - p_a) * E1 - lmbda_2 * E2
dI = lmbda_2 * E2 - mu * I
dIa = lmbda_1 * p_a * E1 - mu * Ia
dR = mu * (I + Ia)
return [dS, dE1, dE2, dI, dIa, dR]
def test_SEIR():
t = 0
u = [1, 1, 1, 1, 1, 1]
computed = SEIR(u, t)
tol = 1e-10
expected = [-0.19583333333333333, -0.13416666666666668, -0.302, 0.3, -0.068, 0.4]
tol = 1e-10
for x, exp in zip(computed, expected):
assert abs(x - exp) < tol, \
f'Failed for x = {x}, expected {exp}, but got {f(x)}'
from ODESolver import *
def solve_SEIR(T, dt, S_0, E2_0):
time_points = np.linspace(0, T)
solver = RungeKutta4(SEIR)
solver.set_initial_condition([S_0, 0, E2_0, 0, 0, 0])
u,t = solver.solve(time_points)
return u,t
def plot_SEIR(u,t):
plt.plot(t, u[:,0], label = 'S(t)')
plt.plot(t, u[:,3], label = 'I(t)')
plt.plot(t, u[:,4], label = 'Ia(t)')
plt.plot(t, u[:,5], label = 'R(t)')
plt.legend()
plt.show()
u,t = solve_SEIR(T=100,dt=1.0,S_0=5e6,E2_0=100)
plot_SEIR(u,t)
"""
Run example:
user$ python3 seir_func.py
plots attached
"""