m-ric's picture
m-ric HF staff
Update app.py
da9ffb8 verified
raw
history blame
10.1 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import gradio as gr
import spaces
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
print("Loading finished.")
print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
STYLE = """
.custom-container {
width: 100%;
display: grid;
align-items: center;
margin: 0!important;
overflow: scroll;
}
.prose ul ul {
margin: 0!important;
font-size: 10px!important;
}
.prose td, th {
padding-left: 2px;
padding-right: 2px;
padding-top: 0;
padding-bottom: 0;
text-wrap: nowrap;
}
.tree {
padding: 0px;
margin: 0!important;
box-sizing: border-box;
font-size: 10px;
width: 100%;
height: auto;
}
.tree ul {
padding-top: 20px;
position: relative;
transition: .5s;
margin: 0!important;
display: flex;
flex-direction: row;
justify-content: center;
gap:10px;
}
.tree li {
display: inline-table;
text-align: center;
list-style-type: none;
position: relative;
padding-top: 10px;
transition: .5s;
}
.tree li::before, .tree li::after {
content: '';
position: absolute;
top: 0;
right: 50%;
border-top: 2px solid var(--body-text-color);
width: 55%;
height: 10px;
}
.tree li::after {
right: auto;
left: 50%;
border-left: 2px solid var(--body-text-color);
}
.tree li:only-child::after, .tree li:only-child::before {
display: none;
}
.tree ul:has(> li:only-child)::before {
height:40px;
}
.tree li:first-child::before, .tree li:last-child::after {
border: 0 none;
}
.tree li:last-child::before {
border-right: 2px solid var(--body-text-color);
border-radius: 0 5px 0 0;
-webkit-border-radius: 0 5px 0 0;
-moz-border-radius: 0 5px 0 0;
}
.tree li:first-child::after {
border-radius: 5px 0 0 0;
-webkit-border-radius: 5px 0 0 0;
-moz-border-radius: 5px 0 0 0;
}
.tree ul ul::before {
content: '';
position: absolute;
top: 0;
left: 50%;
border-left: 2px solid var(--body-text-color);
width: 0;
height: 20px;
}
.tree li a {
border: 1px solid var(--body-text-color);
padding: 5px;
display: inline-grid;
border-radius: 5px;
text-decoration-line: none;
border-radius: 5px;
transition: .5s;
}
.tree li a span {
padding: 5px;
font-size: 12px;
text-transform: uppercase;
letter-spacing: 1px;
font-weight: 500;
}
/*Hover-Section*/
.tree li a:hover, .tree li a:hover+ul li a {
background: #ffedd5;
}
.tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before {
border-color: #f97316;
color= #CCC;
}
.chosen {
background-color: #ea580c;
}
"""
def generate_nodes(token, node):
"""Recursively generate HTML for the tree nodes."""
html_content = f" <li> <a href='#' class={('chosen' if node.table is None else '')}> <span> <b>{token}</b> </span> "
html_content += node.table if node.table is not None else ""
html_content += "</a>"
if len(node.children.keys()) > 0:
html_content += "<ul> "
for token, subnode in node.children.items():
html_content += generate_nodes(token, subnode)
html_content += "</ul>"
html_content += "</li>"
return html_content
def generate_markdown_table(scores, sequence_prob, top_k=4, chosen_tokens=None):
markdown_table = """
<table>
<tr>
<th><b>Token</b></th>
<th><b>Step score</b></th>
<th><b>Total score</b></th>
</tr>"""
for token_idx in np.array(np.argsort(scores)[-top_k:])[::-1]:
token = tokenizer.decode([token_idx])
item_class = ""
if chosen_tokens and token in chosen_tokens:
item_class = "chosen"
markdown_table += f"""
<tr class={item_class}>
<td>{token}</td>
<td>{scores[token_idx]:.4f}</td>
<td>{scores[token_idx] + sequence_prob:.4f}</td>
</tr>"""
markdown_table += """
</table>"""
return markdown_table
def generate_html(start_sentence, original_tree):
html_output = """<div class="custom-container">
<div class="tree">
<ul>"""
html_output += generate_nodes(start_sentence, original_tree)
html_output += """
</ul>
</div>
</body>
"""
return html_output
import pandas as pd
from typing import Dict
from dataclasses import dataclass
@dataclass
class BeamNode:
cumulative_score: float
table: str
current_sentence: str
children: Dict[str, "BeamNode"]
def generate_beams(start_sentence, scores, sequences, beam_indices):
sequences = sequences.cpu().numpy()
original_tree = BeamNode(
cumulative_score=0, table=None, current_sentence=start_sentence, children={}
)
n_beams = len(scores[0])
beam_trees = [original_tree] * n_beams
for step, step_scores in enumerate(scores):
(
top_token_indexes,
top_cumulative_scores,
beam_indexes,
current_completions,
top_tokens,
) = ([], [], [], [], [])
for beam_ix in range(n_beams):
current_beam = beam_trees[beam_ix]
# Get top cumulative scores for the current beam
current_top_token_indexes = list(
np.array(scores[step][beam_ix].argsort()[-n_beams:])[::-1]
)
top_token_indexes += current_top_token_indexes
top_cumulative_scores += list(
np.array(scores[step][beam_ix][current_top_token_indexes])
+ current_beam.cumulative_score
)
beam_indexes += [beam_ix] * n_beams
current_completions += [beam_trees[beam_ix].current_sentence] * n_beams
top_tokens += [
tokenizer.decode([el]) for el in current_top_token_indexes
]
top_df = pd.DataFrame.from_dict(
{
"token_index": top_token_indexes,
"cumulative_score": top_cumulative_scores,
"beam_index": beam_indexes,
"current_completions": current_completions,
"token": top_tokens,
}
)
maxes = top_df.groupby(["token_index", "current_completions"])[
"cumulative_score"
].idxmax()
top_df = top_df.loc[maxes]
# Sort all top probabilities and keep top n_beams
top_df_selected = top_df.sort_values("cumulative_score", ascending=False).iloc[
:n_beams
]
# Write the scores table - one per beam source?
# Edge case: if several beam indexes are actually on the same beam, the selected tokens by beam_index for the second one will be empty. So we reverse
for beam_ix in reversed(list(range(n_beams))):
current_beam = beam_trees[beam_ix]
selected_tokens = top_df_selected.loc[top_df_selected["beam_index"] == beam_ix]
markdown_table = generate_markdown_table(
step_scores[beam_ix, :],
current_beam.cumulative_score,
chosen_tokens=list(selected_tokens["token"].values),
)
beam_trees[beam_ix].table = markdown_table
# Add new children for each beam
cumulative_scores = [beam.cumulative_score for beam in beam_trees]
for beam_ix in range(n_beams):
current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"]
current_token_choice = tokenizer.decode([current_token_choice_ix])
# Update the source tree
source_beam_ix = int(top_df_selected.iloc[beam_ix]["beam_index"])
previous_len = len(str(original_tree))
beam_trees[source_beam_ix].children[current_token_choice] = BeamNode(
table=None,
children={},
current_sentence=beam_trees[source_beam_ix].current_sentence
+ current_token_choice,
cumulative_score=cumulative_scores[source_beam_ix]
+ scores[step][source_beam_ix][current_token_choice_ix].numpy(),
)
assert (
len(str(original_tree)) > previous_len
), "Original tree has not increased size"
# Reassign all beams at once
beam_trees = [
beam_trees[int(top_df_selected.iloc[beam_ix]["beam_index"])]
for beam_ix in range(n_beams)
]
# Advance all beams by one token
for beam_ix in range(n_beams):
current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"]
current_token_choice = tokenizer.decode([current_token_choice_ix])
beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice]
return original_tree
@spaces.GPU
def get_beam_search_html(input_text, number_steps, number_beams):
inputs = tokenizer([input_text], return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=number_steps,
num_beams=number_beams,
num_return_sequences=number_beams,
return_dict_in_generate=True,
output_scores=True,
top_k=5,
do_sample=False,
)
original_tree = generate_beams(
input_text,
outputs.scores[:],
outputs.sequences[:, :],
outputs.beam_indices[:, :],
)
html = generate_html(input_text, original_tree)
return html
with gr.Blocks(
theme=gr.themes.Soft(
text_size="lg", font=["monospace"], primary_hue=gr.themes.colors.yellow
),
css=STYLE,
) as demo:
text = gr.Textbox(label="Sentence to decode from", value="Today is")
steps = gr.Slider(label="Number of steps", minimum=1, maximum=7, step=1, value=4)
beams = gr.Slider(label="Number of beams", minimum=2, maximum=4, step=1, value=3)
button = gr.Button()
out = gr.Markdown(label="Output")
button.click(get_beam_search_html, inputs=[text, steps, beams], outputs=out)
demo.launch()