Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os, requests | |
import numpy as np | |
import torch | |
import cv2 | |
from cell_segmentation.inference.inference_cellvit_experiment_pannuke import InferenceCellViTParser,InferenceCellViT | |
from cell_segmentation.inference.inference_cellvit_experiment_monuseg import InferenceCellViTMoNuSegParser,MoNuSegInference | |
## local | remote | |
RUN_MODE = "remote" | |
if RUN_MODE != "local": | |
os.system("wget https://huggingface.co/xiazhi/LKCell-demo/resolve/main/model_best.pth") | |
## examples | |
os.system("wget https://huggingface.co/xiazhi/LKCell-demo/resolve/main/1.png") | |
os.system("wget https://huggingface.co/xiazhi/LKCell-demo/resolve/main/2.png") | |
os.system("wget https://huggingface.co/xiazhi/LKCell-demo/resolve/main/3.png") | |
os.system("wget https://huggingface.co/xiazhi/LKCell-demo/resolve/main/4.png") | |
## step 1: set up model | |
device = "cpu" | |
## pannuke set | |
pannuke_parser = InferenceCellViTParser() | |
pannuke_configurations = pannuke_parser.parse_arguments() | |
pannuke_inf = InferenceCellViT( | |
run_dir=pannuke_configurations["run_dir"], | |
checkpoint_name=pannuke_configurations["checkpoint_name"], | |
gpu=pannuke_configurations["gpu"], | |
magnification=pannuke_configurations["magnification"], | |
) | |
pannuke_checkpoint = torch.load( | |
pannuke_inf.run_dir / pannuke_inf.checkpoint_name, map_location="cpu" | |
) | |
pannuke_model = pannuke_inf.get_model(model_type=pannuke_checkpoint["arch"]) | |
pannuke_model.load_state_dict(pannuke_checkpoint["model_state_dict"]) | |
# # put model in eval mode | |
pannuke_model.to(device) | |
pannuke_model.eval() | |
## monuseg set | |
monuseg_parser = InferenceCellViTMoNuSegParser() | |
monuseg_configurations = monuseg_parser.parse_arguments() | |
monuseg_inf = MoNuSegInference( | |
model_path=monuseg_configurations["model"], | |
dataset_path=monuseg_configurations["dataset"], | |
outdir=monuseg_configurations["outdir"], | |
gpu=monuseg_configurations["gpu"], | |
patching=monuseg_configurations["patching"], | |
magnification=monuseg_configurations["magnification"], | |
overlap=monuseg_configurations["overlap"], | |
) | |
def click_process(image_input , type_dataset): | |
if type_dataset == "pannuke": | |
pannuke_inf.run_single_image_inference(pannuke_model,image_input) | |
else: | |
monuseg_inf.run_single_image_inference(monuseg_inf.model, image_input) | |
image_output = cv2.imread("pred_img.png") | |
image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2RGB) | |
return image_output | |
demo = gr.Blocks(title="LkCell") | |
with demo: | |
gr.Markdown(value=""" | |
**Gradio demo for LKCell: Efficient Cell Nuclei Instance Segmentation with Large Convolution Kernels**. Check our [Github Repo](https://github.com/ziwei-cui/LKCellv1) π. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
Image_input = gr.Image(type="numpy", label="Input", interactive=True,height=480) | |
with gr.Row(): | |
Type_dataset = gr.Radio(choices=["pannuke", "monuseg"], label=" input image's dataset type",value="pannuke") | |
with gr.Column(): | |
with gr.Row(): | |
image_output = gr.Image(type="numpy", label="Output",height=480) | |
with gr.Row(): | |
Button_run = gr.Button("π Submit (ει) ") | |
clear_button = gr.ClearButton(components=[Image_input,Type_dataset,image_output],value="π§Ή Clear (ζΈ ι€)") | |
Button_run.click(fn=click_process, inputs=[Image_input, Type_dataset ], outputs=[image_output]) | |
## guiline | |
gr.Markdown(value=""" | |
π**Guideline** | |
1. Upload your image or select one from the examples. | |
2. Set up the arguments: "Type_dataset". | |
3. Run the Submit button to get the output. | |
""") | |
# if RUN_MODE != "local": | |
gr.Examples(examples=[ | |
['1.png', "pannuke"], | |
['2.png', "pannuke"], | |
['3.png', "monuseg"], | |
['4.png', "monuseg"], | |
], | |
inputs=[Image_input, Type_dataset], outputs=[image_output], label="Examples") | |
gr.HTML(value=""" | |
<p style="text-align:center; color:orange"> <a href='https://github.com/ziwei-cui/LKCellv1' target='_blank'>Github Repo</a></p> | |
""") | |
gr.Markdown(value=""" | |
Template is adapted from [Here](https://huggingface.co/spaces/menghanxia/disco) | |
""") | |
if RUN_MODE == "local": | |
demo.launch(server_name='127.0.0.1',server_port=8003) | |
else: | |
demo.launch() |