Spaces:
Build error
Build error
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
from math import ceil | |
from collections import Counter | |
from string import punctuation | |
import spacy | |
from spacy import displacy | |
from spacy.lang.en import English | |
import en_ner_bc5cdr_md | |
from streamlit.components.v1 import html | |
# Store the initial value of widgets in session state | |
if "visibility" not in st.session_state: | |
st.session_state.visibility = "visible" | |
st.session_state.disabled = False | |
#nlp = en_core_web_lg.load() | |
nlp = spacy.load("en_ner_bc5cdr_md") | |
st.set_page_config(page_title ='Clinical Note Summarization', | |
#page_icon= "Notes", | |
layout='wide') | |
st.title('Clinical Note Summarization') | |
st.markdown( | |
""" | |
<style> | |
[data-testid="stSidebar"][aria-expanded="true"] > div:first-child { | |
width: 400px; | |
} | |
[data-testid="stSidebar"][aria-expanded="false"] > div:first-child { | |
width: 400px; | |
margin-left: -230px; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.sidebar.markdown('Using transformer model') | |
## Loading in dataset | |
#df = pd.read_csv('mtsamples_small.csv',index_col=0) | |
df = pd.read_csv('shpi_w_rouge21Nov.csv') | |
df['HADM_ID'] = df['HADM_ID'].astype(str).apply(lambda x: x.replace('.0','')) | |
#Renaming column | |
df.rename(columns={'SUBJECT_ID':'Patient_ID', | |
'HADM_ID':'Admission_ID', | |
'hpi_input_text':'Original_Text', | |
'hpi_reference_summary':'Reference_text'}, inplace = True) | |
#data.rename(columns={'gdp':'log(gdp)'}, inplace=True) | |
#Filter selection | |
st.sidebar.header("Search for Patient:") | |
patientid = df['Patient_ID'] | |
patient = st.sidebar.selectbox('Select Patient ID:', patientid) | |
admissionid = df['Admission_ID'].loc[df['Patient_ID'] == patient] | |
HospitalAdmission = st.sidebar.selectbox('', admissionid) | |
# List of Model available | |
model = st.sidebar.selectbox('Select Model', ('BertSummarizer','BertGPT2','t5seq2eq','t5','gensim','pysummarizer')) | |
col3,col4 = st.columns(2) | |
patientid = col3.write(f"Patient ID: {patient} ") | |
admissionid =col4.write(f"Admission ID: {HospitalAdmission} ") | |
runtext = '' | |
inputNote ='Input note here:' | |
# Query out relevant Clinical notes | |
original_text = df.query( | |
"Patient_ID == @patient & Admission_ID == @HospitalAdmission" | |
) | |
original_text2 = original_text['Original_Text'].values | |
reference_text = original_text['Reference_text'].values | |
##========= Buttons to the 4 tabs ======== | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
if st.button("🏥 Admission"): | |
#nav_page('Admission') | |
inputNote = "Input Admission Note" | |
with col2: | |
if st.button('📆Daily Narrative'): | |
#nav_page('Daily Narrative') | |
inputNote = "Input Daily Narrative Note" | |
with col3: | |
if st.button('🗒️Discharge Plan'): | |
#nav_page('Discharge Plan') | |
inputNote = "Input Discharge Plan" | |
with col4: | |
if st.button('📝Social Notes'): | |
#nav_page('Social Notes') | |
inputNote = "Input Social Note" | |
runtext =st.text_area(inputNote, str(original_text2), height=300) | |
# Extract words associated with each entity | |
def genEntities(ann, entity): | |
# entity colour dict | |
#ent_col = {'DISEASE':'#B42D1B', 'CHEMICAL':'#F06292'} | |
ent_col = {'DISEASE':'pink', 'CHEMICAL':'orange'} | |
# separate into the different entities | |
entities = trans_df['Class'].unique() | |
if entity in entities: | |
ent = list(trans_df[trans_df['Class']==entity]['Entity'].unique()) | |
entlist = ",".join(ent) | |
st.markdown(f'<p style="background-color:{ent_col[entity]};color:#080808;font-size:16px;">{entlist}</p>', unsafe_allow_html=True) | |
#for i in ent: | |
#st.markdown(f'<p style="color:{ent_col[entity]};font-size:20px;">{i}</p>', unsafe_allow_html=True) | |
def visualize (run_text,output): | |
text ='' | |
splitruntext = [x for x in runtext.split('.')] | |
splitoutput = [x for x in output.split('.')] | |
# best_sentences = [] | |
# for sentence in output: | |
# best_sentences.append(str(sentence)) | |
# text = '' | |
# #display(HTML(f'<h1>Summary - {title}</h1>')) | |
# for sentence in run_text: | |
# if sentence in best_sentences: | |
# text += ' ' + str(sentence).replace(sentence, f"<mark>{sentence}</mark>") | |
# else: | |
# text += ' ' + sentence | |
# display(HTML(f""" {text} """)) | |
return splitoutput,splitruntext | |
def run_model(input_text): | |
if model == "BertSummarizer": | |
output = original_text['BertSummarizer'].values | |
st.write('Summary') | |
elif model == "BertGPT2": | |
output = original_text['BertGPT2'].values | |
st.write('Summary') | |
elif model == "t5seq2eq": | |
output = original_text['t5seq2eq'].values | |
st.write('Summary') | |
elif model == "t5": | |
output = original_text['t5'].values | |
st.write('Summary') | |
elif model == "gensim": | |
output = original_text['gensim'].values | |
st.write('Summary') | |
elif model == "pysummarizer": | |
output = original_text['pysummarizer'].values | |
st.write('Summary') | |
#st.text_area(visualize (runtext,output)) | |
st.success(output) | |
# return output | |
doc = nlp(str(original_text2)) | |
colors = { "DISEASE": "pink","CHEMICAL": "orange"} | |
options = {"ents": [ "DISEASE", "CHEMICAL"],"colors": colors} | |
ent_html = displacy.render(doc, style="ent", options=options) | |
col1, col2 = st.columns([1,2]) | |
with col1: | |
st.button('Summarize') | |
run_model(runtext) | |
sentences=runtext.split('.') | |
st.text_area('Reference text', str(reference_text), height=150) | |
##====== Storing the Diseases/Text | |
table= {"Entity":[], "Class":[]} | |
ent_bc = {} | |
for x in doc.ents: | |
ent_bc[x.text] = x.label_ | |
for key in ent_bc: | |
table["Entity"].append(key) | |
table["Class"].append(ent_bc[key]) | |
trans_df = pd.DataFrame(table) | |
with col2: | |
st.button('NER') | |
st.markdown('**DISEASE**') | |
genEntities(trans_df, 'DISEASE') | |
st.markdown('**CHEMICAL**') | |
genEntities(trans_df, 'CHEMICAL') | |
#st.table(trans_df) | |
st.markdown('**NER**') | |
st.markdown(ent_html, unsafe_allow_html=True) | |