Spaces:
Runtime error
Runtime error
from PIL import Image | |
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation | |
import torch.nn as nn | |
def calculate_seg_mask(image): | |
image = Image.open(image).convert("RGB") | |
processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes") | |
model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes") | |
class_names = { | |
0: "Background", 1: "Hat", 2: "Hair", 3: "Sunglasses", | |
4: "Upper-clothes", 5: "Skirt", 6: "Pants", 7: "Dress", | |
8: "Belt", 9: "Left-shoe", 10: "Right-shoe", 11: "Face", | |
12: "Left-leg", 13: "Right-leg", 14: "Left-arm", 15: "Right-arm", | |
16: "Bag", 17: "Scarf" | |
} | |
inputs = processor(images=image, return_tensors="pt") | |
outputs = model(**inputs) | |
logits = outputs.logits.cpu() | |
upsampled_logits = nn.functional.interpolate( | |
logits, | |
size=image.size[::-1], | |
mode="bilinear", | |
align_corners=False, | |
) | |
pred_seg = upsampled_logits.argmax(dim=1)[0] | |
return pred_seg |