import gradio as gr
import torch
from PIL import Image
import torchvision.transforms as transforms


def img2label(left, right):
    left = Image.fromarray(left.astype('uint8'), 'RGB')
    right = Image.fromarray(right.astype('uint8'), 'RGB')
    # 将右眼底镜像反转
    r2l = transforms.RandomHorizontalFlip(p=1)
    right = r2l(right)

    # 调整图片
    left_img = my_transforms(left).to(device)
    right_img = my_transforms(right).to(device)

    # 读取模型
    model = torch.load('densenet_FD_e4_l5e-4_b32.pkl', map_location='cpu').to(device)

    with torch.no_grad():
        output = model(left=left_img.unsqueeze(0), right=right_img.unsqueeze(0))

    output = torch.sigmoid(output.squeeze(0))
    # pred = output.cpu().numpy().tolist()
    # return {LABELS[i]: pred[i] for i in range(len(pred))}
    pred = torch.nonzero(output > 0.4).view(-1)
    pred = pred.cpu().numpy().tolist()

    if len(pred) == 0 or (len(pred) == 1 and pred[0] == 0):
        return LABELS[0]
    res = ''
    for i in pred:
        if i == 0:
            continue
        res += ', ' + LABELS[i]
    return '目前的身体状态:' + res[2:]


if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 标题
    title = "基于眼底图像的智能健康诊断分析系统"
    # 标题下的描述,支持md格式
    description = "上传并输入左右眼底图像后,点击 submit 按钮,可根据双目眼底图像智能分析出可能有的疾病!" \
                  "包含的疾病种类有:糖尿病、青光眼、白内障、年龄性黄斑变性、高血压、病理性近视、其他疾病以及正常共计8类"

    # transforms设置
    norm_mean = [0.485, 0.456, 0.406]
    norm_std = [0.229, 0.224, 0.225]
    my_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(norm_mean, norm_std)
    ])

    LABELS = {0: '正常',
              1: '糖尿病',
              2: '青光眼',
              3: '白内障',
              4: '年龄性黄斑变性',
              5: '高血压',
              6: '病理性近视',
              7: '其他疾病'}

    left_img_dir = 'left.jpg'
    right_img_dir = 'right.jpg'
    examples = [[left_img_dir, right_img_dir]]
    # r = img2label(left_img_dir, right_img_dir)
    demo = gr.Interface(fn=img2label, inputs=[gr.inputs.Image(), gr.inputs.Image()],
                        outputs="text", examples=examples,
                        title=title, description=description)
    demo.launch(share=True)