Sifal commited on
Commit
e1828d2
·
verified ·
1 Parent(s): 7c3ed50

Update utils/inference_utils.py

Browse files
Files changed (1) hide show
  1. utils/inference_utils.py +194 -193
utils/inference_utils.py CHANGED
@@ -1,193 +1,194 @@
1
- import torch
2
- import matplotlib.pyplot as plt
3
- from torchvision import transforms
4
- from PIL import Image
5
- import os
6
- import random
7
- from utils.data import CLASS_NAMES
8
-
9
- # Function to find correctly and incorrectly classified images
10
- def find_images(dataloader, model, device, num_correct, num_incorrect):
11
- correct_images = []
12
- incorrect_images = []
13
- correct_labels = []
14
- incorrect_labels = []
15
- correct_preds = []
16
- incorrect_preds = []
17
-
18
- model.eval()
19
- with torch.no_grad():
20
- for images, labels in dataloader:
21
- images, labels = images.to(device), labels.to(device)
22
- outputs = model(images)
23
- _, preds = torch.max(outputs, 1)
24
-
25
- for i in range(images.size(0)):
26
- if preds[i] == labels[i] and len(correct_images) < num_correct:
27
- correct_images.append(images[i].cpu())
28
- correct_labels.append(labels[i].cpu())
29
- correct_preds.append(preds[i].cpu())
30
- elif preds[i] != labels[i] and len(incorrect_images) < num_incorrect:
31
- incorrect_images.append(images[i].cpu())
32
- incorrect_labels.append(labels[i].cpu())
33
- incorrect_preds.append(preds[i].cpu())
34
-
35
- if (
36
- len(correct_images) >= num_correct
37
- and len(incorrect_images) >= num_incorrect
38
- ):
39
- break
40
- if (
41
- len(correct_images) >= num_correct
42
- and len(incorrect_images) >= num_incorrect
43
- ):
44
- break
45
-
46
- return (
47
- correct_images,
48
- correct_labels,
49
- correct_preds,
50
- incorrect_images,
51
- incorrect_labels,
52
- incorrect_preds,
53
- )
54
-
55
- def find_images_from_path(data_path, model, device, num_correct=2, num_incorrect=2, label=None):
56
- correct_images_paths = []
57
- incorrect_images_paths = []
58
- correct_labels = []
59
- incorrect_labels = []
60
-
61
- label_to_idx = {label: idx for idx, label in enumerate(CLASS_NAMES)}
62
-
63
- model.eval()
64
- # First collect available images for the specified label or all labels
65
- label_images = {}
66
- if label:
67
- if os.path.isdir(os.path.join(data_path, label)):
68
- label_path = os.path.join(data_path, label)
69
- label_images[label] = [os.path.join(label_path, img) for img in os.listdir(label_path)]
70
- else:
71
- for label in os.listdir(data_path):
72
- label_path = os.path.join(data_path, label)
73
- if not os.path.isdir(label_path):
74
- continue
75
- label_images[label] = [os.path.join(label_path, img) for img in os.listdir(label_path)]
76
-
77
- # Randomly process images until we have enough samples
78
- with torch.no_grad():
79
- while len(correct_images_paths) < num_correct or len(incorrect_images_paths) < num_incorrect:
80
- # Randomly select a label that still has unprocessed images
81
- available_labels = [l for l in label_images if label_images[l]]
82
- if not available_labels:
83
- break
84
-
85
- selected_label = random.choice(available_labels)
86
- image_path = random.choice(label_images[selected_label])
87
- label_images[selected_label].remove(image_path) # Remove the selected image
88
-
89
- image = preprocess_image(image_path, (224, 224)).to(device)
90
- label_idx = label_to_idx[selected_label]
91
-
92
- outputs = model(image)
93
- _, pred = torch.max(outputs, 1)
94
-
95
- if pred == label_idx and len(correct_images_paths) < num_correct:
96
- correct_images_paths.append(image_path)
97
- correct_labels.append(label_idx)
98
- elif pred != label_idx and len(incorrect_images_paths) < num_incorrect:
99
- incorrect_images_paths.append(image_path)
100
- incorrect_labels.append(label_idx)
101
-
102
- save_images_by_class(correct_images_paths, correct_labels, incorrect_images_paths, incorrect_labels)
103
-
104
- def save_images_by_class(correct_images_paths, correct_labels, incorrect_images_paths, incorrect_labels):
105
- # Create root directories for correct and incorrect classifications
106
- for class_name in CLASS_NAMES:
107
- os.makedirs(os.path.join('predictions', class_name, 'correct'), exist_ok=True)
108
- os.makedirs(os.path.join('predictions', class_name, 'mistake'), exist_ok=True)
109
-
110
- # Save correctly classified images
111
- for img_path, label in zip(correct_images_paths, correct_labels):
112
- class_name = CLASS_NAMES[label]
113
- img_name = os.path.basename(img_path)
114
- destination = os.path.join('predictions', class_name, 'correct', img_name)
115
- os.makedirs(os.path.dirname(destination), exist_ok=True)
116
- Image.open(img_path).save(destination)
117
-
118
- # Save incorrectly classified images
119
- for img_path, label in zip(incorrect_images_paths, incorrect_labels):
120
- class_name = CLASS_NAMES[label]
121
- img_name = os.path.basename(img_path)
122
- destination = os.path.join('predictions', class_name, 'mistake', img_name)
123
- os.makedirs(os.path.dirname(destination), exist_ok=True)
124
- Image.open(img_path).save(destination)
125
-
126
- def show_samples(dataloader, model, device, num_correct=3, num_incorrect=3):
127
- # Get some correctly and incorrectly classified images
128
- (
129
- correct_images,
130
- correct_labels,
131
- correct_preds,
132
- incorrect_images,
133
- incorrect_labels,
134
- incorrect_preds,
135
- ) = find_images(dataloader, model, device, num_correct, num_incorrect)
136
- # Display the results in a grid
137
- fig, axes = plt.subplots(
138
- num_correct + num_incorrect, 1, figsize=(10, (num_correct + num_incorrect) * 5)
139
- )
140
-
141
- for i in range(num_correct):
142
- axes[i].imshow(correct_images[i].permute(1, 2, 0))
143
- axes[i].set_title(
144
- f"Correctly Classified: True Label = {correct_labels[i]}, Predicted = {correct_preds[i]}"
145
- )
146
- axes[i].axis("off")
147
-
148
- for i in range(num_incorrect):
149
- axes[num_correct + i].imshow(incorrect_images[i].permute(1, 2, 0))
150
- axes[num_correct + i].set_title(
151
- f"Incorrectly Classified: True Label = {incorrect_labels[i]}, Predicted = {incorrect_preds[i]}"
152
- )
153
- axes[num_correct + i].axis("off")
154
-
155
- plt.tight_layout()
156
- plt.show()
157
-
158
-
159
- # Function to preprocess image
160
- def preprocess_image(image_path, img_shape):
161
-
162
- # Load the image using PIL
163
- image = Image.open(image_path)
164
-
165
- # Apply preprocessing transformations
166
- preprocess = transforms.Compose([
167
- transforms.Resize(img_shape),
168
- transforms.ToTensor(),
169
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
170
- ])
171
- image = preprocess(image).unsqueeze(0)
172
-
173
- return image
174
-
175
-
176
- # Function to predict
177
- def predict(model, image):
178
- model.eval()
179
- with torch.no_grad():
180
- outputs = model(image)
181
- return outputs
182
-
183
-
184
- # Function to get model predictions for LIME
185
- def batch_predict(model, images, device):
186
- model.eval()
187
- batch = torch.stack(
188
- tuple(preprocess_image(image, (224, 224)) for image in images), dim=0
189
- )
190
- batch = batch.to(device)
191
- logits = model(batch)
192
- probs = torch.nn.functional.softmax(logits, dim=1)
193
- return probs.detach().cpu().numpy()
 
 
1
+ import torch
2
+ import matplotlib.pyplot as plt
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import os
6
+ import random
7
+ from utils.data import CLASS_NAMES
8
+
9
+ # Function to find correctly and incorrectly classified images
10
+ def find_images(dataloader, model, device, num_correct, num_incorrect):
11
+ correct_images = []
12
+ incorrect_images = []
13
+ correct_labels = []
14
+ incorrect_labels = []
15
+ correct_preds = []
16
+ incorrect_preds = []
17
+
18
+ model.eval()
19
+ with torch.no_grad():
20
+ for images, labels in dataloader:
21
+ images, labels = images.to(device), labels.to(device)
22
+ outputs = model(images)
23
+ _, preds = torch.max(outputs, 1)
24
+
25
+ for i in range(images.size(0)):
26
+ if preds[i] == labels[i] and len(correct_images) < num_correct:
27
+ correct_images.append(images[i].cpu())
28
+ correct_labels.append(labels[i].cpu())
29
+ correct_preds.append(preds[i].cpu())
30
+ elif preds[i] != labels[i] and len(incorrect_images) < num_incorrect:
31
+ incorrect_images.append(images[i].cpu())
32
+ incorrect_labels.append(labels[i].cpu())
33
+ incorrect_preds.append(preds[i].cpu())
34
+
35
+ if (
36
+ len(correct_images) >= num_correct
37
+ and len(incorrect_images) >= num_incorrect
38
+ ):
39
+ break
40
+ if (
41
+ len(correct_images) >= num_correct
42
+ and len(incorrect_images) >= num_incorrect
43
+ ):
44
+ break
45
+
46
+ return (
47
+ correct_images,
48
+ correct_labels,
49
+ correct_preds,
50
+ incorrect_images,
51
+ incorrect_labels,
52
+ incorrect_preds,
53
+ )
54
+
55
+ def find_images_from_path(data_path, model, device, num_correct=2, num_incorrect=2, label=None):
56
+ correct_images_paths = []
57
+ incorrect_images_paths = []
58
+ correct_labels = []
59
+ incorrect_labels = []
60
+
61
+ label_to_idx = {label: idx for idx, label in enumerate(CLASS_NAMES)}
62
+
63
+ model.eval()
64
+ # First collect available images for the specified label or all labels
65
+ label_images = {}
66
+ if label:
67
+ if os.path.isdir(os.path.join(data_path, label)):
68
+ label_path = os.path.join(data_path, label)
69
+ label_images[label] = [os.path.join(label_path, img) for img in os.listdir(label_path)]
70
+ else:
71
+ for label in os.listdir(data_path):
72
+ label_path = os.path.join(data_path, label)
73
+ if not os.path.isdir(label_path):
74
+ continue
75
+ label_images[label] = [os.path.join(label_path, img) for img in os.listdir(label_path)]
76
+
77
+ # Randomly process images until we have enough samples
78
+ with torch.no_grad():
79
+ while len(correct_images_paths) < num_correct or len(incorrect_images_paths) < num_incorrect:
80
+ # Randomly select a label that still has unprocessed images
81
+ available_labels = [l for l in label_images if label_images[l]]
82
+ if not available_labels:
83
+ break
84
+
85
+ selected_label = random.choice(available_labels)
86
+ image_path = random.choice(label_images[selected_label])
87
+ label_images[selected_label].remove(image_path) # Remove the selected image
88
+
89
+ image = preprocess_image(image_path, (224, 224)).to(device)
90
+ label_idx = label_to_idx[selected_label]
91
+
92
+ outputs = model(image)
93
+ _, pred = torch.max(outputs, 1)
94
+
95
+ if pred == label_idx and len(correct_images_paths) < num_correct:
96
+ correct_images_paths.append(image_path)
97
+ correct_labels.append(label_idx)
98
+ elif pred != label_idx and len(incorrect_images_paths) < num_incorrect:
99
+ incorrect_images_paths.append(image_path)
100
+ incorrect_labels.append(label_idx)
101
+
102
+ save_images_by_class(correct_images_paths, correct_labels, incorrect_images_paths, incorrect_labels)
103
+
104
+ def save_images_by_class(correct_images_paths, correct_labels, incorrect_images_paths, incorrect_labels):
105
+ # Create root directories for correct and incorrect classifications
106
+ for class_name in CLASS_NAMES:
107
+ os.makedirs(os.path.join('predictions', class_name, 'correct'), exist_ok=True)
108
+ os.makedirs(os.path.join('predictions', class_name, 'mistake'), exist_ok=True)
109
+
110
+ # Save correctly classified images
111
+ for img_path, label in zip(correct_images_paths, correct_labels):
112
+ class_name = CLASS_NAMES[label]
113
+ img_name = os.path.basename(img_path)
114
+ destination = os.path.join('predictions', class_name, 'correct', img_name)
115
+ os.makedirs(os.path.dirname(destination), exist_ok=True)
116
+ Image.open(img_path).save(destination)
117
+
118
+ # Save incorrectly classified images
119
+ for img_path, label in zip(incorrect_images_paths, incorrect_labels):
120
+ class_name = CLASS_NAMES[label]
121
+ img_name = os.path.basename(img_path)
122
+ destination = os.path.join('predictions', class_name, 'mistake', img_name)
123
+ os.makedirs(os.path.dirname(destination), exist_ok=True)
124
+ Image.open(img_path).save(destination)
125
+
126
+ def show_samples(dataloader, model, device, num_correct=3, num_incorrect=3):
127
+ # Get some correctly and incorrectly classified images
128
+ (
129
+ correct_images,
130
+ correct_labels,
131
+ correct_preds,
132
+ incorrect_images,
133
+ incorrect_labels,
134
+ incorrect_preds,
135
+ ) = find_images(dataloader, model, device, num_correct, num_incorrect)
136
+ # Display the results in a grid
137
+ fig, axes = plt.subplots(
138
+ num_correct + num_incorrect, 1, figsize=(10, (num_correct + num_incorrect) * 5)
139
+ )
140
+
141
+ for i in range(num_correct):
142
+ axes[i].imshow(correct_images[i].permute(1, 2, 0))
143
+ axes[i].set_title(
144
+ f"Correctly Classified: True Label = {correct_labels[i]}, Predicted = {correct_preds[i]}"
145
+ )
146
+ axes[i].axis("off")
147
+
148
+ for i in range(num_incorrect):
149
+ axes[num_correct + i].imshow(incorrect_images[i].permute(1, 2, 0))
150
+ axes[num_correct + i].set_title(
151
+ f"Incorrectly Classified: True Label = {incorrect_labels[i]}, Predicted = {incorrect_preds[i]}"
152
+ )
153
+ axes[num_correct + i].axis("off")
154
+
155
+ plt.tight_layout()
156
+ plt.show()
157
+
158
+
159
+ # Function to preprocess image
160
+ def preprocess_image(image, img_shape):
161
+
162
+ # If it's a path, load the image using PIL
163
+ if isinstance(image, str):
164
+ image = Image.open(image_path)
165
+
166
+ # Apply preprocessing transformations
167
+ preprocess = transforms.Compose([
168
+ transforms.Resize(img_shape),
169
+ transforms.ToTensor(),
170
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
171
+ ])
172
+ image = preprocess(image).unsqueeze(0)
173
+
174
+ return image
175
+
176
+
177
+ # Function to predict
178
+ def predict(model, image):
179
+ model.eval()
180
+ with torch.no_grad():
181
+ outputs = model(image)
182
+ return outputs
183
+
184
+
185
+ # Function to get model predictions for LIME
186
+ def batch_predict(model, images, device):
187
+ model.eval()
188
+ batch = torch.stack(
189
+ tuple(preprocess_image(image, (224, 224)) for image in images), dim=0
190
+ )
191
+ batch = batch.to(device)
192
+ logits = model(batch)
193
+ probs = torch.nn.functional.softmax(logits, dim=1)
194
+ return probs.detach().cpu().numpy()