Spaces:
Running
Running
import os | |
import torch | |
import gradio as gr | |
from PIL import Image | |
from torchvision.transforms import transforms | |
from modelscope import snapshot_download | |
MODEL_DIR = snapshot_download("MuGemSt/HEp2", cache_dir="./__pycache__") | |
TRANSLATE = { | |
"Centromere": "着丝粒 Centromere", | |
"Golgi": "高尔基体 Golgi", | |
"Homogeneous": "同质 Homogeneous", | |
"NuMem": "记忆体 NuMem", | |
"Nucleolar": "核仁 Nucleolar", | |
"Speckled": "斑核 Speckled", | |
} | |
CLASSES = list(TRANSLATE.keys()) | |
def embeding(img_path: str): | |
compose = transforms.Compose( | |
[ | |
transforms.Resize(224), | |
transforms.CenterCrop(224), | |
transforms.RandomAffine(5), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
] | |
) | |
img = Image.open(img_path).convert("RGB") | |
return compose(img) | |
def infer(target: str): | |
model = torch.load(f"{MODEL_DIR}/save.pt", map_location=torch.device("cpu")) | |
if not target: | |
return None, "请上传细胞图片 Please upload a cell picture!" | |
torch.cuda.empty_cache() | |
input: torch.Tensor = embeding(target) | |
output: torch.Tensor = model(input.unsqueeze(0)) | |
predict = torch.max(output.data, 1)[1] | |
return os.path.basename(target), TRANSLATE[CLASSES[predict]] | |
if __name__ == "__main__": | |
example_imgs = [] | |
for cls in CLASSES: | |
example_imgs.append(f"{MODEL_DIR}/examples/{cls}.png") | |
with gr.Blocks() as demo: | |
gr.Interface( | |
fn=infer, | |
inputs=gr.Image( | |
type="filepath", label="上传细胞图像 Upload a cell picture" | |
), | |
outputs=[ | |
gr.Textbox(label="图片名 Picture name", show_copy_button=True), | |
gr.Textbox(label="识别结果 Recognition result", show_copy_button=True), | |
], | |
title="请上传 PNG 格式的 HEp2 细胞图片<br>It is recommended to upload HEp2 cell images in PNG format.", | |
examples=example_imgs, | |
allow_flagging="never", | |
cache_examples=False, | |
) | |
demo.launch() | |