rmdhirr's picture
Update app.py
2f8164c verified
raw
history blame
No virus
2.52 kB
import gradio as gr
import tensorflow as tf
import pickle
import numpy as np
from sklearn.preprocessing import LabelEncoder
# Load saved components
with open('preprocessing_params.pkl', 'rb') as f:
preprocessing_params = pickle.load(f)
with open('label_encoder.pkl', 'rb') as f:
label_encoder = pickle.load(f)
with open('url_tokenizer.pkl', 'rb') as f:
url_tokenizer = pickle.load(f)
with open('html_tokenizer.pkl', 'rb') as f:
html_tokenizer = pickle.load(f)
# Load the model
model = tf.keras.models.load_model('new_phishing_detection_model.keras')
# Compile the model with standard loss and metrics
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005),
loss='binary_crossentropy',
metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
# Function to preprocess input
def preprocess_input(input_text, tokenizer, max_length):
sequences = tokenizer.texts_to_sequences([input_text])
padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(sequences, maxlen=max_length, padding='post', truncating='post')
return padded_sequences
# Function to get prediction
def get_prediction(input_text, input_type):
is_url = input_type == "URL"
if is_url:
input_data = preprocess_input(input_text, url_tokenizer, preprocessing_params['max_url_length'])
input_data = [input_data, np.zeros((1, preprocessing_params['max_html_length']))] # dummy HTML input
else:
input_data = preprocess_input(input_text, html_tokenizer, preprocessing_params['max_html_length'])
input_data = [np.zeros((1, preprocessing_params['max_url_length'])), input_data] # dummy URL input
prediction = model.predict(input_data)[0][0]
return prediction
# Gradio UI
def phishing_detection(input_text, input_type):
prediction = get_prediction(input_text, input_type)
if prediction > 0.5:
return f"Warning: This site is likely a phishing site! ({prediction:.2f})"
else:
return f"Safe: This site is not likely a phishing site. ({prediction:.2f})"
iface = gr.Interface(
fn=phishing_detection,
inputs=[
gr.components.Textbox(lines=5, placeholder="Enter URL or HTML code"),
gr.components.Radio(["URL", "HTML"], type="value", label="Input Type")
],
outputs=gr.components.Textbox(label="Phishing Detection Result"),
title="Phishing Detection Model",
description="Check if a URL or HTML is Phishing.",
theme="default"
)
iface.launch()