File size: 2,123 Bytes
7d1df75
 
 
 
 
 
 
965b267
7d1df75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import torch
import gradio as gr
from PIL import Image
from torchvision.transforms import transforms
from modelscope import snapshot_download

MODEL_DIR = snapshot_download("MuGemSt/HEp2", cache_dir="./__pycache__")
TRANSLATE = {
    "Centromere": "着丝粒 Centromere",
    "Golgi": "高尔基体 Golgi",
    "Homogeneous": "同质 Homogeneous",
    "NuMem": "记忆体 NuMem",
    "Nucleolar": "核仁 Nucleolar",
    "Speckled": "斑核 Speckled",
}
CLASSES = list(TRANSLATE.keys())


def embeding(img_path: str):
    compose = transforms.Compose(
        [
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.RandomAffine(5),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )
    img = Image.open(img_path).convert("RGB")
    return compose(img)


def infer(target: str):
    model = torch.load(f"{MODEL_DIR}/save.pt", map_location=torch.device("cpu"))
    if not target:
        return None, "请上传细胞图片 Please upload a cell picture!"

    torch.cuda.empty_cache()
    input: torch.Tensor = embeding(target)
    output: torch.Tensor = model(input.unsqueeze(0))
    predict = torch.max(output.data, 1)[1]
    return os.path.basename(target), TRANSLATE[CLASSES[predict]]


if __name__ == "__main__":
    example_imgs = []
    for cls in CLASSES:
        example_imgs.append(f"{MODEL_DIR}/examples/{cls}.png")

    with gr.Blocks() as demo:
        gr.Interface(
            fn=infer,
            inputs=gr.Image(
                type="filepath", label="上传细胞图像 Upload a cell picture"
            ),
            outputs=[
                gr.Textbox(label="图片名 Picture name", show_copy_button=True),
                gr.Textbox(label="识别结果 Recognition result", show_copy_button=True),
            ],
            title="请上传 PNG 格式的 HEp2 细胞图片<br>It is recommended to upload HEp2 cell images in PNG format.",
            examples=example_imgs,
            allow_flagging="never",
            cache_examples=False,
        )

    demo.launch()