import pandas as pd import numpy as np import os import torch from transformers import pipeline import streamlit as st import plotly.express as px import plotly.figure_factory as ff from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization from captum.attr import visualization as viz from captum import attr from captum.attr._utils.visualization import format_word_importances, format_special_tokens, _get_color os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" def results_to_df(results: dict, metric_name: str): metric_scores = [] for topic, results_dict in results.items(): for metric_name_cur, metric_value in results_dict.items(): if metric_name == metric_name_cur: metric_scores.append(metric_value) return pd.DataFrame({metric_name: metric_scores}) def create_boxplot_1df(results: dict, metric_name: str): df = results_to_df(results, metric_name) fig = px.box(df, y=metric_name) return fig def create_boxplot_2df(results1, results2, metric_name): df1 = results_to_df(results1, metric_name) df2 = results_to_df(results2, metric_name) df2["Run"] = "Run 2" df1["Run"] = "Run 1" df = pd.concat([df1, df2]) # Create distplot with custom bin_size fig = px.histogram(df, x=metric_name, color="Run", marginal="box", hover_data=df.columns) return fig def create_boxplot_diff(results1, results2, metric_name): df1 = results_to_df(results1, metric_name) df2 = results_to_df(results2, metric_name) diff = df1[metric_name] - df2[metric_name] x_axis = f"Difference in {metric_name} from 1 to 2" fig = px.histogram(pd.DataFrame({x_axis: diff}), x=x_axis, marginal="box") return fig def summarize_attributions(attributions): attributions = attributions.sum(dim=-1).squeeze(0) attributions = attributions / torch.norm(attributions) return attributions def get_words(words, importances): words_colored = [] for word, importance in zip(words, importances[: len(words)]): word = format_special_tokens(word) color = _get_color(importance) unwrapped_tag = '{word}'.format( color=color, word=word ) words_colored.append(unwrapped_tag) return words_colored @st.cache_resource def get_model(model_name: str): if model_name == "MonoT5": pipe = pipeline('text2text-generation', model='castorini/monot5-small-msmarco-10k', tokenizer='castorini/monot5-small-msmarco-10k', device='cpu') def formatter(query, doc): return f"Query: {query} Document: {doc} Relevant:" return pipe, formatter def prep_func(pipe, formatter): # variables that only need to be run once decoder_input_ids = pipe.tokenizer([""], return_tensors="pt", add_special_tokens=False, truncation=True).input_ids.to('cpu') decoder_embedding_layer = pipe.model.base_model.decoder.embed_tokens decoder_inputs_emb = decoder_embedding_layer(decoder_input_ids) token_false_id = pipe.tokenizer.get_vocab()['▁false'] token_true_id = pipe.tokenizer.get_vocab()["▁true"] # this function needs to be run for each combination @st.cache_data def get_saliency(query, doc): input_ids = pipe.tokenizer( [formatter(query, doc)], padding=False, truncation=True, return_tensors="pt", max_length=pipe.tokenizer.model_max_length, )["input_ids"].to('cpu') embedding_layer = pipe.model.base_model.encoder.embed_tokens inputs_emb = embedding_layer(input_ids) def forward_from_embeddings(inputs_embeds, decoder_inputs_embeds): logits = pipe.model.forward(inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds)['logits'][:, -1, :] batch_scores = logits[:, [token_false_id, token_true_id]] batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) scores = batch_scores[:, 1].exp() # relevant token return scores lig = attr.Saliency(forward_from_embeddings) attributions_ig, delta = lig.attribute( inputs=(inputs_emb, decoder_inputs_emb) ) attributions_normed = summarize_attributions(attributions_ig) return "\n".join(get_words(pipe.tokenizer.convert_ids_to_tokens(input_ids.squeeze(0).tolist()), attributions_normed)) return get_saliency if __name__ == "__main__": query = "how to add dll to visual studio?" doc = "StackOverflow In the days of 16-bit Windows, a WPARAM was a 16-bit word, while LPARAM was a 32-bit long. These distinctions went away in Win32; they both became 32-bit values. ... WPARAM is defined as UINT_PTR , which in 64-bit Windows is an unsigned, 64-bit value." model, formatter = get_model("MonoT5") get_saliency = prep_func(model, formatter) print(get_saliency(query, doc))