Spaces:
Sleeping
Sleeping
import pandas as pd | |
import numpy as np | |
import os | |
import torch | |
from transformers import pipeline | |
import streamlit as st | |
import plotly.express as px | |
import plotly.figure_factory as ff | |
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization | |
from captum.attr import visualization as viz | |
from captum import attr | |
from captum.attr._utils.visualization import format_word_importances, format_special_tokens, _get_color | |
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" | |
def results_to_df(results: dict, metric_name: str): | |
metric_scores = [] | |
for topic, results_dict in results.items(): | |
for metric_name_cur, metric_value in results_dict.items(): | |
if metric_name == metric_name_cur: | |
metric_scores.append(metric_value) | |
return pd.DataFrame({metric_name: metric_scores}) | |
def create_boxplot_1df(results: dict, metric_name: str): | |
df = results_to_df(results, metric_name) | |
fig = px.box(df, y=metric_name) | |
return fig | |
def create_boxplot_2df(results1, results2, metric_name): | |
df1 = results_to_df(results1, metric_name) | |
df2 = results_to_df(results2, metric_name) | |
df2["Run"] = "Run 2" | |
df1["Run"] = "Run 1" | |
df = pd.concat([df1, df2]) | |
# Create distplot with custom bin_size | |
fig = px.histogram(df, x=metric_name, color="Run", marginal="box", hover_data=df.columns) | |
return fig | |
def create_boxplot_diff(results1, results2, metric_name): | |
df1 = results_to_df(results1, metric_name) | |
df2 = results_to_df(results2, metric_name) | |
diff = df1[metric_name] - df2[metric_name] | |
x_axis = f"Difference in {metric_name} from 1 to 2" | |
fig = px.histogram(pd.DataFrame({x_axis: diff}), x=x_axis, marginal="box") | |
return fig | |
def summarize_attributions(attributions): | |
attributions = attributions.sum(dim=-1).squeeze(0) | |
attributions = attributions / torch.norm(attributions) | |
return attributions | |
def get_words(words, importances): | |
words_colored = [] | |
for word, importance in zip(words, importances[: len(words)]): | |
word = format_special_tokens(word) | |
color = _get_color(importance) | |
unwrapped_tag = '<span style="background-color: {color}; opacity:1.0; line-height:1.75">{word}</span>'.format( | |
color=color, word=word | |
) | |
words_colored.append(unwrapped_tag) | |
return words_colored | |
def get_model(model_name: str): | |
if "MonoT5" in model_name: | |
if model_name == "MonoT5-Small": | |
pipe = pipeline('text2text-generation', | |
model='castorini/monot5-small-msmarco-10k', | |
tokenizer='castorini/monot5-small-msmarco-10k', | |
device='cpu') | |
elif model_name == "MonoT5-3B": | |
pipe = pipeline('text2text-generation', | |
model='castorini/monot5-3b-msmarco-10k', | |
tokenizer='castorini/monot5-3b-msmarco-10k', | |
device='cpu') | |
def formatter(query, doc): | |
return f"Query: {query} Document: {doc} Relevant:" | |
return pipe, formatter | |
def prep_func(pipe, formatter): | |
# variables that only need to be run once | |
decoder_input_ids = pipe.tokenizer(["<pad>"], return_tensors="pt", add_special_tokens=False, truncation=True).input_ids.to('cpu') | |
decoder_embedding_layer = pipe.model.base_model.decoder.embed_tokens | |
decoder_inputs_emb = decoder_embedding_layer(decoder_input_ids) | |
token_false_id = pipe.tokenizer.get_vocab()['▁false'] | |
token_true_id = pipe.tokenizer.get_vocab()["▁true"] | |
# this function needs to be run for each combination | |
def get_saliency(query, doc): | |
input_ids = pipe.tokenizer( | |
[formatter(query, doc)], | |
padding=False, | |
truncation=True, | |
return_tensors="pt", | |
max_length=pipe.tokenizer.model_max_length, | |
)["input_ids"].to('cpu') | |
embedding_layer = pipe.model.base_model.encoder.embed_tokens | |
inputs_emb = embedding_layer(input_ids) | |
def forward_from_embeddings(inputs_embeds, decoder_inputs_embeds): | |
logits = pipe.model.forward(inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds)['logits'][:, -1, :] | |
batch_scores = logits[:, [token_false_id, token_true_id]] | |
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) | |
scores = batch_scores[:, 1].exp() # relevant token | |
return scores | |
lig = attr.Saliency(forward_from_embeddings) | |
attributions_ig, delta = lig.attribute( | |
inputs=(inputs_emb, decoder_inputs_emb) | |
) | |
attributions_normed = summarize_attributions(attributions_ig) | |
return "\n".join(get_words(pipe.tokenizer.convert_ids_to_tokens(input_ids.squeeze(0).tolist()), attributions_normed)) | |
return get_saliency | |
if __name__ == "__main__": | |
query = "how to add dll to visual studio?" | |
doc = "StackOverflow In the days of 16-bit Windows, a WPARAM was a 16-bit word, while LPARAM was a 32-bit long. These distinctions went away in Win32; they both became 32-bit values. ... WPARAM is defined as UINT_PTR , which in 64-bit Windows is an unsigned, 64-bit value." | |
model, formatter = get_model("MonoT5") | |
get_saliency = prep_func(model, formatter) | |
print(get_saliency(query, doc)) | |