nimocodes commited on
Commit
4abab34
·
verified ·
1 Parent(s): 120e1b3

Update inference_2.py

Browse files
Files changed (1) hide show
  1. inference_2.py +40 -9
inference_2.py CHANGED
@@ -1,22 +1,32 @@
 
1
  import cv2
2
  import onnx
3
  import torch
4
  import argparse
5
  import numpy as np
 
 
6
  from models import image
7
 
8
- import warnings
9
  from onnx2pytorch import ConvertModel
10
 
11
- warnings.filterwarnings("ignore", message="The given NumPy array is not writable")
12
- with warnings.catch_warnings():
13
- warnings.filterwarnings("ignore", message="The given NumPy array is not writable")
14
- onnx_model = onnx.load('models/efficientnet.onnx')
15
- pytorch_model = ConvertModel(onnx_model)
16
  torch.manual_seed(42)
17
 
18
 
19
- audio_args = { 'nb_samp': 64600, 'first_conv': 1024, 'in_channels': 1, 'filts': [20, [20, 20], [20, 128], [128, 128]], 'blocks': [2, 4],'nb_fc_node': 1024,'gru_node': 1024, 'nb_gru_layer': 3, 'nb_classes': 2}
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
  def get_args(parser):
@@ -55,14 +65,14 @@ def get_args(parser):
55
 
56
  def load_img_modality_model(args):
57
  rgb_encoder = pytorch_model
58
- ckpt = torch.load('models/model.pth', map_location = torch.device('cpu'))
59
  rgb_encoder.load_state_dict(ckpt['rgb_encoder'], strict = True)
60
  rgb_encoder.eval()
61
  return rgb_encoder
62
 
63
  def load_spec_modality_model(args):
64
  spec_encoder = image.RawNet(args)
65
- ckpt = torch.load('models/model.pth', map_location = torch.device('cpu'))
66
  spec_encoder.load_state_dict(ckpt['spec_encoder'], strict = True)
67
  spec_encoder.eval()
68
  return spec_encoder
@@ -71,9 +81,12 @@ parser = argparse.ArgumentParser(description="Inference models")
71
  get_args(parser)
72
  args, remaining_args = parser.parse_known_args()
73
  assert remaining_args == [], remaining_args
 
74
  spec_model = load_spec_modality_model(args)
 
75
  img_model = load_img_modality_model(args)
76
 
 
77
  def preprocess_img(face):
78
  face = face / 255
79
  face = cv2.resize(face, (256, 256))
@@ -90,12 +103,15 @@ def df_spec_pred(input_audio):
90
  spec_grads = spec_model.forward(audio)
91
  spec_grads_inv = np.exp(spec_grads.cpu().detach().numpy().squeeze())
92
  max_value = np.argmax(spec_grads_inv)
 
93
  if max_value > 0.5:
94
  preds = round(100 - (max_value*100), 3)
95
  text2 = f"The audio is REAL."
 
96
  else:
97
  preds = round(max_value*100, 3)
98
  text2 = f"The audio is FAKE."
 
99
  return text2
100
 
101
  def df_img_pred(input_image):
@@ -104,25 +120,34 @@ def df_img_pred(input_image):
104
  img_grads = img_model.forward(face)
105
  img_grads = img_grads.cpu().detach().numpy()
106
  img_grads_np = np.squeeze(img_grads)
 
107
  if img_grads_np[0] > 0.5:
108
  preds = round(img_grads_np[0] * 100, 3)
109
  text2 = f"The image is REAL. \nConfidence score is: {preds}"
 
110
  else:
111
  preds = round(img_grads_np[1] * 100, 3)
112
  text2 = f"The image is FAKE. \nConfidence score is: {preds}"
 
113
  return text2
114
 
 
115
  def preprocess_video(input_video, n_frames = 3):
116
  v_cap = cv2.VideoCapture(input_video)
