Emil25's picture
Update app.py
ee9b2e8 verified
# Importing the necessary libraries
import streamlit as st
import pandas as pd
import pickle
# Setting up the page configuration for Streamlit App
st.set_page_config(
page_title=" :mushroom: Mushroom App",
page_icon="🍄",
layout="wide",
initial_sidebar_state="expanded"
)
# Function for user input features
def user_input_features():
# Creating sliders and select boxes for user input in the sidebar
cap_diameter = st.sidebar.slider('Cap Diameter',
min_value=0.0,
max_value=2000.0,
value = 1000.0,
step=1.0,
)
cap_shape = st.sidebar.selectbox('Cap Shape',
options=('bell',
'conical',
'convex',
'flat',
'sunken',
'spherical',
'other',)
)
gill_attachment = st.sidebar.selectbox('Gill Attachment',
options=('adnate',
'adnexed',
'decurrent',
'free',
'sinuate',
'pores',
'none',)
)
gill_color = st.sidebar.selectbox('Gill Color',
options=('brown',
'buff',
'gray',
'green',
'pink',
'purple',
'red',
'white',
'yellow',
'blue',
'orange',
'black',)
)
stem_height = st.sidebar.slider('Stem Height',
min_value=0.0,
max_value=4.0,
value=2.0,
step=0.1,
)
stem_width = st.sidebar.slider('Stem Width',
min_value=0.0,
max_value=4000.0,
value=2000.0,
step=1.0,
)
stem_color = st.sidebar.selectbox('Stem Color',
options=('brown',
'buff',
'gray',
'green',
'pink',
'purple',
'red',
'white',
'yellow',
'blue',
'orange',
'black',)
)
season = st.sidebar.selectbox('Season',
options=('spring',
'summer',
'autumn',
'winter',)
)
# Function to get the color code
def get_color(color_name):
color_dict = {
'brown': 0,
'buff': 1,
'gray': 2,
'green': 3,
'pink': 4,
'purple': 5,
'red': 6,
'white': 7,
'yellow': 8,
'blue': 9,
'orange': 10,
'black': 11,
}
return color_dict.get(color_name.lower(), "not found")
# Function to get the cap shape code
def get_cap_shape(cap_shape):
shape_dict = {
'bell': 0,
'conical': 1,
'convex': 2,
'flat': 3,
'sunken': 4,
'spherical': 5,
'other': 6,
}
return shape_dict.get(cap_shape.lower(), "not found")
# Function to get gill attachment code
def get_gill_attachment(gill_attachment):
gill_attachment_dict = {
'adnate': 0,
'adnexed': 1,
'decurrent': 2,
'free': 3,
'sinuate': 4,
'pores': 5,
'none': 6,
}
return gill_attachment_dict.get(gill_attachment.lower(), "not found")
# Function to get season code
def get_season(season):
season_dict = {
'spring': 0,
'summer': 1,
'autumn': 2,
'winter': 3,
}
return season_dict.get(season.lower(), "not found")
# Creating a data dictionary to store the user input data
data = {'cap-diameter': cap_diameter,
'cap-shape': get_cap_shape(cap_shape),
'gill-attachment': get_gill_attachment(gill_attachment),
'gill-color': get_color(gill_color),
'stem-height': stem_height,
'stem-width': stem_width,
'stem-color': get_color(stem_color),
'season': get_season(season),
}
# Creating a DataFrame from the data dictionary
features = pd.DataFrame(data, index=[0])
return features
# Function to load the prediction model
#@st.cache_data()
def get_model():
model = pickle.load(open("models/rfc_model.pkl", "rb"))
return model
# Function to make prediction using the model and input data
def make_prediction(data):
model = get_model()
return model.predict(data)
# Function to process uploaded CSV file and make predictions
def process_file(file):
data = pd.read_csv(file)
model = get_model()
predictions = model.predict(data)
data['prediction'] = predictions
return data
# Main function
def main():
st.write("""# :mushroom: Mushroom App""")
st.sidebar.image("img/dataset-cover.jpg")
user_data = user_input_features()
# Creating a session state button for prediction
if 'btn_predict' not in st.session_state:
st.session_state['btn_predict'] = False
st.session_state['btn_predict'] = st.button("Predict")
# Making prediction and showing result
if st.session_state['btn_predict'] == True:
if make_prediction(user_data) == 1:
st.error("# Result: Poisonous :skull_and_crossbones: ")
else:
st.success("# Result: Edible :mushroom: ")
# File upload for batch prediction
st.write("## Загрузка CSV-файла с данными о грибах для массового предсказания")
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
if uploaded_file is not None:
result_df = process_file(uploaded_file)
st.write(result_df)
csv = result_df.to_csv(index=False).encode('utf-8')
st.download_button(
label="Download predictions as CSV",
data=csv,
file_name='predictions.csv',
mime='text/csv',
)
# Running the main function
if __name__ == "__main__":
main()