mrrandom123 commited on
Commit
99226b3
1 Parent(s): 7113bf9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -1
app.py CHANGED
@@ -1,3 +1,32 @@
1
  import gradio as gr
 
 
 
 
 
2
 
3
- gr.Interface.load("models/mattmdjaga/segformer_b2_clothes").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
3
+ from PIL import Image
4
+ import requests
5
+ import matplotlib.pyplot as plt
6
+ import torch.nn as nn
7
 
8
+ extractor = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
9
+ model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
10
+
11
+ def predict(inp):
12
+ inputs = extractor(images=inp, return_tensors="pt")
13
+ outputs = model(**inputs)
14
+ logits = outputs.logits.cpu()
15
+ upsampled_logits = nn.functional.interpolate(
16
+ logits,
17
+ size=image.size[::-1],
18
+ mode="bilinear",
19
+ align_corners=False,
20
+ )
21
+ pred_seg = upsampled_logits.argmax(dim=1)[0]
22
+ pred_seg[pred_seg != 4] = 0
23
+ arr_seg = pred_seg.cpu().numpy().astype("uint8")
24
+ arr_seg *= 255
25
+ pil_seg = Image.fromarray(arr_seg)
26
+
27
+ return pil_seg
28
+
29
+ gr.Interface(fn=predict,
30
+ inputs=gr.Image(type="pil"),
31
+ outputs="image",
32
+ ).launch()