spuuntries commited on
Commit
c9e9eb6
1 Parent(s): 20cf889

feat: add new model

Browse files
Files changed (3) hide show
  1. 3q7y4e.safetensors +3 -0
  2. app.py +37 -6
  3. models.py +184 -24
3q7y4e.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1646a218094821c8c0ca6df5c7f236bceb1aec6f4085d0a42f920bec6d53bb57
3
+ size 352409020
app.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  from PIL import Image
4
  import torchvision.transforms as transforms
5
  import numpy as np
 
6
  from safetensors.torch import load_model, save_model
7
  from models import *
8
  import os
@@ -33,13 +34,24 @@ class WasteClassifier:
33
  img_tensor = self.transform(image).unsqueeze(0).to(self.device)
34
 
35
  with torch.no_grad():
36
- outputs = self.model(img_tensor)
37
  probabilities = torch.nn.functional.softmax(outputs, dim=1)
38
 
39
  probs = probabilities[0].cpu().numpy()
40
  pred_class = self.class_names[np.argmax(probs)]
41
  confidence = np.max(probs)
42
 
 
 
 
 
 
 
 
 
 
 
 
43
  results = {
44
  "predicted_class": pred_class,
45
  "confidence": confidence,
@@ -47,6 +59,7 @@ class WasteClassifier:
47
  class_name: float(prob)
48
  for class_name, prob in zip(self.class_names, probs)
49
  },
 
50
  }
51
 
52
  return results
@@ -56,6 +69,16 @@ def interface(classifier):
56
  def process_image(image):
57
  results = classifier.predict(image)
58
 
 
 
 
 
 
 
 
 
 
 
59
  output_str = f"Predicted Class: {results['predicted_class']}\n"
60
  output_str += f"Confidence: {results['confidence']*100:.2f}%\n\n"
61
  output_str += "Class Probabilities:\n"
@@ -67,16 +90,23 @@ def interface(classifier):
67
  for class_name, prob in sorted_probs:
68
  output_str += f"{class_name}: {prob*100:.2f}%\n"
69
 
70
- return output_str
 
 
71
 
72
  demo = gr.Interface(
73
  fn=process_image,
74
  inputs=[gr.Image(type="pil", label="Upload Image")],
75
- outputs=[gr.Textbox(label="Classification Results")],
 
 
 
 
76
  title="Waste Classification System",
77
  description="""
78
  Upload an image of waste to classify it into different categories.
79
- The model will predict the type of waste and show confidence scores for each category.
 
80
  """,
81
  examples=(
82
  [["example1.jpg"], ["example2.jpg"], ["example3.jpg"]]
@@ -102,11 +132,12 @@ class_names = [
102
  "Textile Trash",
103
  "Vegetation",
104
  ]
105
- best_model = ResNet50(num_classes=len(class_names))
 
106
  best_model = best_model.to(device)
107
  load_model(
108
  best_model,
109
- os.path.join(os.path.dirname(os.path.abspath(__file__)), "bjf8fp.safetensors"),
110
  )
111
 
112
  classifier = WasteClassifier(best_model, class_names, device)
 
3
  from PIL import Image
4
  import torchvision.transforms as transforms
5
  import numpy as np
6
+ import torch.nn.functional as F
7
  from safetensors.torch import load_model, save_model
8
  from models import *
9
  import os
 
34
  img_tensor = self.transform(image).unsqueeze(0).to(self.device)
35
 
36
  with torch.no_grad():
37
+ outputs, seg_mask = self.model(img_tensor) # Handle both outputs
38
  probabilities = torch.nn.functional.softmax(outputs, dim=1)
39
 
40
  probs = probabilities[0].cpu().numpy()
41
  pred_class = self.class_names[np.argmax(probs)]
42
  confidence = np.max(probs)
43
 
44
+ # Process segmentation mask
45
+ seg_mask = (
46
+ seg_mask[0, 0].cpu().numpy().astype(np.float32)
47
+ ) # Get first image, first channel
48
+ # seg_mask = (seg_mask >= 0.2).astype(np.float32) # Threshold at 0.2
49
+
50
+ # Resize mask back to original image size
51
+ seg_mask = Image.fromarray(seg_mask)
52
+ seg_mask = seg_mask.resize(original_size, Image.NEAREST)
53
+ seg_mask = np.array(seg_mask)
54
+
55
  results = {
56
  "predicted_class": pred_class,
57
  "confidence": confidence,
 
59
  class_name: float(prob)
60
  for class_name, prob in zip(self.class_names, probs)
61
  },
62
+ "segmentation_mask": seg_mask,
63
  }
64
 
65
  return results
 
69
  def process_image(image):
70
  results = classifier.predict(image)
71
 
72
+ if isinstance(image, Image.Image):
73
+ image_np = np.array(image)
74
+ else:
75
+ image_np = image
76
+
77
+ mask = results["segmentation_mask"]
78
+
79
+ overlay = image_np.copy()
80
+ overlay[mask < 0.2] = overlay[mask < 0.2] * 0
81
+
82
  output_str = f"Predicted Class: {results['predicted_class']}\n"
83
  output_str += f"Confidence: {results['confidence']*100:.2f}%\n\n"
84
  output_str += "Class Probabilities:\n"
 
90
  for class_name, prob in sorted_probs:
91
  output_str += f"{class_name}: {prob*100:.2f}%\n"
92
 
93
+ mask_viz = (mask * 255).astype(np.uint8)
94
+
95
+ return [output_str, overlay, mask_viz]
96
 
97
  demo = gr.Interface(
98
  fn=process_image,
99
  inputs=[gr.Image(type="pil", label="Upload Image")],
100
+ outputs=[
101
+ gr.Textbox(label="Classification Results"),
102
+ gr.Image(label="Segmented Object"),
103
+ gr.Image(label="Segmentation Mask"),
104
+ ],
105
  title="Waste Classification System",
106
  description="""
107
  Upload an image of waste to classify it into different categories.
108
+ The model will predict the type of waste, show confidence scores for each category,
109
+ and display the segmented object along with its mask.
110
  """,
111
  examples=(
112
  [["example1.jpg"], ["example2.jpg"], ["example3.jpg"]]
 
132
  "Textile Trash",
133
  "Vegetation",
134
  ]
135
+
136
+ best_model = ResNet101UNet(num_classes=len(class_names))
137
  best_model = best_model.to(device)
138
  load_model(
139
  best_model,
140
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), "3q7y4e.safetensors"),
141
  )
