File size: 4,673 Bytes
aea73e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import os, requests
import numpy as np
import torch
import cv2
from cell_segmentation.inference.inference_cellvit_experiment_pannuke import InferenceCellViTParser,InferenceCellViT
from cell_segmentation.inference.inference_cellvit_experiment_monuseg import InferenceCellViTMoNuSegParser,MoNuSegInference


## local |  remote
RUN_MODE = "remote"
if RUN_MODE != "local":
    os.system("wget https://huggingface.co/xiazhi/LKCell-demo/resolve/main/model_best.pth")
    ## examples
    os.system("wget https://huggingface.co/xiazhi/LKCell-demo/resolve/main/1.png")
    os.system("wget https://huggingface.co/xiazhi/LKCell-demo/resolve/main/2.png")
    os.system("wget https://huggingface.co/xiazhi/LKCell-demo/resolve/main/3.png")
    os.system("wget https://huggingface.co/xiazhi/LKCell-demo/resolve/main/4.png")

## step 1: set up model

device = "cpu"

## pannuke set
pannuke_parser = InferenceCellViTParser()
pannuke_configurations = pannuke_parser.parse_arguments()
pannuke_inf = InferenceCellViT(
        run_dir=pannuke_configurations["run_dir"],
        checkpoint_name=pannuke_configurations["checkpoint_name"],
        gpu=pannuke_configurations["gpu"],
        magnification=pannuke_configurations["magnification"],
    )

pannuke_checkpoint = torch.load(
    pannuke_inf.run_dir  / pannuke_inf.checkpoint_name, map_location="cpu"
)
pannuke_model = pannuke_inf.get_model(model_type=pannuke_checkpoint["arch"])
pannuke_model.load_state_dict(pannuke_checkpoint["model_state_dict"])
# # put model in eval mode
pannuke_model.to(device)
pannuke_model.eval()


## monuseg set
monuseg_parser = InferenceCellViTMoNuSegParser()
monuseg_configurations = monuseg_parser.parse_arguments()
monuseg_inf = MoNuSegInference(
        model_path=monuseg_configurations["model"],
        dataset_path=monuseg_configurations["dataset"],
        outdir=monuseg_configurations["outdir"],
        gpu=monuseg_configurations["gpu"],
        patching=monuseg_configurations["patching"],
        magnification=monuseg_configurations["magnification"],
        overlap=monuseg_configurations["overlap"],
    )


def click_process(image_input , type_dataset):
    if type_dataset == "pannuke":
        pannuke_inf.run_single_image_inference(pannuke_model,image_input)
    else:
        monuseg_inf.run_single_image_inference(monuseg_inf.model, image_input)
        
    image_output = cv2.imread("pred_img.png")
    image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2RGB)   
    return image_output


demo = gr.Blocks(title="LkCell")
with demo:
    gr.Markdown(value="""
                    **Gradio demo for LKCell: Efficient Cell Nuclei Instance Segmentation with Large Convolution Kernels**. Check our [Github Repo](https://github.com/ziwei-cui/LKCellv1) πŸ˜›.
                    """)
    with gr.Row():
        with gr.Column():
            with gr.Row():
                Image_input = gr.Image(type="numpy", label="Input", interactive=True,height=480)
            with gr.Row():
                Type_dataset = gr.Radio(choices=["pannuke", "monuseg"], label=" input image's dataset type",value="pannuke")
                                            
        with gr.Column():
            with gr.Row():
                image_output = gr.Image(type="numpy", label="Output",height=480)
    with gr.Row():
        Button_run = gr.Button("πŸš€ Submit (发送) ")
        clear_button = gr.ClearButton(components=[Image_input,Type_dataset,image_output],value="🧹 Clear (清陀)")

    Button_run.click(fn=click_process, inputs=[Image_input,  Type_dataset ], outputs=[image_output])
    
    ## guiline
    gr.Markdown(value="""    
                    πŸ””**Guideline**
                    1. Upload your image or select one from the examples.
                    2. Set up the arguments: "Type_dataset".
                    3. Run the Submit button to get the output.
                    """)
    # if RUN_MODE != "local":
    gr.Examples(examples=[
                ['1.png', "pannuke"],
                ['2.png', "pannuke"],
                ['3.png', "monuseg"],
                ['4.png', "monuseg"],
                ], 
                inputs=[Image_input, Type_dataset], outputs=[image_output], label="Examples")
    gr.HTML(value="""
                <p style="text-align:center; color:orange"> <a href='https://github.com/ziwei-cui/LKCellv1' target='_blank'>Github Repo</a></p>
                    """)
    gr.Markdown(value="""    
                    Template is adapted from [Here](https://huggingface.co/spaces/menghanxia/disco)
                    """)

if RUN_MODE == "local":
    demo.launch(server_name='127.0.0.1',server_port=8003)
else:
    demo.launch()