import os import re import uuid import json import argparse import torch import gradio as gr import pandas as pd import plotly.express as px import numpy as np from data import load_tokenizer from model import load_model from datetime import datetime from dateutil import parser from demo_assets import * from typing import List, Dict, Any, Optional, Tuple from dataclasses import dataclass from collections import defaultdict def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--data_dir', default='/data/mohamed/data') parser.add_argument('--aim_repo', default='/data/mohamed/') parser.add_argument('--ckpt', default='electra-base.pt') parser.add_argument('--aim_exp', default='mimic-decisions-1215') parser.add_argument('--label_encoding', default='multiclass') parser.add_argument('--multiclass', action='store_true') parser.add_argument('--debug', action='store_true') parser.add_argument('--save_losses', action='store_true') parser.add_argument('--task', default='token', choices=['seq', 'token']) parser.add_argument('--max_len', type=int, default=512) parser.add_argument('--num_layers', type=int, default=3) parser.add_argument('--kernels', nargs=3, type=int, default=[1,2,3]) parser.add_argument('--model', default='roberta-base',) parser.add_argument('--model_name', default='google/electra-base-discriminator',) parser.add_argument('--gpu', default='0') parser.add_argument('--grad_accumulation', default=2, type=int) parser.add_argument('--pheno_id', type=int) parser.add_argument('--unseen_pheno', type=int) parser.add_argument('--text_subset') parser.add_argument('--pheno_n', type=int, default=500) parser.add_argument('--hidden_size', type=int, default=100) parser.add_argument('--emb_size', type=int, default=400) parser.add_argument('--total_steps', type=int, default=5000) parser.add_argument('--train_log', type=int, default=500) parser.add_argument('--val_log', type=int, default=1000) parser.add_argument('--seed', default = '0') parser.add_argument('--num_phenos', type=int, default=10) parser.add_argument('--num_decs', type=int, default=9) parser.add_argument('--num_umls_tags', type=int, default=33) parser.add_argument('--batch_size', type=int, default=8) parser.add_argument('--pos_weight', type=float, default=1.25) parser.add_argument('--alpha_distil', type=float, default=1) parser.add_argument('--distil', action='store_true') parser.add_argument('--distil_att', action='store_true') parser.add_argument('--distil_ckpt') parser.add_argument('--use_umls', action='store_true') parser.add_argument('--include_nolabel', action='store_true') parser.add_argument('--truncate_train', action='store_true') parser.add_argument('--truncate_eval', action='store_true') parser.add_argument('--load_ckpt', action='store_true') parser.add_argument('--gradio', action='store_true') parser.add_argument('--optuna', action='store_true') parser.add_argument('--mimic_data', action='store_true') parser.add_argument('--eval_only', action='store_true') parser.add_argument('--lr', type=float, default=4e-5) parser.add_argument('--resample', default='') parser.add_argument('--verbose', type=bool, default=True) parser.add_argument('--use_crf', type=bool) parser.add_argument('--print_spans', action='store_true') return parser.parse_args() args = get_args() if args.task == 'seq' and args.pheno_id is not None: args.num_labels = 1 elif args.task == 'seq': args.num_labels = args.num_phenos elif args.task == 'token': if args.use_umls: args.num_labels = args.num_umls_tags else: args.num_labels = args.num_decs if args.label_encoding == 'multiclass': args.num_labels = args.num_labels * 2 + 1 elif args.label_encoding == 'bo': args.num_labels *= 2 elif args.label_encoding == 'boe': args.num_labels *= 3 @dataclass class KeyDef: key: str name: str desc: str = '' color: str = 'lightblue' symbol: str = '' class AnnotationState: def __init__(self): self.entity_regex = r'\[\@.*?\#.*?\*\](?!\#)' self.recommend_regex = r'\[\$.*?\#.*?\*\](?!\#)' self.history = [] self.config_file = "configs/default.config" self.press_commands = self.read_config() # Internal state holds the actual annotations self.annotations = [] self.raw_text = "" def read_config(self) -> List[KeyDef]: if not os.path.exists(self.config_file): default_config = [{ 'key': key, 'name': name, 'color': color, 'symbol': symbol } for key, name, color, symbol in zip(keys, categories, colors, unicode_symbols) ] os.makedirs("configs", exist_ok=True) with open(self.config_file, 'w') as fp: json.dump(default_config, fp) with open(self.config_file, 'r') as fp: config_dict = json.load(fp) return [KeyDef(**entry) for entry in config_dict] def get_cmd_by_key(self, key: str) -> Optional[KeyDef]: return next((cmd for cmd in self.press_commands if cmd.key == key), None) def set_text(self, text: str): """Initialize with new text, clearing annotations""" self.raw_text = text self.annotations = [] self.history = [] def add_annotation(self, start: int, end: int, entity_type: str) -> str: """Add new annotation and return display text""" # Save current state to history self.history.append((self.raw_text, list(self.annotations))) if len(self.history) > 20: self.history.pop(0) # Add new annotation self.annotations.append((start, end, entity_type)) return self.get_display_text() def remove_annotation(self, start: int, end: int) -> str: """Remove annotation at position if it exists, splitting spans if needed""" self.history.append((self.raw_text, list(self.annotations))) new_annotations = [] for a in self.annotations: annotation_start, annotation_end, entity_type = a # If the current annotation does not overlap, keep it as is if annotation_end < start or annotation_start > end: new_annotations.append(a) # If the removed span is a proper subset, split the annotation elif annotation_start < start and annotation_end > end: new_annotations.append((annotation_start, start - 1, entity_type)) new_annotations.append((end + 1, annotation_end, entity_type)) # If there's overlap with the start, but not the end elif annotation_start < start <= annotation_end: new_annotations.append((annotation_start, start - 1, entity_type)) # If there's overlap with the end, but not the start elif annotation_start <= end < annotation_end: new_annotations.append((end + 1, annotation_end, entity_type)) self.annotations = new_annotations return self.get_display_text() def undo(self) -> str: """Undo last annotation action""" if not self.history: return self.get_display_text() self.raw_text, self.annotations = self.history.pop() return self.get_display_text() def get_display_text(self) -> str: """Generate display text with HTML formatting for annotations""" if not self.annotations: return f'
{self.raw_text}
' # Sort annotations by start position sorted_annotations = sorted(self.annotations, key=lambda x: (x[0], -x[1])) # Build display text with HTML spans result = ['
'] last_end = 0 for start, end, entity_type in sorted_annotations: if start < last_end and end > last_end: start = last_end elif start < last_end: continue # Add text before annotation result.append(self.raw_text[last_end:start]) # Add annotated text with highlighting text = self.raw_text[start:end] cmd = self.get_cmd_by_key(entity_type) color = cmd.color result.append(f'{text}') # Nicer tooltip last_end = end # Add remaining text result.append(self.raw_text[last_end:]) result.append('
') # Generate legend legend = ['
Legend: '] # Margin and bold legend title used_categories = sorted(list(set([a[2] for a in self.annotations]))) for cat in used_categories: cmd = self.get_cmd_by_key(cat) legend.append(f'{cmd.name}') # Improved legend item styling legend.append('
') result.extend(legend) return "".join(result) def get_annotated_text(self, annotator_id=None, discharge_summary_id=None) -> dict: """Generate a dictionary containing annotation data.""" unique_id = str(uuid.uuid4())[:8] annotations = [] if self.annotations: sorted_annotations = sorted(self.annotations, key=lambda x: (x[0], -x[1])) for idx, (start, end, entity_type) in enumerate(sorted_annotations): cmd = self.get_cmd_by_key(entity_type) annotations.append({ "decision": self.raw_text[start:end], "category": f'Category {categories.index(cmd.name) + 1}: {cmd.name}', "start_offset": start, "end_offset": end, "annotation_id": f'{unique_id}_{idx}' }) return { "annotator_id": annotator_id if annotator_id else None, "discharge_summary_id": discharge_summary_id if discharge_summary_id else None, "annotations": annotations } def init_text(text): if text: state.set_text(text) return state.get_display_text() return "
Enter text to begin...
" def add_entity(cmd_key, start: int, end: int): """Handle adding new entity annotations""" if start == end: return state.get_display_text(), "No text selected" cmd = state.get_cmd_by_key(cmd_key) if not cmd: return state.get_display_text(), "Invalid command" new_text = state.add_annotation(start, end, cmd.key) return new_text, f"Added {cmd.name} entity" def remove_entity(start: int, end: int): """Handle removal of annotations""" if start == end: return state.get_display_text(), "No text selected" return state.remove_annotation(start, end), "Removed annotation" def undo(): """Handle undoing the last action""" return state.undo(), "Undid last action" def download_annotations(annotator_id, discharge_summary_id): """Generates and provides annotation data for download.""" annotation_data = state.get_annotated_text(annotator_id, discharge_summary_id) with open(OUTPUT_PATH, 'w') as f: json.dump(annotation_data, f, indent=4) return OUTPUT_PATH def refresh_annotations(annotator_id, discharge_summary_id): """Refreshes the displayed annotation JSON.""" return state.get_annotated_text(annotator_id, discharge_summary_id) def clear_annotations(): state.set_text(state.raw_text) # Clears annotations by setting empty text return gr.update(interactive=True, elem_classes=[]), state.get_display_text() # added value def model_predict(text): """Placeholder for model prediction logic""" output, t2c = predict(text) spans = indicators_to_spans(output.argmax(-1), t2c) spans = [(s, e, keys[c]) for c, s, e in spans] return spans def apply_predictions(text): predictions = model_predict(text) state.set_text(text) for start, end, entity_type in predictions: state.add_annotation(start, end, entity_type) return state.get_display_text() state = AnnotationState() all_keys = [f'"{cmd.key}"' for cmd in state.press_commands] key_list_str = f'[{", ".join(all_keys)}]' shortcut_js = shortcut_js_template%key_list_str def postprocess_labels(text, logits, t2c): tags = [None for _ in text] labels = logits.argmax(-1) for i,cat in enumerate(labels): if cat != OTHERS_ID: char_ids = t2c(i) if char_ids is None: continue for idx in range(char_ids.start, char_ids.end): if tags[idx] is None and idx < len(text): tags[idx] = categories[cat // 2] for i in range(len(text)-1): if text[i] == ' ' and (text[i+1] == ' ' or tags[i-1] == tags[i+1]): tags[i] = tags[i-1] return tags def indicators_to_spans(labels, t2c = None): def add_span(c, start, end): if t2c(start) is None or t2c(end) is None: start, end = -1, -1 else: start = t2c(start).start end = t2c(end).end span = (c, start, end) spans.add(span) spans = set() num_tokens = len(labels) num_classes = OTHERS_ID // 2 start = None cls = None for t in range(num_tokens): if start and labels[t] == cls + 1: continue elif start: add_span(cls // 2, start, t - 1) start = None # if not start and labels[t] in [2*x for x in range(num_classes)]: if not start and labels[t] != OTHERS_ID: start = t cls = int(labels[t]) // 2 * 2 return spans def extract_date(text): pattern = r'(?<=Date: )\s*(\[\*\*.*?\*\*\]|\d{1,4}[-/]\d{1,2}[-/]\d{1,4})' match = re.search(pattern, text).group(1) start, end = None, None for i, c in enumerate(match): if start is None and c.isnumeric(): start = i elif c.isnumeric(): end = i + 1 match = match[start:end] return match device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = load_tokenizer(args.model_name) model = load_model(args, device)[0] model.eval() torch.set_grad_enabled(False) def predict(text): encoding = tokenizer.encode_plus(text) x = torch.tensor(encoding['input_ids']).unsqueeze(0).to(device) mask = torch.ones_like(x) output = model.generate(x, mask)[0] return output, encoding.token_to_chars def process(text): if text is not None: output, t2c = predict(text) tags = postprocess_labels(text, output, t2c) with open('log.csv', 'a') as f: f.write(f'{datetime.now()},{text}\n') return list(zip(text, tags)) else: return text def process_sum(*inputs): global sum_c dates = {} for i in range(sum_c): text = inputs[i] output, t2c = predict(text) spans = indicators_to_spans(output.argmax(-1), t2c) date = extract_date(text) present_decs = set(cat for cat, _, _ in spans) decs = {k: [] for k in sorted(present_decs)} for c, s, e in spans: decs[c].append(text[s:e]) dates[date] = decs out = "" for date in sorted(dates.keys(), key = lambda x: parser.parse(x)): out += f'## **[{date}]**\n\n' decs = dates[date] for c in decs: out += f'### {unicode_symbols[c]} ***{categories[c]}***\n\n' for dec in decs[c]: out += f'{dec}\n\n' return out def get_structured_data(*inputs): global sum_c data = [] for i in range(sum_c): text = inputs[i] output, t2c = predict(text) spans = indicators_to_spans(output.argmax(-1), t2c) date = extract_date(text) for c, s, e in spans: data.append({ 'date': date, 'timestamp': parser.parse(date), 'decision_cat': c, 'decision_type': categories[c], 'details': text[s:e]}) return data def update_inputs(inputs): outputs = [] if inputs is None: c = 0 else: inputs = [open(f.name).read() for f in inputs] for i, text in enumerate(inputs): outputs.append(gr.update(value=text, visible=True)) c = len(inputs) n = SUM_INPUTS for i in range(n - c): outputs.append(gr.update(value='', visible=False)) global sum_c; sum_c = c global structured_data structured_data = get_structured_data(*inputs) if inputs is not None else [] return outputs def add_ex(*inputs): global sum_c new_idx = sum_c if new_idx < SUM_INPUTS: out = inputs[:new_idx] + (gr.update(visible=True),) + inputs[new_idx+1:] sum_c += 1 else: out = inputs return out def sub_ex(*inputs): global sum_c new_idx = sum_c - 1 if new_idx > 0: out = inputs[:new_idx] + (gr.update(visible=False),) + inputs[new_idx+1:] sum_c -= 1 else: out = inputs return out def create_timeline_plot(data: List[Dict[str, Any]]): df = pd.DataFrame(data) # df['int_cat'] = pd.factorize(df['decision_type'])[0] # df['int_cat_jittered'] = df['int_cat'] + np.random.uniform(-0.1, 0.1, size=len(df)) # fig = px.scatter(df, x='date', y='int_cat_jittered', color='decision_type', hover_data=['details'], # title='Patient Timeline') # fig.update_layout( # yaxis=dict( # tickmode='array', # tickvals=df['int_cat'].unique(), # ticktext=df['decision_type'].unique()), # xaxis_title='Date', # yaxis_title='Category') fig = px.strip(df, x='date', y='decision_type', color='decision_type', hover_data=['details'], stripmode = "overlay", title='Patient Timeline') fig.update_traces(jitter=1.0, marker=dict(size=10, opacity=0.6)) fig.update_layout(height=600) return fig def filter_timeline(decision_types: str, start_date: str, end_date: str) -> px.scatter: global structured_data filtered_data = structured_data if 'All' not in decision_types: filtered_data = [event for event in filtered_data if event['decision_type'] in decision_types] start = parser.parse(start_date) end = parser.parse(end_date) filtered_data = [event for event in filtered_data if start <= event['timestamp'] <= end] return create_timeline_plot(filtered_data) def generate_summary(*inputs) -> str: global structured_data structured_data = get_structured_data(*inputs) dates = defaultdict(lambda: defaultdict(list)) for event in structured_data: dates[event['date']][event['decision_cat']].append(event['details']) out = "" for date in sorted(dates.keys(), key = lambda x: parser.parse(x)): out += f'## **[{date}]**\n\n' decs = dates[date] for c in decs: out += f'### {unicode_symbols[c]} ***{categories[c]}***\n\n' for dec in decs[c]: out += f'{dec}\n\n' return out, create_timeline_plot(structured_data) global sum_c sum_c = 1 structured_data = [] device = model.backbone.device with gr.Blocks(head=shortcut_js, title='MedDecXtract', css=css) as demo: gr.Image('assets/logo.png', height=100, container=False, show_download_button=False) gr.Markdown(title) with gr.Tab("Decision Extraction & Classification"): gr.Markdown(label_desc) with gr.Row(): with gr.Column(): gr.Markdown("## Enter a Discharge Summary or Clinical Note"), text_input = gr.Textbox( # value=examples[0], label="", placeholder="Enter text here...") text_btn = gr.Button('Run') with gr.Column(): gr.Markdown("## Labeled Summary or Note"), text_out = gr.Highlight(label="", combine_adjacent=True, show_legend=False, color_map=color_map) gr.Examples(text_examples, inputs=text_input) with gr.Tab("Patient Visualization"): gr.Markdown(vis_desc) with gr.Column(): sum_inputs = [gr.Text(label='Clinical Note 1', elem_classes='text-limit')] sum_inputs.extend([gr.Text(label='Clinical Note %d'%i, visible=False, elem_classes='text-limit') for i in range(2, SUM_INPUTS + 1)]) with gr.Row(): ex_add = gr.Button("+") ex_sub = gr.Button("-") upload = gr.File(label='Upload clinical notes', file_types=['text'], file_count='multiple') gr.Examples(sum_examples, inputs=upload, fn = update_inputs, outputs=sum_inputs, run_on_click=True) with gr.Column(): with gr.Row(): decision_type = gr.Dropdown(["All"] + categories, multiselect=True, label="Decision Type", value="All") start_date = gr.Textbox(label="Start Date (MM/DD/YYYY)", value="01/01/2006") end_date = gr.Textbox(label="End Date (MM/DD/YYYY)", value="12/31/2024") filter_button = gr.Button("Filter Timeline") timeline_plot = gr.Plot() summary_button = gr.Button("Generate Summary") with gr.Accordion('Summary'): summary_output = gr.Markdown(elem_id='sum-out') #gr.Textbox(label="Summary") with gr.Tab("Interactive Narrative Annotator"): gr.Markdown(annotator_desc) with gr.Row(): with gr.Column(): annot_text_input = gr.Textbox( label="Enter Text to Annotate", placeholder="Enter or paste text here...", lines=5, elem_id='annot_text_input' ) gr.Examples(text_examples, inputs=annot_text_input) msg_output = gr.Textbox(label="Status Messages", interactive=False) display_area = gr.HTML( label="Annotated Text", value="
Output box
" ) k = 3 # Set the maximum number of buttons per row num_buttons = len(state.press_commands) rows = (num_buttons + k - 1) // k entity_buttons = [] with gr.Group(): predict_btn = gr.Button("Generate Predictions", size='lg', variant='primary') for i in range(rows): with gr.Row(): for j in range(min(k, num_buttons - i * k)): real_idx = i * k + j cmd = state.press_commands[real_idx] entity_buttons.append( gr.Button(f"{cmd.symbol} {cmd.name} ({cmd.key})", elem_id=f'btn_{cmd.key}', size='sm')) if i == (rows - 1): remove_btn = gr.Button("Remove (q)", size='sm', variant='secondary', elem_id='btn_q') undo_btn = gr.Button("Undo (z)", size='sm', elem_id='btn_z') clear_btn = gr.Button("Clear Annotations", size='lg', variant='stop') with gr.Accordion("Download/View Annotations \U0001F4BE", open=False): # Combined Accordion with gr.Row(): annotator_id = gr.Textbox(label="Annotator ID", placeholder="Enter your annotator ID") discharge_summary_id = gr.Textbox(label="Discharge Summary ID", placeholder="Enter the discharge summary ID") with gr.Row(): download_file = gr.File(interactive=False, visible=True, label="Download") # download_btn renamed and made into gr.File annotations_json = gr.JSON(label="Annotations JSON") refresh_btn = gr.Button("🔄 Refresh Annotations", elem_id="refresh_btn") # Renamed for clarity download_btn = gr.Button("Download Annotated Text", elem_id="download_btn") # Added a button to trigger download # Hidden state components for selection selection_start = gr.Number(value=0, visible=False) selection_end = gr.Number(value=0, visible=False) gr.Markdown(desc) # Functions # Wire up event handlers annot_text_input.change(init_text, annot_text_input, display_area) # Wire up the buttons with the selection JavaScript for btn, cmd in zip(entity_buttons, state.press_commands): btn.click(lambda s=None, e=None, c=cmd.key: add_entity(c, s, e),[selection_start, selection_end], [display_area, msg_output], js=select_js).then( lambda: gr.update(interactive=state.annotations == [], elem_classes=[] if state.annotations == [] else ['locked-input']), # Disable input if annotations exist outputs=annot_text_input ) remove_btn.click( remove_entity, [selection_start, selection_end], [display_area, msg_output], js=select_js).then( lambda: gr.update(interactive=state.annotations == [], elem_classes=[] if state.annotations == [] else ['locked-input']), outputs=annot_text_input ) undo_btn.click(undo, None, [display_area, msg_output]).then( lambda: gr.update(interactive=state.annotations == [], elem_classes=[] if state.annotations == [] else ['locked-input']), outputs=annot_text_input ) download_btn.click(download_annotations, [annotator_id, discharge_summary_id], download_file) # Output to download_file refresh_btn.click(refresh_annotations, [annotator_id, discharge_summary_id], annotations_json) # No change in functionality clear_btn.click(clear_annotations, outputs=[annot_text_input, display_area]) predict_btn.click(apply_predictions, annot_text_input, display_area).then( lambda: gr.update(interactive=state.annotations == [], elem_classes=[] if state.annotations == [] else ['locked-input']), outputs=text_input ) text_input.submit(process, inputs=text_input, outputs=text_out) text_btn.click(process, inputs=text_input, outputs=text_out) upload.change(update_inputs, inputs=upload, outputs=sum_inputs) ex_add.click(add_ex, inputs=sum_inputs, outputs=sum_inputs) ex_sub.click(sub_ex, inputs=sum_inputs, outputs=sum_inputs) filter_button.click(filter_timeline, inputs=[decision_type, start_date, end_date], outputs=timeline_plot) summary_button.click(generate_summary, inputs=sum_inputs, outputs=[summary_output, timeline_plot]) demo.launch(share=True)