142
 
143
  classifier = WasteClassifier(best_model, class_names, device)
models.py CHANGED
@@ -1,5 +1,6 @@
1
  import torch
2
  import torch.nn as nn
 
3
 
4
 
5
  class BasicBlock(nn.Module):
@@ -76,19 +77,20 @@ class Bottleneck(nn.Module):
76
 
77
 
78
  class ResNet(nn.Module):
79
- def __init__(self, block, num_blocks, num_classes=1000, K=10, T=0.5):
80
  super(ResNet, self).__init__()
81
  self.in_planes = 64
82
- self.K = K
83
- self.T = T
84
 
85
  self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
86
  self.bn1 = nn.BatchNorm2d(64)
87
  self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
 
88
  self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
89
  self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
90
  self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
91
  self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
 
 
92
  self.fc = nn.Linear(512 * block.expansion, num_classes)
93
 
94
  def _make_layer(self, block, planes, num_blocks, stride):
@@ -99,43 +101,201 @@ class ResNet(nn.Module):
99
  self.in_planes = planes * block.expansion
100
  return nn.Sequential(*layers)
101
 
102
- def t_max_avg_pooling(self, x):
103
- B, C, H, W = x.shape
104
- x_flat = x.view(B, C, -1)
105
- top_k_values, _ = torch.topk(x_flat, self.K, dim=2)
106
- max_values = top_k_values.max(dim=2)[0]
107
- avg_values = top_k_values.mean(dim=2)
108
- output = torch.where(max_values >= self.T, max_values, avg_values)
109
- return output
110
-
111
  def forward(self, x):
112
  out = torch.relu(self.bn1(self.conv1(x)))
113
  out = self.maxpool(out)
 
114
  out = self.layer1(out)
115
  out = self.layer2(out)
116
  out = self.layer3(out)
117
  out = self.layer4(out)
118
- out = self.t_max_avg_pooling(out)
119
- out = out.view(out.size(0), -1)
 
120
  out = self.fc(out)
121
  return out
122
 
123
 
124
- def ResNet18(num_classes=1000, K=10, T=0.5):
125
- return ResNet(BasicBlock, [2, 2, 2, 2], num_classes, K, T)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
 
128
- def ResNet34(num_classes=1000, K=10, T=0.5):
129
- return ResNet(BasicBlock, [3, 4, 6, 3], num_classes, K, T)
130
 
131
 
132
- def ResNet50(num_classes=1000, K=10, T=0.5):
133
- return ResNet(Bottleneck, [3, 4, 6, 3], num_classes, K, T)
134
 
