File size: 4,490 Bytes
eb30cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
import tensorflow as tf
import pickle
import numpy as np
import requests
from ProGPT import Conversation

# Load saved components
with open('preprocessing_params.pkl', 'rb') as f:
    preprocessing_params = pickle.load(f)
with open('fisher_information.pkl', 'rb') as f:
    fisher_information = pickle.load(f)
with open('label_encoder.pkl', 'rb') as f:
    label_encoder = pickle.load(f)
with open('url_tokenizer.pkl', 'rb') as f:
    url_tokenizer = pickle.load(f)
with open('html_tokenizer.pkl', 'rb') as f:
    html_tokenizer = pickle.load(f)

# Load the model with custom loss
@tf.keras.utils.register_keras_serializable()
class EWCLoss(tf.keras.losses.Loss):
    def __init__(self, model, fisher_information, importance=1.0, reduction='auto', name=None):
        super(EWCLoss, self).__init__(reduction=reduction, name=name)
        self.model = model
        self.fisher_information = fisher_information
        self.importance = importance
        self.prev_weights = [layer.numpy() for layer in model.trainable_weights]

    def call(self, y_true, y_pred):
        standard_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
        ewc_loss = 0.0
        for layer, fisher_info, prev_weight in zip(self.model.trainable_weights, self.fisher_information, self.prev_weights):
            ewc_loss += tf.reduce_sum(fisher_info * tf.square(layer - prev_weight))
        return standard_loss + (self.importance / 2.0) * ewc_loss
    
    def get_config(self):
        config = super().get_config()
        config.update({
            'importance': self.importance,
            'reduction': self.reduction,
            'name': self.name,
        })
        return config
    
    @classmethod
    def from_config(cls, config):
        # Load fisher information from external file
        with open('fisher_information.pkl', 'rb') as f:
            fisher_information = pickle.load(f)
        return cls(model=None, fisher_information=fisher_information, **config)

# Load the model
model = tf.keras.models.load_model('new_phishing_detection_model.keras',
                                   custom_objects={'EWCLoss': EWCLoss})

# Recompile the model
ewc_loss = EWCLoss(model=model, fisher_information=fisher_information, importance=1000)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005),
              loss=ewc_loss,
              metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])

# Chatbot setup
access_token = 'your_pro_gpt_access_token'
chatbot = Conversation(access_token)

# Function to preprocess input
def preprocess_input(input_text, tokenizer, max_length):
    sequences = tokenizer.texts_to_sequences([input_text])
    padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(sequences, maxlen=max_length, padding='post', truncating='post')
    return padded_sequences

# Function to get prediction
def get_prediction(input_text, input_type):
    is_url = input_type == "URL"
    if is_url:
        input_data = preprocess_input(input_text, url_tokenizer, preprocessing_params['max_url_length'])
    else:
        input_data = preprocess_input(input_text, html_tokenizer, preprocessing_params['max_html_length'])
    
    prediction = model.predict([input_data, input_data])[0][0]
    return prediction

# Function to fetch latest phishing sites from PhishTank
def fetch_latest_phishing_sites():
    try:
        response = requests.get('https://data.phishtank.com/data/online-valid.json')
        data = response.json()
        return data[:5]
    except Exception as e:
        return []

# Gradio UI
def phishing_detection(input_text, input_type):
    prediction = get_prediction(input_text, input_type)
    if prediction > 0.5:
        return f"Warning: This site is likely a phishing site! ({prediction:.2f})"
    else:
        return f"Safe: This site is not likely a phishing site. ({prediction:.2f})"

def latest_phishing_sites():
    sites = fetch_latest_phishing_sites()
    return [f"{site['url']}" for site in sites]

def chatbot_response(user_input):
    response = chatbot.prompt(user_input)
    return response

iface = gr.Interface(
    fn=phishing_detection,
    inputs=[gr.inputs.Textbox(lines=5, placeholder="Enter URL or HTML code"), gr.inputs.Radio(["URL", "HTML"], type="value", label="Input Type")],
    outputs="text",
    title="Phishing Detection with Enhanced EWC Model",
    description="Check if a URL or HTML is Phishing",
    theme="default"
)

iface.launch()