File size: 3,457 Bytes
4a5ddda
 
02ad25e
4a5ddda
 
 
02ad25e
4a5ddda
 
02ad25e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a5ddda
 
 
 
 
 
 
 
 
02ad25e
4a5ddda
 
 
02ad25e
 
 
 
 
 
 
4a5ddda
02ad25e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a5ddda
02ad25e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import wfdb
import shutil
import numpy as np
import gradio as gr
from models.inception import *
from scipy.signal import resample


def load_data(sample_data):
    ecg, meta_data = wfdb.rdsamp(sample_data)
    lead_I = ecg[:,0]
    sample_frequency = meta_data["fs"]
    return lead_I, sample_frequency

def preprocess_ecg(ecg,fs):
    if fs != 100:
        ecg = resample(ecg, int(len(ecg)*(100/fs)))
    else:
        pass
    if len(ecg) > 1000:
        ecg = ecg[:1000]
    else:
        pass
    return ecg

def load_model(sample_frequency,recording_time, num_leads):
    cwd = os.getcwd()
    weights = f"{cwd}/models/weights/model_weights_leadI.h5"
    model = build_model((sample_frequency * recording_time, num_leads), 1)
    model.load_weights(weights)
    return model


def run(header_file, data_file):
    SAMPLE_FREQUENCY = 100
    TIME = 10
    NUM_LEADS = 1
    demo_dir = f"{CWD}/sample_data"
    _, hdr_basename = os.path.split(header_file.name)
    _, data_basename = os.path.split(data_file.name)
    shutil.copyfile(data_file.name, f"{demo_dir}/{data_basename}")
    shutil.copyfile(header_file.name, f"{demo_dir}/{hdr_basename}")
    data, fs = load_data(f"{demo_dir}/{hdr_basename.split('.')[0]}")
    ecg = preprocess_ecg(data,fs)
    model = load_model(sample_frequency=SAMPLE_FREQUENCY,recording_time=TIME,num_leads=NUM_LEADS)
    predicion = model.predict(np.expand_dims(ecg,0)).ravel()[0]
    return str(round(predicion,1))

# Give credit to https://huggingface.co/spaces/Tej3/ECG_Classification/blob/main/app.py for interface

CWD = os.getcwd()

with gr.Blocks() as demo:
    #with  gr.Row():
    #    pred_type = gr.Radio(['Age', 'Gender'], label= "Select Model")
    with gr.Row():
        with gr.Column(scale=1):
            header_file = gr.File(label = "header_file", file_types=[".hea"],)
            data_file = gr.File(label = "data_file", file_types=[".dat"])
        with gr.Column(scale=1):
            output_age = gr.Textbox(label = "Predicted age")
            #output_gender = gr.Textbox(label = "Predicted gender")
    #with gr.Row():
    #    ecg_graph = gr.Plot(label = "ECG Signal Visualisation")
    with gr.Row():    
        predict_btn = gr.Button("Predict")
        predict_btn.click(fn= run, inputs = [#pred_type, 
                                            header_file, data_file], outputs=[output_age])
    with gr.Row():    
        gr.Examples(examples=[[f"{CWD}/sample_data/ath_001.hea", f"{CWD}/sample_data/ath_001.dat"],\
    #                          [f"{CWD}/demo_data/test/00008_lr.hea", f"{CWD}/demo_data/test/00008_lr.dat", "sinusrhythmus linkstyp qrs(t) abnormal    inferiorer infarkt     alter unbest."], \
    #                          [f"{CWD}/demo_data/test/00045_lr.hea", f"{CWD}/demo_data/test/00045_lr.dat", "sinusrhythmus unvollstÄndiger rechtsschenkelblock sonst normales ekg"],\
    #                          [f"{CWD}/demo_data/test/00257_lr.hea", f"{CWD}/demo_data/test/00257_lr.dat", "premature atrial contraction(s). sinus rhythm. left atrial enlargement. qs complexes in v2. st segments are slightly elevated in v2,3. st segments are depressed in i, avl. t waves are low or flat in i, v5,6 and inverted in avl. consistent with ischaemic h"],\
                                ],
                    inputs = [header_file, data_file])
#
if __name__ == "__main__":
    demo.launch()

#iface = gr.Interface(fn=run, inputs="text", outputs="text")
#iface.launch()