Jiranuwat's picture
Create app.py
20f33b6
import pandas as pd
import streamlit as st
import numpy as np
from scipy.integrate import odeint
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_percentage_error
import warnings
warnings.filterwarnings("ignore")
#read files
data = pd.read_csv('owid-monkeypox-data.csv')
data = data[['location','iso_code','date','new_cases','total_cases','new_deaths','total_deaths']]
pop = pd.read_csv('API_SP.POP.TOTL_DS2_en_csv_v2_4578059.csv')
#preprocessiong data
all_location = {}
for i in data['iso_code'].unique():
all_location[i] = data[data['iso_code'] == i].reset_index(drop=True)
popu = pop[['Country Code','2021']].to_dict('index')
pop_dict = {}
for i in popu.values():
pop_dict[i['Country Code']] = i['2021']
pop_dict['GLP'] = 400000
pop_dict['MTQ'] = 376480
pop_dict['OWID_WRL'] = 7836630792
code = dict(data.groupby('location')['iso_code'].unique())
# SIR model differential equations.
def deriv(x, t, beta, gamma):
s, i, r = x
dsdt = -beta * s * i
didt = beta * s * i - gamma * i
drdt = gamma * i
return [dsdt, didt, drdt]
#plot model
def plotdata(t, s, i,r,R0, e=None):
# plot the data
fig = plt.figure(figsize=(12,6))
ax = [fig.add_subplot(221, axisbelow=True),
fig.add_subplot(223),
fig.add_subplot(122)]
ax[0].plot(t, s, lw=3, label='Fraction Susceptible')
ax[0].plot(t, i, lw=3, label='Fraction Infective')
ax[0].plot(t, r, lw=3, label='Recovered')
ax[0].set_title('Susceptible and Recovered Populations')
ax[0].set_xlabel('Time /days')
ax[0].set_ylabel('Fraction')
ax[1].plot(t, i, lw=3, label='Infective')
ax[1].set_title('Infectious Population')
if e is not None: ax[1].plot(t, e, lw=3, label='Exposed')
ax[1].set_ylim(0, 1.0)
ax[1].set_xlabel('Time /days')
ax[1].set_ylabel('Fraction')
ax[2].plot(s, i, lw=3, label='s, i trajectory')
ax[2].plot([1/R0, 1/R0], [0, 1], '--', lw=3, label='di/dt = 0')
ax[2].plot(s[0], i[0], '.', ms=20, label='Initial Condition')
ax[2].plot(s[-1], i[-1], '.', ms=20, label='Final Condition')
ax[2].set_title('State Trajectory')
ax[2].set_aspect('equal')
ax[2].set_ylim(0, 1.05)
ax[2].set_xlim(0, 1.05)
ax[2].set_xlabel('Susceptible')
ax[2].set_ylabel('Infectious')
for a in ax:
a.grid(True)
a.legend()
plt.tight_layout()
return fig
def compare_plt(country,i,pop):
fig = plt.figure(figsize=(12,6))
ax = [fig.add_subplot(121, axisbelow=True),fig.add_subplot(122)]
ax[0].set_title('Monkeypox confirmed cases')
ax[0].plot(all_location[country]['total_cases'],lw=3,label='Infective')
ax[0].set_xlabel('Days')
ax[0].set_ylabel('Number of cases')
ax[0].legend()
scaler = all_location[country]['total_cases'].apply(lambda x : x/pop)
ax[1].set_title('Monkeypox confirmed cases compare with model')
ax[1].plot(scaler,lw=3,label='Real Infective')
ax[1].plot(i,lw=3,label='SIR model Infective')
ax[1].set_ylim(0,0.00005)
ax[1].set_xlim(0,200)
ax[1].set_xlabel('Days')
ax[1].set_ylabel('Fraction Number of cases')
ax[1].legend()
plt.tight_layout()
return fig
#final model
def SIR(country,R0,t_infective,pop):
#R0 = 0.57 - 1.25
# parameter values
R0 = R0
t_infective = t_infective
# initial number of infected and recovered individuals
i_initial = all_location[country]['total_cases'].iloc[0]/pop
r_initial = 0.00
s_initial = 1 - i_initial - r_initial
gamma = 1/t_infective
beta = R0*gamma
t = np.linspace(0, 3000, 3000)
x_initial = s_initial, i_initial, r_initial
soln = odeint(deriv, x_initial, t, args=(beta, gamma))
s, i, r = soln.T
e = None
scaler = all_location[country]['total_cases'].apply(lambda x : x/pop)
rangee = len(all_location[country]['total_cases'])
rmpe = mean_absolute_percentage_error(scaler,i[0:rangee])*100
return R0,t_infective,beta,gamma,rmpe,plotdata(t, s, i,r,R0),compare_plt(country,i,pop)
def main():
st.title("SIR Model for Monkeypox in Thailand")
st.subheader("Latest updated : 10/02/2023")
st.subheader("Reference : https://jckantor.github.io/CBE30338/03.09-COVID-19.html")
st.caption("Display graph of SIR model of monkeypox and comparison between the model and actual data. Try to find the best R0 that fit for the actual data (lowest MAPE).")
with st.form("questionaire"):
recovery = st.slider("How long Monkeypox last until recovery(days)? ", 14, 31, 21)
R0 = st.slider("Basic Reproduction Number (R0)", 0.57, 3.00, 0.57)# user's input
country_code = code["Thailand"][0]
pop = pop_dict[country_code]
# clicked==True only when the button is clicked
clicked = st.form_submit_button("Show Graph")
if clicked:
# Show SIR
SIR_param = SIR(country_code,R0,recovery,pop)
if SIR_param[0] <= 1:
a = 'No epidemic.'
else:
a = 'Epidemic has began.'
st.pyplot(SIR_param[-2])
st.pyplot(SIR_param[-1])
st.success("SIR model parameters of Thailand "+" is")
st.success("R0 (Basic Reproduction Number) = "+str(SIR_param[0])+' '+a)
st.success("Beta (Rate of transmission) = "+str(round(SIR_param[2],3)))
st.success("Gamma (Rate of Recovery) = "+str(round(SIR_param[3],3)))
st.success("MAPE = "+str(round(SIR_param[4],3))+"%")
# Run main()
if __name__ == "__main__":
main()