mmpose-webui / calculate_masks.py
Chris
More WIP. Returns a sensible result now!
c9ec478
raw
history blame
676 Bytes
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
import torch.nn as nn
def calculate_seg_mask(image):
processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = nn.functional.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
pred_seg = upsampled_logits.argmax(dim=1)[0]
return pred_seg