Emil25 commited on
Commit
6e8c2e3
1 Parent(s): a5e9ec9

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -0
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing the necessary libraries
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import pickle
5
+
6
+ # Setting up the page configuration for Streamlit App
7
+ st.set_page_config(
8
+ page_title=" :mushroom: Mushroom App",
9
+ page_icon="🍄",
10
+ layout="wide",
11
+ initial_sidebar_state="expanded"
12
+ )
13
+
14
+ # Function for user input features
15
+ def user_input_features():
16
+ # Creating sliders and select boxes for user input in the sidebar
17
+ cap_diameter = st.sidebar.slider('Cap Diameter',
18
+ min_value=0.0,
19
+ max_value=2000.0,
20
+ value = 1000.0,
21
+ step=1.0,
22
+ )
23
+ cap_shape = st.sidebar.selectbox('Cap Shape',
24
+ options=('bell',
25
+ 'conical',
26
+ 'convex',
27
+ 'flat',
28
+ 'sunken',
29
+ 'spherical',
30
+ 'other',)
31
+ )
32
+ gill_attachment = st.sidebar.selectbox('Gill Attachment',
33
+ options=('adnate',
34
+ 'adnexed',
35
+ 'decurrent',
36
+ 'free',
37
+ 'sinuate',
38
+ 'pores',
39
+ 'none',)
40
+ )
41
+ gill_color = st.sidebar.selectbox('Gill Color',
42
+ options=('brown',
43
+ 'buff',
44
+ 'gray',
45
+ 'green',
46
+ 'pink',
47
+ 'purple',
48
+ 'red',
49
+ 'white',
50
+ 'yellow',
51
+ 'blue',
52
+ 'orange',
53
+ 'black',)
54
+ )
55
+ stem_height = st.sidebar.slider('Stem Height',
56
+ min_value=0.0,
57
+ max_value=4.0,
58
+ value=2.0,
59
+ step=0.1,
60
+ )
61
+ stem_width = st.sidebar.slider('Stem Width',
62
+ min_value=0.0,
63
+ max_value=4000.0,
64
+ value=2000.0,
65
+ step=1.0,
66
+ )
67
+ stem_color = st.sidebar.selectbox('Stem Color',
68
+ options=('brown',
69
+ 'buff',
70
+ 'gray',
71
+ 'green',
72
+ 'pink',
73
+ 'purple',
74
+ 'red',
75
+ 'white',
76
+ 'yellow',
77
+ 'blue',
78
+ 'orange',
79
+ 'black',)
80
+ )
81
+ season = st.sidebar.selectbox('Season',
82
+ options=('spring',
83
+ 'summer',
84
+ 'autumn',
85
+ 'winter',)
86
+ )
87
+
88
+ # Function to get the color code
89
+ def get_color(color_name):
90
+ color_dict = {
91
+ 'brown': 0,
92
+ 'buff': 1,
93
+ 'gray': 2,
94
+ 'green': 3,
95
+ 'pink': 4,
96
+ 'purple': 5,
97
+ 'red': 6,
98
+ 'white': 7,
99
+ 'yellow': 8,
100
+ 'blue': 9,
101
+ 'orange': 10,
102
+ 'black': 11,
103
+ }
104
+ return color_dict.get(color_name.lower(), "not found")
105
+
106
+ # Function to get the cap shape code
107
+ def get_cap_shape(cap_shape):
108
+ shape_dict = {
109
+ 'bell': 0,
110
+ 'conical': 1,
111
+ 'convex': 2,
112
+ 'flat': 3,
113
+ 'sunken': 4,
114
+ 'spherical': 5,
115
+ 'other': 6,
116
+ }
117
+ return shape_dict.get(cap_shape.lower(), "not found")
118
+
119
+ # Function to get gill attachment code
120
+ def get_gill_attachment(gill_attachment):
121
+ gill_attachment_dict = {
122
+ 'adnate': 0,
123
+ 'adnexed': 1,
124
+ 'decurrent': 2,
125
+ 'free': 3,
126
+ 'sinuate': 4,
127
+ 'pores': 5,
128
+ 'none': 6,
129
+ }
130
+ return gill_attachment_dict.get(gill_attachment.lower(), "not found")
131
+
132
+ # Function to get season code
133
+ def get_season(season):
134
+ season_dict = {
135
+ 'spring': 0,
136
+ 'summer': 1,
137
+ 'autumn': 2,
138
+ 'winter': 3,
139
+ }
140
+ return season_dict.get(season.lower(), "not found")
141
+
142
+ # Creating a data dictionary to store the user input data
143
+ data = {'cap-diameter': cap_diameter,
144
+ 'cap-shape': get_cap_shape(cap_shape),
145
+ 'gill-attachment': get_gill_attachment(gill_attachment),
146
+ 'gill-color': get_color(gill_color),
147
+ 'stem-height': stem_height,
148
+ 'stem-width': stem_width,
149
+ 'stem-color': get_color(stem_color),
150
+ 'season': get_season(season),
151
+ }
152
+
153
+ # Creating a DataFrame from the data dictionary
154
+ features = pd.DataFrame(data, index=[0])
155
+ return features
156
+
157
+ # Function to load the prediction model
158
+ #@st.cache_data()
159
+ def get_model():
160
+ model = pickle.load(open("models/rfc_model.pkl", "rb"))
161
+ return model
162
+
163
+ # Function to make prediction using the model and input data
164
+ def make_prediction(data):
165
+ model = get_model()
166
+ return model.predict(data)
167
+
168
+ # Function to process uploaded CSV file and make predictions
169
+ def process_file(file):
170
+ data = pd.read_csv(file)
171
+ model = get_model()
172
+ predictions = model.predict(data)
173
+ data['prediction'] = predictions
174
+ return data
175
+
176
+ # Main function
177
+ def main():
178
+ st.write("""# :mushroom: Mushroom App""")
179
+ st.sidebar.image("img/dataset-cover.jpg")
180
+ user_data = user_input_features()
181
+
182
+ # Creating a session state button for prediction
183
+ if 'btn_predict' not in st.session_state:
184
+ st.session_state['btn_predict'] = False
185
+
186
+ st.session_state['btn_predict'] = st.button("Predict")
187
+
188
+ # Making prediction and showing result
189
+ if st.session_state['btn_predict'] == True:
190
+ if make_prediction(user_data) == 1:
191
+ st.error("# Result: Poisonous :skull_and_crossbones: ")
192
+ else:
193
+ st.success("# Result: Edible :mushroom: ")
194
+
195
+ # File upload for batch prediction
196
+ st.write("## Batch Prediction")
197
+ uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
198
+ if uploaded_file is not None:
199
+ result_df = process_file(uploaded_file)
200
+ st.write(result_df)
201
+ csv = result_df.to_csv(index=False).encode('utf-8')
202
+ st.download_button(
203
+ label="Download predictions as CSV",
204
+ data=csv,
205
+ file_name='predictions.csv',
206
+ mime='text/csv',
207
+ )
208
+
209
+ # Running the main function
210
+ if __name__ == "__main__":
211
+ main()