rmdhirr's picture
Update app.py
3d7830a verified
raw
history blame
No virus
5 kB
import gradio as gr
import tensorflow as tf
import pickle
import numpy as np
import requests
from ProGPT import Conversation
from sklearn.preprocessing import LabelEncoder
# Load saved components
with open('preprocessing_params.pkl', 'rb') as f:
preprocessing_params = pickle.load(f)
with open('fisher_information.pkl', 'rb') as f:
fisher_information = pickle.load(f)
with open('label_encoder.pkl', 'rb') as f:
label_encoder = pickle.load(f)
with open('url_tokenizer.pkl', 'rb') as f:
url_tokenizer = pickle.load(f)
with open('html_tokenizer.pkl', 'rb') as f:
html_tokenizer = pickle.load(f)
# Load the model with custom loss
@tf.keras.utils.register_keras_serializable()
class EWCLoss(tf.keras.losses.Loss):
def __init__(self, model=None, fisher_information=None, importance=1.0, reduction='auto', name=None):
super(EWCLoss, self).__init__(reduction=reduction, name=name)
self.model = model
self.fisher_information = fisher_information
self.importance = importance
self.prev_weights = [layer.numpy() for layer in model.trainable_weights] if model else None
def call(self, y_true, y_pred):
standard_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
ewc_loss = 0.0
for layer, fisher_info, prev_weight in zip(self.model.trainable_weights, self.fisher_information, self.prev_weights):
ewc_loss += tf.reduce_sum(fisher_info * tf.square(layer - prev_weight))
return standard_loss + (self.importance / 2.0) * ewc_loss
def get_config(self):
config = super().get_config()
config.update({
'importance': self.importance,
'reduction': self.reduction,
'name': self.name,
})
return config
@classmethod
def from_config(cls, config):
with open('fisher_information.pkl', 'rb') as f:
fisher_information = pickle.load(f)
return cls(model=None, fisher_information=fisher_information, **config)
# Load the model first without the custom loss
model = tf.keras.models.load_model('new_phishing_detection_model.keras', compile=False)
# Reconstruct the EWC loss
ewc_loss = EWCLoss(model=model, fisher_information=fisher_information, importance=1000)
# Compile the model with EWC loss and metrics
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005),
loss=ewc_loss,
metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
# Chatbot setup
access_token = 'your_pro_gpt_access_token'
chatbot = Conversation(access_token)
# Function to preprocess input
def preprocess_input(input_text, tokenizer, max_length):
sequences = tokenizer.texts_to_sequences([input_text])
padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(sequences, maxlen=max_length, padding='post', truncating='post')
return padded_sequences
# Function to get prediction
def get_prediction(input_text, input_type):
is_url = input_type == "URL"
if is_url:
input_data = preprocess_input(input_text, url_tokenizer, preprocessing_params['max_new_url_length'])
else:
input_data = preprocess_input(input_text, html_tokenizer, preprocessing_params['max_new_html_length'])
prediction = model.predict([input_data, input_data])[0][0]
return prediction
# Function to fetch latest phishing sites from PhishTank
def fetch_latest_phishing_sites():
try:
response = requests.get('https://data.phishtank.com/data/online-valid.json')
data = response.json()
return data[:5]
except Exception as e:
return []
# Gradio UI
def phishing_detection(input_text, input_type):
prediction = get_prediction(input_text, input_type)
if prediction > 0.5:
return f"Warning: This site is likely a phishing site! ({prediction:.2f})"
else:
return f"Safe: This site is not likely a phishing site. ({prediction:.2f})"
def latest_phishing_sites():
sites = fetch_latest_phishing_sites()
return [f"{site['url']}" for site in sites]
def chatbot_response(user_input):
response = chatbot.prompt(user_input)
return response
def interface(input_text, input_type):
result = phishing_detection(input_text, input_type)
latest_sites = latest_phishing_sites()
chatbot_res = chatbot_response(input_text)
return result, latest_sites, chatbot_res
iface = gr.Interface(
fn=interface,
inputs=[
gr.inputs.Textbox(lines=5, placeholder="Enter URL or HTML code"),
gr.inputs.Radio(["URL", "HTML"], type="value", label="Input Type")
],
outputs=[
"text",
gr.outputs.Textbox(label="Latest Phishing Sites"),
gr.outputs.Textbox(label="Chatbot Response")
],
title="Phishing Detection with Enhanced EWC Model",
description="Check if a URL or HTML is Phishing. Latest phishing sites from PhishTank and a chatbot assistant for phishing issues.",
theme="default"
)
iface.launch()