import torch from transformers import AutoModelForCausalLM, AutoTokenizer import numpy as np import gradio as gr import spaces tokenizer = AutoTokenizer.from_pretrained("gpt2") model = AutoModelForCausalLM.from_pretrained("gpt2") print("Loading finished.") print(f"Is CUDA available: {torch.cuda.is_available()}") # True if torch.cuda.is_available(): print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") STYLE = """ .custom-container { display: grid; align-items: center; margin: 0!important; overflow-y: hidden; } .prose ul ul { font-size: 10px!important; } .prose li { margin-bottom: 0!important; } .prose table { margin-bottom: 0!important; } .prose td, th { padding-left: 2px; padding-right: 2px; padding-top: 0; padding-bottom: 0; text-wrap:nowrap; } .tree { padding: 0px; margin: 0!important; box-sizing: border-box; font-size: 10px; width: 100%; height: auto; text-align: center; display:inline-block; } #root { display: inline-grid!important; width:auto!important; min-width: 220px; } .tree ul { padding-left: 20px; position: relative; transition: all 0.5s ease 0s; display: flex; flex-direction: column; gap: 10px; margin: 0px !important; } .tree li { display: flex; text-align: center; list-style-type: none; position: relative; padding-left: 20px; transition: all 0.5s ease 0s; flex-direction: row; justify-content: start; align-items: center; } .tree li::before, .tree li::after { content: ""; position: absolute; left: 0px; border-left: 1px solid var(--body-text-color); width: 20px; } .tree li::before { top: 0; height:50%; } .tree li::after { top: 50%; height: 55%; bottom: auto; border-top: 1px solid var(--body-text-color); } .tree li:only-child::after, li:only-child::before { display: none; } .tree li:first-child::before, .tree li:last-child::after { border: 0 none; } .tree li:last-child::before { border-bottom: 1px solid var(--body-text-color); border-radius: 0px 0px 0px 5px; -webkit-border-radius: 0px 0px 0px 5px; -moz-border-radius: 0px 0px 0px 5px; } .tree li:first-child::after { border-radius: 5px 0 0 0; -webkit-border-radius: 5px 0 0 0; -moz-border-radius: 5px 0 0 0; } .tree ul ul::before { content: ""; position: absolute; left: 0; top: 50%; border-top: 1px solid var(--body-text-color); width: 20px; height: 0; } .tree ul:has(> li:only-child)::before { width:40px; } .child:before { border-right: 2px solid var(--body-text-color); border-bottom: 2px solid var(--body-text-color); content: ""; position: absolute; width: 10px; left: 8px; height: 10px; top: 50%; margin-top: -5px; transform: rotate(315deg); } .tree li a { border: 1px solid var(--body-text-color); padding: 5px; border-radius: 5px; text-decoration-line: none; border-radius: 5px; transition: .5s; display: flex; align-items: center; justify-content: space-between; overflow: hidden; } .tree li a span { padding: 5px; font-size: 12px; letter-spacing: 1px; font-weight: 500; } /*Hover-Section*/ .tree li a:hover, .tree li a:hover+ul li a { background: var(--primary-500); } .tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before, .tree li a:hover+ul a::before { border-color: var(--primary-500); } .chosen-token { background-color: var(--primary-400); } .chosen-token td, .chosen-token tr { color: black!important; } .end-of-text { width:auto!important; } .nonfinal { width:280px; min-width: 280px; } .selected-sequence { background-color: var(--secondary-500); } .nonselected-sequence { background-color: var(--primary-500); } """ def clean(s): return s.replace("\n", r"\n").replace("\t", r"\t").strip() def generate_markdown_table( scores, previous_cumul_score, score_divider, top_k=4, chosen_tokens=None ): markdown_table = """ """ for token_idx in np.array(np.argsort(scores)[-top_k:])[::-1]: token = tokenizer.decode([token_idx]) item_class = "" if chosen_tokens and token in chosen_tokens: item_class = "chosen-token" markdown_table += f""" """ markdown_table += """
Token Step score Total score
{clean(token)} {scores[token_idx]:.4f} {(scores[token_idx] + previous_cumul_score)/score_divider:.4f}
""" return markdown_table def generate_nodes(node, step): """Recursively generate HTML for the tree nodes.""" token = tokenizer.decode([node.current_token_ix]) if node.is_final: if node.is_selected_sequence: selected_class = "selected-sequence" else: selected_class = "nonselected-sequence" return f"
  • {clean(token)}
    Total score: {node.total_score:.2f}
  • " html_content = ( f"
  • {clean(token)} " ) if node.table is not None: html_content += node.table html_content += "" if len(node.children.keys()) > 0: html_content += "" html_content += "
  • " return html_content def generate_html(start_sentence, original_tree): html_output = f"""
    """ return html_output import pandas as pd from typing import Dict from dataclasses import dataclass @dataclass class BeamNode: current_token_ix: int cumulative_score: float children_score_divider: float table: str current_sequence: str children: Dict[int, "BeamNode"] total_score: float is_final: bool is_selected_sequence: bool def generate_beams(start_sentence, scores, length_penalty, decoded_sequences, beam_indexes_source): input_length = len(tokenizer([start_sentence], return_tensors="pt")) original_tree = BeamNode( cumulative_score=0, current_token_ix=None, table=None, current_sequence=start_sentence, children={}, children_score_divider=((input_length + 1) ** length_penalty), total_score=None, is_final=False, is_selected_sequence=False, ) n_beams = len(scores[0]) beam_trees = [original_tree] * n_beams for step, step_scores in enumerate(scores): # Gather all possible descendants for each beam ( top_token_indexes, top_cumulative_scores, beam_indexes, current_sequence, top_tokens, ) = ([], [], [], [], []) for beam_ix in range(n_beams): current_beam = beam_trees[beam_ix] # skip if the beam is already final if current_beam.is_final: continue # Get top cumulative scores for the current beam current_top_token_indexes = list( np.array(scores[step][beam_ix].argsort()[-n_beams:])[::-1] ) top_token_indexes += current_top_token_indexes top_cumulative_scores += list( np.array(scores[step][beam_ix][current_top_token_indexes]) + current_beam.cumulative_score ) beam_indexes += [beam_ix] * n_beams current_sequence += [beam_trees[beam_ix].current_sequence] * n_beams top_tokens += [tokenizer.decode([el]) for el in current_top_token_indexes] top_df = pd.DataFrame.from_dict( { "token_index": top_token_indexes, "cumulative_score": top_cumulative_scores, "beam_index": beam_indexes, "current_sequence": current_sequence, "token": top_tokens, } ) maxes = top_df.groupby(["token_index", "current_sequence"])[ "cumulative_score" ].idxmax() top_df = top_df.loc[maxes] # Sort all top probabilities and keep top n_beams top_df_selected = top_df.sort_values("cumulative_score", ascending=False).iloc[ :n_beams ] if any(["you enjoyed" in el for el in top_df["current_sequence"]]): print("Displaying debug info:::") display(top_df_selected) # Write the scores table - one per beam source for beam_ix in reversed(list(range(n_beams))): current_beam = beam_trees[beam_ix] if current_beam.table is None: selected_tokens = top_df_selected.loc[ top_df_selected["current_sequence"] == current_beam.current_sequence ] markdown_table = generate_markdown_table( step_scores[beam_ix, :], current_beam.cumulative_score, current_beam.children_score_divider, chosen_tokens=list(selected_tokens["token"].values), ) beam_trees[beam_ix].table = markdown_table # Add new children to each beam cumulative_scores = [beam.cumulative_score for beam in beam_trees] for _, row in top_df_selected.iterrows(): # Update the source tree source_beam_ix = int(row["beam_index"]) current_token_choice_ix = row["token_index"] current_token_choice = tokenizer.decode([current_token_choice_ix]) cumulative_score = ( cumulative_scores[source_beam_ix] + scores[step][source_beam_ix][current_token_choice_ix].numpy() ) current_sequence = ( beam_trees[source_beam_ix].current_sequence + current_token_choice ) if current_token_choice_ix == 340: print("Found info:") print(f"We generate token '{current_token_choice}', and the total sequence is '{current_sequence}'") beam_trees[source_beam_ix].children[current_token_choice_ix] = BeamNode( current_token_ix=current_token_choice_ix, table=None, children={}, current_sequence=current_sequence, cumulative_score=cumulative_score, total_score=cumulative_score / ((input_length + step - 1) ** length_penalty), children_score_divider=((input_length + step) ** length_penalty), is_final=( step == len(scores) - 1 or current_token_choice_ix == tokenizer.eos_token_id ), is_selected_sequence=( current_sequence.replace("<|endoftext|>", "") in [el.replace("<|endoftext|>", "") for el in decoded_sequences] ), ) # Swap all beams by descending cumul score, so that n°1 has the highest cumulative score, and so on beam_trees = [ beam_trees[int(top_df_selected.iloc[beam_ix]["beam_index"])] for beam_ix in range(n_beams) ] # Advance all beams by one token for beam_ix in range(n_beams): current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"] beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice_ix] return original_tree @spaces.GPU def get_beam_search_html( input_text, number_steps, number_beams, length_penalty, num_return_sequences ): inputs = tokenizer([input_text], return_tensors="pt") outputs = model.generate( **inputs, max_new_tokens=number_steps, num_beams=number_beams, num_return_sequences=num_return_sequences, return_dict_in_generate=True, length_penalty=length_penalty, output_scores=True, do_sample=False, ) markdown = "The conclusive sequences are the ones that end in an `<|endoftext|>` token or at the end of generation." markdown += "\n\nThey are ranked by their scores, as given by the formula `score = cumulative_score / (output_length ** length_penalty)`.\n\n" markdown += "Only the top `num_beams` scoring sequences are returned: in the tree they are highlighted in **blue**." markdown += " The non-selected sequences are also shown in the tree, highlighted in **yellow**." markdown += "\n#### Output sequences:" # Sequences are padded anyway so you can batch decode them decoded_sequences = tokenizer.batch_decode(outputs.sequences) for i, sequence in enumerate(decoded_sequences): markdown += f"\n- Score `{outputs.sequences_scores[i]:.2f}`: `{clean(sequence.replace(' ', ''))}`" original_tree = generate_beams( input_text, outputs.scores[:], length_penalty, decoded_sequences, outputs.beam_indices, ) html = generate_html(input_text, original_tree) return html, markdown def change_num_return_sequences(n_beams): return gr.Slider( label="Number of sequences", minimum=1, maximum=n_beams, step=1, value=n_beams ) with gr.Blocks( theme=gr.themes.Soft( primary_hue=gr.themes.colors.yellow, secondary_hue=gr.themes.colors.blue, ), css=STYLE, ) as demo: gr.Markdown( """# Beam Search Visualizer Play with the parameters below to understand how beam search decoding works! #### Parameters: - **Sentence to decode from** (`inputs`): the input sequence to your decoder. - **Number of steps** (`max_new_tokens`): the number of tokens to generate. - **Number of beams** (`num_beams`): the number of beams to use. - **Length penalty** (`length_penalty`): the length penalty to apply to outputs. `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences. This parameter will not impact the beam search paths, but only influence the choice of sequences in the end towards longer or shorter sequences. - **Number of return sequences** (`num_return_sequences`): the number of sequences to be returned at the end of generation. Should be `<= num_beams`. """ ) text = gr.Textbox( label="Sentence to decode from", value="Conclusion: thanks a lot. That's all for today", ) with gr.Row(): n_steps = gr.Slider( label="Number of steps", minimum=1, maximum=10, step=1, value=4 ) n_beams = gr.Slider( label="Number of beams", minimum=2, maximum=4, step=1, value=3 ) length_penalty = gr.Slider( label="Length penalty", minimum=-3, maximum=3, step=0.5, value=1 ) num_return_sequences = gr.Slider( label="Number of return sequences", minimum=1, maximum=3, step=1, value=2 ) n_beams.change( fn=change_num_return_sequences, inputs=n_beams, outputs=num_return_sequences ) button = gr.Button() out_html = gr.Markdown() out_markdown = gr.Markdown() button.click( get_beam_search_html, inputs=[text, n_steps, n_beams, length_penalty, num_return_sequences], outputs=[out_html, out_markdown], ) demo.launch()