bjornsing's picture
Adding gender prediction
a53ae2a
raw
history blame
4.05 kB
import os
import wfdb
import shutil
import numpy as np
import gradio as gr
from models.inception import *
from scipy.signal import resample
def load_data(sample_data):
ecg, meta_data = wfdb.rdsamp(sample_data)
lead_I = ecg[:,0]
sample_frequency = meta_data["fs"]
return lead_I, sample_frequency
def preprocess_ecg(ecg,fs):
if fs != 100:
ecg = resample(ecg, int(len(ecg)*(100/fs)))
else:
pass
if len(ecg) > 1000:
ecg = ecg[:1000]
else:
pass
return ecg
def load_age_model(sample_frequency,recording_time, num_leads):
cwd = os.getcwd()
weights = f"{cwd}/models/weights/model_weights_leadI_age.h5"
model = build_age_model((sample_frequency * recording_time, num_leads), 1)
model.load_weights(weights)
return model
def load_gender_model(sample_frequency,recording_time, num_leads):
cwd = os.getcwd()
weights = f"{cwd}/models/weights/model_weights_leadI_gender.h5"
model = build_gender_model((sample_frequency * recording_time, num_leads), 1)
model.load_weights(weights)
return model
def run(header_file, data_file):
SAMPLE_FREQUENCY = 100
TIME = 10
NUM_LEADS = 1
demo_dir = f"{CWD}/sample_data"
_, hdr_basename = os.path.split(header_file.name)
_, data_basename = os.path.split(data_file.name)
shutil.copyfile(data_file.name, f"{demo_dir}/{data_basename}")
shutil.copyfile(header_file.name, f"{demo_dir}/{hdr_basename}")
data, fs = load_data(f"{demo_dir}/{hdr_basename.split('.')[0]}")
ecg = preprocess_ecg(data,fs)
age_model = load_age_model(sample_frequency=SAMPLE_FREQUENCY,recording_time=TIME,num_leads=NUM_LEADS)
gender_model = load_gender_model(sample_frequency=SAMPLE_FREQUENCY,recording_time=TIME,num_leads=NUM_LEADS)
age_estimate = age_model.predict(np.expand_dims(ecg,0)).ravel()[0]
gender_prediction = gender_model.predict(np.expand_dims(ecg,0)).ravel()[0]
return str(round(age_estimate,1)), {"Male": 1- gender_prediction, "Female": gender_prediction}
# Give credit to https://huggingface.co/spaces/Tej3/ECG_Classification/blob/main/app.py for interface
CWD = os.getcwd()
with gr.Blocks() as demo:
#with gr.Row():
# pred_type = gr.Radio(['Age', 'Gender'], label= "Select Model")
with gr.Row():
with gr.Column(scale=1):
header_file = gr.File(label = "header_file", file_types=[".hea"],)
data_file = gr.File(label = "data_file", file_types=[".dat"])
with gr.Column(scale=1):
output_age = gr.Textbox(label = "Estimated age")
output_gender = gr.Label( label = "Predicted gender")
#with gr.Row():
# ecg_graph = gr.Plot(label = "ECG Signal Visualisation")
with gr.Row():
predict_btn = gr.Button("Predict")
predict_btn.click(fn= run, inputs = [#pred_type,
header_file, data_file], outputs=[output_age,output_gender])
with gr.Row():
gr.Examples(examples=[[f"{CWD}/sample_data/ath_001.hea", f"{CWD}/sample_data/ath_001.dat"],\
# [f"{CWD}/demo_data/test/00008_lr.hea", f"{CWD}/demo_data/test/00008_lr.dat", "sinusrhythmus linkstyp qrs(t) abnormal inferiorer infarkt alter unbest."], \
# [f"{CWD}/demo_data/test/00045_lr.hea", f"{CWD}/demo_data/test/00045_lr.dat", "sinusrhythmus unvollstÄndiger rechtsschenkelblock sonst normales ekg"],\
# [f"{CWD}/demo_data/test/00257_lr.hea", f"{CWD}/demo_data/test/00257_lr.dat", "premature atrial contraction(s). sinus rhythm. left atrial enlargement. qs complexes in v2. st segments are slightly elevated in v2,3. st segments are depressed in i, avl. t waves are low or flat in i, v5,6 and inverted in avl. consistent with ischaemic h"],\
],
inputs = [header_file, data_file])
#
if __name__ == "__main__":
demo.launch(share=True)
#iface = gr.Interface(fn=run, inputs="text", outputs="text")
#iface.launch()