leannebriffa's picture
app.py fix
9aff8dd
import gradio as gr
import pandas as pd
import numpy as np
from tensorflow import keras
import pickle
# Configuration section START
with open('model_config.pkl', 'rb') as f:
tf = pickle.load(f)
model = keras.models.load_model('model.keras')
all_selected_features = ['Alcohol', 'Arrest Type', 'Belts', 'Contributed To Accident', 'Disobedience', 'Driver City',
'Gender',
'Invalid Documentation', 'Make', 'Mobile Phone', 'Negligent Driving', 'Number Of Offences',
'Personal Injury',
'Property Damage', 'Race', 'Road Signs And Markings', 'Search Outcome', 'Speeding',
'Stop Hour', 'Stop Year',
'SubAgency', 'Vehicle Safety And Standards', 'VehicleType', 'Year']
very_high_cardinality_features = ['Driver City']
high_cardinality_features = ['Make', 'VehicleType']
bools = ['Alcohol', 'Belts', 'Contributed To Accident', 'Disobedience', 'Invalid Documentation', 'Mobile Phone',
'Negligent Driving',
'Personal Injury', 'Property Damage', 'Road Signs And Markings', 'Speeding', 'Vehicle Safety And Standards']
num_cols = ['Driver City', 'Make', 'Number Of Offences', 'Stop Hour', 'Stop Year', 'VehicleType', 'Year']
# Configuration section END
def make_prediction(alcohol, arrest_type, belts, contributed_to_accident, disobedience, driver_city, gender,
invalid_documentation, make, mobile_phone, negligent_driving, number_of_offences, personal_injury,
property_damage, race, road_signs_and_markings, search_outcome, speeding, stop_hour, stop_year,
subagency, vehicle_safety_and_standards, vehicletype, year):
"""
Function to predict the 'Violation Type' of an individual sample of traffic stop:
:param alcohol: boolean
:param arrest_type: String
:param belts: Boolean
:param contributed_to_accident: Boolean
:param disobedience: Boolean
:param driver_city: String
:param gender: 'M', 'F' or 'N'
:param invalid_documentation: Boolean
:param make: String
:param mobile_phone: Boolean
:param negligent_driving: Boolean
:param number_of_offences: Integer
:param personal_injury: Boolean
:param property_damage: Boolean
:param race: One of 'HISPANIC', 'BLACK', 'WHITE', 'OTHER', 'ASIAN', or 'NATIVE AMERICAN'
:param road_signs_and_markings: Boolean
:param search_outcome: String
:param speeding: Boolean
:param stop_hour: Integer
:param stop_year: Integer
:param subagency: String
:param vehicle_safety_and_standards: Boolean
:param vehicletype: String
:param year: Integer
:return:
"""
# Create a dataframe with the feature values
X = pd.DataFrame([[alcohol, arrest_type, belts, contributed_to_accident, disobedience, driver_city, gender,
invalid_documentation, make,
mobile_phone, negligent_driving, number_of_offences, personal_injury, property_damage, race,
road_signs_and_markings,
search_outcome, speeding, stop_hour, stop_year, subagency, vehicle_safety_and_standards,
vehicletype, year]],
columns=all_selected_features)
# Transform the features
# Encode very high cardinality features (Driver City) with ordinal encoding, by fitting on the whole dataset
oev = tf["OrdinalEncoder_VeryHighCardinality"].set_params(handle_unknown='use_encoded_value').set_params(
unknown_value=6714)
X[very_high_cardinality_features] = oev.transform(X[very_high_cardinality_features])
# Encode high cardinality features with ordinal encoding by fitting only on the training set
X[high_cardinality_features] = tf["OrdinalEncoder_HighCardinality"].transform(X[high_cardinality_features])
# Scale all the numerical features, including those that were ordinal encoded
X[num_cols] = tf["StandardScaler"].transform(X[num_cols])
# Convert booleans to numbers
X[bools] = X[bools].astype('int8')
# One-hot encode the low cardinality features
X = tf['OneHotEncoder'].transform(X)
# Make the prediction using the model
prediction = model.predict(X, verbose=0)
prediction = tf["LabelEncoder"].inverse_transform(np.argmax(prediction, axis=1))
# Return the prediction
# if prediction[0] == 0:
# return 'SERO'
# elif prediction[0] == 1:
# return 'Warning'
# else:
# return 'Citation'
return prediction[0]
iface = gr.Interface(fn=make_prediction,
inputs=[gr.components.Checkbox(label='Was the driver under the influence of alcohol?'),
gr.components.Dropdown(label='Choose the arrest type', choices=list(tf['OneHotEncoder'].transformers_[0][1].categories_[0]), value='A - Marked Patrol'),
gr.components.Checkbox(label='Were seatbelts used appropriately?'),
gr.components.Checkbox(label='Did the driver actions contribute to an accident?'),
gr.components.Checkbox(label='Was the driver disobedient? (such as failing to display documentation upon request)?'),
gr.components.Dropdown(label='Choose the driver city', choices=list(tf['OrdinalEncoder_VeryHighCardinality'].categories_[0]), value='SILVER SPRING'),
gr.components.Dropdown(label='Driver Gender', choices=['M', 'F', 'N'], value='M'),
gr.components.Checkbox(label='Was the driver driving with Invalid Documentation (such as suspended registration, suspended license, expired registration plates and validation tabs or expired license plate)?'),
gr.components.Dropdown(label='Vehicle Make', choices=list(tf['OrdinalEncoder_HighCardinality'].categories_[0]), value='TOYOTA'),
gr.components.Checkbox(label='Was the driver using a mobile phone while driving?'),
gr.components.Checkbox(label='Was the driver caught driving with negligence (example switching lanes in an unsafe manner)?'),
gr.components.Slider(minimum=1, step=1, label='Number of offences committed'),
gr.components.Checkbox(label='Did the violation involve any personal injury?'),
gr.components.Checkbox(label='Did the violation involve any property damage?'),
gr.components.Dropdown(label='Choose the race of the driver', choices=list(tf['OneHotEncoder'].transformers_[0][1].categories_[2]), value='WHITE'),
gr.components.Checkbox(label='Did the driver fail to obey signs and markings (such as traffic control device instructions, stop lights, red signal and stop sign lines)?'),
gr.components.Dropdown(label='What was the outcome of the search (if conducted)?', choices=list(tf['OneHotEncoder'].transformers_[0][1].categories_[3]), value='NO SEARCH CONDUCTED'),
gr.components.Checkbox(label='Was the driver caught speeding?'),
gr.components.Slider(maximum=23, step=1, label='Time HOUR when stop occurred in 24-hour format'),
gr.components.Slider(minimum=2012, maximum=2024, step=1, label='Year when stop occurred'),
gr.components.Dropdown(label='What is the name of the subagency that conducted the traffic stop?', choices=list(tf['OneHotEncoder'].transformers_[0][1].categories_[1]), value='4th District, Wheaton'),
gr.components.Checkbox(label='Was the vehicle safe and up to standards (lights properly switched, registration plates attached etc.)?'),
gr.components.Dropdown(label='What is the vehicle type?', choices=list(tf['OrdinalEncoder_HighCardinality'].categories_[1]), value='02 - Automobile'),
gr.components.Slider(minimum=1970, maximum=2023, step=1, label='Year of manufacture of the vehicle:')],
outputs=["text"])
iface.launch(debug=True)