sapiens-demo / app.py
joselobenitezg's picture
wip
7a3883a
raw
history blame
3.88 kB
# Part of the code is from: fashn-ai/sapiens-body-part-segmentation
import os
import gradio as gr
import numpy as np
import spaces
import torch
from gradio.themes.utils import sizes
from PIL import Image
from torchvision import transforms
from utils.vis_utils import get_palette, visualize_mask_with_overlay
from config import SAPIENS_LITE_MODELS_PATH
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
CHECKPOINTS_DIR = "checkpoints"
def load_model(checkpoint_name: str):
checkpoint_path = os.path.join(CHECKPOINTS_DIR, CHECKPOINTS[checkpoint_name])
model = torch.jit.load(checkpoint_path)
model.eval()
model.to("cuda")
return model
#MODELS = {name: load_model(name) for name in CHECKPOINTS.keys()}
@torch.inference_mode()
def run_model(model, input_tensor, height, width):
output = model(input_tensor)
output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
_, preds = torch.max(output, 1)
return preds
transform_fn = transforms.Compose(
[
transforms.Resize((1024, 768)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
@spaces.GPU
def segment(image: Image.Image, model_name: str) -> Image.Image:
input_tensor = transform_fn(image).unsqueeze(0).to("cuda")
model = MODELS[model_name]
preds = run_model(model, input_tensor, height=image.height, width=image.width)
mask = preds.squeeze(0).cpu().numpy()
mask_image = Image.fromarray(mask.astype("uint8"))
blended_image = visualize_mask_with_overlay(image, mask_image, LABELS_TO_IDS, alpha=0.5)
return blended_image
def update_model_choices(task):
model_choices = list(SAPIENS_LITE_MODELS_PATH[task.lower()].keys())
return gr.Dropdown(choices=model_choices, value=model_choices[0] if model_choices else None)
with gr.Blocks() as demo:
gr.Markdown("# Sapiens Arena 🤸🏽‍♂️ - WIP devmode- Not yet available")
with gr.Tabs():
with gr.TabItem('Image'):
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil", format="png")
select_task = gr.Radio(
["Seg", "Pose", "Depth", "Normal"],
label="Task",
info="Choose the task to perfom",
choices=list(SAPIENS_LITE_MODELS_PATH.keys())
)
model_name = gr.Dropdown(
label="Model Version",
choices=list(SAPIENS_LITE_MODELS_PATH["seg"].keys()),
value="0.3B",
)
# example_model = gr.Examples(
# inputs=input_image,
# examples_per_page=10,
# examples=[
# os.path.join(ASSETS_DIR, "examples", img)
# for img in os.listdir(os.path.join(ASSETS_DIR, "examples"))
# ],
# )
with gr.Column():
result_image = gr.Image(label="Segmentation Result", format="png")
run_button = gr.Button("Run")
#gr.Image(os.path.join(ASSETS_DIR, "legend.png"), label="Legend", type="filepath")
with gr.TabItem('Video'):
gr.Markdown("In construction")
select_task.change(fn=update_model_choices, inputs=select_task, outputs=model_name)
run_button.click(
fn=segment,
inputs=[input_image, model_name],
outputs=[result_image],
)
if __name__ == "__main__":
demo.launch(share=False)