135
 
136
- def ResNet101(num_classes=1000, K=10, T=0.5):
137
- return ResNet(Bottleneck, [3, 4, 23, 3], num_classes, K, T)
138
 
139
 
140
- def ResNet152(num_classes=1000, K=10, T=0.5):
141
- return ResNet(Bottleneck, [3, 8, 36, 3], num_classes, K, T)
 
1
  import torch
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
4
 
5
 
6
  class BasicBlock(nn.Module):
 
77
 
78
 
79
  class ResNet(nn.Module):
80
+ def __init__(self, block, num_blocks, num_classes=1000):
81
  super(ResNet, self).__init__()
82
  self.in_planes = 64
 
 
83
 
84
  self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
85
  self.bn1 = nn.BatchNorm2d(64)
86
  self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
87
+
88
  self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
89
  self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
90
  self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
91
  self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
92
+
93
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
94
  self.fc = nn.Linear(512 * block.expansion, num_classes)
95
 
96
  def _make_layer(self, block, planes, num_blocks, stride):
 
101
  self.in_planes = planes * block.expansion
102
  return nn.Sequential(*layers)
103
 
 
 
 
 
 
 
 
 
 
104
  def forward(self, x):
105
  out = torch.relu(self.bn1(self.conv1(x)))
106
  out = self.maxpool(out)
107
+
108
  out = self.layer1(out)
109
  out = self.layer2(out)
110
  out = self.layer3(out)
111
  out = self.layer4(out)
112
+
113
+ out = self.avgpool(out)
114
+ out = torch.flatten(out, 1)
115
  out = self.fc(out)
116
  return out
117
 
118
 
