nlp-qual-space / app.py
maxspad's picture
added a checkbox with default false for word importances
39b062f
import streamlit as st
import transformers as tf
import pandas as pd
from datetime import datetime
from plotly import graph_objects as go
from transformers_interpret import SequenceClassificationExplainer
from annotated_text import annotated_text
from palettable.scientific.sequential import Devon_10_r
from palettable.colorbrewer.diverging import RdYlGn_10, PuOr_10, BrBG_10
from overview import NQDOverview
import torch
cuda_available = torch.cuda.is_available()
print(f"Is CUDA available: {cuda_available}")
if cuda_available:
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
# Function to load and cache models
@st.experimental_singleton(show_spinner=False)
def load_model(username, prefix, model_name):
p = tf.pipeline('text-classification', f'{username}/{prefix}-{model_name}', return_all_scores=True)
return p
@st.experimental_singleton(show_spinner=False)
def load_pickle(f):
return pd.read_pickle(f)
def get_results(model, c):
res = model(c)[0]
scores = [r['score'] for r in res]
label = max(range(len(scores)), key=lambda i: scores[i])
# label = float(res['label'].split('_')[1])
# scores = res['score']
return {'label': label, 'scores': scores}
def run_models(model_names, models, c):
results = {}
for mn in model_names:
results[mn] = get_results(models[mn], c)
return results
st.title('Assess the *QuAL*ity of your feedback')
st.caption(
"""Medical education requires high-quality *written* feedback,
but evaluating these *supervisor narrative comments* is time-consuming.
The QuAL score has validity evidence for measuring the quality of short
comments in this context. We developed a NLP/ML-powered tool to
assess written comment quality via the QuAL score with high accuracy.
See our paper in *Academic Medicine* at [https://doi.org/10.1097/ACM.0000000000005634](https://doi.org/10.1097/ACM.0000000000005634)
*Try it for yourself!*
""")
### Load models
# Specify which models to load
USERNAME = 'maxspad'
PREFIX = 'nlp-qual'
models_to_load = ['qual', 'q1', 'q2i', 'q3i']
n_models = float(len(models_to_load))
models = {}
# Show a progress bar while models are downloading,
# then hide it when done
lc_placeholder = st.empty()
loader_container = lc_placeholder.container()
loader_container.caption('Loading models... please wait...')
pbar = loader_container.progress(0.0)
for i, mn in enumerate(models_to_load):
pbar.progress((i+1.0) / n_models)
models[mn] = load_model(USERNAME, PREFIX, mn)
lc_placeholder.empty()
### Load example data
examples = load_pickle('test.pkl')
### Process input
ex = examples['comment'].sample(1, random_state=int(datetime.now().timestamp())).tolist()[0]
try:
ex = ex.strip().replace('_x000D_', '').replace('nan', 'blank')
except:
ex = 'blank'
if 'comment' not in st.session_state:
st.session_state['comment'] = ex
with st.form('comment_form'):
comment = st.text_area('Try a comment:', value=st.session_state['comment'])
left_col, right_col = st.columns([1,9], gap='medium')
submitted = left_col.form_submit_button('Submit')
trying_example = right_col.form_submit_button('Try an example!')
if submitted:
st.session_state['button_clicked'] = 'submit'
st.session_state['comment'] = comment
st.experimental_rerun()
elif trying_example:
st.session_state['button_clicked'] = 'example'
st.session_state['comment'] = ex
st.experimental_rerun()
results = run_models(models_to_load, models, st.session_state['comment'])
#Modify results to sum the QuAL score and to ignore Q3 if Q2 no suggestion
if results['q2i']['label'] == 1:
results['q3i']['label'] = 1 # can't have connection if no suggestion
results['qual']['label'] = results['q1']['label'] + (not results['q2i']['label']) + (not results['q3i']['label'])
overview = NQDOverview(st, results)
overview.draw()
def rescale(x):
return (x + 1.0) / 2.0
def get_explained_words(comment, pipe, label, cmap):
cls_explainer = SequenceClassificationExplainer(
pipe.model,
pipe.tokenizer)
word_attributions = cls_explainer(comment, class_name=label)[1:-1]
# Get rid of "##"
to_disp = [
(word, '', f'rgba{tuple([int(c*255) for c in cmap.mpl_colormap(rescale(word_score))])}')
for word, word_score in word_attributions
]
return to_disp
qual_map = {
0: 'minimal',
1: 'very low',
2: 'low',
3: 'average',
4: 'above average',
5: 'excellent'
}
q1_map = {
0: "minimal",
1: "low",
2: "moderate",
3: "high"
}
q2i_map = {
0: "did",
1: "did not"
}
with st.expander('Expand to explore further'):
st.write(f'Your comment was rated as a QuAL score of **{results["qual"]["label"]}**, indicating **{qual_map[results["qual"]["label"]]}** quality feedback.')
do_word_importances = st.checkbox("Calculate word importance. This provides more detail on model reasoning below, but takes much longer to compute.",
value=False)
st.markdown('### Level of Detail')
st.write(f"The model identified a **{q1_map[results['q1']['label']]}** level of detail in your comment.")
if do_word_importances:
st.write("Below are words that pointed the model toward (green) or against (red) identifying a high level of detail:")
with st.spinner("Calculating word importances, may take a while..."):
annotated_text(get_explained_words(st.session_state['comment'], models['q1'], 'LABEL_3', RdYlGn_10))
st.markdown('### Suggestion for Improvement')
st.write(f"The model **{q2i_map[results['q2i']['label']]}** predict that you provided a suggestion for improvement in your comment.")
if do_word_importances:
st.write(f"Below are words that pointed the model toward (green) or against (red) identifying a suggestion for improvement:")
with st.spinner("Calculating word importances, may take a while..."):
annotated_text(get_explained_words(st.session_state['comment'], models['q2i'], 'LABEL_0', RdYlGn_10))
if results['q2i']['label'] == 0:
st.markdown('### Suggestion Linking')
st.write(f"The model **{q2i_map[results['q3i']['label']]}** predict that you linked your suggestion")
if do_word_importances:
st.write(f"Below are words that pointed the model toward (green) or against (red) identifying a linked suggestion:")
with st.spinner("Calculating word importances, may take a while..."):
annotated_text(get_explained_words(st.session_state['comment'], models['q3i'], 'LABEL_0', RdYlGn_10))
# annotated_text(to_disp)
# cls_explainer = SequenceClassificationExplainer(
# models['q1'].model,
# models['q1'].tokenizer)
# word_attributions = cls_explainer(st.session_state['comment'], class_name='LABEL_3')[1:-1]
# to_disp = [
# (word, f'{word_score:.2f}', f'rgba{tuple([int(c*255) for c in Devon_10_r.mpl_colormap(word_score)])}')
# for word, word_score in word_attributions
# ]
# print(to_disp)
# annotated_text(to_disp)