rmdhirr commited on
Commit
eb30cad
1 Parent(s): d5217c5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ import pickle
4
+ import numpy as np
5
+ import requests
6
+ from ProGPT import Conversation
7
+
8
+ # Load saved components
9
+ with open('preprocessing_params.pkl', 'rb') as f:
10
+ preprocessing_params = pickle.load(f)
11
+ with open('fisher_information.pkl', 'rb') as f:
12
+ fisher_information = pickle.load(f)
13
+ with open('label_encoder.pkl', 'rb') as f:
14
+ label_encoder = pickle.load(f)
15
+ with open('url_tokenizer.pkl', 'rb') as f:
16
+ url_tokenizer = pickle.load(f)
17
+ with open('html_tokenizer.pkl', 'rb') as f:
18
+ html_tokenizer = pickle.load(f)
19
+
20
+ # Load the model with custom loss
21
+ @tf.keras.utils.register_keras_serializable()
22
+ class EWCLoss(tf.keras.losses.Loss):
23
+ def __init__(self, model, fisher_information, importance=1.0, reduction='auto', name=None):
24
+ super(EWCLoss, self).__init__(reduction=reduction, name=name)
25
+ self.model = model
26
+ self.fisher_information = fisher_information
27
+ self.importance = importance
28
+ self.prev_weights = [layer.numpy() for layer in model.trainable_weights]
29
+
30
+ def call(self, y_true, y_pred):
31
+ standard_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
32
+ ewc_loss = 0.0
33
+ for layer, fisher_info, prev_weight in zip(self.model.trainable_weights, self.fisher_information, self.prev_weights):
34
+ ewc_loss += tf.reduce_sum(fisher_info * tf.square(layer - prev_weight))
35
+ return standard_loss + (self.importance / 2.0) * ewc_loss
36
+
37
+ def get_config(self):
38
+ config = super().get_config()
39
+ config.update({
40
+ 'importance': self.importance,
41
+ 'reduction': self.reduction,
42
+ 'name': self.name,
43
+ })
44
+ return config
45
+
46
+ @classmethod
47
+ def from_config(cls, config):
48
+ # Load fisher information from external file
49
+ with open('fisher_information.pkl', 'rb') as f:
50
+ fisher_information = pickle.load(f)
51
+ return cls(model=None, fisher_information=fisher_information, **config)
52
+
53
+ # Load the model
54
+ model = tf.keras.models.load_model('new_phishing_detection_model.keras',
55
+ custom_objects={'EWCLoss': EWCLoss})
56
+
57
+ # Recompile the model
58
+ ewc_loss = EWCLoss(model=model, fisher_information=fisher_information, importance=1000)
59
+ model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005),
60
+ loss=ewc_loss,
61
+ metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
62
+
63
+ # Chatbot setup
64
+ access_token = 'your_pro_gpt_access_token'
65
+ chatbot = Conversation(access_token)
66
+
67
+ # Function to preprocess input
68
+ def preprocess_input(input_text, tokenizer, max_length):
69
+ sequences = tokenizer.texts_to_sequences([input_text])
70
+ padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(sequences, maxlen=max_length, padding='post', truncating='post')
71
+ return padded_sequences
72
+
73
+ # Function to get prediction
74
+ def get_prediction(input_text, input_type):
75
+ is_url = input_type == "URL"
76
+ if is_url:
77
+ input_data = preprocess_input(input_text, url_tokenizer, preprocessing_params['max_url_length'])
78
+ else:
79
+ input_data = preprocess_input(input_text, html_tokenizer, preprocessing_params['max_html_length'])
80
+
81
+ prediction = model.predict([input_data, input_data])[0][0]
82
+ return prediction
83
+
84
+ # Function to fetch latest phishing sites from PhishTank
85
+ def fetch_latest_phishing_sites():
86
+ try:
87
+ response = requests.get('https://data.phishtank.com/data/online-valid.json')
88
+ data = response.json()
89
+ return data[:5]
90
+ except Exception as e:
91
+ return []
92
+
93
+ # Gradio UI
94
+ def phishing_detection(input_text, input_type):
95
+ prediction = get_prediction(input_text, input_type)
96
+ if prediction > 0.5:
97
+ return f"Warning: This site is likely a phishing site! ({prediction:.2f})"
98
+ else:
99
+ return f"Safe: This site is not likely a phishing site. ({prediction:.2f})"
100
+
101
+ def latest_phishing_sites():
102
+ sites = fetch_latest_phishing_sites()
103
+ return [f"{site['url']}" for site in sites]
104
+
105
+ def chatbot_response(user_input):
106
+ response = chatbot.prompt(user_input)
107
+ return response
108
+
109
+ iface = gr.Interface(
110
+ fn=phishing_detection,
111
+ inputs=[gr.inputs.Textbox(lines=5, placeholder="Enter URL or HTML code"), gr.inputs.Radio(["URL", "HTML"], type="value", label="Input Type")],
112
+ outputs="text",
113
+ title="Phishing Detection with Enhanced EWC Model",
114
+ description="Check if a URL or HTML is Phishing",
115
+ theme="default"
116
+ )
117
+
118
+ iface.launch()