File size: 14,958 Bytes
2d63e52
88356be
 
fb66618
5abc01d
88356be
89b8c0b
 
e4010d8
 
 
2d63e52
 
46edcf6
 
2d63e52
ad36776
e4010d8
ad36776
 
5da9cef
f6daa90
ad36776
9a4471c
e4010d8
9a4471c
07a76f8
 
 
5582229
 
 
e4010d8
 
 
 
 
fe4c0f1
e4010d8
ad36776
a7a4e14
da19af6
a7a4e14
e4010d8
ad36776
 
07a76f8
 
ad36776
5582229
 
 
57c2a5d
5582229
ad36776
07a76f8
 
 
e4010d8
07a76f8
 
 
ad36776
 
07a76f8
 
 
 
 
 
 
 
 
ad36776
 
07a76f8
 
 
 
 
ad36776
07a76f8
 
 
ad36776
07a76f8
 
5582229
07a76f8
 
ea604ea
07a76f8
 
ad36776
 
07a76f8
ad36776
 
07a76f8
 
 
 
ad36776
 
 
 
 
 
 
07a76f8
 
 
 
 
 
 
 
 
 
ad36776
c3aa211
bfbc171
 
37b41e9
 
 
bfbc171
37b41e9
 
 
 
 
ad36776
27c3fcd
e4010d8
ad36776
 
 
 
5582229
 
fe4c0f1
 
ad36776
 
e4010d8
ad36776
 
 
 
 
7a70688
179959e
ad36776
c3aa211
179959e
ad36776
c3aa211
179959e
fe4c0f1
 
5582229
e4010d8
fe4c0f1
 
 
c3aa211
 
179959e
c3aa211
ad36776
bc46ee1
7a37cfb
e20ac5c
fe4c0f1
ad36776
 
e20ac5c
 
 
ad36776
 
 
 
7a37cfb
e4010d8
ad36776
e4010d8
ad36776
4ef6980
ad36776
c3aa211
ad36776
4ef6980
e20ac5c
c9b1b1f
e20ac5c
ad36776
 
 
 
 
 
c3aa211
e20ac5c
c3aa211
 
 
50809fa
 
c3aa211
50809fa
c3aa211
efa8da2
 
e20ac5c
50809fa
e20ac5c
 
 
c3aa211
e20ac5c
 
50809fa
e20ac5c
 
 
e4010d8
e20ac5c
a7a4e14
50809fa
 
c3aa211
 
50809fa
e4010d8
50809fa
e4010d8
 
 
 
 
 
 
 
 
 
 
 
 
50809fa
e4010d8
e20ac5c
e4010d8
c3aa211
e20ac5c
50809fa
 
c3aa211
e4010d8
 
c3aa211
e20ac5c
e4010d8
e20ac5c
50809fa
e20ac5c
c3aa211
e20ac5c
 
50809fa
 
c3aa211
e4010d8
 
 
50809fa
e4010d8
 
 
 
 
 
 
 
50809fa
e4010d8
50809fa
 
 
 
 
e4010d8
 
 
 
 
 
 
 
 
 
c3aa211
07a76f8
e4010d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad36776
e4010d8
 
 
 
ad36776
e4010d8
 
 
 
07a76f8
 
 
ad36776
e4010d8
 
e20ac5c
e4010d8
ad36776
e4010d8
ad36776
e4010d8
 
 
 
 
ad36776
e4010d8
 
7a37cfb
e20ac5c
 
 
 
c3aa211
 
 
e20ac5c
50809fa
e4010d8
 
c3aa211
e20ac5c
50809fa
 
f6daa90
50809fa
 
 
 
c3aa211
e4010d8
ad36776
e4010d8
 
 
 
 
 
 
 
 
e20ac5c
 
e4010d8
ad36776
5abc01d
e20ac5c
ad36776
 
 
 
 
 
 
 
5e72e33
ad36776
e4010d8
bc46ee1
c3aa211
 
179959e
 
c3aa211
50809fa
4fde691
 
c3aa211
ad36776
e4010d8
66e4ecd
e4010d8
e20ac5c
c3aa211
bc46ee1
e4010d8
4fde691
7a37cfb
 
ad36776
 
c3aa211
 
ad36776
 
 
50809fa
179959e
4fde691
 
 
179959e
c3aa211
 
 
 
5a53522
50809fa
 
 
 
 
 
e20ac5c
50809fa
 
 
 
 
 
 
f6daa90
50809fa
ad36776
4fde691
 
50809fa
 
 
 
 
ad36776
bc46ee1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import gradio as gr
import spaces

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

print("Loading finished.")

print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")

