| | import os |
| | import gradio as gr |
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from torchvision import transforms |
| | from PIL import Image |
| | import tempfile |
| |
|
| | from gradio.themes.utils import sizes |
| | from classes_and_palettes import GOLIATH_CLASSES |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class Config: |
| | ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets") |
| | CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints") |
| | CHECKPOINTS = { |
| | "0.3b": "sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2", |
| | "0.6b": "sapiens_0.6b_goliath_best_goliath_mIoU_7777_epoch_178_torchscript.pt2", |
| | "1b": "sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2", |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class ModelManager: |
| | _cache = {} |
| |
|
| | @staticmethod |
| | def load_model(name: str): |
| | if name in ModelManager._cache: |
| | return ModelManager._cache[name] |
| |
|
| | path = os.path.join(Config.CHECKPOINTS_DIR, Config.CHECKPOINTS[name]) |
| | model = torch.jit.load(path) |
| | model.eval().to("cuda") |
| | ModelManager._cache[name] = model |
| | return model |
| |
|
| | @staticmethod |
| | @torch.inference_mode() |
| | def run(model, x, h, w): |
| | out = model(x) |
| | out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=False) |
| | return out.argmax(1) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class ImageProcessor: |
| | def __init__(self): |
| | self.tf = transforms.Compose([ |
| | transforms.Resize((1024, 768)), |
| | transforms.ToTensor(), |
| | transforms.Normalize( |
| | mean=[123.5 / 255, 116.5 / 255, 103.5 / 255], |
| | std=[58.5 / 255, 57.0 / 255, 57.5 / 255], |
| | ), |
| | ]) |
| |
|
| | def process(self, image: Image.Image, model_name: str): |
| | model = ModelManager.load_model(model_name) |
| | x = self.tf(image).unsqueeze(0).to("cuda") |
| |
|
| | pred = ModelManager.run(model, x, image.height, image.width) |
| | mask = pred.squeeze(0).cpu().numpy() |
| |
|
| | |
| | npy_path = tempfile.mktemp(suffix=".npy") |
| | np.save(npy_path, mask) |
| |
|
| | |
| | annotations = self._build_annotations(mask) |
| |
|
| | return (image, annotations), npy_path |
| |
|
| | def _build_annotations(self, mask: np.ndarray): |
| | annotations = [] |
| | for class_id in np.unique(mask): |
| | if class_id >= len(GOLIATH_CLASSES): |
| | continue |
| |
|
| | binary_mask = (mask == class_id).astype(np.uint8) |
| | if binary_mask.sum() == 0: |
| | continue |
| |
|
| | annotations.append( |
| | (binary_mask, GOLIATH_CLASSES[class_id]) |
| | ) |
| |
|
| | return annotations |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class GradioInterface: |
| | def __init__(self): |
| | self.processor = ImageProcessor() |
| |
|
| | def create(self): |
| | def run(image, model): |
| | return self.processor.process(image, model) |
| |
|
| | with gr.Blocks() as demo: |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | input_image = gr.Image( |
| | label="Input Image", |
| | type="pil", |
| | ) |
| |
|
| | model_name = gr.Dropdown( |
| | label="Model Size", |
| | choices=list(Config.CHECKPOINTS.keys()), |
| | value="1b", |
| | ) |
| |
|
| | run_btn = gr.Button("Run Segmentation", variant="primary") |
| |
|
| | with gr.Column(scale=2): |
| | annotated = gr.AnnotatedImage( |
| | label="Segmentation Result", |
| | show_legend=True, |
| | height=512, |
| | ) |
| |
|
| | mask_file = gr.File(label="Raw Mask (.npy)") |
| |
|
| | run_btn.click( |
| | fn=run, |
| | inputs=[input_image, model_name], |
| | outputs=[annotated, mask_file], |
| | ) |
| |
|
| | return demo |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def main(): |
| | if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | torch.backends.cudnn.allow_tf32 = True |
| |
|
| | app = GradioInterface().create() |
| | app.launch(server_name="0.0.0.0", share=False) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| |
|