File size: 3,258 Bytes
a52b4d4
924d062
 
 
 
 
 
a52b4d4
924d062
a52b4d4
924d062
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import time
from concrete.ml.torch.compile import compile_torch_model

from custom_resnet import resnet18_custom  # Assuming custom_resnet.py is in the same directory

# Load class names (FLIPPED as ['Fake', 'Real'])
class_names = ['Fake', 'Real']  # Fix the incorrect mapping

# Load the trained model
def load_model(model_path, device):
    model = resnet18_custom(weights=None)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, len(class_names))  # Assuming 2 classes: Fake and Real
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    model.eval()  # Set model to evaluation mode
    return model


def load_secure_model(model):
    print("Compiling secure model...")
    secure_model = compile_torch_model(model.to("cpu"), 
                                     n_bits={"model_inputs": 4, "op_inputs": 3, "op_weights": 3, "model_outputs": 5},
                                     rounding_threshold_bits={"n_bits": 7},
                                     torch_inputset=torch.rand(10, 3, 224, 224))
    return secure_model

# Image preprocessing (match with the transforms used during training)
data_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Prediction function
def predict(image, mode):
    # Device configuration
    device = torch.device(
        "cuda:0" if torch.cuda.is_available() else
        "mps" if torch.backends.mps.is_available() else
        "cpu"
    )

    print(f"Device: {device}")
    # Load model
    model_path = 'models/deepfake_detection_model.pth'
    model = load_model(model_path, device)

    # Apply transformations to the input image
    image = Image.open(image).convert('RGB')
    image = data_transform(image).unsqueeze(0).to(device)  # Add batch dimension

    # Inference
    with torch.no_grad():
        start_time = time.time()
        
        if mode == "Fast":
            # Fast mode (less computation)
            outputs = model(image)
        elif mode == "Secure":
            # Secure mode (e.g., running multiple times for higher confidence)
            secure_model = load_secure_model(model)
            detached_input = image.detach().numpy()
            outputs = secure_model(detached_input, fhe="simulate")
        
        print(outputs)
        _, preds = torch.max(outputs, 1)
        elapsed_time = time.time() - start_time

    predicted_class = class_names[preds[0]]
    return f"Predicted: {predicted_class}", f"Time taken: {elapsed_time:.2f} seconds"

# Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="filepath", label="Upload an Image"),  # Update to gr.Image
        gr.Radio(choices=["Fast", "Secure"], label="Inference Mode", value="Fast")  # Update to gr.Radio
    ],
    outputs=[
        gr.Textbox(label="Prediction"),  # Update to gr.Textbox
        gr.Textbox(label="Time Taken")  # Update to gr.Textbox
    ],
    title="Deepfake Detection Model",
    description="Upload an image and select the inference mode (Fast or Secure)."
)

if __name__ == "__main__":
    iface.launch(share=True)