rmdhirr's picture
Create app.py
eb30cad verified
raw
history blame
No virus
4.49 kB
import gradio as gr
import tensorflow as tf
import pickle
import numpy as np
import requests
from ProGPT import Conversation
# Load saved components
with open('preprocessing_params.pkl', 'rb') as f:
preprocessing_params = pickle.load(f)
with open('fisher_information.pkl', 'rb') as f:
fisher_information = pickle.load(f)
with open('label_encoder.pkl', 'rb') as f:
label_encoder = pickle.load(f)
with open('url_tokenizer.pkl', 'rb') as f:
url_tokenizer = pickle.load(f)
with open('html_tokenizer.pkl', 'rb') as f:
html_tokenizer = pickle.load(f)
# Load the model with custom loss
@tf.keras.utils.register_keras_serializable()
class EWCLoss(tf.keras.losses.Loss):
def __init__(self, model, fisher_information, importance=1.0, reduction='auto', name=None):
super(EWCLoss, self).__init__(reduction=reduction, name=name)
self.model = model
self.fisher_information = fisher_information
self.importance = importance
self.prev_weights = [layer.numpy() for layer in model.trainable_weights]
def call(self, y_true, y_pred):
standard_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
ewc_loss = 0.0
for layer, fisher_info, prev_weight in zip(self.model.trainable_weights, self.fisher_information, self.prev_weights):
ewc_loss += tf.reduce_sum(fisher_info * tf.square(layer - prev_weight))
return standard_loss + (self.importance / 2.0) * ewc_loss
def get_config(self):
config = super().get_config()
config.update({
'importance': self.importance,
'reduction': self.reduction,
'name': self.name,
})
return config
@classmethod
def from_config(cls, config):
# Load fisher information from external file
with open('fisher_information.pkl', 'rb') as f:
fisher_information = pickle.load(f)
return cls(model=None, fisher_information=fisher_information, **config)
# Load the model
model = tf.keras.models.load_model('new_phishing_detection_model.keras',
custom_objects={'EWCLoss': EWCLoss})
# Recompile the model
ewc_loss = EWCLoss(model=model, fisher_information=fisher_information, importance=1000)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005),
loss=ewc_loss,
metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
# Chatbot setup
access_token = 'your_pro_gpt_access_token'
chatbot = Conversation(access_token)
# Function to preprocess input
def preprocess_input(input_text, tokenizer, max_length):
sequences = tokenizer.texts_to_sequences([input_text])
padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(sequences, maxlen=max_length, padding='post', truncating='post')
return padded_sequences
# Function to get prediction
def get_prediction(input_text, input_type):
is_url = input_type == "URL"
if is_url:
input_data = preprocess_input(input_text, url_tokenizer, preprocessing_params['max_url_length'])
else:
input_data = preprocess_input(input_text, html_tokenizer, preprocessing_params['max_html_length'])
prediction = model.predict([input_data, input_data])[0][0]
return prediction
# Function to fetch latest phishing sites from PhishTank
def fetch_latest_phishing_sites():
try:
response = requests.get('https://data.phishtank.com/data/online-valid.json')
data = response.json()
return data[:5]
except Exception as e:
return []
# Gradio UI
def phishing_detection(input_text, input_type):
prediction = get_prediction(input_text, input_type)
if prediction > 0.5:
return f"Warning: This site is likely a phishing site! ({prediction:.2f})"
else:
return f"Safe: This site is not likely a phishing site. ({prediction:.2f})"
def latest_phishing_sites():
sites = fetch_latest_phishing_sites()
return [f"{site['url']}" for site in sites]
def chatbot_response(user_input):
response = chatbot.prompt(user_input)
return response
iface = gr.Interface(
fn=phishing_detection,
inputs=[gr.inputs.Textbox(lines=5, placeholder="Enter URL or HTML code"), gr.inputs.Radio(["URL", "HTML"], type="value", label="Input Type")],
outputs="text",
title="Phishing Detection with Enhanced EWC Model",
description="Check if a URL or HTML is Phishing",
theme="default"
)
iface.launch()