STYLE = """
.custom-container {
	display: grid;
	align-items: center;
    margin: 0!important;
    overflow-y: hidden;
}
.prose ul ul {
    font-size: 10px!important;
}
.prose li {
    margin-bottom: 0!important;
}
.prose table {
    margin-bottom: 0!important;
}
.prose td, th {
    padding-left: 2px;
    padding-right: 2px;
    padding-top: 0;
    padding-bottom: 0;
    text-wrap:nowrap;
}
.tree {
	padding: 0px;
	margin: 0!important;
	box-sizing: border-box;
    font-size: 10px;
	width: 100%;
	height: auto;
	text-align: center;
    display:inline-block;
}
#root {
    display: inline-grid!important;
    width:auto!important;
    min-width: 220px;
}
.tree ul {
    padding-left: 20px;
    position: relative;
    transition: all 0.5s ease 0s;
    display: flex;
    flex-direction: column;
    gap: 10px;
    margin: 0px !important;
}
.tree li {
    display: flex;
    text-align: center;
    list-style-type: none;
    position: relative;
    padding-left: 20px;
    transition: all 0.5s ease 0s;
    flex-direction: row;
    justify-content: start;
    align-items: center;
}
.tree li::before, .tree li::after {
    content: "";
    position: absolute;
    left: 0px;
    border-left: 1px solid var(--body-text-color);
    width: 20px;
}
.tree li::before {
    top: 0;
    height:50%;
}
.tree li::after {
    top: 50%;
    height: 55%;
    bottom: auto;
    border-top: 1px solid var(--body-text-color);
}
.tree li:only-child::after, li:only-child::before {
    display: none;
}
.tree li:first-child::before, .tree li:last-child::after {
    border: 0 none;
}
.tree li:last-child::before {
	border-bottom: 1px solid var(--body-text-color);
	border-radius: 0px 0px 0px 5px;
	-webkit-border-radius: 0px 0px 0px 5px;
	-moz-border-radius: 0px 0px 0px 5px;
}
.tree li:first-child::after {
	border-radius: 5px 0 0 0;
	-webkit-border-radius: 5px 0 0 0;
	-moz-border-radius: 5px 0 0 0;
}
.tree ul ul::before {
    content: "";
    position: absolute;
    left: 0;
    top: 50%;
    border-top: 1px solid var(--body-text-color);
    width: 20px;
    height: 0;
}
.tree ul:has(> li:only-child)::before {
    width:40px;
}
.child:before {
    border-right: 2px solid var(--body-text-color);
    border-bottom: 2px solid var(--body-text-color);
    content: "";
    position: absolute;
    width: 10px;
    left: 8px;
    height: 10px;
    top: 50%;
    margin-top: -5px;
    transform: rotate(315deg);
}
.tree li a {
	border: 1px solid var(--body-text-color);
	padding: 5px;
	border-radius: 5px;
	text-decoration-line: none;
	border-radius: 5px;
	transition: .5s;
    display: flex;
    align-items: center;
    justify-content: space-between;
    overflow: hidden;
}
.tree li a span {
	padding: 5px;
	font-size: 12px;
	letter-spacing: 1px;
	font-weight: 500;
}
/*Hover-Section*/
.tree li a:hover, .tree li a:hover+ul li a {
	background: var(--primary-600);
}
.tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before, .tree li a:hover+ul a::before {
	border-color: var(--primary-400);
}
.end-of-text, .chosen-token {
    background-color: var(--primary-500);
}
.end-of-text {
    width:auto!important;
}
.nonfinal {
    width:280px;
    min-width: 280px;
}
.selected-sequence {
    background-color: var(--secondary-500)!important;
}
"""


def clean(s):
    return s.replace("\n", r"\n").replace("\t", r"\t").strip()


def generate_markdown_table(
    scores, previous_cumul_score, score_divider, top_k=4, chosen_tokens=None
):
    markdown_table = """
    <table>
        <tr>
            <th><b>Token</b></th>
            <th><b>Step score</b></th>
            <th><b>Total score</b></th>
        </tr>"""
    for token_idx in np.array(np.argsort(scores)[-top_k:])[::-1]:
        token = tokenizer.decode([token_idx])
        item_class = ""
        if chosen_tokens and token in chosen_tokens:
            item_class = "chosen-token"
        markdown_table += f"""
        <tr class={item_class}>
            <td>{clean(token)}</td>
            <td>{scores[token_idx]:.4f}</td>
            <td>{(scores[token_idx] + previous_cumul_score)/score_divider:.4f}</td>
        </tr>"""
    markdown_table += """
    </table>"""
    return markdown_table


