Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
print(f"Is CUDA available: {torch.cuda.is_available()}") | |
# True | |
if torch.cuda.is_available(): | |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
STYLE = """ | |
.container { | |
width: 100%; | |
display: grid; | |
align-items: center; | |
margin: 0!important; | |
} | |
.prose ul ul { | |
margin: 0!important; | |
} | |
.tree { | |
padding: 0px; | |
margin: 0!important; | |
box-sizing: border-box; | |
font-size: 16px; | |
width: 100%; | |
height: auto; | |
text-align: center; | |
} | |
.tree ul { | |
padding-top: 20px; | |
position: relative; | |
transition: .5s; | |
margin: 0!important; | |
} | |
.tree li { | |
display: inline-table; | |
text-align: center; | |
list-style-type: none; | |
position: relative; | |
padding: 10px; | |
transition: .5s; | |
} | |
.tree li::before, .tree li::after { | |
content: ''; | |
position: absolute; | |
top: 0; | |
right: 50%; | |
border-top: 1px solid #ccc; | |
width: 51%; | |
height: 10px; | |
} | |
.tree li::after { | |
right: auto; | |
left: 50%; | |
border-left: 1px solid #ccc; | |
} | |
.tree li:only-child::after, .tree li:only-child::before { | |
display: none; | |
} | |
.tree li:only-child { | |
padding-top: 0; | |
} | |
.tree li:first-child::before, .tree li:last-child::after { | |
border: 0 none; | |
} | |
.tree li:last-child::before { | |
border-right: 1px solid #ccc; | |
border-radius: 0 5px 0 0; | |
-webkit-border-radius: 0 5px 0 0; | |
-moz-border-radius: 0 5px 0 0; | |
} | |
.tree li:first-child::after { | |
border-radius: 5px 0 0 0; | |
-webkit-border-radius: 5px 0 0 0; | |
-moz-border-radius: 5px 0 0 0; | |
} | |
.tree ul ul::before { | |
content: ''; | |
position: absolute; | |
top: 0; | |
left: 50%; | |
border-left: 1px solid #ccc; | |
width: 0; | |
height: 20px; | |
} | |
.tree li a { | |
border: 1px solid #ccc; | |
padding: 10px; | |
display: inline-grid; | |
border-radius: 5px; | |
text-decoration-line: none; | |
border-radius: 5px; | |
transition: .5s; | |
} | |
.tree li a span { | |
border: 1px solid #ccc; | |
border-radius: 5px; | |
color: #666; | |
padding: 8px; | |
font-size: 12px; | |
text-transform: uppercase; | |
letter-spacing: 1px; | |
font-weight: 500; | |
} | |
/*Hover-Section*/ | |
.tree li a:hover, .tree li a:hover i, .tree li a:hover span, .tree li a:hover+ul li a { | |
background: #c8e4f8; | |
color: #000; | |
border: 1px solid #94a0b4; | |
} | |
.tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before { | |
border-color: #94a0b4; | |
} | |
""" | |
from transformers import GPT2Tokenizer, AutoModelForCausalLM, AutoTokenizer | |
import numpy as np | |
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") | |
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b") | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
print("Loading finished.") | |
def generate_html(token, node): | |
"""Recursively generate HTML for the tree.""" | |
html_content = f" <li> <a href='#'> <span> <b>{token}</b> </span> " | |
html_content += node["table"] if node["table"] is not None else "" | |
html_content += "</a>" | |
if len(node["children"].keys()) > 0: | |
html_content += "<ul> " | |
for token, subnode in node["children"].items(): | |
html_content += generate_html(token, subnode) | |
html_content += "</ul>" | |
html_content += "</li>" | |
return html_content | |
def generate_markdown_table(scores, top_k=4, chosen_tokens=None): | |
markdown_table = """ | |
<table> | |
<tr> | |
<th><b>Token</b></th> | |
<th><b>Probability</b></th> | |
</tr>""" | |
for token_idx in np.argsort(scores)[-top_k:]: | |
token = tokenizer.decode([token_idx]) | |
style = "" | |
if chosen_tokens and token in chosen_tokens: | |
style = "background-color:red" | |
markdown_table += f""" | |
<tr style={style}> | |
<td>{token}</td> | |
<td>{scores[token_idx]}</td> | |
</tr>""" | |
markdown_table += """ | |
</table>""" | |
return markdown_table | |
def display_tree(scores, sequences, beam_indices): | |
display = """<div class="container"> | |
<div class="tree"> | |
<ul>""" | |
sequences = sequences.cpu().numpy() | |
print(tokenizer.batch_decode(sequences)) | |
original_tree = {"table": None, "children": {}} | |
for sequence_ix in range(len(sequences)): | |
current_tree = original_tree | |
for step, step_scores in enumerate(scores): | |
current_token_choice = tokenizer.decode([sequences[sequence_ix, step]]) | |
current_beam = beam_indices[sequence_ix, step] | |
if current_token_choice not in current_tree["children"]: | |
current_tree["children"][current_token_choice] = { | |
"table": None, | |
"children": {}, | |
} | |
# Rewrite the probs table even if it was there before, since new chosen nodes have appeared in the children of current tree | |
markdown_table = generate_markdown_table( | |
step_scores[current_beam, :], | |
chosen_tokens=current_tree["children"].keys(), | |
) | |
current_tree["table"] = markdown_table | |
current_tree = current_tree["children"][current_token_choice] | |
display += generate_html("Today is", original_tree) | |
display += """ | |
</ul> | |
</div> | |
</body> | |
""" | |
return display | |
def get_tables(input_text, number_steps, number_beams): | |
inputs = tokenizer([input_text], return_tensors="pt") | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=number_steps, | |
num_beams=number_beams, | |
num_return_sequences=number_beams, | |
return_dict_in_generate=True, | |
output_scores=True, | |
top_k=5, | |
temperature=1.0, | |
do_sample=True, | |
) | |
tables = display_tree( | |
outputs.scores, | |
outputs.sequences[:, len(inputs) :], | |
outputs.beam_indices[:, : -len(inputs)], | |
) | |
return tables | |
import gradio as gr | |
with gr.Blocks( | |
theme=gr.themes.Soft( | |
text_size="lg", font=["monospace"], primary_hue=gr.themes.colors.green | |
), | |
css=STYLE, | |
) as demo: | |
text = gr.Textbox(label="Sentence to decode from🪶", value="Today is") | |
steps = gr.Slider(label="Number of steps", minimum=1, maximum=10, step=1, value=4) | |
beams = gr.Slider(label="Number of beams", minimum=1, maximum=3, step=1, value=3) | |
button = gr.Button() | |
out = gr.Markdown(label="Output") | |
button.click(get_tables, inputs=[text, steps, beams], outputs=out) | |
demo.launch() |