import gradio as gr import torch import torch.nn as nn from torchvision import transforms from PIL import Image import time from concrete.ml.torch.compile import compile_torch_model from custom_resnet import resnet18_custom # Assuming custom_resnet.py is in the same directory # Load class names (FLIPPED as ['Fake', 'Real']) class_names = ['Fake', 'Real'] # Fix the incorrect mapping # Load the trained model def load_model(model_path, device): model = resnet18_custom(weights=None) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, len(class_names)) # Assuming 2 classes: Fake and Real model.load_state_dict(torch.load(model_path, map_location=device)) model = model.to(device) model.eval() # Set model to evaluation mode return model def load_secure_model(model): print("Compiling secure model...") secure_model = compile_torch_model(model.to("cpu"), n_bits={"model_inputs": 4, "op_inputs": 3, "op_weights": 3, "model_outputs": 5}, rounding_threshold_bits={"n_bits": 7}, torch_inputset=torch.rand(10, 3, 224, 224)) return secure_model # Image preprocessing (match with the transforms used during training) data_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) # Prediction function def predict(image, mode): # Device configuration device = torch.device( "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) print(f"Device: {device}") # Load model model_path = 'models/deepfake_detection_model.pth' model = load_model(model_path, device) # Apply transformations to the input image image = Image.open(image).convert('RGB') image = data_transform(image).unsqueeze(0).to(device) # Add batch dimension # Inference with torch.no_grad(): start_time = time.time() if mode == "Fast": # Fast mode (less computation) outputs = model(image) elif mode == "Secure": # Secure mode (e.g., running multiple times for higher confidence) secure_model = load_secure_model(model) detached_input = image.detach().numpy() outputs = secure_model(detached_input, fhe="simulate") print(outputs) _, preds = torch.max(outputs, 1) elapsed_time = time.time() - start_time predicted_class = class_names[preds[0]] return f"Predicted: {predicted_class}", f"Time taken: {elapsed_time:.2f} seconds" # Gradio interface iface = gr.Interface( fn=predict, inputs=[ gr.Image(type="filepath", label="Upload an Image"), # Update to gr.Image gr.Radio(choices=["Fast", "Secure"], label="Inference Mode", value="Fast") # Update to gr.Radio ], outputs=[ gr.Textbox(label="Prediction"), # Update to gr.Textbox gr.Textbox(label="Time Taken") # Update to gr.Textbox ], title="Deepfake Detection Model", description="Upload an image and select the inference mode (Fast or Secure)." ) if __name__ == "__main__": iface.launch(share=True)