Spaces:
Runtime error
Runtime error
File size: 6,327 Bytes
4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 4abab34 3da0952 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import os
import cv2
import onnx
import torch
import argparse
import numpy as np
import torch.nn as nn
from models.TMC import ETMC
from models import image
from onnx2pytorch import ConvertModel
onnx_model = onnx.load('checkpoints/efficientnet.onnx')
pytorch_model = ConvertModel(onnx_model)
torch.manual_seed(42)
audio_args = {
'nb_samp': 64600,
'first_conv': 1024,
'in_channels': 1,
'filts': [20, [20, 20], [20, 128], [128, 128]],
'blocks': [2, 4],
'nb_fc_node': 1024,
'gru_node': 1024,
'nb_gru_layer': 3,
'nb_classes': 2
}
def get_args(parser):
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--data_dir", type=str, default="datasets/train/fakeavceleb*")
parser.add_argument("--LOAD_SIZE", type=int, default=256)
parser.add_argument("--FINE_SIZE", type=int, default=224)
parser.add_argument("--dropout", type=float, default=0.2)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--hidden", nargs="*", type=int, default=[])
parser.add_argument("--hidden_sz", type=int, default=768)
parser.add_argument("--img_embed_pool_type", type=str, default="avg", choices=["max", "avg"])
parser.add_argument("--img_hidden_sz", type=int, default=1024)
parser.add_argument("--include_bn", type=int, default=True)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--lr_factor", type=float, default=0.3)
parser.add_argument("--lr_patience", type=int, default=10)
parser.add_argument("--max_epochs", type=int, default=500)
parser.add_argument("--n_workers", type=int, default=12)
parser.add_argument("--name", type=str, default="MMDF")
parser.add_argument("--num_image_embeds", type=int, default=1)
parser.add_argument("--patience", type=int, default=20)
parser.add_argument("--savedir", type=str, default="./savepath/")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--n_classes", type=int, default=2)
parser.add_argument("--annealing_epoch", type=int, default=10)
parser.add_argument("--device", type=str, default='cpu')
parser.add_argument("--pretrained_image_encoder", type=bool, default = False)
parser.add_argument("--freeze_image_encoder", type=bool, default = False)
parser.add_argument("--pretrained_audio_encoder", type = bool, default=False)
parser.add_argument("--freeze_audio_encoder", type = bool, default = False)
parser.add_argument("--augment_dataset", type = bool, default = True)
for key, value in audio_args.items():
parser.add_argument(f"--{key}", type=type(value), default=value)
def load_img_modality_model(args):
rgb_encoder = pytorch_model
ckpt = torch.load('checkpoints/model.pth', map_location = torch.device('cpu'))
rgb_encoder.load_state_dict(ckpt['rgb_encoder'], strict = True)
rgb_encoder.eval()
return rgb_encoder
def load_spec_modality_model(args):
spec_encoder = image.RawNet(args)
ckpt = torch.load('checkpoints/model.pth', map_location = torch.device('cpu'))
spec_encoder.load_state_dict(ckpt['spec_encoder'], strict = True)
spec_encoder.eval()
return spec_encoder
parser = argparse.ArgumentParser(description="Inference models")
get_args(parser)
args, remaining_args = parser.parse_known_args()
assert remaining_args == [], remaining_args
spec_model = load_spec_modality_model(args)
img_model = load_img_modality_model(args)
def preprocess_img(face):
face = face / 255
face = cv2.resize(face, (256, 256))
face_pt = torch.unsqueeze(torch.Tensor(face), dim = 0)
return face_pt
def preprocess_audio(audio_file):
audio_pt = torch.unsqueeze(torch.Tensor(audio_file), dim = 0)
return audio_pt
def df_spec_pred(input_audio):
x, _ = input_audio
audio = preprocess_audio(x)
spec_grads = spec_model.forward(audio)
spec_grads_inv = np.exp(spec_grads.cpu().detach().numpy().squeeze())
max_value = np.argmax(spec_grads_inv)
if max_value > 0.5:
preds = round(100 - (max_value*100), 3)
text2 = f"The audio is REAL."
else:
preds = round(max_value*100, 3)
text2 = f"The audio is FAKE."
return text2
def df_img_pred(input_image):
face = preprocess_img(input_image)
print(f"Face shape is: {face.shape}")
img_grads = img_model.forward(face)
img_grads = img_grads.cpu().detach().numpy()
img_grads_np = np.squeeze(img_grads)
if img_grads_np[0] > 0.5:
preds = round(img_grads_np[0] * 100, 3)
text2 = f"The image is REAL. \nConfidence score is: {preds}"
else:
preds = round(img_grads_np[1] * 100, 3)
text2 = f"The image is FAKE. \nConfidence score is: {preds}"
return text2
def preprocess_video(input_video, n_frames = 3):
v_cap = cv2.VideoCapture(input_video)
v_len = int(v_cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Pick 'n_frames' evenly spaced frames to sample
if n_frames is None:
sample = np.arange(0, v_len)
else:
sample = np.linspace(0, v_len - 1, n_frames).astype(int)
#Loop through frames.
frames = []
for j in range(v_len):
success = v_cap.grab()
if j in sample:
# Load frame
success, frame = v_cap.retrieve()
if not success:
continue
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = preprocess_img(frame)
frames.append(frame)
v_cap.release()
return frames
def df_video_pred(input_video):
video_frames = preprocess_video(input_video)
real_faces_list = []
fake_faces_list = []
for face in video_frames:
img_grads = img_model.forward(face)
img_grads = img_grads.cpu().detach().numpy()
img_grads_np = np.squeeze(img_grads)
real_faces_list.append(img_grads_np[0])
fake_faces_list.append(img_grads_np[1])
real_faces_mean = np.mean(real_faces_list)
fake_faces_mean = np.mean(fake_faces_list)
if real_faces_mean > 0.5:
preds = round(real_faces_mean * 100, 3)
text2 = f"The video is REAL. \nConfidence score is: {preds}%"
else:
preds = round(fake_faces_mean * 100, 3)
text2 = f"The video is FAKE. \nConfidence score is: {preds}%"
return text2 |