Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from PIL import Image | |
import torchvision.transforms as transforms | |
def img2label(left, right): | |
left = Image.fromarray(left.astype('uint8'), 'RGB') | |
right = Image.fromarray(right.astype('uint8'), 'RGB') | |
# 将右眼底镜像反转 | |
r2l = transforms.RandomHorizontalFlip(p=1) | |
right = r2l(right) | |
# 调整图片 | |
left_img = my_transforms(left).to(device) | |
right_img = my_transforms(right).to(device) | |
# 读取模型 | |
model = torch.load('densenet_FD_e4_l5e-4_b32.pkl', map_location='cpu').to(device) | |
# 模型推理 | |
with torch.no_grad(): | |
output = model(left=left_img.unsqueeze(0), right=right_img.unsqueeze(0)) | |
# 结果处理 | |
output = torch.sigmoid(output.squeeze(0)) | |
output_ = output.cpu().numpy().tolist() | |
res_dict = {LABELS[i]: output_[i] for i in range(len(output_))} | |
pred = torch.nonzero(output > 0.4).view(-1) | |
pred = pred.cpu().numpy().tolist() | |
if len(pred) == 0 or (len(pred) == 1 and pred[0] == 0): | |
return res_dict, LABELS[0] | |
res = '' | |
for i in pred: | |
if i == 0: | |
continue | |
res += ', ' + LABELS[i] | |
return res_dict, '目前的身体状态:' + res[2:] | |
if __name__ == '__main__': | |
device = torch.device("cpu") | |
# 标题 | |
title = "基于眼底图像的智能健康诊断分析系统" | |
# 标题下的描述,支持md格式 | |
description = "上传并输入左右眼底图像后,点击 submit 按钮,可根据双目眼底图像智能分析出可能有的疾病!" \ | |
"包含的疾病种类有:糖尿病、青光眼、白内障、年龄性黄斑变性、高血压、病理性近视、其他疾病以及正常共计8类" | |
# transforms设置 | |
norm_mean = [0.485, 0.456, 0.406] | |
norm_std = [0.229, 0.224, 0.225] | |
my_transforms = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(norm_mean, norm_std) | |
]) | |
# 标签设置 | |
LABELS = {0: '正常', | |
1: '糖尿病', | |
2: '青光眼', | |
3: '白内障', | |
4: '年龄性黄斑变性', | |
5: '高血压', | |
6: '病理性近视', | |
7: '其他疾病'} | |
left_img_dir = 'left.jpg' | |
right_img_dir = 'right.jpg' | |
examples = [[left_img_dir, right_img_dir]] | |
# r = img2label(left_img_dir, right_img_dir) | |
demo = gr.Interface(fn=img2label, | |
inputs=[gr.inputs.Image(), gr.inputs.Image()], | |
outputs=["label", "text"], | |
examples=examples, title=title, description=description) | |
demo.launch() | |