|
import gradio as gr |
|
import pandas as pd |
|
import numpy as np |
|
from tensorflow import keras |
|
import pickle |
|
|
|
|
|
|
|
with open('model_config.pkl', 'rb') as f: |
|
tf = pickle.load(f) |
|
|
|
model = keras.models.load_model('model.keras') |
|
|
|
all_selected_features = ['Alcohol', 'Arrest Type', 'Belts', 'Contributed To Accident', 'Disobedience', 'Driver City', |
|
'Gender', |
|
'Invalid Documentation', 'Make', 'Mobile Phone', 'Negligent Driving', 'Number Of Offences', |
|
'Personal Injury', |
|
'Property Damage', 'Race', 'Road Signs And Markings', 'Search Outcome', 'Speeding', |
|
'Stop Hour', 'Stop Year', |
|
'SubAgency', 'Vehicle Safety And Standards', 'VehicleType', 'Year'] |
|
very_high_cardinality_features = ['Driver City'] |
|
high_cardinality_features = ['Make', 'VehicleType'] |
|
bools = ['Alcohol', 'Belts', 'Contributed To Accident', 'Disobedience', 'Invalid Documentation', 'Mobile Phone', |
|
'Negligent Driving', |
|
'Personal Injury', 'Property Damage', 'Road Signs And Markings', 'Speeding', 'Vehicle Safety And Standards'] |
|
num_cols = ['Driver City', 'Make', 'Number Of Offences', 'Stop Hour', 'Stop Year', 'VehicleType', 'Year'] |
|
|
|
|
|
|
|
|
|
|
|
def make_prediction(alcohol, arrest_type, belts, contributed_to_accident, disobedience, driver_city, gender, |
|
invalid_documentation, make, mobile_phone, negligent_driving, number_of_offences, personal_injury, |
|
property_damage, race, road_signs_and_markings, search_outcome, speeding, stop_hour, stop_year, |
|
subagency, vehicle_safety_and_standards, vehicletype, year): |
|
""" |
|
Function to predict the 'Violation Type' of an individual sample of traffic stop: |
|
:param alcohol: boolean |
|
:param arrest_type: String |
|
:param belts: Boolean |
|
:param contributed_to_accident: Boolean |
|
:param disobedience: Boolean |
|
:param driver_city: String |
|
:param gender: 'M', 'F' or 'N' |
|
:param invalid_documentation: Boolean |
|
:param make: String |
|
:param mobile_phone: Boolean |
|
:param negligent_driving: Boolean |
|
:param number_of_offences: Integer |
|
:param personal_injury: Boolean |
|
:param property_damage: Boolean |
|
:param race: One of 'HISPANIC', 'BLACK', 'WHITE', 'OTHER', 'ASIAN', or 'NATIVE AMERICAN' |
|
:param road_signs_and_markings: Boolean |
|
:param search_outcome: String |
|
:param speeding: Boolean |
|
:param stop_hour: Integer |
|
:param stop_year: Integer |
|
:param subagency: String |
|
:param vehicle_safety_and_standards: Boolean |
|
:param vehicletype: String |
|
:param year: Integer |
|
:return: |
|
""" |
|
|
|
X = pd.DataFrame([[alcohol, arrest_type, belts, contributed_to_accident, disobedience, driver_city, gender, |
|
invalid_documentation, make, |
|
mobile_phone, negligent_driving, number_of_offences, personal_injury, property_damage, race, |
|
road_signs_and_markings, |
|
search_outcome, speeding, stop_hour, stop_year, subagency, vehicle_safety_and_standards, |
|
vehicletype, year]], |
|
columns=all_selected_features) |
|
|
|
|
|
|
|
oev = tf["OrdinalEncoder_VeryHighCardinality"].set_params(handle_unknown='use_encoded_value').set_params( |
|
unknown_value=6714) |
|
X[very_high_cardinality_features] = oev.transform(X[very_high_cardinality_features]) |
|
|
|
|
|
X[high_cardinality_features] = tf["OrdinalEncoder_HighCardinality"].transform(X[high_cardinality_features]) |
|
|
|
|
|
X[num_cols] = tf["StandardScaler"].transform(X[num_cols]) |
|
|
|
|
|
X[bools] = X[bools].astype('int8') |
|
|
|
|
|
X = tf['OneHotEncoder'].transform(X) |
|
|
|
|
|
prediction = model.predict(X, verbose=0) |
|
|
|
prediction = tf["LabelEncoder"].inverse_transform(np.argmax(prediction, axis=1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return prediction[0] |
|
|
|
|
|
iface = gr.Interface(fn=make_prediction, |
|
inputs=[gr.components.Checkbox(label='Was the driver under the influence of alcohol?'), |
|
gr.components.Dropdown(label='Choose the arrest type', choices=list(tf['OneHotEncoder'].transformers_[0][1].categories_[0]), value='A - Marked Patrol'), |
|
gr.components.Checkbox(label='Were seatbelts used appropriately?'), |
|
gr.components.Checkbox(label='Did the driver actions contribute to an accident?'), |
|
gr.components.Checkbox(label='Was the driver disobedient? (such as failing to display documentation upon request)?'), |
|
gr.components.Dropdown(label='Choose the driver city', choices=list(tf['OrdinalEncoder_VeryHighCardinality'].categories_[0]), value='SILVER SPRING'), |
|
gr.components.Dropdown(label='Driver Gender', choices=['M', 'F', 'N'], value='M'), |
|
gr.components.Checkbox(label='Was the driver driving with Invalid Documentation (such as suspended registration, suspended license, expired registration plates and validation tabs or expired license plate)?'), |
|
gr.components.Dropdown(label='Vehicle Make', choices=list(tf['OrdinalEncoder_HighCardinality'].categories_[0]), value='TOYOTA'), |
|
gr.components.Checkbox(label='Was the driver using a mobile phone while driving?'), |
|
gr.components.Checkbox(label='Was the driver caught driving with negligence (example switching lanes in an unsafe manner)?'), |
|
gr.components.Slider(minimum=1, step=1, label='Number of offences committed'), |
|
gr.components.Checkbox(label='Did the violation involve any personal injury?'), |
|
gr.components.Checkbox(label='Did the violation involve any property damage?'), |
|
gr.components.Dropdown(label='Choose the race of the driver', choices=list(tf['OneHotEncoder'].transformers_[0][1].categories_[2]), value='WHITE'), |
|
gr.components.Checkbox(label='Did the driver fail to obey signs and markings (such as traffic control device instructions, stop lights, red signal and stop sign lines)?'), |
|
gr.components.Dropdown(label='What was the outcome of the search (if conducted)?', choices=list(tf['OneHotEncoder'].transformers_[0][1].categories_[3]), value='NO SEARCH CONDUCTED'), |
|
gr.components.Checkbox(label='Was the driver caught speeding?'), |
|
gr.components.Slider(maximum=23, step=1, label='Time HOUR when stop occurred in 24-hour format'), |
|
gr.components.Slider(minimum=2012, maximum=2024, step=1, label='Year when stop occurred'), |
|
gr.components.Dropdown(label='What is the name of the subagency that conducted the traffic stop?', choices=list(tf['OneHotEncoder'].transformers_[0][1].categories_[1]), value='4th District, Wheaton'), |
|
gr.components.Checkbox(label='Was the vehicle safe and up to standards (lights properly switched, registration plates attached etc.)?'), |
|
gr.components.Dropdown(label='What is the vehicle type?', choices=list(tf['OrdinalEncoder_HighCardinality'].categories_[1]), value='02 - Automobile'), |
|
gr.components.Slider(minimum=1970, maximum=2023, step=1, label='Year of manufacture of the vehicle:')], |
|
outputs=["text"]) |
|
|
|
iface.launch(debug=True) |