File size: 7,120 Bytes
9be5a22
 
 
e99bd97
8e11190
0568b17
 
 
 
5cc9296
9be5a22
7250862
078bc56
 
 
 
 
7250862
9be5a22
 
 
f58c9c5
9be5a22
 
 
 
 
 
 
 
f58c9c5
 
 
 
 
9be5a22
 
 
 
 
 
 
 
 
 
7e4e512
 
 
 
84b47ad
 
7e4e512
 
9be5a22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec81259
9be5a22
 
e99bd97
9be5a22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0568b17
 
 
 
9be5a22
a002819
0568b17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39b062f
 
 
0568b17
 
39b062f
 
 
 
0568b17
 
 
39b062f
 
 
 
0568b17
 
 
 
39b062f
 
 
 
0568b17
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import streamlit as st
import transformers as tf
import pandas as pd
from datetime import datetime
from plotly import graph_objects as go
from transformers_interpret import SequenceClassificationExplainer
from annotated_text import annotated_text
from palettable.scientific.sequential import Devon_10_r
from palettable.colorbrewer.diverging import RdYlGn_10, PuOr_10, BrBG_10
from overview import NQDOverview

import torch

cuda_available = torch.cuda.is_available()
print(f"Is CUDA available: {cuda_available}")
if cuda_available:
    print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")

# Function to load and cache models
@st.experimental_singleton(show_spinner=False)
def load_model(username, prefix, model_name):
    p = tf.pipeline('text-classification', f'{username}/{prefix}-{model_name}', return_all_scores=True)
    return p

@st.experimental_singleton(show_spinner=False)
def load_pickle(f):
    return pd.read_pickle(f)

def get_results(model, c):
    res = model(c)[0]
    scores = [r['score'] for r in res]
    label = max(range(len(scores)), key=lambda i: scores[i])
    # label = float(res['label'].split('_')[1])
    # scores = res['score']
    return {'label': label, 'scores': scores}

def run_models(model_names, models, c):
    results = {}
    for mn in model_names:
        results[mn] = get_results(models[mn], c)
    return results


st.title('Assess the *QuAL*ity of your feedback')
st.caption(
"""Medical education requires high-quality *written* feedback, 
but evaluating these *supervisor narrative comments* is time-consuming. 
The QuAL score has validity evidence for measuring the quality of short 
comments in  this context. We developed a NLP/ML-powered tool to 
assess written comment quality via the QuAL score with high accuracy.
See our paper in *Academic Medicine* at [https://doi.org/10.1097/ACM.0000000000005634](https://doi.org/10.1097/ACM.0000000000005634)

*Try it for yourself!*
""")

### Load models
# Specify which models to load 
USERNAME = 'maxspad'
PREFIX = 'nlp-qual'
models_to_load = ['qual', 'q1', 'q2i', 'q3i']
n_models = float(len(models_to_load))
models = {}
# Show a progress bar while models are downloading, 
# then hide it when done
lc_placeholder = st.empty()
loader_container = lc_placeholder.container()
loader_container.caption('Loading models... please wait...')
pbar = loader_container.progress(0.0)
for i, mn in enumerate(models_to_load):
    pbar.progress((i+1.0) / n_models)
    models[mn] = load_model(USERNAME, PREFIX, mn)
lc_placeholder.empty()

### Load example data
examples = load_pickle('test.pkl')

### Process input
ex = examples['comment'].sample(1, random_state=int(datetime.now().timestamp())).tolist()[0]
try:
    ex = ex.strip().replace('_x000D_', '').replace('nan', 'blank')
except:
    ex = 'blank'
if 'comment' not in st.session_state:
    st.session_state['comment'] = ex
