import pandas as pd
import numpy as np
import os
import torch
from transformers import pipeline
import streamlit as st
import plotly.express as px
import plotly.figure_factory as ff
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization
from captum.attr import visualization as viz
from captum import attr
from captum.attr._utils.visualization import format_word_importances, format_special_tokens, _get_color
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
def results_to_df(results: dict, metric_name: str):
metric_scores = []
for topic, results_dict in results.items():
for metric_name_cur, metric_value in results_dict.items():
if metric_name == metric_name_cur:
metric_scores.append(metric_value)
return pd.DataFrame({metric_name: metric_scores})
def create_boxplot_1df(results: dict, metric_name: str):
df = results_to_df(results, metric_name)
fig = px.box(df, y=metric_name)
return fig
def create_boxplot_2df(results1, results2, metric_name):
df1 = results_to_df(results1, metric_name)
df2 = results_to_df(results2, metric_name)
df2["Run"] = "Run 2"
df1["Run"] = "Run 1"
df = pd.concat([df1, df2])
# Create distplot with custom bin_size
fig = px.histogram(df, x=metric_name, color="Run", marginal="box", hover_data=df.columns)
return fig
def create_boxplot_diff(results1, results2, metric_name):
df1 = results_to_df(results1, metric_name)
df2 = results_to_df(results2, metric_name)
diff = df1[metric_name] - df2[metric_name]
x_axis = f"Difference in {metric_name} from 1 to 2"
fig = px.histogram(pd.DataFrame({x_axis: diff}), x=x_axis, marginal="box")
return fig
def summarize_attributions(attributions):
attributions = attributions.sum(dim=-1).squeeze(0)
attributions = attributions / torch.norm(attributions)
return attributions
def get_words(words, importances):
words_colored = []
for word, importance in zip(words, importances[: len(words)]):
word = format_special_tokens(word)
color = _get_color(importance)
unwrapped_tag = '{word}'.format(
color=color, word=word
)
words_colored.append(unwrapped_tag)
return words_colored
@st.cache_resource
def get_model(model_name: str):
if model_name == "MonoT5":
pipe = pipeline('text2text-generation',
model='castorini/monot5-small-msmarco-10k',
tokenizer='castorini/monot5-small-msmarco-10k',
device='cpu')
def formatter(query, doc):
return f"Query: {query} Document: {doc} Relevant:"
return pipe, formatter
def prep_func(pipe, formatter):
# variables that only need to be run once
decoder_input_ids = pipe.tokenizer([""], return_tensors="pt", add_special_tokens=False, truncation=True).input_ids.to('cpu')
decoder_embedding_layer = pipe.model.base_model.decoder.embed_tokens
decoder_inputs_emb = decoder_embedding_layer(decoder_input_ids)
token_false_id = pipe.tokenizer.get_vocab()['▁false']
token_true_id = pipe.tokenizer.get_vocab()["▁true"]
# this function needs to be run for each combination
@st.cache_data
def get_saliency(query, doc):
input_ids = pipe.tokenizer(
[formatter(query, doc)],
padding=False,
truncation=True,
return_tensors="pt",
max_length=pipe.tokenizer.model_max_length,
)["input_ids"].to('cpu')
embedding_layer = pipe.model.base_model.encoder.embed_tokens
inputs_emb = embedding_layer(input_ids)
def forward_from_embeddings(inputs_embeds, decoder_inputs_embeds):
logits = pipe.model.forward(inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds)['logits'][:, -1, :]
batch_scores = logits[:, [token_false_id, token_true_id]]
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
scores = batch_scores[:, 1].exp() # relevant token
return scores
lig = attr.Saliency(forward_from_embeddings)
attributions_ig, delta = lig.attribute(
inputs=(inputs_emb, decoder_inputs_emb)
)
attributions_normed = summarize_attributions(attributions_ig)
return "\n".join(get_words(pipe.tokenizer.convert_ids_to_tokens(input_ids.squeeze(0).tolist()), attributions_normed))
return get_saliency
if __name__ == "__main__":
query = "how to add dll to visual studio?"
doc = "StackOverflow In the days of 16-bit Windows, a WPARAM was a 16-bit word, while LPARAM was a 32-bit long. These distinctions went away in Win32; they both became 32-bit values. ... WPARAM is defined as UINT_PTR , which in 64-bit Windows is an unsigned, 64-bit value."
model, formatter = get_model("MonoT5")
get_saliency = prep_func(model, formatter)
print(get_saliency(query, doc))