Spaces:
Sleeping
Sleeping
Update utils/inference_utils.py
Browse files- utils/inference_utils.py +194 -193
utils/inference_utils.py
CHANGED
@@ -1,193 +1,194 @@
|
|
1 |
-
import torch
|
2 |
-
import matplotlib.pyplot as plt
|
3 |
-
from torchvision import transforms
|
4 |
-
from PIL import Image
|
5 |
-
import os
|
6 |
-
import random
|
7 |
-
from utils.data import CLASS_NAMES
|
8 |
-
|
9 |
-
# Function to find correctly and incorrectly classified images
|
10 |
-
def find_images(dataloader, model, device, num_correct, num_incorrect):
|
11 |
-
correct_images = []
|
12 |
-
incorrect_images = []
|
13 |
-
correct_labels = []
|
14 |
-
incorrect_labels = []
|
15 |
-
correct_preds = []
|
16 |
-
incorrect_preds = []
|
17 |
-
|
18 |
-
model.eval()
|
19 |
-
with torch.no_grad():
|
20 |
-
for images, labels in dataloader:
|
21 |
-
images, labels = images.to(device), labels.to(device)
|
22 |
-
outputs = model(images)
|
23 |
-
_, preds = torch.max(outputs, 1)
|
24 |
-
|
25 |
-
for i in range(images.size(0)):
|
26 |
-
if preds[i] == labels[i] and len(correct_images) < num_correct:
|
27 |
-
correct_images.append(images[i].cpu())
|
28 |
-
correct_labels.append(labels[i].cpu())
|
29 |
-
correct_preds.append(preds[i].cpu())
|
30 |
-
elif preds[i] != labels[i] and len(incorrect_images) < num_incorrect:
|
31 |
-
incorrect_images.append(images[i].cpu())
|
32 |
-
incorrect_labels.append(labels[i].cpu())
|
33 |
-
incorrect_preds.append(preds[i].cpu())
|
34 |
-
|
35 |
-
if (
|
36 |
-
len(correct_images) >= num_correct
|
37 |
-
and len(incorrect_images) >= num_incorrect
|
38 |
-
):
|
39 |
-
break
|
40 |
-
if (
|
41 |
-
len(correct_images) >= num_correct
|
42 |
-
and len(incorrect_images) >= num_incorrect
|
43 |
-
):
|
44 |
-
break
|
45 |
-
|
46 |
-
return (
|
47 |
-
correct_images,
|
48 |
-
correct_labels,
|
49 |
-
correct_preds,
|
50 |
-
incorrect_images,
|
51 |
-
incorrect_labels,
|
52 |
-
incorrect_preds,
|
53 |
-
)
|
54 |
-
|
55 |
-
def find_images_from_path(data_path, model, device, num_correct=2, num_incorrect=2, label=None):
|
56 |
-
correct_images_paths = []
|
57 |
-
incorrect_images_paths = []
|
58 |
-
correct_labels = []
|
59 |
-
incorrect_labels = []
|
60 |
-
|
61 |
-
label_to_idx = {label: idx for idx, label in enumerate(CLASS_NAMES)}
|
62 |
-
|
63 |
-
model.eval()
|
64 |
-
# First collect available images for the specified label or all labels
|
65 |
-
label_images = {}
|
66 |
-
if label:
|
67 |
-
if os.path.isdir(os.path.join(data_path, label)):
|
68 |
-
label_path = os.path.join(data_path, label)
|
69 |
-
label_images[label] = [os.path.join(label_path, img) for img in os.listdir(label_path)]
|
70 |
-
else:
|
71 |
-
for label in os.listdir(data_path):
|
72 |
-
label_path = os.path.join(data_path, label)
|
73 |
-
if not os.path.isdir(label_path):
|
74 |
-
continue
|
75 |
-
label_images[label] = [os.path.join(label_path, img) for img in os.listdir(label_path)]
|
76 |
-
|
77 |
-
# Randomly process images until we have enough samples
|
78 |
-
with torch.no_grad():
|
79 |
-
while len(correct_images_paths) < num_correct or len(incorrect_images_paths) < num_incorrect:
|
80 |
-
# Randomly select a label that still has unprocessed images
|
81 |
-
available_labels = [l for l in label_images if label_images[l]]
|
82 |
-
if not available_labels:
|
83 |
-
break
|
84 |
-
|
85 |
-
selected_label = random.choice(available_labels)
|
86 |
-
image_path = random.choice(label_images[selected_label])
|
87 |
-
label_images[selected_label].remove(image_path) # Remove the selected image
|
88 |
-
|
89 |
-
image = preprocess_image(image_path, (224, 224)).to(device)
|
90 |
-
label_idx = label_to_idx[selected_label]
|
91 |
-
|
92 |
-
outputs = model(image)
|
93 |
-
_, pred = torch.max(outputs, 1)
|
94 |
-
|
95 |
-
if pred == label_idx and len(correct_images_paths) < num_correct:
|
96 |
-
correct_images_paths.append(image_path)
|
97 |
-
correct_labels.append(label_idx)
|
98 |
-
elif pred != label_idx and len(incorrect_images_paths) < num_incorrect:
|
99 |
-
incorrect_images_paths.append(image_path)
|
100 |
-
incorrect_labels.append(label_idx)
|
101 |
-
|
102 |
-
save_images_by_class(correct_images_paths, correct_labels, incorrect_images_paths, incorrect_labels)
|
103 |
-
|
104 |
-
def save_images_by_class(correct_images_paths, correct_labels, incorrect_images_paths, incorrect_labels):
|
105 |
-
# Create root directories for correct and incorrect classifications
|
106 |
-
for class_name in CLASS_NAMES:
|
107 |
-
os.makedirs(os.path.join('predictions', class_name, 'correct'), exist_ok=True)
|
108 |
-
os.makedirs(os.path.join('predictions', class_name, 'mistake'), exist_ok=True)
|
109 |
-
|
110 |
-
# Save correctly classified images
|
111 |
-
for img_path, label in zip(correct_images_paths, correct_labels):
|
112 |
-
class_name = CLASS_NAMES[label]
|
113 |
-
img_name = os.path.basename(img_path)
|
114 |
-
destination = os.path.join('predictions', class_name, 'correct', img_name)
|
115 |
-
os.makedirs(os.path.dirname(destination), exist_ok=True)
|
116 |
-
Image.open(img_path).save(destination)
|
117 |
-
|
118 |
-
# Save incorrectly classified images
|
119 |
-
for img_path, label in zip(incorrect_images_paths, incorrect_labels):
|
120 |
-
class_name = CLASS_NAMES[label]
|
121 |
-
img_name = os.path.basename(img_path)
|
122 |
-
destination = os.path.join('predictions', class_name, 'mistake', img_name)
|
123 |
-
os.makedirs(os.path.dirname(destination), exist_ok=True)
|
124 |
-
Image.open(img_path).save(destination)
|
125 |
-
|
126 |
-
def show_samples(dataloader, model, device, num_correct=3, num_incorrect=3):
|
127 |
-
# Get some correctly and incorrectly classified images
|
128 |
-
(
|
129 |
-
correct_images,
|
130 |
-
correct_labels,
|
131 |
-
correct_preds,
|
132 |
-
incorrect_images,
|
133 |
-
incorrect_labels,
|
134 |
-
incorrect_preds,
|
135 |
-
) = find_images(dataloader, model, device, num_correct, num_incorrect)
|
136 |
-
# Display the results in a grid
|
137 |
-
fig, axes = plt.subplots(
|
138 |
-
num_correct + num_incorrect, 1, figsize=(10, (num_correct + num_incorrect) * 5)
|
139 |
-
)
|
140 |
-
|
141 |
-
for i in range(num_correct):
|
142 |
-
axes[i].imshow(correct_images[i].permute(1, 2, 0))
|
143 |
-
axes[i].set_title(
|
144 |
-
f"Correctly Classified: True Label = {correct_labels[i]}, Predicted = {correct_preds[i]}"
|
145 |
-
)
|
146 |
-
axes[i].axis("off")
|
147 |
-
|
148 |
-
for i in range(num_incorrect):
|
149 |
-
axes[num_correct + i].imshow(incorrect_images[i].permute(1, 2, 0))
|
150 |
-
axes[num_correct + i].set_title(
|
151 |
-
f"Incorrectly Classified: True Label = {incorrect_labels[i]}, Predicted = {incorrect_preds[i]}"
|
152 |
-
)
|
153 |
-
axes[num_correct + i].axis("off")
|
154 |
-
|
155 |
-
plt.tight_layout()
|
156 |
-
plt.show()
|
157 |
-
|
158 |
-
|
159 |
-
# Function to preprocess image
|
160 |
-
def preprocess_image(
|
161 |
-
|
162 |
-
#
|
163 |
-
image
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
transforms.
|
169 |
-
transforms.
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
|
1 |
+
import torch
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
from torchvision import transforms
|
4 |
+
from PIL import Image
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
from utils.data import CLASS_NAMES
|
8 |
+
|
9 |
+
# Function to find correctly and incorrectly classified images
|
10 |
+
def find_images(dataloader, model, device, num_correct, num_incorrect):
|
11 |
+
correct_images = []
|
12 |
+
incorrect_images = []
|
13 |
+
correct_labels = []
|
14 |
+
incorrect_labels = []
|
15 |
+
correct_preds = []
|
16 |
+
incorrect_preds = []
|
17 |
+
|
18 |
+
model.eval()
|
19 |
+
with torch.no_grad():
|
20 |
+
for images, labels in dataloader:
|
21 |
+
images, labels = images.to(device), labels.to(device)
|
22 |
+
outputs = model(images)
|
23 |
+
_, preds = torch.max(outputs, 1)
|
24 |
+
|
25 |
+
for i in range(images.size(0)):
|
26 |
+
if preds[i] == labels[i] and len(correct_images) < num_correct:
|
27 |
+
correct_images.append(images[i].cpu())
|
28 |
+
correct_labels.append(labels[i].cpu())
|
29 |
+
correct_preds.append(preds[i].cpu())
|
30 |
+
elif preds[i] != labels[i] and len(incorrect_images) < num_incorrect:
|
31 |
+
incorrect_images.append(images[i].cpu())
|
32 |
+
incorrect_labels.append(labels[i].cpu())
|
33 |
+
incorrect_preds.append(preds[i].cpu())
|
34 |
+
|
35 |
+
if (
|
36 |
+
len(correct_images) >= num_correct
|
37 |
+
and len(incorrect_images) >= num_incorrect
|
38 |
+
):
|
39 |
+
break
|
40 |
+
if (
|
41 |
+
len(correct_images) >= num_correct
|
42 |
+
and len(incorrect_images) >= num_incorrect
|
43 |
+
):
|
44 |
+
break
|
45 |
+
|
46 |
+
return (
|
47 |
+
correct_images,
|
48 |
+
correct_labels,
|
49 |
+
correct_preds,
|
50 |
+
incorrect_images,
|
51 |
+
incorrect_labels,
|
52 |
+
incorrect_preds,
|
53 |
+
)
|
54 |
+
|
55 |
+
def find_images_from_path(data_path, model, device, num_correct=2, num_incorrect=2, label=None):
|
56 |
+
correct_images_paths = []
|
57 |
+
incorrect_images_paths = []
|
58 |
+
correct_labels = []
|
59 |
+
incorrect_labels = []
|
60 |
+
|
61 |
+
label_to_idx = {label: idx for idx, label in enumerate(CLASS_NAMES)}
|
62 |
+
|
63 |
+
model.eval()
|
64 |
+
# First collect available images for the specified label or all labels
|
65 |
+
label_images = {}
|
66 |
+
if label:
|
67 |
+
if os.path.isdir(os.path.join(data_path, label)):
|
68 |
+
label_path = os.path.join(data_path, label)
|
69 |
+
label_images[label] = [os.path.join(label_path, img) for img in os.listdir(label_path)]
|
70 |
+
else:
|
71 |
+
for label in os.listdir(data_path):
|
72 |
+
label_path = os.path.join(data_path, label)
|
73 |
+
if not os.path.isdir(label_path):
|
74 |
+
continue
|
75 |
+
label_images[label] = [os.path.join(label_path, img) for img in os.listdir(label_path)]
|
76 |
+
|
77 |
+
# Randomly process images until we have enough samples
|
78 |
+
with torch.no_grad():
|
79 |
+
while len(correct_images_paths) < num_correct or len(incorrect_images_paths) < num_incorrect:
|
80 |
+
# Randomly select a label that still has unprocessed images
|
81 |
+
available_labels = [l for l in label_images if label_images[l]]
|
82 |
+
if not available_labels:
|
83 |
+
break
|
84 |
+
|
85 |
+
selected_label = random.choice(available_labels)
|
86 |
+
image_path = random.choice(label_images[selected_label])
|
87 |
+
label_images[selected_label].remove(image_path) # Remove the selected image
|
88 |
+
|
89 |
+
image = preprocess_image(image_path, (224, 224)).to(device)
|
90 |
+
label_idx = label_to_idx[selected_label]
|
91 |
+
|
92 |
+
outputs = model(image)
|
93 |
+
_, pred = torch.max(outputs, 1)
|
94 |
+
|
95 |
+
if pred == label_idx and len(correct_images_paths) < num_correct:
|
96 |
+
correct_images_paths.append(image_path)
|
97 |
+
correct_labels.append(label_idx)
|
98 |
+
elif pred != label_idx and len(incorrect_images_paths) < num_incorrect:
|
99 |
+
incorrect_images_paths.append(image_path)
|
100 |
+
incorrect_labels.append(label_idx)
|
101 |
+
|
102 |
+
save_images_by_class(correct_images_paths, correct_labels, incorrect_images_paths, incorrect_labels)
|
103 |
+
|
104 |
+
def save_images_by_class(correct_images_paths, correct_labels, incorrect_images_paths, incorrect_labels):
|
105 |
+
# Create root directories for correct and incorrect classifications
|
106 |
+
for class_name in CLASS_NAMES:
|
107 |
+
os.makedirs(os.path.join('predictions', class_name, 'correct'), exist_ok=True)
|
108 |
+
os.makedirs(os.path.join('predictions', class_name, 'mistake'), exist_ok=True)
|
109 |
+
|
110 |
+
# Save correctly classified images
|
111 |
+
for img_path, label in zip(correct_images_paths, correct_labels):
|
112 |
+
class_name = CLASS_NAMES[label]
|
113 |
+
img_name = os.path.basename(img_path)
|
114 |
+
destination = os.path.join('predictions', class_name, 'correct', img_name)
|
115 |
+
os.makedirs(os.path.dirname(destination), exist_ok=True)
|
116 |
+
Image.open(img_path).save(destination)
|
117 |
+
|
118 |
+
# Save incorrectly classified images
|
119 |
+
for img_path, label in zip(incorrect_images_paths, incorrect_labels):
|
120 |
+
class_name = CLASS_NAMES[label]
|
121 |
+
img_name = os.path.basename(img_path)
|
122 |
+
destination = os.path.join('predictions', class_name, 'mistake', img_name)
|
123 |
+
os.makedirs(os.path.dirname(destination), exist_ok=True)
|
124 |
+
Image.open(img_path).save(destination)
|
125 |
+
|
126 |
+
def show_samples(dataloader, model, device, num_correct=3, num_incorrect=3):
|
127 |
+
# Get some correctly and incorrectly classified images
|
128 |
+
(
|
129 |
+
correct_images,
|
130 |
+
correct_labels,
|
131 |
+
correct_preds,
|
132 |
+
incorrect_images,
|
133 |
+
incorrect_labels,
|
134 |
+
incorrect_preds,
|
135 |
+
) = find_images(dataloader, model, device, num_correct, num_incorrect)
|
136 |
+
# Display the results in a grid
|
137 |
+
fig, axes = plt.subplots(
|
138 |
+
num_correct + num_incorrect, 1, figsize=(10, (num_correct + num_incorrect) * 5)
|
139 |
+
)
|
140 |
+
|
141 |
+
for i in range(num_correct):
|
142 |
+
axes[i].imshow(correct_images[i].permute(1, 2, 0))
|
143 |
+
axes[i].set_title(
|
144 |
+
f"Correctly Classified: True Label = {correct_labels[i]}, Predicted = {correct_preds[i]}"
|
145 |
+
)
|
146 |
+
axes[i].axis("off")
|
147 |
+
|
148 |
+
for i in range(num_incorrect):
|
149 |
+
axes[num_correct + i].imshow(incorrect_images[i].permute(1, 2, 0))
|
150 |
+
axes[num_correct + i].set_title(
|
151 |
+
f"Incorrectly Classified: True Label = {incorrect_labels[i]}, Predicted = {incorrect_preds[i]}"
|
152 |
+
)
|
153 |
+
axes[num_correct + i].axis("off")
|
154 |
+
|
155 |
+
plt.tight_layout()
|
156 |
+
plt.show()
|
157 |
+
|
158 |
+
|
159 |
+
# Function to preprocess image
|
160 |
+
def preprocess_image(image, img_shape):
|
161 |
+
|
162 |
+
# If it's a path, load the image using PIL
|
163 |
+
if isinstance(image, str):
|
164 |
+
image = Image.open(image_path)
|
165 |
+
|
166 |
+
# Apply preprocessing transformations
|
167 |
+
preprocess = transforms.Compose([
|
168 |
+
transforms.Resize(img_shape),
|
169 |
+
transforms.ToTensor(),
|
170 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
171 |
+
])
|
172 |
+
image = preprocess(image).unsqueeze(0)
|
173 |
+
|
174 |
+
return image
|
175 |
+
|
176 |
+
|
177 |
+
# Function to predict
|
178 |
+
def predict(model, image):
|
179 |
+
model.eval()
|
180 |
+
with torch.no_grad():
|
181 |
+
outputs = model(image)
|
182 |
+
return outputs
|
183 |
+
|
184 |
+
|
185 |
+
# Function to get model predictions for LIME
|
186 |
+
def batch_predict(model, images, device):
|
187 |
+
model.eval()
|
188 |
+
batch = torch.stack(
|
189 |
+
tuple(preprocess_image(image, (224, 224)) for image in images), dim=0
|
190 |
+
)
|
191 |
+
batch = batch.to(device)
|
192 |
+
logits = model(batch)
|
193 |
+
probs = torch.nn.functional.softmax(logits, dim=1)
|
194 |
+
return probs.detach().cpu().numpy()
|