Oussamahajoui commited on
Commit
cc4257f
β€’
1 Parent(s): 10bdb7e
Files changed (2) hide show
  1. Code/app.py +163 -0
  2. sample_submission.csv +0 -0
Code/app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import albumentations as A
4
+ import gradio as gr
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from albumentations.pytorch import ToTensorV2
11
+ from efficientnet_pytorch import EfficientNet
12
+ from PIL import Image
13
+ from sklearn import metrics
14
+ from torch import nn, optim
15
+ from torch.utils.data import DataLoader, Dataset
16
+ from torchvision import models
17
+ from tqdm import tqdm
18
+
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+
22
+ class Dataset(Dataset):
23
+ def __init__(self, root_images, root_file, transform=None):
24
+ self.root_images = root_images
25
+ self.root_file = root_file
26
+ self.transform = transform
27
+ self.file = pd.read_csv(root_file)
28
+
29
+ def __len__(self):
30
+ return self.file.shape[0]
31
+
32
+ def __getitem__(self, index):
33
+ img_path = os.path.join(self.root_images, self.file["id"][index])
34
+ image = np.array(Image.open(img_path).convert("RGB"))
35
+
36
+ if self.transform is not None:
37
+ augmentations = self.transform(image=image)
38
+ image = augmentations["image"]
39
+
40
+ return image
41
+
42
+
43
+ learning_rate = 0.0001
44
+ batch_size = 32
45
+ epochs = 10
46
+ height = 224
47
+ width = 224
48
+ IMG = "AI images or Not/test"
49
+ FILE = "Data/sample_submission.csv"
50
+
51
+
52
+ def get_loader(image, file, batch_size, test_transform):
53
+
54
+ test_ds = Dataset(image, file, test_transform)
55
+ test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
56
+
57
+ return test_loader
58
+
59
+
60
+ normalize = A.Normalize(
61
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.255], max_pixel_value=255.0
62
+ )
63
+
64
+
65
+ test_transform = A.Compose(
66
+ [A.Resize(width=width, height=height), normalize, ToTensorV2()]
67
+ )
68
+
69
+
70
+ class Net(nn.Module):
71
+ def __init__(self):
72
+ super().__init__()
73
+ self.model = EfficientNet.from_pretrained("efficientnet-b4")
74
+ self.fct = nn.Linear(1000, 1)
75
+
76
+ def forward(self, img):
77
+ x = self.model(img)
78
+ # print(x.shape)
79
+ x = self.fct(x)
80
+ return x
81
+
82
+
83
+ def load_checkpoint(checkpoint, model, optimizer):
84
+ print("====> Loading...")
85
+ model.load_state_dict(checkpoint["state_dict"])
86
+ optimizer.load_state_dict(checkpoint["optimizer"])
87
+
88
+
89
+ # test = pd.read_csv(FILE)
90
+ # test
91
+
92
+ model = Net().to(device)
93
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
94
+
95
+ checkpoint_file = "Checkpoint/baseline_V0.pth.tar"
96
+ test_loader = get_loader(IMG, FILE, batch_size, test_transform)
97
+ checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu"))
98
+ load_checkpoint(checkpoint, model, optimizer)
99
+
100
+ model.eval()
101
+
102
+
103
+ # define the predict function
104
+ def predict(image):
105
+ # preprocess the image
106
+ image = np.array(image)
107
+ image = test_transform(image=image)["image"]
108
+ image = image.unsqueeze(0).to(device)
109
+
110
+ # get the model prediction
111
+ with torch.no_grad():
112
+ output = model(image)
113
+ pred = torch.sigmoid(output).cpu().numpy().squeeze()
114
+
115
+ # check if prediction is AI generated, not AI generated, or uncertain
116
+ if pred >= 0.6:
117
+ prediction = "AI generated"
118
+ confidence = pred
119
+ elif pred <= 0.4:
120
+ prediction = "NOT AI generated"
121
+ confidence = 1 - pred
122
+ else:
123
+ prediction = "uncertain"
124
+ confidence = abs(0.5 - pred) * 2
125
+
126
+ # return the prediction and confidence as a string
127
+ return f"This image is {prediction} with {confidence:.2%} confidence."
128
+
129
+
130
+ # define the input interface with examples
131
+ inputs = gr.inputs.Image(shape=(224, 224))
132
+ outputs = gr.outputs.Textbox()
133
+ examples = [
134
+ ["Data/train/3.jpg"],
135
+ ["Data/train/10.jpg"],
136
+ ["Data/train/14.jpg"],
137
+ ["Data/train/4515.jpg"],
138
+ ["Data/train/4518.jpg"],
139
+ ["Data/train/6122.jpg"],
140
+ ["Data/train/6123.jpg"],
141
+ ["Data/train/6124.jpg"],
142
+ ["Data/train/6125.jpg"],
143
+ ["Data/train/7461.jpg"],
144
+ ["Data/train/7462.jpg"],
145
+ ["Data/train/7463.jpg"],
146
+ ["Data/train/7464.jpg"],
147
+ ["Data/train/7465.jpg"],
148
+ ["Data/train/8546.jpg"],
149
+ ["Data/train/8543.jpg"],
150
+ ["Data/train/9120.jpg"],
151
+ ["Data/train/10120.jpg"],
152
+ ]
153
+ iface = gr.Interface(
154
+ fn=predict,
155
+ inputs=inputs,
156
+ outputs=outputs,
157
+ title="AI image detector πŸ”Ž",
158
+ description="Check if an image is AI generated or real.",
159
+ examples=examples,
160
+ )
161
+
162
+ # launch the gradio app
163
+ iface.launch()
sample_submission.csv DELETED
The diff for this file is too large to render. See raw diff