119
+ def ResNet18(num_classes=1000):
120
+ return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
121
+
122
+
123
+ def ResNet34(num_classes=1000):
124
+ return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
125
+
126
+
127
+ def ResNet50(num_classes=1000):
128
+ return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)
129
+
130
+
131
+ def ResNet101(num_classes=1000):
132
+ return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)
133
+
134
+
135
+ def ResNet152(num_classes=1000):
136
+ return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)
137
+
138
+
139
+ class ClassifierHead(nn.Module):
140
+ def __init__(self, in_features, num_classes):
141
+ super().__init__()
142
+ self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
143
+ self.max_pool = nn.AdaptiveMaxPool2d((1, 1))
144
+
145
+ self.classifier = nn.Sequential(
146
+ nn.Linear(in_features * 2, 1024),
147
+ nn.BatchNorm1d(1024),
148
+ nn.ReLU(),
149
+ nn.Dropout(0.5),
150
+ nn.Linear(1024, 512),
151
+ nn.BatchNorm1d(512),
152
+ nn.ReLU(),
153
+ nn.Dropout(0.3),
154
+ nn.Linear(512, num_classes),
155
+ )
156
+
157
+ def forward(self, x):
158
+ avg_pooled = self.avg_pool(x).flatten(1)
159
+ max_pooled = self.max_pool(x).flatten(1)
160
+ features = torch.cat([avg_pooled, max_pooled], dim=1)
161
+ return self.classifier(features)
162
+
163
+
164
+ class ResNetUNet(ResNet):
165
+ def __init__(self, block, num_blocks, num_classes=1000):
166
+ super().__init__(block, num_blocks, num_classes)
167
+
168
+ # Calculate encoder channel sizes
169
+ self.enc_channels = [
170
+ 64,
171
+ 64 * block.expansion,
172
+ 128 * block.expansion,
173
+ 256 * block.expansion,
174
+ 512 * block.expansion,
175
+ ]
176
+
177
+ # Replace t_max_avg_pooling with standard avgpool
178
+ in_features = 512 * block.expansion
179
+ self.classifier_head = ClassifierHead(in_features, num_classes)
180
+
181
+ # Decoder layers remain the same
182
+ self.decoder5 = nn.Sequential(
183
+ nn.Conv2d(2048 + 1024, 1024, 3, padding=1),
184
+ nn.BatchNorm2d(1024),
185
+ nn.ReLU(inplace=True),
186
+ nn.Conv2d(1024, 512, 3, padding=1),
187
+ nn.BatchNorm2d(512),
188
+ nn.ReLU(inplace=True),
189
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
190
+ )
191
+
192
+ self.decoder4 = nn.Sequential(
193
+ nn.Conv2d(512 + 512, 512, 3, padding=1),
194
+ nn.BatchNorm2d(512),
195
+ nn.ReLU(inplace=True),
196
+ nn.Conv2d(512, 256, 3, padding=1),
197
+ nn.BatchNorm2d(256),
198
+ nn.ReLU(inplace=True),
199
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
200
+ )
201
+
202
+ self.decoder3 = nn.Sequential(
203
+ nn.Conv2d(256 + 256, 256, 3, padding=1),
204
+ nn.BatchNorm2d(256),
205
+ nn.ReLU(inplace=True),
206
+ nn.Conv2d(256, 128, 3, padding=1),
207
+ nn.BatchNorm2d(128),
208
+ nn.ReLU(inplace=True),
209
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
210
+ )
211
+
212
+ self.decoder2 = nn.Sequential(
213
+ nn.Conv2d(128 + 64, 128, 3, padding=1),
214
+ nn.BatchNorm2d(128),
215
+ nn.ReLU(inplace=True),
216
+ nn.Conv2d(128, 64, 3, padding=1),
217
+ nn.BatchNorm2d(64),
218
+ nn.ReLU(inplace=True),
219
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
220
+ )
221
+
222
+ self.final_conv = nn.Sequential(
223
+ nn.Conv2d(64, 32, 3, padding=1),
224
+ nn.BatchNorm2d(32),
225
+ nn.ReLU(inplace=True),
226
+ nn.Conv2d(32, 1, 1),
227
+ nn.Sigmoid(),
228
+ )
229
+
230
+ def forward(self, x):
231
+ input_size = x.shape[-2:]
232
+
233
+ # Encoder path
234
+ x = torch.relu(self.bn1(self.conv1(x)))
235
+ e1 = self.maxpool(x)
236
+
237
+ e2 = self.layer1(e1)
238
+ e3 = self.layer2(e2)
239
+ e4 = self.layer3(e3)
240
+ e5 = self.layer4(e4)
241
+
242
+ # Get segmentation first
243
+ e4_resized = F.interpolate(
244
+ e4, size=e5.shape[-2:], mode="bilinear", align_corners=True
245
+ )
246
+ d5 = self.decoder5(torch.cat([e5, e4_resized], dim=1))
247
+
248
+ e3_resized = F.interpolate(
249
+ e3, size=d5.shape[-2:], mode="bilinear", align_corners=True
250
+ )
251
+ d4 = self.decoder4(torch.cat([d5, e3_resized], dim=1))
252
+
253
+ e2_resized = F.interpolate(
254
+ e2, size=d4.shape[-2:], mode="bilinear", align_corners=True
255
+ )
256
+ d3 = self.decoder3(torch.cat([d4, e2_resized], dim=1))
257
+
258
+ e1_resized = F.interpolate(
259
+ e1, size=d3.shape[-2:], mode="bilinear", align_corners=True
260
+ )
261
+ d2 = self.decoder2(torch.cat([d3, e1_resized], dim=1))
262
+
263
+ seg_out = self.final_conv(d2)
264
+ seg_out = F.interpolate(
265
+ seg_out, size=input_size, mode="bilinear", align_corners=True
266
+ )
267
+
268
+ # Use segmentation to mask features before classification
269
+ # Upsample segmentation mask to match feature size
270
+ attention_mask = F.interpolate(
271
+ seg_out, size=e5.shape[2:], mode="bilinear", align_corners=True
272
+ )
273
+
274
+ # Apply attention mask to features
275
+ attended_features = e5 * (0.25 + attention_mask)
276
+
277
+ # Use new classifier head
278
+ cls_out = self.classifier_head(attended_features)
279
+
280
+ return cls_out, seg_out
281
+
282
+
283
+ # Helper functions without K and T parameters
284
+ def ResNet18UNet(num_classes=1000):
285
+ return ResNetUNet(BasicBlock, [2, 2, 2, 2], num_classes)
286
 
287
 
288
+ def ResNet34UNet(num_classes=1000):
289
+ return ResNetUNet(BasicBlock, [3, 4, 6, 3], num_classes)
290
 
291
 
292
+ def ResNet50UNet(num_classes=1000):
293
+ return ResNetUNet(Bottleneck, [3, 4, 6, 3], num_classes)
294
 
295
 
296
+ def ResNet101UNet(num_classes=1000):
297
+ return ResNetUNet(Bottleneck, [3, 4, 23, 3], num_classes)
298
 
299
 
300
+ def ResNet152UNet(num_classes=1000):
301
+ return ResNetUNet(Bottleneck, [3, 8, 36, 3], num_classes)