mohdelgaar commited on
Commit
59dd739
·
1 Parent(s): c47c7dc
Files changed (2) hide show
  1. app.py +239 -1
  2. demo.py +0 -241
app.py CHANGED
@@ -1,8 +1,12 @@
 
1
  import argparse
2
  import torch
 
3
  from data import load_tokenizer
4
  from model import load_model
5
- from demo import run_gradio
 
 
6
 
7
  parser = argparse.ArgumentParser()
8
  parser.add_argument('--data_dir', default='/data/mohamed/data')
@@ -72,6 +76,240 @@ elif args.task == 'token':
72
  elif args.label_encoding == 'boe':
73
  args.num_labels *= 3
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77
  tokenizer = load_tokenizer(args.model_name)
 
1
+ import re
2
  import argparse
3
  import torch
4
+ import gradio as gr
5
  from data import load_tokenizer
6
  from model import load_model
7
+ from datetime import datetime
8
+ from dateutil import parser
9
+ from demo_assets import *
10
 
11
  parser = argparse.ArgumentParser()
12
  parser.add_argument('--data_dir', default='/data/mohamed/data')
 
76
  elif args.label_encoding == 'boe':
77
  args.num_labels *= 3
78
 
79
+ categories = ['Contact related', 'Gathering additional information', 'Defining problem',
80
+ 'Treatment goal', 'Drug related', 'Therapeutic procedure related', 'Evaluating test result',
81
+ 'Deferment', 'Advice and precaution', 'Legal and insurance related']
82
+ unicode_symbols = [
83
+ "\U0001F91D", # Handshake
84
+ "\U0001F50D", # Magnifying glass
85
+ "\U0001F9E9", # Puzzle piece
86
+ "\U0001F3AF", # Target
87
+ "\U0001F48A", # Pill
88
+ "\U00002702", # Surgical scissors
89
+ "\U0001F9EA", # Test tube
90
+ "\U000023F0", # Alarm clock
91
+ "\U000026A0", # Warning sign
92
+ "\U0001F4C4" # Document
93
+ ]
94
+
95
+ OTHERS_ID = 18
96
+ def postprocess_labels(text, logits, t2c):
97
+ tags = [None for _ in text]
98
+ labels = logits.argmax(-1)
99
+ for i,cat in enumerate(labels):
100
+ if cat != OTHERS_ID:
101
+ char_ids = t2c(i)
102
+ if char_ids is None:
103
+ continue
104
+ for idx in range(char_ids.start, char_ids.end):
105
+ if tags[idx] is None and idx < len(text):
106
+ tags[idx] = categories[cat // 2]
107
+ for i in range(len(text)-1):
108
+ if text[i] == ' ' and (text[i+1] == ' ' or tags[i-1] == tags[i+1]):
109
+ tags[i] = tags[i-1]
110
+ return tags
111
+
112
+ def indicators_to_spans(labels, t2c = None):
113
+ def add_span(c, start, end):
114
+ if t2c(start) is None or t2c(end) is None:
115
+ start, end = -1, -1
116
+ else:
117
+ start = t2c(start).start
118
+ end = t2c(end).end
119
+ span = (c, start, end)
120
+ spans.add(span)
121
+
122
+ spans = set()
123
+ num_tokens = len(labels)
124
+ num_classes = OTHERS_ID // 2
125
+ start = None
126
+ cls = None
127
+ for t in range(num_tokens):
128
+ if start and labels[t] == cls + 1:
129
+ continue
130
+ elif start:
131
+ add_span(cls // 2, start, t - 1)
132
+ start = None
133
+ # if not start and labels[t] in [2*x for x in range(num_classes)]:
134
+ if not start and labels[t] != OTHERS_ID:
135
+ start = t
136
+ cls = int(labels[t]) // 2 * 2
137
+ return spans
138
+
139
+ def extract_date(text):
140
+ pattern = r'(?<=Date: )\s*(\[\*\*.*?\*\*\]|\d{1,4}[-/]\d{1,2}[-/]\d{1,4})'
141
+ match = re.search(pattern, text).group(1)
142
+ start, end = None, None
143
+ for i, c in enumerate(match):
144
+ if start is None and c.isnumeric():
145
+ start = i
146
+ elif c.isnumeric():
147
+ end = i + 1
148
+ match = match[start:end]
149
+ return match
150
+
151
+
152
+
153
+ def run_gradio(model, tokenizer):
154
+ def predict(text):
155
+ encoding = tokenizer.encode_plus(text)
156
+ x = torch.tensor(encoding['input_ids']).unsqueeze(0).to(device)
157
+ mask = torch.ones_like(x)
158
+ output = model.generate(x, mask)[0]
159
+ return output, encoding.token_to_chars
160
+
161
+ def process(text):
162
+ if text is not None:
163
+ output, t2c = predict(text)
164
+ tags = postprocess_labels(text, output, t2c)
165
+ with open('log.csv', 'a') as f:
166
+ f.write(f'{datetime.now()},{text}\n')
167
+ return list(zip(text, tags))
168
+ else:
169
+ return text
170
+
171
+ def process_sum(*inputs):
172
+ global sum_c
173
+ dates = {}
174
+ for i in range(sum_c):
175
+ text = inputs[i]
176
+ output, t2c = predict(text)
177
+ spans = indicators_to_spans(output.argmax(-1), t2c)
178
+ date = extract_date(text)
179
+ present_decs = set(cat for cat, _, _ in spans)
180
+ decs = {k: [] for k in sorted(present_decs)}
181
+ for c, s, e in spans:
182
+ decs[c].append(text[s:e])
183
+ dates[date] = decs
184
+
185
+ out = ""
186
+ for date in sorted(dates.keys(), key = lambda x: parser.parse(x)):
187
+ out += f'## **[{date}]**\n\n'
188
+ decs = dates[date]
189
+ for c in decs:
190
+ out += f'### {unicode_symbols[c]} ***{categories[c]}***\n\n'
191
+ for dec in decs[c]:
192
+ out += f'{dec}\n\n'
193
+
194
+ return out
195
+
196
+ global sum_c
197
+ sum_c = 1
198
+ SUM_INPUTS = 20
199
+ def update_inputs(inputs):
200
+ outputs = []
201
+ if inputs is None:
202
+ c = 0
203
+ else:
204
+ inputs = [open(f.name).read() for f in inputs]
205
+ for i, text in enumerate(inputs):
206
+ outputs.append(gr.update(value=text, visible=True))
207
+ c = len(inputs)
208
+
209
+ n = SUM_INPUTS
210
+ for i in range(n - c):
211
+ outputs.append(gr.update(value='', visible=False))
212
+ global sum_c; sum_c = c
213
+ return outputs
214
+
215
+ def add_ex(*inputs):
216
+ global sum_c
217
+ new_idx = sum_c
218
+ if new_idx < SUM_INPUTS:
219
+ out = inputs[:new_idx] + (gr.update(visible=True),) + inputs[new_idx+1:]
220
+ sum_c += 1
221
+ else:
222
+ out = inputs
223
+ return out
224
+
225
+ def sub_ex(*inputs):
226
+ global sum_c
227
+ new_idx = sum_c - 1
228
+ if new_idx > 0:
229
+ out = inputs[:new_idx] + (gr.update(visible=False),) + inputs[new_idx+1:]
230
+ sum_c -= 1
231
+ else:
232
+ out = inputs
233
+ return out
234
+
235
+
236
+ device = model.backbone.device
237
+ # colors = ['aqua', 'blue', 'fuchsia', 'teal', 'green', 'olive', 'lime', 'silver', 'purple', 'red',
238
+ # 'yellow', 'navy', 'gray', 'white', 'maroon', 'black']
239
+ colors = ['#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', '#fdb462', '#b3de69', '#fccde5', '#d9d9d9', '#bc80bd']
240
+
241
+ color_map = {cat: colors[i] for i,cat in enumerate(categories)}
242
+
243
+ det_desc = ['Admit, discharge, follow-up, referral',
244
+ 'Ordering test, consulting colleague, seeking external information',
245
+ 'Diagnostic conclusion, evaluation of health state, etiological inference, prognostic judgment',
246
+ 'Quantitative or qualitative',
247
+ 'Start, stop, alter, maintain, refrain',
248
+ 'Start, stop, alter, maintain, refrain',
249
+ 'Positive, negative, ambiguous test results',
250
+ 'Transfer responsibility, wait and see, change subject',
251
+ 'Advice or precaution',
252
+ 'Sick leave, drug refund, insurance, disability']
253
+
254
+ desc = '### Zones (categories)\n'
255
+ desc += '| | |\n| --- | --- |\n'
256
+ for i,cat in enumerate(categories):
257
+ desc += f'| {unicode_symbols[i]} **{cat}** | {det_desc[i]}|\n'
258
+
259
+ #colors
260
+ #markdown labels
261
+ #legend and desc
262
+ #css font-size
263
+ css = '.category-legend {border:1px dashed black;}'\
264
+ '.text-sm {font-size: 1.5rem; line-height: 200%;}'\
265
+ '.gr-sample-textbox {width: 1000px; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;}'\
266
+ '.text-limit label textarea {height: 150px !important; overflow: scroll; }'\
267
+ '.text-gray-500 {color: #111827; font-weight: 600; font-size: 1.25em; margin-top: 1.6em; margin-bottom: 0.6em;'\
268
+ 'line-height: 1.6;}'\
269
+ '#sum-out {border: 2px solid #007bff; padding: 20px; border-radius: 10px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);'
270
+ title='Clinical Decision Zoning'
271
+ with gr.Blocks(title=title, css=css) as demo:
272
+ gr.Markdown(f'# {title}')
273
+ with gr.Tab("Label a Clinical Note"):
274
+ with gr.Row():
275
+ with gr.Column():
276
+ gr.Markdown("## Enter a Discharge Summary or Clinical Note"),
277
+ text_input = gr.Textbox(
278
+ # value=examples[0],
279
+ label="",
280
+ placeholder="Enter text here...")
281
+ text_btn = gr.Button('Run')
282
+ with gr.Column():
283
+ gr.Markdown("## Labeled Summary or Note"),
284
+ text_out = gr.Highlight(label="", combine_adjacent=True, show_legend=False, color_map=color_map)
285
+ gr.Examples(text_examples, inputs=text_input)
286
+ with gr.Tab("Summarize Patient History"):
287
+ with gr.Row():
288
+ with gr.Column():
289
+ sum_inputs = [gr.Text(label='Clinical Note 1', elem_classes='text-limit')]
290
+ sum_inputs.extend([gr.Text(label='Clinical Note %d'%i, visible=False, elem_classes='text-limit')
291
+ for i in range(2, SUM_INPUTS + 1)])
292
+ sum_btn = gr.Button('Run')
293
+ with gr.Row():
294
+ ex_add = gr.Button("+")
295
+ ex_sub = gr.Button("-")
296
+ upload = gr.File(label='Upload clinical notes', file_type='text', file_count='multiple')
297
+ gr.Examples(sum_examples, inputs=upload,
298
+ fn = update_inputs, outputs=sum_inputs, run_on_click=True)
299
+ with gr.Column():
300
+ gr.Markdown("## Summarized Clinical Decision History")
301
+ sum_out = gr.Markdown(elem_id='sum-out')
302
+ gr.Markdown(desc)
303
+
304
+ # Functions
305
+ text_input.submit(process, inputs=text_input, outputs=text_out)
306
+ text_btn.click(process, inputs=text_input, outputs=text_out)
307
+ upload.change(update_inputs, inputs=upload, outputs=sum_inputs)
308
+ ex_add.click(add_ex, inputs=sum_inputs, outputs=sum_inputs)
309
+ ex_sub.click(sub_ex, inputs=sum_inputs, outputs=sum_inputs)
310
+ sum_btn.click(process_sum, inputs=sum_inputs, outputs=sum_out)
311
+ # demo = gr.TabbedInterface([text_demo, sum_demo], ["Label a Clinical Note", "Summarize Patient History"])
312
+ demo.launch(share=False)
313
 
314
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
315
  tokenizer = load_tokenizer(args.model_name)
demo.py DELETED
@@ -1,241 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- from datetime import datetime
4
- from dateutil import parser
5
- from demo_assets import *
6
- import re
7
-
8
- categories = ['Contact related', 'Gathering additional information', 'Defining problem',
9
- 'Treatment goal', 'Drug related', 'Therapeutic procedure related', 'Evaluating test result',
10
- 'Deferment', 'Advice and precaution', 'Legal and insurance related']
11
- unicode_symbols = [
12
- "\U0001F91D", # Handshake
13
- "\U0001F50D", # Magnifying glass
14
- "\U0001F9E9", # Puzzle piece
15
- "\U0001F3AF", # Target
16
- "\U0001F48A", # Pill
17
- "\U00002702", # Surgical scissors
18
- "\U0001F9EA", # Test tube
19
- "\U000023F0", # Alarm clock
20
- "\U000026A0", # Warning sign
21
- "\U0001F4C4" # Document
22
- ]
23
-
24
- OTHERS_ID = 18
25
- def postprocess_labels(text, logits, t2c):
26
- tags = [None for _ in text]
27
- labels = logits.argmax(-1)
28
- for i,cat in enumerate(labels):
29
- if cat != OTHERS_ID:
30
- char_ids = t2c(i)
31
- if char_ids is None:
32
- continue
33
- for idx in range(char_ids.start, char_ids.end):
34
- if tags[idx] is None and idx < len(text):
35
- tags[idx] = categories[cat // 2]
36
- for i in range(len(text)-1):
37
- if text[i] == ' ' and (text[i+1] == ' ' or tags[i-1] == tags[i+1]):
38
- tags[i] = tags[i-1]
39
- return tags
40
-
41
- def indicators_to_spans(labels, t2c = None):
42
- def add_span(c, start, end):
43
- if t2c(start) is None or t2c(end) is None:
44
- start, end = -1, -1
45
- else:
46
- start = t2c(start).start
47
- end = t2c(end).end
48
- span = (c, start, end)
49
- spans.add(span)
50
-
51
- spans = set()
52
- num_tokens = len(labels)
53
- num_classes = OTHERS_ID // 2
54
- start = None
55
- cls = None
56
- for t in range(num_tokens):
57
- if start and labels[t] == cls + 1:
58
- continue
59
- elif start:
60
- add_span(cls // 2, start, t - 1)
61
- start = None
62
- # if not start and labels[t] in [2*x for x in range(num_classes)]:
63
- if not start and labels[t] != OTHERS_ID:
64
- start = t
65
- cls = int(labels[t]) // 2 * 2
66
- return spans
67
-
68
- def extract_date(text):
69
- pattern = r'(?<=Date: )\s*(\[\*\*.*?\*\*\]|\d{1,4}[-/]\d{1,2}[-/]\d{1,4})'
70
- match = re.search(pattern, text).group(1)
71
- start, end = None, None
72
- for i, c in enumerate(match):
73
- if start is None and c.isnumeric():
74
- start = i
75
- elif c.isnumeric():
76
- end = i + 1
77
- match = match[start:end]
78
- return match
79
-
80
-
81
-
82
- def run_gradio(model, tokenizer):
83
- def predict(text):
84
- encoding = tokenizer.encode_plus(text)
85
- x = torch.tensor(encoding['input_ids']).unsqueeze(0).to(device)
86
- mask = torch.ones_like(x)
87
- output = model.generate(x, mask)[0]
88
- return output, encoding.token_to_chars
89
-
90
- def process(text):
91
- if text is not None:
92
- output, t2c = predict(text)
93
- tags = postprocess_labels(text, output, t2c)
94
- with open('log.csv', 'a') as f:
95
- f.write(f'{datetime.now()},{text}\n')
96
- return list(zip(text, tags))
97
- else:
98
- return text
99
-
100
- def process_sum(*inputs):
101
- global sum_c
102
- dates = {}
103
- for i in range(sum_c):
104
- text = inputs[i]
105
- output, t2c = predict(text)
106
- spans = indicators_to_spans(output.argmax(-1), t2c)
107
- date = extract_date(text)
108
- present_decs = set(cat for cat, _, _ in spans)
109
- decs = {k: [] for k in sorted(present_decs)}
110
- for c, s, e in spans:
111
- decs[c].append(text[s:e])
112
- dates[date] = decs
113
-
114
- out = ""
115
- for date in sorted(dates.keys(), key = lambda x: parser.parse(x)):
116
- out += f'## **[{date}]**\n\n'
117
- decs = dates[date]
118
- for c in decs:
119
- out += f'### {unicode_symbols[c]} ***{categories[c]}***\n\n'
120
- for dec in decs[c]:
121
- out += f'{dec}\n\n'
122
-
123
- return out
124
-
125
- global sum_c
126
- sum_c = 1
127
- SUM_INPUTS = 20
128
- def update_inputs(inputs):
129
- outputs = []
130
- if inputs is None:
131
- c = 0
132
- else:
133
- inputs = [open(f.name).read() for f in inputs]
134
- for i, text in enumerate(inputs):
135
- outputs.append(gr.update(value=text, visible=True))
136
- c = len(inputs)
137
-
138
- n = SUM_INPUTS
139
- for i in range(n - c):
140
- outputs.append(gr.update(value='', visible=False))
141
- global sum_c; sum_c = c
142
- return outputs
143
-
144
- def add_ex(*inputs):
145
- global sum_c
146
- new_idx = sum_c
147
- if new_idx < SUM_INPUTS:
148
- out = inputs[:new_idx] + (gr.update(visible=True),) + inputs[new_idx+1:]
149
- sum_c += 1
150
- else:
151
- out = inputs
152
- return out
153
-
154
- def sub_ex(*inputs):
155
- global sum_c
156
- new_idx = sum_c - 1
157
- if new_idx > 0:
158
- out = inputs[:new_idx] + (gr.update(visible=False),) + inputs[new_idx+1:]
159
- sum_c -= 1
160
- else:
161
- out = inputs
162
- return out
163
-
164
-
165
- device = model.backbone.device
166
- # colors = ['aqua', 'blue', 'fuchsia', 'teal', 'green', 'olive', 'lime', 'silver', 'purple', 'red',
167
- # 'yellow', 'navy', 'gray', 'white', 'maroon', 'black']
168
- colors = ['#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', '#fdb462', '#b3de69', '#fccde5', '#d9d9d9', '#bc80bd']
169
-
170
- color_map = {cat: colors[i] for i,cat in enumerate(categories)}
171
-
172
- det_desc = ['Admit, discharge, follow-up, referral',
173
- 'Ordering test, consulting colleague, seeking external information',
174
- 'Diagnostic conclusion, evaluation of health state, etiological inference, prognostic judgment',
175
- 'Quantitative or qualitative',
176
- 'Start, stop, alter, maintain, refrain',
177
- 'Start, stop, alter, maintain, refrain',
178
- 'Positive, negative, ambiguous test results',
179
- 'Transfer responsibility, wait and see, change subject',
180
- 'Advice or precaution',
181
- 'Sick leave, drug refund, insurance, disability']
182
-
183
- desc = '### Zones (categories)\n'
184
- desc += '| | |\n| --- | --- |\n'
185
- for i,cat in enumerate(categories):
186
- desc += f'| {unicode_symbols[i]} **{cat}** | {det_desc[i]}|\n'
187
-
188
- #colors
189
- #markdown labels
190
- #legend and desc
191
- #css font-size
192
- css = '.category-legend {border:1px dashed black;}'\
193
- '.text-sm {font-size: 1.5rem; line-height: 200%;}'\
194
- '.gr-sample-textbox {width: 1000px; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;}'\
195
- '.text-limit label textarea {height: 150px !important; overflow: scroll; }'\
196
- '.text-gray-500 {color: #111827; font-weight: 600; font-size: 1.25em; margin-top: 1.6em; margin-bottom: 0.6em;'\
197
- 'line-height: 1.6;}'\
198
- '#sum-out {border: 2px solid #007bff; padding: 20px; border-radius: 10px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);'
199
- title='Clinical Decision Zoning'
200
- with gr.Blocks(title=title, css=css) as demo:
201
- gr.Markdown(f'# {title}')
202
- with gr.Tab("Label a Clinical Note"):
203
- with gr.Row():
204
- with gr.Column():
205
- gr.Markdown("## Enter a Discharge Summary or Clinical Note"),
206
- text_input = gr.Textbox(
207
- # value=examples[0],
208
- label="",
209
- placeholder="Enter text here...")
210
- text_btn = gr.Button('Run')
211
- with gr.Column():
212
- gr.Markdown("## Labeled Summary or Note"),
213
- text_out = gr.Highlight(label="", combine_adjacent=True, show_legend=False, color_map=color_map)
214
- gr.Examples(text_examples, inputs=text_input)
215
- with gr.Tab("Summarize Patient History"):
216
- with gr.Row():
217
- with gr.Column():
218
- sum_inputs = [gr.Text(label='Clinical Note 1', elem_classes='text-limit')]
219
- sum_inputs.extend([gr.Text(label='Clinical Note %d'%i, visible=False, elem_classes='text-limit')
220
- for i in range(2, SUM_INPUTS + 1)])
221
- sum_btn = gr.Button('Run')
222
- with gr.Row():
223
- ex_add = gr.Button("+")
224
- ex_sub = gr.Button("-")
225
- upload = gr.File(label='Upload clinical notes', file_type='text', file_count='multiple')
226
- gr.Examples(sum_examples, inputs=upload,
227
- fn = update_inputs, outputs=sum_inputs, run_on_click=True)
228
- with gr.Column():
229
- gr.Markdown("## Summarized Clinical Decision History")
230
- sum_out = gr.Markdown(elem_id='sum-out')
231
- gr.Markdown(desc)
232
-
233
- # Functions
234
- text_input.submit(process, inputs=text_input, outputs=text_out)
235
- text_btn.click(process, inputs=text_input, outputs=text_out)
236
- upload.change(update_inputs, inputs=upload, outputs=sum_inputs)
237
- ex_add.click(add_ex, inputs=sum_inputs, outputs=sum_inputs)
238
- ex_sub.click(sub_ex, inputs=sum_inputs, outputs=sum_inputs)
239
- sum_btn.click(process_sum, inputs=sum_inputs, outputs=sum_out)
240
- # demo = gr.TabbedInterface([text_demo, sum_demo], ["Label a Clinical Note", "Summarize Patient History"])
241
- demo.launch(share=False)