def generate_nodes(node, step):
    """Recursively generate HTML for the tree nodes."""
    token = tokenizer.decode([node.current_token_ix])

    selected_class = "selected-sequence" if node.is_selected_sequence else ""

    if node.is_final:
        return f"<li> <a href='#' class='end-of-text child {selected_class}'> <span> <b>{clean(token)}</b> <br>Total score: {node.total_score:.2f}</span> </a> </li>"

    html_content = f"<li> <a href='#' class='nonfinal child {selected_class}'> <span> <b>{clean(token)}</b> </span>"
    if node.table is not None:
        html_content += node.table
    html_content += "</a>"

    if len(node.children.keys()) > 0:
        html_content += "<ul> "
        for token_ix, subnode in node.children.items():
            html_content += generate_nodes(subnode, step=step + 1)
        html_content += "</ul>"
    html_content += "</li>"

    return html_content


def generate_html(start_sentence, original_tree):
    html_output = f"""<div class="custom-container">
				<div class="tree">
                <ul> <li> <a href='#' id='root'> <span> <b>{start_sentence}</b> </span> {original_tree.table} </a>"""
    html_output += "<ul> "
    for subnode in original_tree.children.values():
        html_output += generate_nodes(subnode, step=1)
    html_output += "</ul>"
    html_output += """
        </li> </ul>
        </div>
    </body>
    """
    return html_output


import pandas as pd
from typing import Dict
from dataclasses import dataclass


@dataclass
class BeamNode:
    current_token_ix: int
    cumulative_score: float
    children_score_divider: float
    table: str
    current_sequence: str
    children: Dict[int, "BeamNode"]
    total_score: float
    is_final: bool
    is_selected_sequence: bool


def generate_beams(start_sentence, scores, length_penalty, decoded_sequences):
    input_length = len(tokenizer([start_sentence], return_tensors="pt"))
    original_tree = BeamNode(
        cumulative_score=0,
        current_token_ix=None,
        table=None,
        current_sequence=start_sentence,
        children={},
        children_score_divider=((input_length + 1) ** length_penalty),
        total_score=None,
        is_final=False,
        is_selected_sequence=False,
    )
    n_beams = len(scores[0])
    beam_trees = [original_tree] * n_beams

    for step, step_scores in enumerate(scores):
        (
            top_token_indexes,
            top_cumulative_scores,
            beam_indexes,
            current_completions,
            top_tokens,
        ) = ([], [], [], [], [])
        for beam_ix in range(n_beams):  # Get possible descendants for each beam
            current_beam = beam_trees[beam_ix]

            # skip if the beam is already final
            if current_beam.is_final:
                continue

            # Get top cumulative scores for the current beam
            current_top_token_indexes = list(
                np.array(scores[step][beam_ix].argsort()[-n_beams:])[::-1]
            )
            top_token_indexes += current_top_token_indexes
            top_cumulative_scores += list(
                np.array(scores[step][beam_ix][current_top_token_indexes])
                + current_beam.cumulative_score
            )
            beam_indexes += [beam_ix] * n_beams
            current_completions += [beam_trees[beam_ix].current_sequence] * n_beams
            top_tokens += [tokenizer.decode([el]) for el in current_top_token_indexes]

        top_df = pd.DataFrame.from_dict(
            {
                "token_index": top_token_indexes,
                "cumulative_score": top_cumulative_scores,
                "beam_index": beam_indexes,
                "current_completions": current_completions,
                "token": top_tokens,
            }
        )
        maxes = top_df.groupby(["token_index", "current_completions"])[
            "cumulative_score"
        ].idxmax()

        top_df = top_df.loc[maxes]

        # Sort all top probabilities and keep top n_beams
        top_df_selected = top_df.sort_values("cumulative_score", ascending=False).iloc[
            :n_beams
        ]

        # Write the scores table - one per beam source?
        # Edge case: if several beam indexes are actually on the same beam, the selected tokens by beam_index for the second one will be empty. So we reverse
        for beam_ix in reversed(list(range(n_beams))):
            current_beam = beam_trees[beam_ix]
            selected_tokens = top_df_selected.loc[
                top_df_selected["beam_index"] == beam_ix
            ]
            markdown_table = generate_markdown_table(
                step_scores[beam_ix, :],
                current_beam.cumulative_score,
                current_beam.children_score_divider,
                chosen_tokens=list(selected_tokens["token"].values),
            )
            beam_trees[beam_ix].table = markdown_table

        # Add new children for each beam
        cumulative_scores = [beam.cumulative_score for beam in beam_trees]
        for beam_ix in range(n_beams):
            current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"]
            current_token_choice = tokenizer.decode([current_token_choice_ix])

            # Update the source tree
            source_beam_ix = int(top_df_selected.iloc[beam_ix]["beam_index"])

            cumulative_score = (
                cumulative_scores[source_beam_ix]
                + scores[step][source_beam_ix][current_token_choice_ix].numpy()
            )
            current_sequence = (
                beam_trees[source_beam_ix].current_sequence + current_token_choice
            )
            beam_trees[source_beam_ix].children[current_token_choice_ix] = BeamNode(
                current_token_ix=current_token_choice_ix,
                table=None,
                children={},
                current_sequence=current_sequence,
                cumulative_score=cumulative_score,
                total_score=cumulative_score
                / ((input_length + step - 1) ** length_penalty),
                children_score_divider=((input_length + step) ** length_penalty),
                is_final=(
                    step == len(scores) - 1
                    or current_token_choice_ix == tokenizer.eos_token_id
                ),
                is_selected_sequence=(current_sequence in decoded_sequences),
            )

        # Reassign all beams at once
        beam_trees = [
            beam_trees[int(top_df_selected.iloc[beam_ix]["beam_index"])]
            for beam_ix in range(n_beams)
        ]

        # Advance all beams by one token
        for beam_ix in range(n_beams):
            current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"]
            beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice_ix]

    return original_tree

