bjornsing commited on
Commit
a53ae2a
·
1 Parent(s): 118a1b1

Adding gender prediction

Browse files
app.py CHANGED
@@ -24,14 +24,21 @@ def preprocess_ecg(ecg,fs):
24
  pass
25
  return ecg
26
 
27
- def load_model(sample_frequency,recording_time, num_leads):
28
  cwd = os.getcwd()
29
- weights = f"{cwd}/models/weights/model_weights_leadI.h5"
30
- model = build_model((sample_frequency * recording_time, num_leads), 1)
31
  model.load_weights(weights)
32
  return model
33
 
34
 
 
 
 
 
 
 
 
35
  def run(header_file, data_file):
36
  SAMPLE_FREQUENCY = 100
37
  TIME = 10
@@ -43,9 +50,11 @@ def run(header_file, data_file):
43
  shutil.copyfile(header_file.name, f"{demo_dir}/{hdr_basename}")
44
  data, fs = load_data(f"{demo_dir}/{hdr_basename.split('.')[0]}")
45
  ecg = preprocess_ecg(data,fs)
46
- model = load_model(sample_frequency=SAMPLE_FREQUENCY,recording_time=TIME,num_leads=NUM_LEADS)
47
- predicion = model.predict(np.expand_dims(ecg,0)).ravel()[0]
48
- return str(round(predicion,1))
 
 
49
 
50
  # Give credit to https://huggingface.co/spaces/Tej3/ECG_Classification/blob/main/app.py for interface
51
 
@@ -59,14 +68,14 @@ with gr.Blocks() as demo:
59
  header_file = gr.File(label = "header_file", file_types=[".hea"],)
60
  data_file = gr.File(label = "data_file", file_types=[".dat"])
61
  with gr.Column(scale=1):
62
- output_age = gr.Textbox(label = "Predicted age")
63
- #output_gender = gr.Textbox(label = "Predicted gender")
64
  #with gr.Row():
65
  # ecg_graph = gr.Plot(label = "ECG Signal Visualisation")
66
  with gr.Row():
67
  predict_btn = gr.Button("Predict")
68
  predict_btn.click(fn= run, inputs = [#pred_type,
69
- header_file, data_file], outputs=[output_age])
70
  with gr.Row():
