Spaces:
Running
Running
import streamlit as st | |
import transformers as tf | |
import pandas as pd | |
from datetime import datetime | |
from plotly import graph_objects as go | |
from transformers_interpret import SequenceClassificationExplainer | |
from annotated_text import annotated_text | |
from palettable.scientific.sequential import Devon_10_r | |
from palettable.colorbrewer.diverging import RdYlGn_10, PuOr_10, BrBG_10 | |
from overview import NQDOverview | |
import torch | |
cuda_available = torch.cuda.is_available() | |
print(f"Is CUDA available: {cuda_available}") | |
if cuda_available: | |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
# Function to load and cache models | |
def load_model(username, prefix, model_name): | |
p = tf.pipeline('text-classification', f'{username}/{prefix}-{model_name}', return_all_scores=True) | |
return p | |
def load_pickle(f): | |
return pd.read_pickle(f) | |
def get_results(model, c): | |
res = model(c)[0] | |
scores = [r['score'] for r in res] | |
label = max(range(len(scores)), key=lambda i: scores[i]) | |
# label = float(res['label'].split('_')[1]) | |
# scores = res['score'] | |
return {'label': label, 'scores': scores} | |
def run_models(model_names, models, c): | |
results = {} | |
for mn in model_names: | |
results[mn] = get_results(models[mn], c) | |
return results | |
st.title('Assess the *QuAL*ity of your feedback') | |
st.caption( | |
"""Medical education requires high-quality *written* feedback, | |
but evaluating these *supervisor narrative comments* is time-consuming. | |
The QuAL score has validity evidence for measuring the quality of short | |
comments in this context. We developed a NLP/ML-powered tool to | |
assess written comment quality via the QuAL score with high accuracy. | |
See our paper in *Academic Medicine* at [https://doi.org/10.1097/ACM.0000000000005634](https://doi.org/10.1097/ACM.0000000000005634) | |
*Try it for yourself!* | |
""") | |
### Load models | |
# Specify which models to load | |
USERNAME = 'maxspad' | |
PREFIX = 'nlp-qual' | |
models_to_load = ['qual', 'q1', 'q2i', 'q3i'] | |
n_models = float(len(models_to_load)) | |
models = {} | |
# Show a progress bar while models are downloading, | |
# then hide it when done | |
lc_placeholder = st.empty() | |
loader_container = lc_placeholder.container() | |
loader_container.caption('Loading models... please wait...') | |
pbar = loader_container.progress(0.0) | |
for i, mn in enumerate(models_to_load): | |
pbar.progress((i+1.0) / n_models) | |
models[mn] = load_model(USERNAME, PREFIX, mn) | |
lc_placeholder.empty() | |
### Load example data | |
examples = load_pickle('test.pkl') | |
### Process input | |
ex = examples['comment'].sample(1, random_state=int(datetime.now().timestamp())).tolist()[0] | |
try: | |
ex = ex.strip().replace('_x000D_', '').replace('nan', 'blank') | |
except: | |
ex = 'blank' | |
if 'comment' not in st.session_state: | |
st.session_state['comment'] = ex | |
with st.form('comment_form'): | |
comment = st.text_area('Try a comment:', value=st.session_state['comment']) | |
left_col, right_col = st.columns([1,9], gap='medium') | |
submitted = left_col.form_submit_button('Submit') | |
trying_example = right_col.form_submit_button('Try an example!') | |
if submitted: | |
st.session_state['button_clicked'] = 'submit' | |
st.session_state['comment'] = comment | |
st.experimental_rerun() | |
elif trying_example: | |
st.session_state['button_clicked'] = 'example' | |
st.session_state['comment'] = ex | |
st.experimental_rerun() | |
results = run_models(models_to_load, models, st.session_state['comment']) | |
#Modify results to sum the QuAL score and to ignore Q3 if Q2 no suggestion | |
if results['q2i']['label'] == 1: | |
results['q3i']['label'] = 1 # can't have connection if no suggestion | |
results['qual']['label'] = results['q1']['label'] + (not results['q2i']['label']) + (not results['q3i']['label']) | |
overview = NQDOverview(st, results) | |
overview.draw() | |
def rescale(x): | |
return (x + 1.0) / 2.0 | |
def get_explained_words(comment, pipe, label, cmap): | |
cls_explainer = SequenceClassificationExplainer( | |
pipe.model, | |
pipe.tokenizer) | |
word_attributions = cls_explainer(comment, class_name=label)[1:-1] | |
# Get rid of "##" | |
to_disp = [ | |
(word, '', f'rgba{tuple([int(c*255) for c in cmap.mpl_colormap(rescale(word_score))])}') | |
for word, word_score in word_attributions | |
] | |
return to_disp | |
qual_map = { | |
0: 'minimal', | |
1: 'very low', | |
2: 'low', | |
3: 'average', | |
4: 'above average', | |
5: 'excellent' | |
} | |
q1_map = { | |
0: "minimal", | |
1: "low", | |
2: "moderate", | |
3: "high" | |
} | |
q2i_map = { | |
0: "did", | |
1: "did not" | |
} | |
with st.expander('Expand to explore further'): | |
st.write(f'Your comment was rated as a QuAL score of **{results["qual"]["label"]}**, indicating **{qual_map[results["qual"]["label"]]}** quality feedback.') | |
do_word_importances = st.checkbox("Calculate word importance. This provides more detail on model reasoning below, but takes much longer to compute.", | |
value=False) | |
st.markdown('### Level of Detail') | |
st.write(f"The model identified a **{q1_map[results['q1']['label']]}** level of detail in your comment.") | |
if do_word_importances: | |
st.write("Below are words that pointed the model toward (green) or against (red) identifying a high level of detail:") | |
with st.spinner("Calculating word importances, may take a while..."): | |
annotated_text(get_explained_words(st.session_state['comment'], models['q1'], 'LABEL_3', RdYlGn_10)) | |
st.markdown('### Suggestion for Improvement') | |
st.write(f"The model **{q2i_map[results['q2i']['label']]}** predict that you provided a suggestion for improvement in your comment.") | |
if do_word_importances: | |
st.write(f"Below are words that pointed the model toward (green) or against (red) identifying a suggestion for improvement:") | |
with st.spinner("Calculating word importances, may take a while..."): | |
annotated_text(get_explained_words(st.session_state['comment'], models['q2i'], 'LABEL_0', RdYlGn_10)) | |
if results['q2i']['label'] == 0: | |
st.markdown('### Suggestion Linking') | |
st.write(f"The model **{q2i_map[results['q3i']['label']]}** predict that you linked your suggestion") | |
if do_word_importances: | |
st.write(f"Below are words that pointed the model toward (green) or against (red) identifying a linked suggestion:") | |
with st.spinner("Calculating word importances, may take a while..."): | |
annotated_text(get_explained_words(st.session_state['comment'], models['q3i'], 'LABEL_0', RdYlGn_10)) | |
# annotated_text(to_disp) | |
# cls_explainer = SequenceClassificationExplainer( | |
# models['q1'].model, | |
# models['q1'].tokenizer) | |
# word_attributions = cls_explainer(st.session_state['comment'], class_name='LABEL_3')[1:-1] | |
# to_disp = [ | |
# (word, f'{word_score:.2f}', f'rgba{tuple([int(c*255) for c in Devon_10_r.mpl_colormap(word_score)])}') | |
# for word, word_score in word_attributions | |
# ] | |
# print(to_disp) | |
# annotated_text(to_disp) |