@spaces.GPU
def get_beam_search_html(input_text, number_steps, number_beams, length_penalty):
    inputs = tokenizer([input_text], return_tensors="pt")

    outputs = model.generate(
        **inputs,
        max_new_tokens=number_steps,
        num_beams=number_beams,
        num_return_sequences=number_beams,
        return_dict_in_generate=True,
        length_penalty=length_penalty,
        output_scores=True,
        do_sample=False,
    )
    markdown = "The conclusive sequences are the ones that end in an `<|endoftext|>` token or at the end of generation."
    markdown += "\n\nThey are ranked by their scores, as given by the formula `score = cumulative_score / (output_length ** length_penalty)`.\n\n"
    markdown += "Only the top `num_beams` scoring sequences are returned: in the tree they are highlighted in **<span style='color:var(--secondary-500)!important'>blue</span>**."
    markdown += " The non-selected sequences are also shown in the tree, highlighted in **<span style='color:var(--primary-500)!important'>yellow</span>**."
    markdown += "\n#### <span style='color:var(--secondary-600)!important'>Output sequences:</span>"
    # Sequences are padded anyway so you can batch decode them
    decoded_sequences = tokenizer.batch_decode(outputs.sequences)
    for i, sequence in enumerate(decoded_sequences):
        markdown += f"\n- Score `{outputs.sequences_scores[i]:.2f}`: `{clean(sequence.replace('<s> ', ''))}`"

    original_tree = generate_beams(
        input_text,
        outputs.scores[:],
        length_penalty,
        decoded_sequences,
    )
    html = generate_html(input_text, original_tree)
    return html, markdown


with gr.Blocks(
    theme=gr.themes.Soft(
        primary_hue=gr.themes.colors.yellow,
        secondary_hue=gr.themes.colors.blue,
    ),
    css=STYLE,
) as demo:
    gr.Markdown(
        """# <span style='color:var(--primary-500)!important'>Beam Search Visualizer</span>
                
Play with the parameters below to understand how beam search decoding works!
                
#### <span style='color:var(--primary-500)!important'>Parameters:</span>
- **Sentence to decode from** (`inputs`): the input sequence to your decoder.
- **Number of steps** (`max_new_tokens`): the number of tokens to generate
- **Number of beams** (`num_beams`): the number of beams to use
- **Length penalty** (`length_penalty`): the length penalty to apply to outputs. `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences.
This parameter will not impact the beam search paths, but only influence the choice of sequences in the end towards longer or shorter sequences.
"""
    )
    text = gr.Textbox(
        label="Sentence to decode from",
        value="Conclusion: thanks a lot. This article was originally published on",
    )
    with gr.Row():
        steps = gr.Slider(
            label="Number of steps", minimum=1, maximum=8, step=1, value=4
        )
        beams = gr.Slider(
            label="Number of beams", minimum=2, maximum=4, step=1, value=3
        )
        length_penalty = gr.Slider(
            label="Length penalty", minimum=-3, maximum=3, step=0.5, value=1
        )
    button = gr.Button()
    out_html = gr.Markdown()
    out_markdown = gr.Markdown()
    button.click(
        get_beam_search_html,
        inputs=[text, steps, beams, length_penalty],
        outputs=[out_html, out_markdown],
    )

demo.launch()