117
  v_len = int(v_cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
 
118
  if n_frames is None:
119
  sample = np.arange(0, v_len)
120
  else:
121
  sample = np.linspace(0, v_len - 1, n_frames).astype(int)
 
 
122
  frames = []
123
  for j in range(v_len):
124
  success = v_cap.grab()
125
  if j in sample:
 
126
  success, frame = v_cap.retrieve()
127
  if not success:
128
  continue
@@ -132,22 +157,28 @@ def preprocess_video(input_video, n_frames = 3):
132
  v_cap.release()
133
  return frames
134
 
 
135
  def df_video_pred(input_video):
136
  video_frames = preprocess_video(input_video)
137
  real_faces_list = []
138
  fake_faces_list = []
 
139
  for face in video_frames:
140
  img_grads = img_model.forward(face)
141
  img_grads = img_grads.cpu().detach().numpy()
142
  img_grads_np = np.squeeze(img_grads)
143
  real_faces_list.append(img_grads_np[0])
144
  fake_faces_list.append(img_grads_np[1])
 
145
  real_faces_mean = np.mean(real_faces_list)
146
  fake_faces_mean = np.mean(fake_faces_list)
 
147
  if real_faces_mean > 0.5:
148
  preds = round(real_faces_mean * 100, 3)
149
  text2 = f"The video is REAL. \nConfidence score is: {preds}%"
 
150
  else:
151
  preds = round(fake_faces_mean * 100, 3)
152
  text2 = f"The video is FAKE. \nConfidence score is: {preds}%"
 
153
  return text2
 
1
+ import os
2
  import cv2
3
  import onnx
4
  import torch
5
  import argparse
6
  import numpy as np
7
+ import torch.nn as nn
8
+ from models.TMC import ETMC
9
  from models import image
10
 
 
11
  from onnx2pytorch import ConvertModel
12
 
13
+ onnx_model = onnx.load('checkpoints/efficientnet.onnx')
14
+ pytorch_model = ConvertModel(onnx_model)
15
+
 
 
16
  torch.manual_seed(42)
17
 
18
 
19
+ audio_args = {
20
+ 'nb_samp': 64600,
21
+ 'first_conv': 1024,
22
+ 'in_channels': 1,
23
+ 'filts': [20, [20, 20], [20, 128], [128, 128]],
24
+ 'blocks': [2, 4],
25
+ 'nb_fc_node': 1024,
26
+ 'gru_node': 1024,
27
+ 'nb_gru_layer': 3,
28
+ 'nb_classes': 2
29
+ }
30
 
31
 
32
  def get_args(parser):
 
65
 
66
  def load_img_modality_model(args):
67
  rgb_encoder = pytorch_model
68
+ ckpt = torch.load('checkpoints/model.pth', map_location = torch.device('cpu'))
69
  rgb_encoder.load_state_dict(ckpt['rgb_encoder'], strict = True)
70
  rgb_encoder.eval()
71
  return rgb_encoder
72
 
73
  def load_spec_modality_model(args):
74
  spec_encoder = image.RawNet(args)
75
+ ckpt = torch.load('checkpoints/model.pth', map_location = torch.device('cpu'))
76
  spec_encoder.load_state_dict(ckpt['spec_encoder'], strict = True)
77
  spec_encoder.eval()
78
  return spec_encoder
 
81
  get_args(parser)
82
  args, remaining_args = parser.parse_known_args()
83
  assert remaining_args == [], remaining_args
84
+
85
  spec_model = load_spec_modality_model(args)
86
+
87
  img_model = load_img_modality_model(args)
88
 
89
+
90
  def preprocess_img(face):
91
  face = face / 255
92
  face = cv2.resize(face, (256, 256))
 
103
  spec_grads = spec_model.forward(audio)
104
  spec_grads_inv = np.exp(spec_grads.cpu().detach().numpy().squeeze())
105
  max_value = np.argmax(spec_grads_inv)
106
+
107
  if max_value > 0.5:
108
  preds = round(100 - (max_value*100), 3)
109
  text2 = f"The audio is REAL."
110
+
111
  else:
112
  preds = round(max_value*100, 3)
113
  text2 = f"The audio is FAKE."
114
+
115
  return text2
116
 
117
  def df_img_pred(input_image):
 
120
  img_grads = img_model.forward(face)
121
  img_grads = img_grads.cpu().detach().numpy()
122
  img_grads_np = np.squeeze(img_grads)
123
+
124
  if img_grads_np[0] > 0.5:
125
  preds = round(img_grads_np[0] * 100, 3)
126
  text2 = f"The image is REAL. \nConfidence score is: {preds}"
127
+
128
  else:
129
  preds = round(img_grads_np[1] * 100, 3)
130
  text2 = f"The image is FAKE. \nConfidence score is: {preds}"
131
+
132
  return text2
133
 
134
+
135
  def preprocess_video(input_video, n_frames = 3):
136
  v_cap = cv2.VideoCapture(input_video)
137
  v_len = int(v_cap.get(cv2.CAP_PROP_FRAME_COUNT))
138
+
139
+ # Pick 'n_frames' evenly spaced frames to sample
140
  if n_frames is None:
141
  sample = np.arange(0, v_len)
142
  else:
143
  sample = np.linspace(0, v_len - 1, n_frames).astype(int)
144
+
145
+ #Loop through frames.
146
  frames = []
147
  for j in range(v_len):
148
  success = v_cap.grab()
149
  if j in sample:
150
+ # Load frame
151
  success, frame = v_cap.retrieve()
152
  if not success:
153
  continue
 
157
  v_cap.release()
158
  return frames
159
 
160
+
161
  def df_video_pred(input_video):
162
  video_frames = preprocess_video(input_video)
163
  real_faces_list = []
164
  fake_faces_list = []
165
+
166
  for face in video_frames:
167
  img_grads = img_model.forward(face)
168
  img_grads = img_grads.cpu().detach().numpy()
169
  img_grads_np = np.squeeze(img_grads)
170
  real_faces_list.append(img_grads_np[0])
171
  fake_faces_list.append(img_grads_np[1])
172
+
173
  real_faces_mean = np.mean(real_faces_list)
174
  fake_faces_mean = np.mean(fake_faces_list)
175
+
176
  if real_faces_mean > 0.5:
177
  preds = round(real_faces_mean * 100, 3)
178
  text2 = f"The video is REAL. \nConfidence score is: {preds}%"
179
+
180
  else:
181
  preds = round(fake_faces_mean * 100, 3)
182
  text2 = f"The video is FAKE. \nConfidence score is: {preds}%"
183
+
184
  return text2