File size: 2,682 Bytes
be05fd1
 
 
 
 
 
 
3b16162
 
be05fd1
 
3b16162
be05fd1
 
3b16162
 
be05fd1
 
 
 
c177f41
be05fd1
 
 
c177f41
be05fd1
9ed6175
 
 
3b16162
 
be05fd1
3b16162
86d1958
3b16162
 
 
 
 
86d1958
be05fd1
 
 
865f10a
be05fd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c177f41
be05fd1
 
 
 
 
 
 
 
 
3b16162
 
 
7bd876a
9ed6175
 
86d1958
9ed6175
1e947cf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms as transforms


def img2label(left, right):
    left = Image.fromarray(left.astype('uint8'), 'RGB')
    right = Image.fromarray(right.astype('uint8'), 'RGB')
    # 将右眼底镜像反转
    r2l = transforms.RandomHorizontalFlip(p=1)
    right = r2l(right)

    # 调整图片
    left_img = my_transforms(left).to(device)
    right_img = my_transforms(right).to(device)

    # 读取模型
    model = torch.load('densenet_FD_e4_l5e-4_b32.pkl', map_location='cpu').to(device)

    # 模型推理
    with torch.no_grad():
        output = model(left=left_img.unsqueeze(0), right=right_img.unsqueeze(0))

    # 结果处理
    output = torch.sigmoid(output.squeeze(0))
    output_ = output.cpu().numpy().tolist()
    res_dict = {LABELS[i]: output_[i] for i in range(len(output_))}

    pred = torch.nonzero(output > 0.4).view(-1)
    pred = pred.cpu().numpy().tolist()

    if len(pred) == 0 or (len(pred) == 1 and pred[0] == 0):
        return res_dict, LABELS[0]
    res = ''
    for i in pred:
        if i == 0:
            continue
        res += ', ' + LABELS[i]
    return res_dict, '目前的身体状态:' + res[2:]


if __name__ == '__main__':
    device = torch.device("cpu")

    # 标题
    title = "基于眼底图像的智能健康诊断分析系统"
    # 标题下的描述,支持md格式
    description = "上传并输入左右眼底图像后,点击 submit 按钮,可根据双目眼底图像智能分析出可能有的疾病!" \
                  "包含的疾病种类有:糖尿病、青光眼、白内障、年龄性黄斑变性、高血压、病理性近视、其他疾病以及正常共计8类"

    # transforms设置
    norm_mean = [0.485, 0.456, 0.406]
    norm_std = [0.229, 0.224, 0.225]
    my_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(norm_mean, norm_std)
    ])

    # 标签设置
    LABELS = {0: '正常',
              1: '糖尿病',
              2: '青光眼',
              3: '白内障',
              4: '年龄性黄斑变性',
              5: '高血压',
              6: '病理性近视',
              7: '其他疾病'}

    left_img_dir = 'left.jpg'
    right_img_dir = 'right.jpg'
    examples = [[left_img_dir, right_img_dir]]
    # r = img2label(left_img_dir, right_img_dir)
    demo = gr.Interface(fn=img2label,
                        inputs=[gr.inputs.Image(), gr.inputs.Image()],
                        outputs=["label", "text"],
                        examples=examples, title=title, description=description)
    demo.launch()