rap-sam / main.py
HarborYuan's picture
add rap_sam
502989e
raw
history blame
6.08 kB
import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
# mm libs
from mmdet.registry import MODELS
from mmdet.structures import DetDataSample
from mmengine import Config, print_log
from mmengine.structures import InstanceData
from PIL import ImageDraw
IMG_SIZE = 1024
TITLE = "<center><strong><font size='8'>🚀RAP-SAM: Towards Real-Time All-Purpose Segment Anything<font></strong></center>"
CSS = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
model_cfg = Config.fromfile('app/configs/rap_sam_r50_12e_adaptor.py')
model = MODELS.build(model_cfg.model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device=device)
model = model.eval()
model.init_weights()
mean = torch.tensor([123.675, 116.28, 103.53], device=device)[:, None, None]
std = torch.tensor([58.395, 57.12, 57.375], device=device)[:, None, None]
class IMGState:
def __init__(self):
self.img = None
self.selected_points = []
self.available_to_set = True
def set_img(self, img):
self.img = img
self.available_to_set = False
def clear(self):
self.img = None
self.selected_points = []
self.available_to_set = True
def clean(self):
self.selected_points = []
@property
def available(self):
return self.available_to_set
@classmethod
def cls_clean(cls, state):
state.clean()
return Image.fromarray(state.img), None
@classmethod
def cls_clear(cls, state):
state.clear()
return None, None
def store_img(img, img_state):
w, h = img.size
scale = IMG_SIZE / max(w, h)
new_w = int(w * scale)
new_h = int(h * scale)
img = img.resize((new_w, new_h), resample=Image.Resampling.BILINEAR)
img_numpy = np.array(img)
img_state.set_img(img_numpy)
print_log(f"Successfully loaded an image with size {new_w} x {new_h}", logger='current')
return img, None
def get_points_with_draw(image, img_state, evt: gr.SelectData):
x, y = evt.index[0], evt.index[1]
print_log(f"Point: {x}_{y}", logger='current')
point_radius, point_color = 10, (97, 217, 54)
img_state.selected_points.append([x, y])
if len(img_state.selected_points) > 0:
img_state.selected_points = img_state.selected_points[-1:]
image = Image.fromarray(img_state.img)
draw = ImageDraw.Draw(image)
draw.ellipse(
[(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
fill=point_color,
)
return image
def segment_point(image, img_state):
output_img = img_state.img
h, w = output_img.shape[:2]
img_tensor = torch.tensor(output_img, device=device, dtype=torch.float32).permute((2, 0, 1))[None]
img_tensor = (img_tensor - mean) / std
im_w = w if w % 32 == 0 else w // 32 * 32 + 32
im_h = h if h % 32 == 0 else h // 32 * 32 + 32
img_tensor = F.pad(img_tensor, (0, im_w - w, 0, im_h - h), 'constant', 0)
if len(img_state.selected_points) > 0:
input_points = torch.tensor(img_state.selected_points, dtype=torch.float32, device=device)
batch_data_samples = [DetDataSample()]
selected_point = torch.cat([input_points - 3, input_points + 3], 1)
gt_instances = InstanceData(
point_coords=selected_point,
)
pb_labels = torch.ones(len(gt_instances), dtype=torch.long, device=device)
gt_instances.pb_labels = pb_labels
batch_data_samples[0].gt_instances_collected = gt_instances
batch_data_samples[0].set_metainfo(dict(batch_input_shape=(im_h, im_w)))
batch_data_samples[0].set_metainfo(dict(img_shape=(h, w)))
else:
batch_data_samples = [DetDataSample()]
batch_data_samples[0].set_metainfo(dict(batch_input_shape=(im_h, im_w)))
batch_data_samples[0].set_metainfo(dict(img_shape=(h, w)))
with torch.no_grad():
masks, cls_pred = model.predict_with_point(img_tensor, batch_data_samples)
masks = masks[0, 0, :h, :w]
masks = masks > 0.
rgb_shape = tuple(list(masks.shape) + [3])
color = np.zeros(rgb_shape, dtype=np.uint8)
color[masks] = np.array([97, 217, 54])
# color[masks] = np.array([217, 90, 54])
output_img = (output_img * 0.7 + color * 0.3).astype(np.uint8)
output_img = Image.fromarray(output_img)
return image, output_img
def register_title():
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(TITLE)
def register_point_mode():
with gr.Tab("Point mode"):
img_state = gr.State(IMGState())
with gr.Row(variant="panel"):
with gr.Column(scale=1):
img_p = gr.Image(label="Input Image", type="pil")
with gr.Column(scale=1):
segm_p = gr.Image(label="Segment", interactive=False, type="pil")
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
segment_btn = gr.Button("Segment", variant="primary")
clean_btn = gr.Button("Clean Prompts", variant="secondary")
img_p.upload(
store_img,
[img_p, img_state],
[img_p, segm_p]
)
img_p.select(
get_points_with_draw,
[img_p, img_state],
img_p
)
segment_btn.click(
segment_point,
[img_p, img_state],
[img_p, segm_p]
)
clean_btn.click(
IMGState.cls_clean,
img_state,
[img_p, segm_p]
)
img_p.clear(
IMGState.cls_clear,
img_state,
[img_p, segm_p]
)
def build_demo():
with gr.Blocks(css=CSS, title="RAP-SAM") as _demo:
register_title()
register_point_mode()
return _demo
if __name__ == '__main__':
demo = build_demo()
demo.queue(api_open=False)
demo.launch(server_name='0.0.0.0')