File size: 5,317 Bytes
aea73e2
 
 
 
 
 
 
 
 
 
3bc3097
6a26a7c
44bb304
6a26a7c
f2d83f6
 
 
 
aea73e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f95222d
 
aea73e2
 
f95222d
 
aea73e2
f2d83f6
 
 
 
 
 
aea73e2
 
 
 
 
59daad4
aea73e2
 
 
 
 
 
 
 
 
f2d83f6
 
 
aea73e2
 
f2d83f6
aea73e2
f2d83f6
aea73e2
 
 
 
 
bd518fa
 
 
aea73e2
 
 
 
 
 
 
 
f2d83f6
aea73e2
59daad4
aea73e2
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import os, requests
import numpy as np
import torch
import cv2
from cell_segmentation.inference.inference_cellvit_experiment_pannuke import InferenceCellViTParser,InferenceCellViT
from cell_segmentation.inference.inference_cellvit_experiment_monuseg import InferenceCellViTMoNuSegParser,MoNuSegInference


## local |  remote
RUN_MODE = "remote"
if RUN_MODE != "local":
    os.system("wget https://huggingface.co/xiazhi/LKCell/resolve/main/model_best.pth")
    ## examples
    os.system("wget https://huggingface.co/xiazhi/LKCell/resolve/main/1.png")
    os.system("wget https://huggingface.co/xiazhi/LKCell/resolve/main/2.png")
    os.system("wget https://huggingface.co/xiazhi/LKCell/resolve/main/3.png")
    os.system("wget https://huggingface.co/xiazhi/LKCell/resolve/main/4.png")

## step 1: set up model

device = "cpu"

## pannuke set
pannuke_parser = InferenceCellViTParser()
pannuke_configurations = pannuke_parser.parse_arguments()
pannuke_inf = InferenceCellViT(
        run_dir=pannuke_configurations["run_dir"],
        checkpoint_name=pannuke_configurations["checkpoint_name"],
        gpu=pannuke_configurations["gpu"],
        magnification=pannuke_configurations["magnification"],
    )

pannuke_checkpoint = torch.load(
    pannuke_inf.run_dir  / pannuke_inf.checkpoint_name, map_location="cpu"
)
pannuke_model = pannuke_inf.get_model(model_type=pannuke_checkpoint["arch"])
pannuke_model.load_state_dict(pannuke_checkpoint["model_state_dict"])
# # put model in eval mode
pannuke_model.to(device)
pannuke_model.eval()


## monuseg set
monuseg_parser = InferenceCellViTMoNuSegParser()
monuseg_configurations = monuseg_parser.parse_arguments()
monuseg_inf = MoNuSegInference(
        model_path=monuseg_configurations["model"],
        dataset_path=monuseg_configurations["dataset"],
        outdir=monuseg_configurations["outdir"],
        gpu=monuseg_configurations["gpu"],
        patching=monuseg_configurations["patching"],
        magnification=monuseg_configurations["magnification"],
        overlap=monuseg_configurations["overlap"],
    )


def click_process(image_input , type_dataset):
    if type_dataset == "pannuke":
        if image_input.shape[0] > 512 and image_input.shape[1] > 512:
            image_input = cv2.resize(image_input, (512,512))
        pannuke_inf.run_single_image_inference(pannuke_model,image_input)
    else:
        if image_input.shape[0] > 512 and image_input.shape[1] > 512:
            image_input = cv2.resize(image_input, (512,512))
        monuseg_inf.run_single_image_inference(monuseg_inf.model, image_input)

    image_output = cv2.imread("raw_pred.png")
    image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2RGB)     
    image_output2 = cv2.imread("pred_img.png")
    image_output2 = cv2.cvtColor(image_output2, cv2.COLOR_BGR2RGB)   
    return image_output,image_output2


demo = gr.Blocks(title="LkCell")
with demo:
    gr.Markdown(value="""
                    **Gradio demo for LKCell: Efficient Cell Nuclei Instance Segmentation with Large Convolution Kernels**. Check our [Github Repo](https://github.com/hustvl/LKCell) πŸ˜›.
                    """)
    with gr.Row():
        with gr.Column():
            with gr.Row():
                Image_input = gr.Image(type="numpy", label="Input", interactive=True,height=480)
            with gr.Row():
                Type_dataset = gr.Radio(choices=["pannuke", "monuseg"], label=" input image's dataset type",value="pannuke")
                                            
        with gr.Column():
            image_output = gr.Image(type="numpy", label="image prediction",height=480,width=480)
            image_output2 = gr.Image(type="numpy", label="all predictions",height=480)
            
    with gr.Row():
        Button_run = gr.Button("πŸš€ Submit (发送) ")
        clear_button = gr.ClearButton(components=[Image_input,Type_dataset,image_output,image_output2],value="🧹 Clear (清陀)")

    Button_run.click(fn=click_process, inputs=[Image_input,  Type_dataset ], outputs=[image_output,image_output2])
    
    ## guiline
    gr.Markdown(value="""    
                    πŸ””**Guideline**
                    1. Upload your image or select one from the examples.
                    2. Set up the arguments: "Type_dataset" to enjoy two dataset type's inference
                    3. Due to the limit of CPU , we resize the input image whose size is larger than (512,512) to (512,512)
                    4. Run the Submit button to get the output.
                    """)
    # if RUN_MODE != "local":
    gr.Examples(examples=[
                ['1.png', "pannuke"],
                ['2.png', "pannuke"],
                ['3.png', "monuseg"],
                ['4.png', "monuseg"],
                ], 
                inputs=[Image_input, Type_dataset], outputs=[image_output,image_output2], label="Examples")
    gr.HTML(value="""
                <p style="text-align:center; color:orange"> <a href='https://github.com/hustvl/LKCell' target='_blank'>Github Repo</a></p>
                    """)
    gr.Markdown(value="""    
                    Template is adapted from [Here](https://huggingface.co/spaces/menghanxia/disco)
                    """)

if RUN_MODE == "local":
    demo.launch(server_name='127.0.0.1',server_port=8003)
else:
    demo.launch()