with st.form('comment_form'):
    comment = st.text_area('Try a comment:', value=st.session_state['comment'])
    left_col, right_col = st.columns([1,9], gap='medium')
    submitted = left_col.form_submit_button('Submit')
    trying_example = right_col.form_submit_button('Try an example!')
    
    if submitted:
        st.session_state['button_clicked'] = 'submit'
        st.session_state['comment'] = comment
        st.experimental_rerun()
    elif trying_example:
        st.session_state['button_clicked'] = 'example'
        st.session_state['comment'] = ex
        st.experimental_rerun()
    
results = run_models(models_to_load, models, st.session_state['comment'])
#Modify results to sum the QuAL score and to ignore Q3 if Q2 no suggestion
if results['q2i']['label'] == 1:
    results['q3i']['label'] = 1 # can't have connection if no suggestion
results['qual']['label'] = results['q1']['label'] + (not results['q2i']['label']) + (not results['q3i']['label'])

overview = NQDOverview(st, results)
overview.draw()

def rescale(x):
    return (x + 1.0) / 2.0

def get_explained_words(comment, pipe, label, cmap):
    cls_explainer = SequenceClassificationExplainer(
        pipe.model,
        pipe.tokenizer)
    word_attributions = cls_explainer(comment, class_name=label)[1:-1]  

    # Get rid of "##"
    to_disp = [
        (word, '', f'rgba{tuple([int(c*255) for c in cmap.mpl_colormap(rescale(word_score))])}')
        for word, word_score in word_attributions
    ]
    return to_disp

qual_map = {
    0: 'minimal',
    1: 'very low',
    2: 'low',
    3: 'average',
    4: 'above average',
    5: 'excellent'
}

q1_map = {
    0: "minimal",
    1: "low",
    2: "moderate",
    3: "high"
}

q2i_map = {
    0: "did",
    1: "did not"
}
            
with st.expander('Expand to explore further'):
    st.write(f'Your comment was rated as a QuAL score of **{results["qual"]["label"]}**, indicating **{qual_map[results["qual"]["label"]]}** quality feedback.')
    
    do_word_importances = st.checkbox("Calculate word importance. This provides more detail on model reasoning below, but takes much longer to compute.",
                                      value=False)
    
    st.markdown('### Level of Detail')
    st.write(f"The model identified a **{q1_map[results['q1']['label']]}** level of detail in your comment.")
    if do_word_importances:
        st.write("Below are words that pointed the model toward (green) or against (red) identifying a high level of detail:")
        with st.spinner("Calculating word importances, may take a while..."):
            annotated_text(get_explained_words(st.session_state['comment'], models['q1'], 'LABEL_3', RdYlGn_10))

    st.markdown('### Suggestion for Improvement')
    st.write(f"The model **{q2i_map[results['q2i']['label']]}** predict that you provided a suggestion for improvement in your comment.")
    if do_word_importances:
        st.write(f"Below are words that pointed the model toward (green) or against (red) identifying a suggestion for improvement:")
        with st.spinner("Calculating word importances, may take a while..."):
            annotated_text(get_explained_words(st.session_state['comment'], models['q2i'], 'LABEL_0', RdYlGn_10))

    if results['q2i']['label'] == 0:
        st.markdown('### Suggestion Linking')
        st.write(f"The model **{q2i_map[results['q3i']['label']]}** predict that you linked your suggestion")
        if do_word_importances:
            st.write(f"Below are words that pointed the model toward (green) or against (red) identifying a linked suggestion:")
            with st.spinner("Calculating word importances, may take a while..."):
                annotated_text(get_explained_words(st.session_state['comment'], models['q3i'], 'LABEL_0', RdYlGn_10))


# annotated_text(to_disp)
# cls_explainer = SequenceClassificationExplainer(
#     models['q1'].model,
#     models['q1'].tokenizer)
# word_attributions = cls_explainer(st.session_state['comment'], class_name='LABEL_3')[1:-1]
# to_disp = [
#     (word, f'{word_score:.2f}', f'rgba{tuple([int(c*255) for c in Devon_10_r.mpl_colormap(word_score)])}')
#     for word, word_score in word_attributions
# ]
# print(to_disp)
# annotated_text(to_disp)