Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
""" | |
Created on Fri Mar 31 17:45:36 2023 | |
@author: Gaspar Avit Ferrero | |
""" | |
import os | |
import streamlit as st | |
from streamlit import session_state as session | |
from htbuilder import HtmlElement, div, hr, a, p, styles | |
from htbuilder.units import percent, px | |
from catboost import CatBoostClassifier | |
############################### | |
## ------- FUNCTIONS ------- ## | |
############################### | |
def link(link, text, **style): | |
return a(_href=link, _target="_blank", style=styles(**style))(text) | |
def layout(*args): | |
style = """ | |
<style> | |
# MainMenu {visibility: hidden;} | |
footer {visibility: hidden;} | |
.stApp { bottom: 105px; } | |
</style> | |
""" | |
style_div = styles( | |
position="fixed", | |
left=0, | |
bottom=0, | |
margin=px(0, 0, 0, 0), | |
width=percent(100), | |
height=px(10), | |
color="black", | |
text_align="center", | |
# height="auto", | |
opacity=1 | |
) | |
style_hr = styles( | |
display="block", | |
margin=px(8, 8, "auto", "auto"), | |
border_style="inset", | |
border_width=px(0) | |
) | |
body = p() | |
foot = div( | |
style=style_div | |
)( | |
body | |
) | |
st.markdown(style, unsafe_allow_html=True) | |
for arg in args: | |
if isinstance(arg, str): | |
body(arg) | |
elif isinstance(arg, HtmlElement): | |
body(arg) | |
st.markdown(str(foot), unsafe_allow_html=True) | |
def footer(): | |
myargs = [ | |
"Made by ", | |
link("https://www.linkedin.com/in/gaspar-avit/", "Gaspar Avit"), | |
] # with ❤️ | |
layout(*myargs) | |
def update_prediction(input_data): | |
"""Callback to automatically update prediction if button has already been | |
clicked""" | |
if is_clicked: | |
launch_prediction(input_data) | |
def get_input_data(): | |
""" | |
Generate input layout and get input values. | |
-return: DataFrame with input data. | |
""" | |
session.input_data = pd.DataFrame() | |
input_expander = st.expander('Input parameters', True) | |
with input_expander: | |
# Row 1 | |
col_age, col_sex = st.columns(2) | |
with col_age: | |
session.input_data['age'] = st.slider( | |
'Age', 18, 75, on_change=update_prediction(session.input_data)) | |
with col_sex: | |
session.input_data['sex'] = st.radio( | |
'Sex', ['Female', 'Male'], | |
on_change=update_prediction(session.input_data)) | |
# Row 2 | |
col_height, col_weight = st.columns(2) | |
with col_height: | |
session.input_data['height'] = st.slider( | |
'Height', 140, 200, | |
on_change=update_prediction(session.input_data)) | |
with col_weight: | |
session.input_data['weight'] = st.slider( | |
'Weight', 40, 140, | |
on_change=update_prediction(session.input_data)) | |
# Row 3 | |
col_ap_hi, col_ap_lo = st.columns(2) | |
with col_ap_hi: | |
session.input_data['ap_hi'] = st.slider( | |
'Systolic blood pressure', 90, 200, | |
on_change=update_prediction(session.input_data)) | |
with col_ap_lo: | |
session.input_data['ap_lo'] = st.slider( | |
'Diastolic blood pressure', 50, 120, | |
on_change=update_prediction(session.input_data)) | |
st.write("") | |
return session.input_data | |
def generate_prediction(input_data): | |
""" | |
Generate prediction of cardiovascular disease probability based on input | |
data. | |
-param input_data: DataFrame with input data | |
-return: prediction of cardiovascular disease probability | |
""" | |
return MODEL.predict(input_data) | |
############################### | |
## --------- MAIN ---------- ## | |
############################### | |
if __name__ == "__main__": | |
## --- Page config ------------ ## | |
# Set page title | |
st.title(""" | |
Cardiovascular Disease predictor | |
#### This app aims to give a scoring of how probable is that an individual \ | |
would suffer from a cardiovascular disease given its physical \ | |
characteristics | |
#### Just enter your info and get a prediction. | |
""") | |
# Set page footer | |
# footer() | |
# Initialize clicking flag | |
is_clicked = False | |
## --------------------------- ## | |
# Load classification model | |
MODEL = CatBoostClassifier() | |
MODEL.load_model('./model.cbm') | |
# Get inputs | |
session.input_data = get_input_data() | |
# Create button to trigger poster generation | |
buffer1, col1, buffer2 = st.columns([1.3, 1, 1]) | |
is_clicked = col1.button(label="Generate predictions") | |
st.text("") | |
st.text("") | |
# Generate poster | |
if is_clicked: | |
prediction = generate_prediction(session.input_data) | |
st.write(prediction) | |
st.text("") | |
st.text("") | |