Spaces:
Sleeping
Sleeping
File size: 5,003 Bytes
eb30cad 3d7830a eb30cad 3d7830a eb30cad 3d7830a eb30cad 3d7830a eb30cad 3d7830a eb30cad 3d7830a eb30cad 3d7830a eb30cad 3d7830a eb30cad 3d7830a eb30cad 3d7830a eb30cad 3d7830a eb30cad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
import tensorflow as tf
import pickle
import numpy as np
import requests
from ProGPT import Conversation
from sklearn.preprocessing import LabelEncoder
# Load saved components
with open('preprocessing_params.pkl', 'rb') as f:
preprocessing_params = pickle.load(f)
with open('fisher_information.pkl', 'rb') as f:
fisher_information = pickle.load(f)
with open('label_encoder.pkl', 'rb') as f:
label_encoder = pickle.load(f)
with open('url_tokenizer.pkl', 'rb') as f:
url_tokenizer = pickle.load(f)
with open('html_tokenizer.pkl', 'rb') as f:
html_tokenizer = pickle.load(f)
# Load the model with custom loss
@tf.keras.utils.register_keras_serializable()
class EWCLoss(tf.keras.losses.Loss):
def __init__(self, model=None, fisher_information=None, importance=1.0, reduction='auto', name=None):
super(EWCLoss, self).__init__(reduction=reduction, name=name)
self.model = model
self.fisher_information = fisher_information
self.importance = importance
self.prev_weights = [layer.numpy() for layer in model.trainable_weights] if model else None
def call(self, y_true, y_pred):
standard_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
ewc_loss = 0.0
for layer, fisher_info, prev_weight in zip(self.model.trainable_weights, self.fisher_information, self.prev_weights):
ewc_loss += tf.reduce_sum(fisher_info * tf.square(layer - prev_weight))
return standard_loss + (self.importance / 2.0) * ewc_loss
def get_config(self):
config = super().get_config()
config.update({
'importance': self.importance,
'reduction': self.reduction,
'name': self.name,
})
return config
@classmethod
def from_config(cls, config):
with open('fisher_information.pkl', 'rb') as f:
fisher_information = pickle.load(f)
return cls(model=None, fisher_information=fisher_information, **config)
# Load the model first without the custom loss
model = tf.keras.models.load_model('new_phishing_detection_model.keras', compile=False)
# Reconstruct the EWC loss
ewc_loss = EWCLoss(model=model, fisher_information=fisher_information, importance=1000)
# Compile the model with EWC loss and metrics
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005),
loss=ewc_loss,
metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
# Chatbot setup
access_token = 'your_pro_gpt_access_token'
chatbot = Conversation(access_token)
# Function to preprocess input
def preprocess_input(input_text, tokenizer, max_length):
sequences = tokenizer.texts_to_sequences([input_text])
padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(sequences, maxlen=max_length, padding='post', truncating='post')
return padded_sequences
# Function to get prediction
def get_prediction(input_text, input_type):
is_url = input_type == "URL"
if is_url:
input_data = preprocess_input(input_text, url_tokenizer, preprocessing_params['max_new_url_length'])
else:
input_data = preprocess_input(input_text, html_tokenizer, preprocessing_params['max_new_html_length'])
prediction = model.predict([input_data, input_data])[0][0]
return prediction
# Function to fetch latest phishing sites from PhishTank
def fetch_latest_phishing_sites():
try:
response = requests.get('https://data.phishtank.com/data/online-valid.json')
data = response.json()
return data[:5]
except Exception as e:
return []
# Gradio UI
def phishing_detection(input_text, input_type):
prediction = get_prediction(input_text, input_type)
if prediction > 0.5:
return f"Warning: This site is likely a phishing site! ({prediction:.2f})"
else:
return f"Safe: This site is not likely a phishing site. ({prediction:.2f})"
def latest_phishing_sites():
sites = fetch_latest_phishing_sites()
return [f"{site['url']}" for site in sites]
def chatbot_response(user_input):
response = chatbot.prompt(user_input)
return response
def interface(input_text, input_type):
result = phishing_detection(input_text, input_type)
latest_sites = latest_phishing_sites()
chatbot_res = chatbot_response(input_text)
return result, latest_sites, chatbot_res
iface = gr.Interface(
fn=interface,
inputs=[
gr.inputs.Textbox(lines=5, placeholder="Enter URL or HTML code"),
gr.inputs.Radio(["URL", "HTML"], type="value", label="Input Type")
],
outputs=[
"text",
gr.outputs.Textbox(label="Latest Phishing Sites"),
gr.outputs.Textbox(label="Chatbot Response")
],
title="Phishing Detection with Enhanced EWC Model",
description="Check if a URL or HTML is Phishing. Latest phishing sites from PhishTank and a chatbot assistant for phishing issues.",
theme="default"
)
iface.launch() |