File size: 2,196 Bytes
be05fd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bd876a
 
 
be05fd1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms as transforms


def img2label(left, right):
    left_img = Image.open(left).convert('RGB')
    right_img = Image.open(right).convert('RGB')
    # 将右眼底镜像反转
    r2l = transforms.RandomHorizontalFlip(p=1)
    right_img = r2l(right_img)

    # 调整图片
    left_img = my_transforms(left_img).to(device)
    right_img = my_transforms(right_img).to(device)

    # 读取模型
    model = torch.load('densenet_FD_e4_l5e-4_b32.pkl', map_location='cpu').to(device)

    with torch.no_grad():
        output = model(left=left_img.unsqueeze(0), right=right_img.unsqueeze(0))

    output = torch.sigmoid(output.squeeze(0))
    pred = output.cpu().numpy().tolist()

    return {LABELS[i]: pred[i] for i in range(len(pred))}


if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 标题
    title = "基于眼底图像的智能健康诊断分析系统"
    # 标题下的描述,支持md格式
    description = "上传并输入左右眼底图像后,点击 submit 按钮,可根据双目眼底图像智能分析出可能有的疾病!" \
                  "包含的疾病种类有:糖尿病、青光眼、白内障、年龄性黄斑变性、高血压、病理性近视、其他疾病以及正常共计8类"

    # transforms设置
    norm_mean = [0.485, 0.456, 0.406]
    norm_std = [0.229, 0.224, 0.225]
    my_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(norm_mean, norm_std)
    ])

    LABELS = {0: '正常',
              1: '糖尿病',
              2: '青光眼',
              3: '白内障',
              4: '年龄性黄斑变性',
              5: '高血压',
              6: '病理性近视',
              7: '其他疾病'}

    # left_img_dir = 'left.jpg'
    # right_img_dir = 'right.jpg'
    # r = img2label(left_img_dir, right_img_dir)
    demo = gr.Interface(fn=img2label, inputs=[gr.inputs.Image(), gr.inputs.Image()], outputs='label',
                        title=title, description=description)
    demo.launch(share=True)