haritsahm's picture
Reformat Overview (#3)
15c4bc8
raw
history blame
3 kB
import json
import os
from pathlib import Path
import gradio as gr
import numpy as np
import torch
from monai.bundle import ConfigParser
from utils import page_utils
with open("configs/inference.json") as f:
inference_config = json.load(f)
device = torch.device('cpu')
if torch.cuda.is_available():
device = torch.device('cuda:0')
# * NOTE: device must be hardcoded, config file won't affect the device selection
inference_config["device"] = device
parser = ConfigParser()
parser.read_config(f=inference_config)
parser.read_meta(f="configs/metadata.json")
inference = parser.get_parsed_content("inferer")
# loader = parser.get_parsed_content("dataloader")
network = parser.get_parsed_content("network_def")
preprocess = parser.get_parsed_content("preprocessing")
postprocess = parser.get_parsed_content("postprocessing")
use_fp16 = os.environ.get('USE_FP16', False)
state_dict = torch.load("models/model.pt")
network.load_state_dict(state_dict, strict=True)
network = network.to(device)
network.eval()
if use_fp16 and torch.cuda.is_available():
network = network.half()
label2color = {0: (0, 0, 0),
1: (225, 24, 69), # RED
2: (135, 233, 17), # GREEN
3: (0, 87, 233), # BLUE
4: (242, 202, 25), # YELLOW
5: (137, 49, 239),} # PURPLE
example_files = list(Path("sample_data").glob("*.png"))
def visualize_instance_seg_mask(mask):
image = np.zeros((mask.shape[0], mask.shape[1], 3))
labels = np.unique(mask)
for i in range(image.shape[0]):
for j in range(image.shape[1]):
image[i, j, :] = label2color[mask[i, j]]
image = image / 255
return image
def query_image(img):
data = {"image": img}
batch = preprocess(data)
batch['image'] = batch['image'].to(device)
if use_fp16 and torch.cuda.is_available():
batch['image'] = batch['image'].half()
with torch.no_grad():
pred = inference(batch['image'].unsqueeze(dim=0), network)
batch["pred"] = pred
for k,v in batch["pred"].items():
batch["pred"][k] = v.squeeze(dim=0)
batch = postprocess(batch)
result = visualize_instance_seg_mask(batch["type_map"].squeeze())
# Combine image
result = batch["image"].permute(1, 2, 0).cpu().numpy() * 0.5 + result * 0.5
# Solve rotating problem
result = np.fliplr(result)
result = np.rot90(result, k=1)
return result
# load Markdown file
with open('index.html', encoding='utf-8') as f:
html_content = f.read()
demo = gr.Interface(
query_image,
inputs=[gr.Image(type="filepath")],
outputs="image",
theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set(
button_primary_background_fill="*primary_600",
button_primary_background_fill_hover="*primary_500",
button_primary_text_color="white",
),
description = html_content,
examples=example_files,
)
demo.queue(concurrency_count=20).launch()