# -------------------------------------------------------- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Xueyan Zou (xueyan@cs.wisc.edu), Jianwei Yang (jianwyan@microsoft.com) # -------------------------------------------------------- import os os.system("python -m pip install git+https://github.com/MaureenZOU/detectron2-xyz.git") import gradio as gr import torch import argparse from xdecoder.BaseModel import BaseModel from xdecoder import build_model from utils.distributed import init_distributed from utils.arguments import load_opt_from_config_files from tasks import * def parse_option(): parser = argparse.ArgumentParser('X-Decoder All-in-One Demo', add_help=False) parser.add_argument('--conf_files', default="configs/xdecoder/svlp_focalt_lang.yaml", metavar="FILE", help='path to config file', ) args = parser.parse_args() return args ''' build args ''' args = parse_option() opt = load_opt_from_config_files(args.conf_files) opt = init_distributed(opt) # META DATA pretrained_pth_last = os.path.join("xdecoder_focalt_last.pt") pretrained_pth_novg = os.path.join("xdecoder_focalt_last_novg.pt") if not os.path.exists(pretrained_pth_last): os.system("wget {}".format("https://projects4jw.blob.core.windows.net/x-decoder/release/xdecoder_focalt_last.pt")) if not os.path.exists(pretrained_pth_novg): os.system("wget {}".format("https://projects4jw.blob.core.windows.net/x-decoder/release/xdecoder_focalt_last_novg.pt")) ''' build model ''' model_last = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth_last).eval().cuda() model_cap = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth_novg).eval().cuda() with torch.no_grad(): model_last.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(["background", "background"], is_eval=True) model_cap.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(["background", "background"], is_eval=True) ''' inference model ''' @torch.no_grad() def inference(image, task, *args, **kwargs): image = image.convert("RGB") with torch.autocast(device_type='cuda', dtype=torch.float16): if task == 'Referring Inpainting': return referring_inpainting(model_last, image, *args, **kwargs) elif task == 'Referring Segmentation': return referring_segmentation(model_last, image, *args, **kwargs) elif task == 'Open Vocabulary Semantic Segmentation': return open_semseg(model_last, image, *args, **kwargs) elif task == 'Open Vocabulary Panoptic Segmentation': return open_panoseg(model_last, image, *args, **kwargs) elif task == 'Open Vocabulary Instance Segmentation': return open_instseg(model_last, image, *args, **kwargs) elif task == 'Image Captioning': return image_captioning(model_cap, image, *args, **kwargs) elif task == 'Referring Captioning (Beta)': return referring_captioning([model_last, model_cap], image, *args, **kwargs) elif task == 'Text Retrieval': return text_retrieval(model_cap, image, *args, **kwargs) elif task == 'Image/Region Retrieval (Only Support Exampled 80 images)': return region_retrieval([model_cap, model_last], image, *args, **kwargs) ''' launch app ''' title = "X-Decoder All-in-One Demo" description = "
Project Page | Paper | Github Repo | Video
" article = "The Demo is Run on X-Decoder (Focal-T)." inputs = [gr.inputs.Image(type='pil'), gr.inputs.Radio(choices=["Referring Segmentation", 'Open Vocabulary Semantic Segmentation','Open Vocabulary Instance Segmentation', "Open Vocabulary Panoptic Segmentation", "Image Captioning", "Text Retrieval", "Referring Inpainting", "Referring Captioning (Beta)", "Image/Region Retrieval (Only Support Exampled 80 images)"], type="value", default="OpenVocab Semantic Segmentation", label="Task"), gr.Textbox(label="xdecoder_text"), gr.Textbox(label="inpainting_text"), gr.Textbox(label="task_description")] gr.Interface( fn=inference, inputs=inputs, outputs=[ gr.outputs.Image( type="pil", label="segmentation results"), gr.Textbox(label="text restuls"), gr.outputs.Image( type="pil", label="inpainting results"), ], examples=[ ["./images/fruit.jpg", "Referring Segmentation", "The larger watermelon.,The front white flower.,White tea pot.,Flower bunch.,white vase.,The peach on the left.,The brown knife.", '', 'Format: s,s,s'], ["./images/animals.png", "Open Vocabulary Semantic Segmentation", "zebra,antelope,giraffe,ostrich,sky,water,grass,sand,tree", '', 'Format: x,x,x'], ["./images/street.jpg", "Open Vocabulary Panoptic Segmentation", "stuff:building,sky,street,tree,rock,sidewalk;thing:car,person,traffic light", '', 'Format: stuff:x,x,x;thing:y,y,y'], ["./images/owls.jpeg", "Open Vocabulary Instance Segmentation", "owl", '', 'Format: y,y,y'], ["./images/mountain.jpeg", "Image Captioning", "", '', ''], ["./images/rose.webp", "Text Retrieval", "lily,rose,peoney,tulip", '', 'Format: s,s,s'], ["./images/region_retrieval.png", "Image/Region Retrieval (Only Support Exampled 80 images)", "The tangerine on the plate.", '', 'Please describe the object in a detailed way.'], ["./images/landscape.jpg", "Referring Captioning (Beta)", "cloud", '', 'Please fill in a noun/noun phrase. (may start with a/the)'], ["./images/apples.jpg", "Referring Inpainting", "a yellow apple", 'a pear', 'x-decoder + ldm (inference takes ~40s.)'], ], title=title, description=description, article=article, allow_flagging='never', cache_examples=True, ).launch(share=True)