import streamlit as st import transformers as tf import pandas as pd from datetime import datetime from plotly import graph_objects as go from transformers_interpret import SequenceClassificationExplainer from annotated_text import annotated_text from palettable.scientific.sequential import Devon_10_r from palettable.colorbrewer.diverging import RdYlGn_10, PuOr_10, BrBG_10 from overview import NQDOverview import torch cuda_available = torch.cuda.is_available() print(f"Is CUDA available: {cuda_available}") if cuda_available: print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") # Function to load and cache models @st.experimental_singleton(show_spinner=False) def load_model(username, prefix, model_name): p = tf.pipeline('text-classification', f'{username}/{prefix}-{model_name}', return_all_scores=True) return p @st.experimental_singleton(show_spinner=False) def load_pickle(f): return pd.read_pickle(f) def get_results(model, c): res = model(c)[0] scores = [r['score'] for r in res] label = max(range(len(scores)), key=lambda i: scores[i]) # label = float(res['label'].split('_')[1]) # scores = res['score'] return {'label': label, 'scores': scores} def run_models(model_names, models, c): results = {} for mn in model_names: results[mn] = get_results(models[mn], c) return results st.title('Assess the *QuAL*ity of your feedback') st.caption( """Medical education requires high-quality *written* feedback, but evaluating these *supervisor narrative comments* is time-consuming. The QuAL score has validity evidence for measuring the quality of short comments in this context. We developed a NLP/ML-powered tool to assess written comment quality via the QuAL score with high accuracy. See our paper in *Academic Medicine* at [https://doi.org/10.1097/ACM.0000000000005634](https://doi.org/10.1097/ACM.0000000000005634) *Try it for yourself!* """) ### Load models # Specify which models to load USERNAME = 'maxspad' PREFIX = 'nlp-qual' models_to_load = ['qual', 'q1', 'q2i', 'q3i'] n_models = float(len(models_to_load)) models = {} # Show a progress bar while models are downloading, # then hide it when done lc_placeholder = st.empty() loader_container = lc_placeholder.container() loader_container.caption('Loading models... please wait...') pbar = loader_container.progress(0.0) for i, mn in enumerate(models_to_load): pbar.progress((i+1.0) / n_models) models[mn] = load_model(USERNAME, PREFIX, mn) lc_placeholder.empty() ### Load example data examples = load_pickle('test.pkl') ### Process input ex = examples['comment'].sample(1, random_state=int(datetime.now().timestamp())).tolist()[0] try: ex = ex.strip().replace('_x000D_', '').replace('nan', 'blank') except: ex = 'blank' if 'comment' not in st.session_state: st.session_state['comment'] = ex with st.form('comment_form'): comment = st.text_area('Try a comment:', value=st.session_state['comment']) left_col, right_col = st.columns([1,9], gap='medium') submitted = left_col.form_submit_button('Submit') trying_example = right_col.form_submit_button('Try an example!') if submitted: st.session_state['button_clicked'] = 'submit' st.session_state['comment'] = comment st.experimental_rerun() elif trying_example: st.session_state['button_clicked'] = 'example' st.session_state['comment'] = ex st.experimental_rerun() results = run_models(models_to_load, models, st.session_state['comment']) #Modify results to sum the QuAL score and to ignore Q3 if Q2 no suggestion if results['q2i']['label'] == 1: results['q3i']['label'] = 1 # can't have connection if no suggestion results['qual']['label'] = results['q1']['label'] + (not results['q2i']['label']) + (not results['q3i']['label']) overview = NQDOverview(st, results) overview.draw() def rescale(x): return (x + 1.0) / 2.0 def get_explained_words(comment, pipe, label, cmap): cls_explainer = SequenceClassificationExplainer( pipe.model, pipe.tokenizer) word_attributions = cls_explainer(comment, class_name=label)[1:-1] # Get rid of "##" to_disp = [ (word, '', f'rgba{tuple([int(c*255) for c in cmap.mpl_colormap(rescale(word_score))])}') for word, word_score in word_attributions ] return to_disp qual_map = { 0: 'minimal', 1: 'very low', 2: 'low', 3: 'average', 4: 'above average', 5: 'excellent' } q1_map = { 0: "minimal", 1: "low", 2: "moderate", 3: "high" } q2i_map = { 0: "did", 1: "did not" } with st.expander('Expand to explore further'): st.write(f'Your comment was rated as a QuAL score of **{results["qual"]["label"]}**, indicating **{qual_map[results["qual"]["label"]]}** quality feedback.') do_word_importances = st.checkbox("Calculate word importance. This provides more detail on model reasoning below, but takes much longer to compute.", value=False) st.markdown('### Level of Detail') st.write(f"The model identified a **{q1_map[results['q1']['label']]}** level of detail in your comment.") if do_word_importances: st.write("Below are words that pointed the model toward (green) or against (red) identifying a high level of detail:") with st.spinner("Calculating word importances, may take a while..."): annotated_text(get_explained_words(st.session_state['comment'], models['q1'], 'LABEL_3', RdYlGn_10)) st.markdown('### Suggestion for Improvement') st.write(f"The model **{q2i_map[results['q2i']['label']]}** predict that you provided a suggestion for improvement in your comment.") if do_word_importances: st.write(f"Below are words that pointed the model toward (green) or against (red) identifying a suggestion for improvement:") with st.spinner("Calculating word importances, may take a while..."): annotated_text(get_explained_words(st.session_state['comment'], models['q2i'], 'LABEL_0', RdYlGn_10)) if results['q2i']['label'] == 0: st.markdown('### Suggestion Linking') st.write(f"The model **{q2i_map[results['q3i']['label']]}** predict that you linked your suggestion") if do_word_importances: st.write(f"Below are words that pointed the model toward (green) or against (red) identifying a linked suggestion:") with st.spinner("Calculating word importances, may take a while..."): annotated_text(get_explained_words(st.session_state['comment'], models['q3i'], 'LABEL_0', RdYlGn_10)) # annotated_text(to_disp) # cls_explainer = SequenceClassificationExplainer( # models['q1'].model, # models['q1'].tokenizer) # word_attributions = cls_explainer(st.session_state['comment'], class_name='LABEL_3')[1:-1] # to_disp = [ # (word, f'{word_score:.2f}', f'rgba{tuple([int(c*255) for c in Devon_10_r.mpl_colormap(word_score)])}') # for word, word_score in word_attributions # ] # print(to_disp) # annotated_text(to_disp)