import gradio as gr import tensorflow as tf import pickle import numpy as np import requests from ProGPT import Conversation from sklearn.preprocessing import LabelEncoder # Load saved components with open('preprocessing_params.pkl', 'rb') as f: preprocessing_params = pickle.load(f) with open('fisher_information.pkl', 'rb') as f: fisher_information = pickle.load(f) with open('label_encoder.pkl', 'rb') as f: label_encoder = pickle.load(f) with open('url_tokenizer.pkl', 'rb') as f: url_tokenizer = pickle.load(f) with open('html_tokenizer.pkl', 'rb') as f: html_tokenizer = pickle.load(f) # Load the model with custom loss @tf.keras.utils.register_keras_serializable() class EWCLoss(tf.keras.losses.Loss): def __init__(self, model=None, fisher_information=None, importance=1.0, reduction='auto', name=None): super(EWCLoss, self).__init__(reduction=reduction, name=name) self.model = model self.fisher_information = fisher_information self.importance = importance self.prev_weights = [layer.numpy() for layer in model.trainable_weights] if model else None def call(self, y_true, y_pred): standard_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred) ewc_loss = 0.0 for layer, fisher_info, prev_weight in zip(self.model.trainable_weights, self.fisher_information, self.prev_weights): ewc_loss += tf.reduce_sum(fisher_info * tf.square(layer - prev_weight)) return standard_loss + (self.importance / 2.0) * ewc_loss def get_config(self): config = super().get_config() config.update({ 'importance': self.importance, 'reduction': self.reduction, 'name': self.name, }) return config @classmethod def from_config(cls, config): with open('fisher_information.pkl', 'rb') as f: fisher_information = pickle.load(f) return cls(model=None, fisher_information=fisher_information, **config) # Load the model first without the custom loss model = tf.keras.models.load_model('new_phishing_detection_model.keras', compile=False) # Reconstruct the EWC loss ewc_loss = EWCLoss(model=model, fisher_information=fisher_information, importance=1000) # Compile the model with EWC loss and metrics model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005), loss=ewc_loss, metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()]) # Chatbot setup access_token = 'your_pro_gpt_access_token' chatbot = Conversation(access_token) # Function to preprocess input def preprocess_input(input_text, tokenizer, max_length): sequences = tokenizer.texts_to_sequences([input_text]) padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(sequences, maxlen=max_length, padding='post', truncating='post') return padded_sequences # Function to get prediction def get_prediction(input_text, input_type): is_url = input_type == "URL" if is_url: input_data = preprocess_input(input_text, url_tokenizer, preprocessing_params['max_new_url_length']) else: input_data = preprocess_input(input_text, html_tokenizer, preprocessing_params['max_new_html_length']) prediction = model.predict([input_data, input_data])[0][0] return prediction # Function to fetch latest phishing sites from PhishTank def fetch_latest_phishing_sites(): try: response = requests.get('https://data.phishtank.com/data/online-valid.json') data = response.json() return data[:5] except Exception as e: return [] # Gradio UI def phishing_detection(input_text, input_type): prediction = get_prediction(input_text, input_type) if prediction > 0.5: return f"Warning: This site is likely a phishing site! ({prediction:.2f})" else: return f"Safe: This site is not likely a phishing site. ({prediction:.2f})" def latest_phishing_sites(): sites = fetch_latest_phishing_sites() return [f"{site['url']}" for site in sites] def chatbot_response(user_input): response = chatbot.prompt(user_input) return response def interface(input_text, input_type): result = phishing_detection(input_text, input_type) latest_sites = latest_phishing_sites() chatbot_res = chatbot_response(input_text) return result, latest_sites, chatbot_res iface = gr.Interface( fn=interface, inputs=[ gr.inputs.Textbox(lines=5, placeholder="Enter URL or HTML code"), gr.inputs.Radio(["URL", "HTML"], type="value", label="Input Type") ], outputs=[ "text", gr.outputs.Textbox(label="Latest Phishing Sites"), gr.outputs.Textbox(label="Chatbot Response") ], title="Phishing Detection with Enhanced EWC Model", description="Check if a URL or HTML is Phishing. Latest phishing sites from PhishTank and a chatbot assistant for phishing issues.", theme="default" ) iface.launch()