Spaces:
Sleeping
Sleeping
# Importing the necessary libraries | |
import streamlit as st | |
import pandas as pd | |
import pickle | |
# Setting up the page configuration for Streamlit App | |
st.set_page_config( | |
page_title=" :mushroom: Mushroom App", | |
page_icon="üçÑ", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Function for user input features | |
def user_input_features(): | |
# Creating sliders and select boxes for user input in the sidebar | |
cap_diameter = st.sidebar.slider('Cap Diameter', | |
min_value=0.0, | |
max_value=2000.0, | |
value = 1000.0, | |
step=1.0, | |
) | |
cap_shape = st.sidebar.selectbox('Cap Shape', | |
options=('bell', | |
'conical', | |
'convex', | |
'flat', | |
'sunken', | |
'spherical', | |
'other',) | |
) | |
gill_attachment = st.sidebar.selectbox('Gill Attachment', | |
options=('adnate', | |
'adnexed', | |
'decurrent', | |
'free', | |
'sinuate', | |
'pores', | |
'none',) | |
) | |
gill_color = st.sidebar.selectbox('Gill Color', | |
options=('brown', | |
'buff', | |
'gray', | |
'green', | |
'pink', | |
'purple', | |
'red', | |
'white', | |
'yellow', | |
'blue', | |
'orange', | |
'black',) | |
) | |
stem_height = st.sidebar.slider('Stem Height', | |
min_value=0.0, | |
max_value=4.0, | |
value=2.0, | |
step=0.1, | |
) | |
stem_width = st.sidebar.slider('Stem Width', | |
min_value=0.0, | |
max_value=4000.0, | |
value=2000.0, | |
step=1.0, | |
) | |
stem_color = st.sidebar.selectbox('Stem Color', | |
options=('brown', | |
'buff', | |
'gray', | |
'green', | |
'pink', | |
'purple', | |
'red', | |
'white', | |
'yellow', | |
'blue', | |
'orange', | |
'black',) | |
) | |
season = st.sidebar.selectbox('Season', | |
options=('spring', | |
'summer', | |
'autumn', | |
'winter',) | |
) | |
# Function to get the color code | |
def get_color(color_name): | |
color_dict = { | |
'brown': 0, | |
'buff': 1, | |
'gray': 2, | |
'green': 3, | |
'pink': 4, | |
'purple': 5, | |
'red': 6, | |
'white': 7, | |
'yellow': 8, | |
'blue': 9, | |
'orange': 10, | |
'black': 11, | |
} | |
return color_dict.get(color_name.lower(), "not found") | |
# Function to get the cap shape code | |
def get_cap_shape(cap_shape): | |
shape_dict = { | |
'bell': 0, | |
'conical': 1, | |
'convex': 2, | |
'flat': 3, | |
'sunken': 4, | |
'spherical': 5, | |
'other': 6, | |
} | |
return shape_dict.get(cap_shape.lower(), "not found") | |
# Function to get gill attachment code | |
def get_gill_attachment(gill_attachment): | |
gill_attachment_dict = { | |
'adnate': 0, | |
'adnexed': 1, | |
'decurrent': 2, | |
'free': 3, | |
'sinuate': 4, | |
'pores': 5, | |
'none': 6, | |
} | |
return gill_attachment_dict.get(gill_attachment.lower(), "not found") | |
# Function to get season code | |
def get_season(season): | |
season_dict = { | |
'spring': 0, | |
'summer': 1, | |
'autumn': 2, | |
'winter': 3, | |
} | |
return season_dict.get(season.lower(), "not found") | |
# Creating a data dictionary to store the user input data | |
data = {'cap-diameter': cap_diameter, | |
'cap-shape': get_cap_shape(cap_shape), | |
'gill-attachment': get_gill_attachment(gill_attachment), | |
'gill-color': get_color(gill_color), | |
'stem-height': stem_height, | |
'stem-width': stem_width, | |
'stem-color': get_color(stem_color), | |
'season': get_season(season), | |
} | |
# Creating a DataFrame from the data dictionary | |
features = pd.DataFrame(data, index=[0]) | |
return features | |
# Function to load the prediction model | |
#@st.cache_data() | |
def get_model(): | |
model = pickle.load(open("models/rfc_model.pkl", "rb")) | |
return model | |
# Function to make prediction using the model and input data | |
def make_prediction(data): | |
model = get_model() | |
return model.predict(data) | |
# Function to process uploaded CSV file and make predictions | |
def process_file(file): | |
data = pd.read_csv(file) | |
model = get_model() | |
predictions = model.predict(data) | |
data['prediction'] = predictions | |
return data | |
# Main function | |
def main(): | |
st.write("""# :mushroom: Mushroom App""") | |
st.sidebar.image("img/dataset-cover.jpg") | |
user_data = user_input_features() | |
# Creating a session state button for prediction | |
if 'btn_predict' not in st.session_state: | |
st.session_state['btn_predict'] = False | |
st.session_state['btn_predict'] = st.button("Predict") | |
# Making prediction and showing result | |
if st.session_state['btn_predict'] == True: | |
if make_prediction(user_data) == 1: | |
st.error("# Result: Poisonous :skull_and_crossbones: ") | |
else: | |
st.success("# Result: Edible :mushroom: ") | |
# File upload for batch prediction | |
st.write("## Загрузка CSV-файла с данными о грибах для массового предсказания") | |
uploaded_file = st.file_uploader("Choose a CSV file", type="csv") | |
if uploaded_file is not None: | |
result_df = process_file(uploaded_file) | |
st.write(result_df) | |
csv = result_df.to_csv(index=False).encode('utf-8') | |
st.download_button( | |
label="Download predictions as CSV", | |
data=csv, | |
file_name='predictions.csv', | |
mime='text/csv', | |
) | |
# Running the main function | |
if __name__ == "__main__": | |
main() | |