|
|
|
|
|
import argparse |
|
import glob |
|
import multiprocessing as mp |
|
import os |
|
os.system('pip install git+https://github.com/facebookresearch/detectron2.git') |
|
|
|
|
|
import sys |
|
sys.path.insert(1, os.path.join(sys.path[0], '..')) |
|
|
|
|
|
import tempfile |
|
import time |
|
import warnings |
|
|
|
import cv2 |
|
import numpy as np |
|
import tqdm |
|
|
|
from detectron2.config import get_cfg |
|
from detectron2.data.detection_utils import read_image |
|
from detectron2.projects.deeplab import add_deeplab_config |
|
from detectron2.utils.logger import setup_logger |
|
|
|
from cat_seg import add_cat_seg_config |
|
from demo.predictor import VisualizationDemo |
|
import gradio as gr |
|
from matplotlib.backends.backend_agg import FigureCanvasAgg as fc |
|
|
|
|
|
WINDOW_NAME = "MaskFormer demo" |
|
|
|
|
|
def setup_cfg(args): |
|
|
|
cfg = get_cfg() |
|
add_deeplab_config(cfg) |
|
add_cat_seg_config(cfg) |
|
cfg.merge_from_file(args.config_file) |
|
cfg.merge_from_list(args.opts) |
|
cfg.freeze() |
|
return cfg |
|
|
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs") |
|
parser.add_argument( |
|
"--config-file", |
|
default="configs/vitl_swinb_384.yaml", |
|
metavar="FILE", |
|
help="path to config file", |
|
) |
|
parser.add_argument( |
|
"--input", |
|
nargs="+", |
|
help="A list of space separated input images; " |
|
"or a single glob pattern such as 'directory/*.jpg'", |
|
) |
|
parser.add_argument( |
|
"--opts", |
|
help="Modify config options using the command-line 'KEY VALUE' pairs", |
|
default=["MODEL.WEIGHTS", "model_final.pth", |
|
"MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON", "datasets/voc20.json", |
|
"MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON", "datasets/voc20.json", |
|
"TEST.SLIDING_WINDOW", "True", |
|
"MODEL.SEM_SEG_HEAD.POOLING_SIZES", "[1,1]"], |
|
nargs=argparse.REMAINDER, |
|
) |
|
return parser |
|
|
|
def save_masks(preds, text): |
|
preds = preds['sem_seg'].argmax(dim=0).cpu().numpy() |
|
for i, t in enumerate(text): |
|
dir = f"masks/mask_{t}.png" |
|
mask = preds == i |
|
cv2.imwrite(dir, mask * 255) |
|
|
|
def predict(image, text): |
|
args = get_parser().parse_args() |
|
cfg = setup_cfg(args) |
|
demo = VisualizationDemo(cfg, text=text) |
|
predictions, visualized_output = demo.run_on_image(image) |
|
save_masks(predictions, text.split(',')) |
|
canvas = fc(visualized_output.fig) |
|
canvas.draw() |
|
out = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(canvas.get_width_height()[::-1] + (3,)) |
|
|
|
return out[..., ::-1] |
|
|
|
if __name__ == "__main__": |
|
args = get_parser().parse_args() |
|
cfg = setup_cfg(args) |
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=[gr.Image(), gr.Textbox(placeholder="Classes to segment")], |
|
outputs="image", |
|
) |
|
iface.launch() |
|
|