File size: 4,699 Bytes
4a5ddda
 
02ad25e
4a5ddda
 
 
02ad25e
3c0932f
4a5ddda
 
3c0932f
 
 
 
 
 
 
02ad25e
 
 
 
 
 
3c0932f
 
02ad25e
 
 
 
 
 
 
 
 
 
4a5ddda
a53ae2a
4a5ddda
a53ae2a
 
4a5ddda
 
 
 
a53ae2a
 
 
 
 
 
 
02ad25e
4a5ddda
 
 
3c0932f
02ad25e
3c0932f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a53ae2a
 
 
 
 
02ad25e
 
 
 
 
 
 
 
 
 
 
3c0932f
02ad25e
a53ae2a
 
02ad25e
 
 
 
 
a53ae2a
02ad25e
 
 
 
 
 
 
 
 
118a1b1
4a5ddda
02ad25e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import wfdb
import shutil
import numpy as np
import gradio as gr
from models.inception import *
from scipy.signal import resample
import pandas as pd



def load_apple_data():
    return None



# load wfdb data
def load_data(sample_data):
    ecg, meta_data = wfdb.rdsamp(sample_data)
    lead_I = ecg[:,0]
    sample_frequency = meta_data["fs"]
    return lead_I, sample_frequency



def preprocess_ecg(ecg,fs):
    if fs != 100:
        ecg = resample(ecg, int(len(ecg)*(100/fs)))
    else:
        pass
    if len(ecg) > 1000:
        ecg = ecg[:1000]
    else:
        pass
    return ecg

def load_age_model(sample_frequency,recording_time, num_leads):
    cwd = os.getcwd()
    weights = f"{cwd}/models/weights/model_weights_leadI_age.h5"
    model = build_age_model((sample_frequency * recording_time, num_leads), 1)
    model.load_weights(weights)
    return model


def load_gender_model(sample_frequency,recording_time, num_leads):
    cwd = os.getcwd()
    weights = f"{cwd}/models/weights/model_weights_leadI_gender.h5"
    model = build_gender_model((sample_frequency * recording_time, num_leads), 1)
    model.load_weights(weights)
    return model

def run(header_file, data_file):
    SAMPLE_FREQUENCY = 100
    TIME = 10
    NUM_LEADS = 1
    NEW_SAMPLE_FREQUENCY = 100
    demo_dir = f"{CWD}/sample_data"

    if data_file.name.endswith(".csv"):
        _, data_basename = os.path.split(data_file.name)
        shutil.copyfile(data_file.name, f"{demo_dir}/{data_basename}")
        df= pd.read_csv(f"{demo_dir}/{data_basename}", skiprows=12, sep=";", header=None,  decimal=',')
        ecg = np.asarray(df[0].str.replace(',', '.').str.replace('−', '-').astype(float))
        ecg = resample(ecg[(TIME*SAMPLE_FREQUENCY):(TIME*SAMPLE_FREQUENCY*2)],TIME * NEW_SAMPLE_FREQUENCY)
        ecg = ecg/1000
    else:
        _, hdr_basename = os.path.split(header_file.name)
        _, data_basename = os.path.split(data_file.name)
        shutil.copyfile(data_file.name, f"{demo_dir}/{data_basename}")
        shutil.copyfile(header_file.name, f"{demo_dir}/{hdr_basename}")
        data, fs = load_data(f"{demo_dir}/{hdr_basename.split('.')[0]}")
        ecg = preprocess_ecg(data,fs)
    age_model = load_age_model(sample_frequency=SAMPLE_FREQUENCY,recording_time=TIME,num_leads=NUM_LEADS)
    gender_model = load_gender_model(sample_frequency=SAMPLE_FREQUENCY,recording_time=TIME,num_leads=NUM_LEADS)
    age_estimate = age_model.predict(np.expand_dims(ecg,0)).ravel()[0]
    gender_prediction = gender_model.predict(np.expand_dims(ecg,0)).ravel()[0]
    return str(round(age_estimate,1)), {"Male": 1- gender_prediction, "Female": gender_prediction}

# Give credit to https://huggingface.co/spaces/Tej3/ECG_Classification/blob/main/app.py for interface

CWD = os.getcwd()

with gr.Blocks() as demo:
    #with  gr.Row():
    #    pred_type = gr.Radio(['Age', 'Gender'], label= "Select Model")
    with gr.Row():
        with gr.Column(scale=1):
            header_file = gr.File(label = "header_file", file_types=[".hea"],)
            data_file = gr.File(label = "data_file", file_types=[".dat",".csv"])
        with gr.Column(scale=1):
            output_age = gr.Textbox(label = "Estimated age")
            output_gender = gr.Label( label = "Predicted gender")
    #with gr.Row():
    #    ecg_graph = gr.Plot(label = "ECG Signal Visualisation")
    with gr.Row():    
        predict_btn = gr.Button("Predict")
        predict_btn.click(fn= run, inputs = [#pred_type, 
                                            header_file, data_file], outputs=[output_age,output_gender])
    with gr.Row():    
        gr.Examples(examples=[[f"{CWD}/sample_data/ath_001.hea", f"{CWD}/sample_data/ath_001.dat"],\
    #                          [f"{CWD}/demo_data/test/00008_lr.hea", f"{CWD}/demo_data/test/00008_lr.dat", "sinusrhythmus linkstyp qrs(t) abnormal    inferiorer infarkt     alter unbest."], \
    #                          [f"{CWD}/demo_data/test/00045_lr.hea", f"{CWD}/demo_data/test/00045_lr.dat", "sinusrhythmus unvollstÄndiger rechtsschenkelblock sonst normales ekg"],\
    #                          [f"{CWD}/demo_data/test/00257_lr.hea", f"{CWD}/demo_data/test/00257_lr.dat", "premature atrial contraction(s). sinus rhythm. left atrial enlargement. qs complexes in v2. st segments are slightly elevated in v2,3. st segments are depressed in i, avl. t waves are low or flat in i, v5,6 and inverted in avl. consistent with ischaemic h"],\
                                ],
                    inputs = [header_file, data_file])
#
if __name__ == "__main__":
    demo.launch(share=True)

#iface = gr.Interface(fn=run, inputs="text", outputs="text")
#iface.launch()