LKCell / app.py
xiazhi's picture
Update app.py
44bb304 verified
import gradio as gr
import os, requests
import numpy as np
import torch
import cv2
from cell_segmentation.inference.inference_cellvit_experiment_pannuke import InferenceCellViTParser,InferenceCellViT
from cell_segmentation.inference.inference_cellvit_experiment_monuseg import InferenceCellViTMoNuSegParser,MoNuSegInference
## local | remote
RUN_MODE = "remote"
if RUN_MODE != "local":
os.system("wget https://huggingface.co/xiazhi/LKCell/resolve/main/model_best.pth")
## examples
os.system("wget https://huggingface.co/xiazhi/LKCell/resolve/main/1.png")
os.system("wget https://huggingface.co/xiazhi/LKCell/resolve/main/2.png")
os.system("wget https://huggingface.co/xiazhi/LKCell/resolve/main/3.png")
os.system("wget https://huggingface.co/xiazhi/LKCell/resolve/main/4.png")
## step 1: set up model
device = "cpu"
## pannuke set
pannuke_parser = InferenceCellViTParser()
pannuke_configurations = pannuke_parser.parse_arguments()
pannuke_inf = InferenceCellViT(
run_dir=pannuke_configurations["run_dir"],
checkpoint_name=pannuke_configurations["checkpoint_name"],
gpu=pannuke_configurations["gpu"],
magnification=pannuke_configurations["magnification"],
)
pannuke_checkpoint = torch.load(
pannuke_inf.run_dir / pannuke_inf.checkpoint_name, map_location="cpu"
)
pannuke_model = pannuke_inf.get_model(model_type=pannuke_checkpoint["arch"])
pannuke_model.load_state_dict(pannuke_checkpoint["model_state_dict"])
# # put model in eval mode
pannuke_model.to(device)
pannuke_model.eval()
## monuseg set
monuseg_parser = InferenceCellViTMoNuSegParser()
monuseg_configurations = monuseg_parser.parse_arguments()
monuseg_inf = MoNuSegInference(
model_path=monuseg_configurations["model"],
dataset_path=monuseg_configurations["dataset"],
outdir=monuseg_configurations["outdir"],
gpu=monuseg_configurations["gpu"],
patching=monuseg_configurations["patching"],
magnification=monuseg_configurations["magnification"],
overlap=monuseg_configurations["overlap"],
)
def click_process(image_input , type_dataset):
if type_dataset == "pannuke":
if image_input.shape[0] > 512 and image_input.shape[1] > 512:
image_input = cv2.resize(image_input, (512,512))
pannuke_inf.run_single_image_inference(pannuke_model,image_input)
else:
if image_input.shape[0] > 512 and image_input.shape[1] > 512:
image_input = cv2.resize(image_input, (512,512))
monuseg_inf.run_single_image_inference(monuseg_inf.model, image_input)
image_output = cv2.imread("raw_pred.png")
image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2RGB)
image_output2 = cv2.imread("pred_img.png")
image_output2 = cv2.cvtColor(image_output2, cv2.COLOR_BGR2RGB)
return image_output,image_output2
demo = gr.Blocks(title="LkCell")
with demo:
gr.Markdown(value="""
**Gradio demo for LKCell: Efficient Cell Nuclei Instance Segmentation with Large Convolution Kernels**. Check our [Github Repo](https://github.com/hustvl/LKCell) πŸ˜›.
""")
with gr.Row():
with gr.Column():
with gr.Row():
Image_input = gr.Image(type="numpy", label="Input", interactive=True,height=480)
with gr.Row():
Type_dataset = gr.Radio(choices=["pannuke", "monuseg"], label=" input image's dataset type",value="pannuke")
with gr.Column():
image_output = gr.Image(type="numpy", label="image prediction",height=480,width=480)
image_output2 = gr.Image(type="numpy", label="all predictions",height=480)
with gr.Row():
Button_run = gr.Button("πŸš€ Submit (发送) ")
clear_button = gr.ClearButton(components=[Image_input,Type_dataset,image_output,image_output2],value="🧹 Clear (清陀)")
Button_run.click(fn=click_process, inputs=[Image_input, Type_dataset ], outputs=[image_output,image_output2])
## guiline
gr.Markdown(value="""
πŸ””**Guideline**
1. Upload your image or select one from the examples.
2. Set up the arguments: "Type_dataset" to enjoy two dataset type's inference
3. Due to the limit of CPU , we resize the input image whose size is larger than (512,512) to (512,512)
4. Run the Submit button to get the output.
""")
# if RUN_MODE != "local":
gr.Examples(examples=[
['1.png', "pannuke"],
['2.png', "pannuke"],
['3.png', "monuseg"],
['4.png', "monuseg"],
],
inputs=[Image_input, Type_dataset], outputs=[image_output,image_output2], label="Examples")
gr.HTML(value="""
<p style="text-align:center; color:orange"> <a href='https://github.com/hustvl/LKCell' target='_blank'>Github Repo</a></p>
""")
gr.Markdown(value="""
Template is adapted from [Here](https://huggingface.co/spaces/menghanxia/disco)
""")
if RUN_MODE == "local":
demo.launch(server_name='127.0.0.1',server_port=8003)
else:
demo.launch()