71
  gr.Examples(examples=[[f"{CWD}/sample_data/ath_001.hea", f"{CWD}/sample_data/ath_001.dat"],\
72
  # [f"{CWD}/demo_data/test/00008_lr.hea", f"{CWD}/demo_data/test/00008_lr.dat", "sinusrhythmus linkstyp qrs(t) abnormal inferiorer infarkt alter unbest."], \
 
24
  pass
25
  return ecg
26
 
27
+ def load_age_model(sample_frequency,recording_time, num_leads):
28
  cwd = os.getcwd()
29
+ weights = f"{cwd}/models/weights/model_weights_leadI_age.h5"
30
+ model = build_age_model((sample_frequency * recording_time, num_leads), 1)
31
  model.load_weights(weights)
32
  return model
33
 
34
 
35
+ def load_gender_model(sample_frequency,recording_time, num_leads):
36
+ cwd = os.getcwd()
37
+ weights = f"{cwd}/models/weights/model_weights_leadI_gender.h5"
38
+ model = build_gender_model((sample_frequency * recording_time, num_leads), 1)
39
+ model.load_weights(weights)
40
+ return model
41
+
42
  def run(header_file, data_file):
43
  SAMPLE_FREQUENCY = 100
44
  TIME = 10
 
50
  shutil.copyfile(header_file.name, f"{demo_dir}/{hdr_basename}")
51
  data, fs = load_data(f"{demo_dir}/{hdr_basename.split('.')[0]}")
52
  ecg = preprocess_ecg(data,fs)
53
+ age_model = load_age_model(sample_frequency=SAMPLE_FREQUENCY,recording_time=TIME,num_leads=NUM_LEADS)
54
+ gender_model = load_gender_model(sample_frequency=SAMPLE_FREQUENCY,recording_time=TIME,num_leads=NUM_LEADS)
55
+ age_estimate = age_model.predict(np.expand_dims(ecg,0)).ravel()[0]
56
+ gender_prediction = gender_model.predict(np.expand_dims(ecg,0)).ravel()[0]
57
+ return str(round(age_estimate,1)), {"Male": 1- gender_prediction, "Female": gender_prediction}
58
 
59
  # Give credit to https://huggingface.co/spaces/Tej3/ECG_Classification/blob/main/app.py for interface
60
 
 
68
  header_file = gr.File(label = "header_file", file_types=[".hea"],)
69
  data_file = gr.File(label = "data_file", file_types=[".dat"])
70
  with gr.Column(scale=1):
71
+ output_age = gr.Textbox(label = "Estimated age")
72
+ output_gender = gr.Label( label = "Predicted gender")
73
  #with gr.Row():
74
  # ecg_graph = gr.Plot(label = "ECG Signal Visualisation")
75
  with gr.Row():
76
  predict_btn = gr.Button("Predict")
77
  predict_btn.click(fn= run, inputs = [#pred_type,
78
+ header_file, data_file], outputs=[output_age,output_gender])
79
  with gr.Row():
80
  gr.Examples(examples=[[f"{CWD}/sample_data/ath_001.hea", f"{CWD}/sample_data/ath_001.dat"],\
81
  # [f"{CWD}/demo_data/test/00008_lr.hea", f"{CWD}/demo_data/test/00008_lr.dat", "sinusrhythmus linkstyp qrs(t) abnormal inferiorer infarkt alter unbest."], \
models/__pycache__/inception.cpython-39.pyc CHANGED
Binary files a/models/__pycache__/inception.cpython-39.pyc and b/models/__pycache__/inception.cpython-39.pyc differ
 
models/inception.py CHANGED
@@ -70,7 +70,7 @@ def _shortcut_layer(input_tensor, out_tensor):
70
  return x
71
 
72
 
73
- def build_model(
74
  input_shape: Tuple[int, int],
75
  nb_classes: int,
76
  depth: int = 6,
@@ -105,4 +105,43 @@ def build_model(
105
  metrics=[tf.keras.metrics.MeanSquaredError()],
106
  )
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  return model
 
70
  return x
71
 
72
 
73
+ def build_age_model(
74
  input_shape: Tuple[int, int],
75
  nb_classes: int,
76
  depth: int = 6,
 
105
  metrics=[tf.keras.metrics.MeanSquaredError()],
106
  )
107
 
108
+ return model
109
+
110
+
111
+
112
+ def build_gender_model(
113
+ input_shape: Tuple[int, int],
114
+ nb_classes: int,
115
+ depth: int = 6,
116
+ use_residual: bool = True,
117
+ )-> tf.keras.models.Model:
118
+ """
119
+ Model proposed by HI Fawas et al 2019 "Finding AlexNet for Time Series Classification - InceptionTime"
120
+ """
121
+ input_layer = tf.keras.layers.Input(input_shape)
122
+
123
+ x = input_layer
124
+ input_res = input_layer
125
+
126
+ for d in range(depth):
127
+
128
+ x = _inception_module(x)
129
+
130
+ if use_residual and d % 3 == 2:
131
+ x = _shortcut_layer(input_res, x)
132
+ input_res = x
133
+
134
+ gap_layer = tf.keras.layers.GlobalAveragePooling1D()(x)
135
+
136
+ output_layer = tf.keras.layers.Dense(units=nb_classes, activation="sigmoid")(
137
+ gap_layer
138
+ )
139
+
140
+ model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer)
141
+ model.compile(
142
+ loss=tf.keras.losses.BinaryCrossentropy(),
143
+ optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
144
+ metrics=[tf.keras.metrics.AUC(curve='ROC',name="AUROC")],
145
+ )
146
+
147
  return model
models/weights/{model_weights_leadI.h5 → model_weights_leadI_age.h5} RENAMED
File without changes
models/weights/model_weights_leadI_gender.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac1616b91eb8f740aaae4c44c08948b3b3d1469e628eb3c145b7e9670264f023
3
+ size 1833768
sample_data/ath_008.dat ADDED
Binary file (120 kB). View file
 
sample_data/ath_008.hea ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ath_008 12 500 5000
2
+ ath_008.dat 16 50000/mV 16 0 -11527 54167 0 I
3
+ ath_008.dat 16 50000/mV 16 0 -18070 20408 0 II
4
+ ath_008.dat 16 50000/mV 16 0 -15660 38252 0 III
5
+ ath_008.dat 16 50000/mV 16 0 16820 59491 0 AVR
6
+ ath_008.dat 16 50000/mV 16 0 4368 62574 0 AVL
7
+ ath_008.dat 16 50000/mV 16 0 -17613 24434 0 AVF
8
+ ath_008.dat 16 50000/mV 16 0 11148 20294 0 V1
9
+ ath_008.dat 16 50000/mV 16 0 10557 25409 0 V2
10
+ ath_008.dat 16 50000/mV 16 0 8134 4135 0 V3
11
+ ath_008.dat 16 50000/mV 16 0 1343 20358 0 V4
12
+ ath_008.dat 16 50000/mV 16 0 -12126 42898 0 V5
13
+ ath_008.dat 16 50000/mV 16 0 -22817 49296 0 V6
14
+ #SL12: Normal sinus rhythm, RSR' or QR pattern in V1 suggests right ventricular conduction delay, Borderline ECG
15
+ #C: Normal sinus rhythm, Incomplete right bundle branch block, Normal ECG
sample_data/ath_013.dat ADDED
Binary file (120 kB). View file
 
sample_data/ath_013.hea ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ath_013 12 500 5000
2
+ ath_013.dat 16 50000/mV 16 0 5460 51465 0 I
3
+ ath_013.dat 16 50000/mV 16 0 -18724 34405 0 II
4
+ ath_013.dat 16 50000/mV 16 0 -28034 29205 0 III
5
+ ath_013.dat 16 50000/mV 16 0 9138 13588 0 AVR
6
+ ath_013.dat 16 50000/mV 16 0 19917 20645 0 AVL
7
+ ath_013.dat 16 50000/mV 16 0 -24408 31082 0 AVF
8
+ ath_013.dat 16 50000/mV 16 0 21564 56921 0 V1
9
+ ath_013.dat 16 50000/mV 16 0 20969 29226 0 V2
10
+ ath_013.dat 16 50000/mV 16 0 18256 12273 0 V3
11
+ ath_013.dat 16 50000/mV 16 0 -456 37253 0 V4
12
+ ath_013.dat 16 50000/mV 16 0 -7056 1140 0 V5
13
+ ath_013.dat 16 50000/mV 16 0 -14084 59880 0 V6
14
+ #SL12: Marked sinus bradycardia, Right axis deviation, Abnormal ECG
15
+ #C: Sinus bradycardia, Normal ECG