Team3.1 / app.py
Mojmir's picture
i
924d062
raw
history blame
No virus
3.26 kB
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import time
from concrete.ml.torch.compile import compile_torch_model
from custom_resnet import resnet18_custom # Assuming custom_resnet.py is in the same directory
# Load class names (FLIPPED as ['Fake', 'Real'])
class_names = ['Fake', 'Real'] # Fix the incorrect mapping
# Load the trained model
def load_model(model_path, device):
model = resnet18_custom(weights=None)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names)) # Assuming 2 classes: Fake and Real
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval() # Set model to evaluation mode
return model
def load_secure_model(model):
print("Compiling secure model...")
secure_model = compile_torch_model(model.to("cpu"),
n_bits={"model_inputs": 4, "op_inputs": 3, "op_weights": 3, "model_outputs": 5},
rounding_threshold_bits={"n_bits": 7},
torch_inputset=torch.rand(10, 3, 224, 224))
return secure_model
# Image preprocessing (match with the transforms used during training)
data_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Prediction function
def predict(image, mode):
# Device configuration
device = torch.device(
"cuda:0" if torch.cuda.is_available() else
"mps" if torch.backends.mps.is_available() else
"cpu"
)
print(f"Device: {device}")
# Load model
model_path = 'models/deepfake_detection_model.pth'
model = load_model(model_path, device)
# Apply transformations to the input image
image = Image.open(image).convert('RGB')
image = data_transform(image).unsqueeze(0).to(device) # Add batch dimension
# Inference
with torch.no_grad():
start_time = time.time()
if mode == "Fast":
# Fast mode (less computation)
outputs = model(image)
elif mode == "Secure":
# Secure mode (e.g., running multiple times for higher confidence)
secure_model = load_secure_model(model)
detached_input = image.detach().numpy()
outputs = secure_model(detached_input, fhe="simulate")
print(outputs)
_, preds = torch.max(outputs, 1)
elapsed_time = time.time() - start_time
predicted_class = class_names[preds[0]]
return f"Predicted: {predicted_class}", f"Time taken: {elapsed_time:.2f} seconds"
# Gradio interface
iface = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="filepath", label="Upload an Image"), # Update to gr.Image
gr.Radio(choices=["Fast", "Secure"], label="Inference Mode", value="Fast") # Update to gr.Radio
],
outputs=[
gr.Textbox(label="Prediction"), # Update to gr.Textbox
gr.Textbox(label="Time Taken") # Update to gr.Textbox
],
title="Deepfake Detection Model",
description="Upload an image and select the inference mode (Fast or Secure)."
)
if __name__ == "__main__":
iface.launch(share=True)