m-ric's picture
m-ric HF staff
Update app.py
9a4471c verified
raw
history blame
6.01 kB
STYLE = """
.container {
width: 100%;
display: grid;
align-items: center;
margin: 0!important;
}
.prose ul ul {
margin: 0!important;
}
.tree {
padding: 0px;
margin: 0!important;
box-sizing: border-box;
font-size: 16px;
width: 100%;
height: auto;
text-align: center;
}
.tree ul {
padding-top: 20px;
position: relative;
transition: .5s;
margin: 0!important;
}
.tree li {
display: inline-table;
text-align: center;
list-style-type: none;
position: relative;
padding: 10px;
transition: .5s;
}
.tree li::before, .tree li::after {
content: '';
position: absolute;
top: 0;
right: 50%;
border-top: 1px solid #ccc;
width: 51%;
height: 10px;
}
.tree li::after {
right: auto;
left: 50%;
border-left: 1px solid #ccc;
}
.tree li:only-child::after, .tree li:only-child::before {
display: none;
}
.tree li:only-child {
padding-top: 0;
}
.tree li:first-child::before, .tree li:last-child::after {
border: 0 none;
}
.tree li:last-child::before {
border-right: 1px solid #ccc;
border-radius: 0 5px 0 0;
-webkit-border-radius: 0 5px 0 0;
-moz-border-radius: 0 5px 0 0;
}
.tree li:first-child::after {
border-radius: 5px 0 0 0;
-webkit-border-radius: 5px 0 0 0;
-moz-border-radius: 5px 0 0 0;
}
.tree ul ul::before {
content: '';
position: absolute;
top: 0;
left: 50%;
border-left: 1px solid #ccc;
width: 0;
height: 20px;
}
.tree li a {
border: 1px solid #ccc;
padding: 10px;
display: inline-grid;
border-radius: 5px;
text-decoration-line: none;
border-radius: 5px;
transition: .5s;
}
.tree li a span {
border: 1px solid #ccc;
border-radius: 5px;
color: #666;
padding: 8px;
font-size: 12px;
text-transform: uppercase;
letter-spacing: 1px;
font-weight: 500;
}
/*Hover-Section*/
.tree li a:hover, .tree li a:hover i, .tree li a:hover span, .tree li a:hover+ul li a {
background: #c8e4f8;
color: #000;
border: 1px solid #94a0b4;
}
.tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before {
border-color: #94a0b4;
}
"""
from transformers import GPT2Tokenizer, AutoModelForCausalLM
import numpy as np
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
pritn("Loading finished.")
def generate_html(token, node):
"""Recursively generate HTML for the tree."""
html_content = f" <li> <a href='#'> <span> <b>{token}</b> </span> "
html_content += node["table"] if node["table"] is not None else ""
html_content += "</a>"
if len(node["children"].keys()) > 0:
html_content += "<ul> "
for token, subnode in node["children"].items():
html_content += generate_html(token, subnode)
html_content += "</ul>"
html_content += "</li>"
return html_content
def generate_markdown_table(scores, top_k=4, chosen_tokens=None):
markdown_table = """
<table>
<tr>
<th><b>Token</b></th>
<th><b>Probability</b></th>
</tr>"""
for token_idx in np.argsort(scores)[-top_k:]:
token = tokenizer.decode([token_idx])
style = ""
if chosen_tokens and token in chosen_tokens:
style = "background-color:red"
markdown_table += f"""
<tr style={style}>
<td>{token}</td>
<td>{scores[token_idx]}</td>
</tr>"""
markdown_table += """
</table>"""
return markdown_table
def display_tree(scores, sequences, beam_indices):
display = """<div class="container">
<div class="tree">
<ul>"""
sequences = sequences.cpu().numpy()
print(tokenizer.batch_decode(sequences))
original_tree = {"table": None, "children": {}}
for sequence_ix in range(len(sequences)):
current_tree = original_tree
for step, step_scores in enumerate(scores):
current_token_choice = tokenizer.decode([sequences[sequence_ix, step]])
current_beam = beam_indices[sequence_ix, step]
if current_token_choice not in current_tree["children"]:
current_tree["children"][current_token_choice] = {
"table": None,
"children": {},
}
# Rewrite the probs table even if it was there before, since new chosen nodes have appeared in the children of current tree
markdown_table = generate_markdown_table(
step_scores[current_beam, :],
chosen_tokens=current_tree["children"].keys(),
)
current_tree["table"] = markdown_table
current_tree = current_tree["children"][current_token_choice]
display += generate_html("Today is", original_tree)
display += """
</ul>
</div>
</body>
"""
print(display)
return display
def get_tables(input_text, number_steps, number_beams):
inputs = tokenizer([input_text], return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=number_steps,
num_beams=number_beams,
num_return_sequences=number_beams,
return_dict_in_generate=True,
output_scores=True,
top_k=5,
temperature=1.0,
do_sample=True,
)
tables = display_tree(
outputs.scores,
outputs.sequences[:, len(inputs) :],
outputs.beam_indices[:, : -len(inputs)],
)
return tables
import gradio as gr
with gr.Blocks(
theme=gr.themes.Soft(
text_size="lg", font=["monospace"], primary_hue=gr.themes.colors.green
),
css=STYLE,
) as demo:
text = gr.Textbox(label="Sentence to decode from🪶", value="Today is")
steps = gr.Slider(label="Number of steps", minimum=1, maximum=10, step=1, value=4)
beams = gr.Slider(label="Number of beams", minimum=1, maximum=3, step=1, value=3)
button = gr.Button()
out = gr.Markdown(label="Output")
button.click(get_tables, inputs=[text, steps, beams], outputs=out)
demo.launch()