File size: 4,699 Bytes
4a5ddda 02ad25e 4a5ddda 02ad25e 3c0932f 4a5ddda 3c0932f 02ad25e 3c0932f 02ad25e 4a5ddda a53ae2a 4a5ddda a53ae2a 4a5ddda a53ae2a 02ad25e 4a5ddda 3c0932f 02ad25e 3c0932f a53ae2a 02ad25e 3c0932f 02ad25e a53ae2a 02ad25e a53ae2a 02ad25e 118a1b1 4a5ddda 02ad25e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import os
import wfdb
import shutil
import numpy as np
import gradio as gr
from models.inception import *
from scipy.signal import resample
import pandas as pd
def load_apple_data():
return None
# load wfdb data
def load_data(sample_data):
ecg, meta_data = wfdb.rdsamp(sample_data)
lead_I = ecg[:,0]
sample_frequency = meta_data["fs"]
return lead_I, sample_frequency
def preprocess_ecg(ecg,fs):
if fs != 100:
ecg = resample(ecg, int(len(ecg)*(100/fs)))
else:
pass
if len(ecg) > 1000:
ecg = ecg[:1000]
else:
pass
return ecg
def load_age_model(sample_frequency,recording_time, num_leads):
cwd = os.getcwd()
weights = f"{cwd}/models/weights/model_weights_leadI_age.h5"
model = build_age_model((sample_frequency * recording_time, num_leads), 1)
model.load_weights(weights)
return model
def load_gender_model(sample_frequency,recording_time, num_leads):
cwd = os.getcwd()
weights = f"{cwd}/models/weights/model_weights_leadI_gender.h5"
model = build_gender_model((sample_frequency * recording_time, num_leads), 1)
model.load_weights(weights)
return model
def run(header_file, data_file):
SAMPLE_FREQUENCY = 100
TIME = 10
NUM_LEADS = 1
NEW_SAMPLE_FREQUENCY = 100
demo_dir = f"{CWD}/sample_data"
if data_file.name.endswith(".csv"):
_, data_basename = os.path.split(data_file.name)
shutil.copyfile(data_file.name, f"{demo_dir}/{data_basename}")
df= pd.read_csv(f"{demo_dir}/{data_basename}", skiprows=12, sep=";", header=None, decimal=',')
ecg = np.asarray(df[0].str.replace(',', '.').str.replace('−', '-').astype(float))
ecg = resample(ecg[(TIME*SAMPLE_FREQUENCY):(TIME*SAMPLE_FREQUENCY*2)],TIME * NEW_SAMPLE_FREQUENCY)
ecg = ecg/1000
else:
_, hdr_basename = os.path.split(header_file.name)
_, data_basename = os.path.split(data_file.name)
shutil.copyfile(data_file.name, f"{demo_dir}/{data_basename}")
shutil.copyfile(header_file.name, f"{demo_dir}/{hdr_basename}")
data, fs = load_data(f"{demo_dir}/{hdr_basename.split('.')[0]}")
ecg = preprocess_ecg(data,fs)
age_model = load_age_model(sample_frequency=SAMPLE_FREQUENCY,recording_time=TIME,num_leads=NUM_LEADS)
gender_model = load_gender_model(sample_frequency=SAMPLE_FREQUENCY,recording_time=TIME,num_leads=NUM_LEADS)
age_estimate = age_model.predict(np.expand_dims(ecg,0)).ravel()[0]
gender_prediction = gender_model.predict(np.expand_dims(ecg,0)).ravel()[0]
return str(round(age_estimate,1)), {"Male": 1- gender_prediction, "Female": gender_prediction}
# Give credit to https://huggingface.co/spaces/Tej3/ECG_Classification/blob/main/app.py for interface
CWD = os.getcwd()
with gr.Blocks() as demo:
#with gr.Row():
# pred_type = gr.Radio(['Age', 'Gender'], label= "Select Model")
with gr.Row():
with gr.Column(scale=1):
header_file = gr.File(label = "header_file", file_types=[".hea"],)
data_file = gr.File(label = "data_file", file_types=[".dat",".csv"])
with gr.Column(scale=1):
output_age = gr.Textbox(label = "Estimated age")
output_gender = gr.Label( label = "Predicted gender")
#with gr.Row():
# ecg_graph = gr.Plot(label = "ECG Signal Visualisation")
with gr.Row():
predict_btn = gr.Button("Predict")
predict_btn.click(fn= run, inputs = [#pred_type,
header_file, data_file], outputs=[output_age,output_gender])
with gr.Row():
gr.Examples(examples=[[f"{CWD}/sample_data/ath_001.hea", f"{CWD}/sample_data/ath_001.dat"],\
# [f"{CWD}/demo_data/test/00008_lr.hea", f"{CWD}/demo_data/test/00008_lr.dat", "sinusrhythmus linkstyp qrs(t) abnormal inferiorer infarkt alter unbest."], \
# [f"{CWD}/demo_data/test/00045_lr.hea", f"{CWD}/demo_data/test/00045_lr.dat", "sinusrhythmus unvollstÄndiger rechtsschenkelblock sonst normales ekg"],\
# [f"{CWD}/demo_data/test/00257_lr.hea", f"{CWD}/demo_data/test/00257_lr.dat", "premature atrial contraction(s). sinus rhythm. left atrial enlargement. qs complexes in v2. st segments are slightly elevated in v2,3. st segments are depressed in i, avl. t waves are low or flat in i, v5,6 and inverted in avl. consistent with ischaemic h"],\
],
inputs = [header_file, data_file])
#
if __name__ == "__main__":
demo.launch(share=True)
#iface = gr.Interface(fn=run, inputs="text", outputs="text")
#iface.launch() |