rmdhirr commited on
Commit
3d7830a
1 Parent(s): 4896a5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -13
app.py CHANGED
@@ -4,6 +4,7 @@ import pickle
4
  import numpy as np
5
  import requests
6
  from ProGPT import Conversation
 
7
 
8
  # Load saved components
9
  with open('preprocessing_params.pkl', 'rb') as f:
@@ -20,12 +21,12 @@ with open('html_tokenizer.pkl', 'rb') as f:
20
  # Load the model with custom loss
21
  @tf.keras.utils.register_keras_serializable()
22
  class EWCLoss(tf.keras.losses.Loss):
23
- def __init__(self, model, fisher_information, importance=1.0, reduction='auto', name=None):
24
  super(EWCLoss, self).__init__(reduction=reduction, name=name)
25
  self.model = model
26
  self.fisher_information = fisher_information
27
  self.importance = importance
28
- self.prev_weights = [layer.numpy() for layer in model.trainable_weights]
29
 
30
  def call(self, y_true, y_pred):
31
  standard_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
@@ -45,17 +46,17 @@ class EWCLoss(tf.keras.losses.Loss):
45
 
46
  @classmethod
47
  def from_config(cls, config):
48
- # Load fisher information from external file
49
  with open('fisher_information.pkl', 'rb') as f:
50
  fisher_information = pickle.load(f)
51
  return cls(model=None, fisher_information=fisher_information, **config)
52
 
53
- # Load the model
54
- model = tf.keras.models.load_model('new_phishing_detection_model.keras',
55
- custom_objects={'EWCLoss': EWCLoss})
56
 
57
- # Recompile the model
58
  ewc_loss = EWCLoss(model=model, fisher_information=fisher_information, importance=1000)
 
 
59
  model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005),
60
  loss=ewc_loss,
61
  metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
@@ -74,9 +75,9 @@ def preprocess_input(input_text, tokenizer, max_length):
74
  def get_prediction(input_text, input_type):
75
  is_url = input_type == "URL"
76
  if is_url:
77
- input_data = preprocess_input(input_text, url_tokenizer, preprocessing_params['max_url_length'])
78
  else:
79
- input_data = preprocess_input(input_text, html_tokenizer, preprocessing_params['max_html_length'])
80
 
81
  prediction = model.predict([input_data, input_data])[0][0]
82
  return prediction
@@ -106,12 +107,25 @@ def chatbot_response(user_input):
106
  response = chatbot.prompt(user_input)
107
  return response
108
 
 
 
 
 
 
 
109
  iface = gr.Interface(
110
- fn=phishing_detection,
111
- inputs=[gr.inputs.Textbox(lines=5, placeholder="Enter URL or HTML code"), gr.inputs.Radio(["URL", "HTML"], type="value", label="Input Type")],
112
- outputs="text",
 
 
 
 
 
 
 
113
  title="Phishing Detection with Enhanced EWC Model",
114
- description="Check if a URL or HTML is Phishing",
115
  theme="default"
116
  )
117
 
 
4
  import numpy as np
5
  import requests
6
  from ProGPT import Conversation
7
+ from sklearn.preprocessing import LabelEncoder
8
 
9
  # Load saved components
10
  with open('preprocessing_params.pkl', 'rb') as f:
 
21
  # Load the model with custom loss
22
  @tf.keras.utils.register_keras_serializable()
23
  class EWCLoss(tf.keras.losses.Loss):
24
+ def __init__(self, model=None, fisher_information=None, importance=1.0, reduction='auto', name=None):
25
  super(EWCLoss, self).__init__(reduction=reduction, name=name)
26
  self.model = model
27
  self.fisher_information = fisher_information
28
  self.importance = importance
29
+ self.prev_weights = [layer.numpy() for layer in model.trainable_weights] if model else None
30
 
31
  def call(self, y_true, y_pred):
32
  standard_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
 
46
 
47
  @classmethod
48
  def from_config(cls, config):
 
49
  with open('fisher_information.pkl', 'rb') as f:
50
  fisher_information = pickle.load(f)
51
  return cls(model=None, fisher_information=fisher_information, **config)
52
 
53
+ # Load the model first without the custom loss
54
+ model = tf.keras.models.load_model('new_phishing_detection_model.keras', compile=False)
 
55
 
56
+ # Reconstruct the EWC loss
57
  ewc_loss = EWCLoss(model=model, fisher_information=fisher_information, importance=1000)
58
+
59
+ # Compile the model with EWC loss and metrics
60
  model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005),
61
  loss=ewc_loss,
62
  metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
 
75
  def get_prediction(input_text, input_type):
76
  is_url = input_type == "URL"
77
  if is_url:
78
+ input_data = preprocess_input(input_text, url_tokenizer, preprocessing_params['max_new_url_length'])
79
  else:
80
+ input_data = preprocess_input(input_text, html_tokenizer, preprocessing_params['max_new_html_length'])
81
 
82
  prediction = model.predict([input_data, input_data])[0][0]
83
  return prediction
 
107
  response = chatbot.prompt(user_input)
108
  return response
109
 
110
+ def interface(input_text, input_type):
111
+ result = phishing_detection(input_text, input_type)
112
+ latest_sites = latest_phishing_sites()
113
+ chatbot_res = chatbot_response(input_text)
114
+ return result, latest_sites, chatbot_res
115
+
116
  iface = gr.Interface(
117
+ fn=interface,
118
+ inputs=[
119
+ gr.inputs.Textbox(lines=5, placeholder="Enter URL or HTML code"),
120
+ gr.inputs.Radio(["URL", "HTML"], type="value", label="Input Type")
121
+ ],
122
+ outputs=[
123
+ "text",
124
+ gr.outputs.Textbox(label="Latest Phishing Sites"),
125
+ gr.outputs.Textbox(label="Chatbot Response")
126
+ ],
127
  title="Phishing Detection with Enhanced EWC Model",
128
+ description="Check if a URL or HTML is Phishing. Latest phishing sites from PhishTank and a chatbot assistant for phishing issues.",
129
  theme="default"
130
  )
131