Spaces:
Sleeping
Sleeping
# Part of the code is from: fashn-ai/sapiens-body-part-segmentation | |
import os | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch | |
from gradio.themes.utils import sizes | |
from PIL import Image | |
from torchvision import transforms | |
from utils.vis_utils import get_palette, visualize_mask_with_overlay | |
from config import SAPIENS_LITE_MODELS_PATH | |
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
CHECKPOINTS_DIR = "checkpoints" | |
def load_model(checkpoint_name: str): | |
checkpoint_path = os.path.join(CHECKPOINTS_DIR, CHECKPOINTS[checkpoint_name]) | |
model = torch.jit.load(checkpoint_path) | |
model.eval() | |
model.to("cuda") | |
return model | |
#MODELS = {name: load_model(name) for name in CHECKPOINTS.keys()} | |
def run_model(model, input_tensor, height, width): | |
output = model(input_tensor) | |
output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False) | |
_, preds = torch.max(output, 1) | |
return preds | |
transform_fn = transforms.Compose( | |
[ | |
transforms.Resize((1024, 768)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
] | |
) | |
def segment(image: Image.Image, model_name: str) -> Image.Image: | |
input_tensor = transform_fn(image).unsqueeze(0).to("cuda") | |
model = MODELS[model_name] | |
preds = run_model(model, input_tensor, height=image.height, width=image.width) | |
mask = preds.squeeze(0).cpu().numpy() | |
mask_image = Image.fromarray(mask.astype("uint8")) | |
blended_image = visualize_mask_with_overlay(image, mask_image, LABELS_TO_IDS, alpha=0.5) | |
return blended_image | |
def update_model_choices(task): | |
model_choices = list(SAPIENS_LITE_MODELS_PATH[task.lower()].keys()) | |
return gr.Dropdown(choices=model_choices, value=model_choices[0] if model_choices else None) | |
with gr.Blocks() as demo: | |
gr.Markdown("# Sapiens Arena 🤸🏽♂️ - WIP devmode- Not yet available") | |
with gr.Tabs(): | |
with gr.TabItem('Image'): | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Input Image", type="pil", format="png") | |
select_task = gr.Radio( | |
["Seg", "Pose", "Depth", "Normal"], | |
label="Task", | |
info="Choose the task to perfom", | |
choices=list(SAPIENS_LITE_MODELS_PATH.keys()) | |
) | |
model_name = gr.Dropdown( | |
label="Model Version", | |
choices=list(SAPIENS_LITE_MODELS_PATH["seg"].keys()), | |
value="0.3B", | |
) | |
# example_model = gr.Examples( | |
# inputs=input_image, | |
# examples_per_page=10, | |
# examples=[ | |
# os.path.join(ASSETS_DIR, "examples", img) | |
# for img in os.listdir(os.path.join(ASSETS_DIR, "examples")) | |
# ], | |
# ) | |
with gr.Column(): | |
result_image = gr.Image(label="Segmentation Result", format="png") | |
run_button = gr.Button("Run") | |
#gr.Image(os.path.join(ASSETS_DIR, "legend.png"), label="Legend", type="filepath") | |
with gr.TabItem('Video'): | |
gr.Markdown("In construction") | |
select_task.change(fn=update_model_choices, inputs=select_task, outputs=model_name) | |
run_button.click( | |
fn=segment, | |
inputs=[input_image, model_name], | |
outputs=[result_image], | |
) | |
if __name__ == "__main__": | |
demo.launch(share=False) | |