Spaces:
Runtime error
Runtime error
File size: 2,682 Bytes
be05fd1 3b16162 be05fd1 3b16162 be05fd1 3b16162 be05fd1 c177f41 be05fd1 c177f41 be05fd1 9ed6175 3b16162 be05fd1 3b16162 86d1958 3b16162 86d1958 be05fd1 865f10a be05fd1 c177f41 be05fd1 3b16162 7bd876a 9ed6175 86d1958 9ed6175 1e947cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms as transforms
def img2label(left, right):
left = Image.fromarray(left.astype('uint8'), 'RGB')
right = Image.fromarray(right.astype('uint8'), 'RGB')
# 将右眼底镜像反转
r2l = transforms.RandomHorizontalFlip(p=1)
right = r2l(right)
# 调整图片
left_img = my_transforms(left).to(device)
right_img = my_transforms(right).to(device)
# 读取模型
model = torch.load('densenet_FD_e4_l5e-4_b32.pkl', map_location='cpu').to(device)
# 模型推理
with torch.no_grad():
output = model(left=left_img.unsqueeze(0), right=right_img.unsqueeze(0))
# 结果处理
output = torch.sigmoid(output.squeeze(0))
output_ = output.cpu().numpy().tolist()
res_dict = {LABELS[i]: output_[i] for i in range(len(output_))}
pred = torch.nonzero(output > 0.4).view(-1)
pred = pred.cpu().numpy().tolist()
if len(pred) == 0 or (len(pred) == 1 and pred[0] == 0):
return res_dict, LABELS[0]
res = ''
for i in pred:
if i == 0:
continue
res += ', ' + LABELS[i]
return res_dict, '目前的身体状态:' + res[2:]
if __name__ == '__main__':
device = torch.device("cpu")
# 标题
title = "基于眼底图像的智能健康诊断分析系统"
# 标题下的描述,支持md格式
description = "上传并输入左右眼底图像后,点击 submit 按钮,可根据双目眼底图像智能分析出可能有的疾病!" \
"包含的疾病种类有:糖尿病、青光眼、白内障、年龄性黄斑变性、高血压、病理性近视、其他疾病以及正常共计8类"
# transforms设置
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
my_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std)
])
# 标签设置
LABELS = {0: '正常',
1: '糖尿病',
2: '青光眼',
3: '白内障',
4: '年龄性黄斑变性',
5: '高血压',
6: '病理性近视',
7: '其他疾病'}
left_img_dir = 'left.jpg'
right_img_dir = 'right.jpg'
examples = [[left_img_dir, right_img_dir]]
# r = img2label(left_img_dir, right_img_dir)
demo = gr.Interface(fn=img2label,
inputs=[gr.inputs.Image(), gr.inputs.Image()],
outputs=["label", "text"],
examples=examples, title=title, description=description)
demo.launch()
|