import os import wfdb import shutil import numpy as np import gradio as gr from models.inception import * from scipy.signal import resample def load_data(sample_data): ecg, meta_data = wfdb.rdsamp(sample_data) lead_I = ecg[:,0] sample_frequency = meta_data["fs"] return lead_I, sample_frequency def preprocess_ecg(ecg,fs): if fs != 100: ecg = resample(ecg, int(len(ecg)*(100/fs))) else: pass if len(ecg) > 1000: ecg = ecg[:1000] else: pass return ecg def load_age_model(sample_frequency,recording_time, num_leads): cwd = os.getcwd() weights = f"{cwd}/models/weights/model_weights_leadI_age.h5" model = build_age_model((sample_frequency * recording_time, num_leads), 1) model.load_weights(weights) return model def load_gender_model(sample_frequency,recording_time, num_leads): cwd = os.getcwd() weights = f"{cwd}/models/weights/model_weights_leadI_gender.h5" model = build_gender_model((sample_frequency * recording_time, num_leads), 1) model.load_weights(weights) return model def run(header_file, data_file): SAMPLE_FREQUENCY = 100 TIME = 10 NUM_LEADS = 1 demo_dir = f"{CWD}/sample_data" _, hdr_basename = os.path.split(header_file.name) _, data_basename = os.path.split(data_file.name) shutil.copyfile(data_file.name, f"{demo_dir}/{data_basename}") shutil.copyfile(header_file.name, f"{demo_dir}/{hdr_basename}") data, fs = load_data(f"{demo_dir}/{hdr_basename.split('.')[0]}") ecg = preprocess_ecg(data,fs) age_model = load_age_model(sample_frequency=SAMPLE_FREQUENCY,recording_time=TIME,num_leads=NUM_LEADS) gender_model = load_gender_model(sample_frequency=SAMPLE_FREQUENCY,recording_time=TIME,num_leads=NUM_LEADS) age_estimate = age_model.predict(np.expand_dims(ecg,0)).ravel()[0] gender_prediction = gender_model.predict(np.expand_dims(ecg,0)).ravel()[0] return str(round(age_estimate,1)), {"Male": 1- gender_prediction, "Female": gender_prediction} # Give credit to https://huggingface.co/spaces/Tej3/ECG_Classification/blob/main/app.py for interface CWD = os.getcwd() with gr.Blocks() as demo: #with gr.Row(): # pred_type = gr.Radio(['Age', 'Gender'], label= "Select Model") with gr.Row(): with gr.Column(scale=1): header_file = gr.File(label = "header_file", file_types=[".hea"],) data_file = gr.File(label = "data_file", file_types=[".dat"]) with gr.Column(scale=1): output_age = gr.Textbox(label = "Estimated age") output_gender = gr.Label( label = "Predicted gender") #with gr.Row(): # ecg_graph = gr.Plot(label = "ECG Signal Visualisation") with gr.Row(): predict_btn = gr.Button("Predict") predict_btn.click(fn= run, inputs = [#pred_type, header_file, data_file], outputs=[output_age,output_gender]) with gr.Row(): gr.Examples(examples=[[f"{CWD}/sample_data/ath_001.hea", f"{CWD}/sample_data/ath_001.dat"],\ # [f"{CWD}/demo_data/test/00008_lr.hea", f"{CWD}/demo_data/test/00008_lr.dat", "sinusrhythmus linkstyp qrs(t) abnormal inferiorer infarkt alter unbest."], \ # [f"{CWD}/demo_data/test/00045_lr.hea", f"{CWD}/demo_data/test/00045_lr.dat", "sinusrhythmus unvollstÄndiger rechtsschenkelblock sonst normales ekg"],\ # [f"{CWD}/demo_data/test/00257_lr.hea", f"{CWD}/demo_data/test/00257_lr.dat", "premature atrial contraction(s). sinus rhythm. left atrial enlargement. qs complexes in v2. st segments are slightly elevated in v2,3. st segments are depressed in i, avl. t waves are low or flat in i, v5,6 and inverted in avl. consistent with ischaemic h"],\ ], inputs = [header_file, data_file]) # if __name__ == "__main__": demo.launch(share=True) #iface = gr.Interface(fn=run, inputs="text", outputs="text") #iface.launch()