File size: 12,577 Bytes
2d63e52
88356be
 
fb66618
7a37cfb
88356be
89b8c0b
 
e4010d8
 
 
2d63e52
 
46edcf6
 
2d63e52
ad36776
e4010d8
ad36776
 
5da9cef
57c2a5d
ad36776
9a4471c
e4010d8
9a4471c
07a76f8
 
 
5582229
 
 
 
e4010d8
 
 
 
 
 
 
ad36776
a7a4e14
da19af6
a7a4e14
e4010d8
ad36776
 
07a76f8
 
ad36776
5582229
 
 
 
57c2a5d
5582229
 
ad36776
07a76f8
 
 
e4010d8
07a76f8
 
 
ad36776
 
07a76f8
 
 
 
 
 
 
 
 
ad36776
07a76f8
ad36776
07a76f8
 
 
 
 
ad36776
07a76f8
 
 
ad36776
07a76f8
 
5582229
07a76f8
 
ea604ea
07a76f8
 
ad36776
07a76f8
ad36776
07a76f8
ad36776
 
07a76f8
 
 
 
ad36776
 
 
 
 
 
07a76f8
ad36776
07a76f8
 
 
 
 
 
 
 
 
 
ad36776
37b41e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad36776
27c3fcd
e4010d8
ad36776
 
 
 
efa8da2
5582229
 
 
ad36776
 
e4010d8
ad36776
 
 
 
 
 
7a70688
8666c72
ad36776
 
07a76f8
ad36776
e4010d8
8666c72
5582229
e4010d8
ad36776
bc46ee1
7a37cfb
e20ac5c
 
ad36776
 
e20ac5c
 
 
ad36776
 
 
 
7a37cfb
e4010d8
ad36776
e4010d8
ad36776
4ef6980
ad36776
4ef6980
ad36776
4ef6980
e20ac5c
c9b1b1f
e20ac5c
ad36776
 
 
 
 
 
e20ac5c
 
 
 
efa8da2
 
e20ac5c
 
 
 
 
 
 
 
 
 
e4010d8
07a76f8
e20ac5c
a7a4e14
e20ac5c
 
 
 
 
 
 
e4010d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e20ac5c
e4010d8
 
e20ac5c
e4010d8
 
e20ac5c
e4010d8
e20ac5c
e4010d8
e20ac5c
 
 
 
 
e4010d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07a76f8
e4010d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad36776
e4010d8
 
 
 
ad36776
e4010d8
 
 
 
07a76f8
 
 
ad36776
e4010d8
 
e20ac5c
e4010d8
ad36776
e4010d8
ad36776
e4010d8
 
 
 
 
ad36776
e4010d8
 
7a37cfb
e20ac5c
 
 
 
 
e4010d8
 
 
 
e20ac5c
 
e4010d8
ad36776
e4010d8
 
 
 
 
 
 
 
 
e20ac5c
 
e4010d8
ad36776
e774464
e20ac5c
ad36776
 
 
 
 
 
 
 
e20ac5c
ad36776
e4010d8
bc46ee1
efa8da2
4fde691
 
85a45ca
ad36776
e4010d8
66e4ecd
e4010d8
 
e20ac5c
bc46ee1
e4010d8
4fde691
7a37cfb
 
ad36776
 
27c3fcd
ad36776
 
 
4fde691
 
 
 
 
 
 
 
 
 
efa8da2
e20ac5c
 
 
efa8da2
ad36776
4fde691
 
 
ad36776
bc46ee1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import gradio as gr
import spaces

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

print("Loading finished.")

print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")

STYLE = """
.custom-container {
	display: grid;
	align-items: center;
    margin: 0!important;
    overflow: auto;
}
.prose ul ul {
    font-size: 10px!important;
}
.prose li {
    margin-bottom: 0!important;
}
.prose table {
    margin-bottom: 0!important;
}

.prose td, th {
    padding-left: 2px;
    padding-right: 2px;
    padding-top: 0;
    padding-bottom: 0;
}

.tree {
	padding: 0px;
	margin: 0!important;
	box-sizing: border-box;
    font-size: 10px;
	width: 100%;
	height: auto;
	text-align: center;
    display:inline-block;
}

#root {
    display: inline-grid!important;
    width:auto!important;
    min-width: 220px;
}

.tree ul {
    padding-left: 20px;
    position: relative;
    transition: all 0.5s ease 0s;
    display: flex;
    flex-direction: column;
    gap: 10px;
    margin: 0px !important;
}
.tree li {
    display: flex;
    text-align: center;
    list-style-type: none;
    position: relative;
    padding-left: 20px;
    transition: all 0.5s ease 0s;
    flex-direction: row;
    justify-content: start;
    align-items: center;
}

.tree li::before, .tree li::after {
    content: "";
    position: absolute;
    left: 0px;
    border-left: 1px solid var(--body-text-color);
    width: 20px;
}
.tree li::before {
    top: 0;
    height:50%;
}
.tree li::after {
    top: 50%;
    height: 55%;
    bottom: auto;
    border-top: 1px solid var(--body-text-color);
}
.tree li:only-child::after, li:only-child::before {
    display: none;
}

.tree li:first-child::before, .tree li:last-child::after {
    border: 0 none;
}
.tree li:last-child::before {
	border-bottom: 1px solid var(--body-text-color);
	border-radius: 0px 0px 0px 5px;
	-webkit-border-radius: 0px 0px 0px 5px;
	-moz-border-radius: 0px 0px 0px 5px;
}
.tree li:first-child::after {
	border-radius: 5px 0 0 0;
	-webkit-border-radius: 5px 0 0 0;
	-moz-border-radius: 5px 0 0 0;
}

.tree ul ul::before {
    content: "";
    position: absolute;
    left: 0;
    top: 50%;
    border-top: 1px solid var(--body-text-color);
    width: 20px;
    height: 0;
}
.tree ul:has(> li:only-child)::before {
    width:40px;
}

a:before {
    border-right: 1px solid var(--body-text-color);
    border-bottom: 1px solid var(--body-text-color);
    content: "";
    position: absolute;
    width: 10px;
    left: 0px;
    height: 10px;
    top: 50%;
    margin-top: -5px;
    margin-left: 6px;
    transform: rotate(315deg);
}


.tree li a {
	border: 1px solid var(--body-text-color);
	padding: 5px;
	border-radius: 5px;
	text-decoration-line: none;
	border-radius: 5px;
	transition: .5s;
    width: 280px;
    display: flex;
    align-items: center;
    justify-content: space-around;
}
.tree li a span {
	padding: 5px;
	font-size: 12px;
	text-transform: uppercase;
	letter-spacing: 1px;
	font-weight: 500;
}
/*Hover-Section*/
.tree li a:hover, .tree li a:hover+ul li a {
	background: #ffedd5;
}
.tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before {
	border-color: #7c2d12;
}
.chosen {
    background-color: #ea580c;
    width:auto!important;
}
"""


def clean(s):
    return s.replace("\n", r"\n").replace("\t", r"\t")


def generate_markdown_table(
    scores, previous_cumul_score, score_divider, top_k=4, chosen_tokens=None
):
    markdown_table = """
    <table>
        <tr>
            <th><b>Token</b></th>
            <th><b>Step score</b></th>
            <th><b>Total score</b></th>
        </tr>"""
    for token_idx in np.array(np.argsort(scores)[-top_k:])[::-1]:
        token = tokenizer.decode([token_idx])
        item_class = ""
        if chosen_tokens and token in chosen_tokens:
            item_class = "chosen"
        markdown_table += f"""
        <tr class={item_class}>
            <td>{clean(token)}</td>
            <td>{scores[token_idx]:.4f}</td>
            <td>{(scores[token_idx] + previous_cumul_score)/score_divider:.4f}</td>
        </tr>"""
    markdown_table += """
    </table>"""
    return markdown_table


def generate_nodes(token_ix, node, step):
    """Recursively generate HTML for the tree nodes."""
    token = tokenizer.decode([token_ix])
    html_content = f" <li> <a href='#' class='{('chosen' if node.table is None else '')}'> <span> <b>{token_ix}:<br>{clean(token)}</b> </span> "
    if node.table is not None:
        html_content += node.table
    html_content += "</a>"
    if len(node.children.keys()) > 0:
        html_content += "<ul> "
        for token_ix, subnode in node.children.items():
            html_content += generate_nodes(token_ix, subnode, step=step + 1)
        html_content += "</ul>"
    html_content += "</li>"
    return html_content


def generate_html(start_sentence, original_tree):

    html_output = f"""<div class="custom-container">
				<div class="tree">
                <ul>
                <li> <a href='#' id='root'> <span> <b>{start_sentence}</b> </span> {original_tree.table} </a>"""
    if len(original_tree.children.keys()) > 0:
        html_output += "<ul> "
        for token_ix, subnode in original_tree.children.items():
            html_output += generate_nodes(token_ix, subnode, step=1)
        html_output += "</ul>"

    html_output += """
        </ul>
        </div>
    </body>
    """
    return html_output


import pandas as pd
from typing import Dict
from dataclasses import dataclass


@dataclass
class BeamNode:
    cumulative_score: float
    children_score_divider: float
    table: str
    current_sentence: str
    children: Dict[int, "BeamNode"]


def generate_beams(start_sentence, scores, sequences, length_penalty):
    sequences = sequences.cpu().numpy()
    input_length = len(tokenizer([start_sentence], return_tensors="pt"))
    original_tree = BeamNode(
        cumulative_score=0,
        table=None,
        current_sentence=start_sentence,
        children={},
        children_score_divider=((input_length + 1) ** length_penalty),
    )
    n_beams = len(scores[0])
    beam_trees = [original_tree] * n_beams
    for step, step_scores in enumerate(scores):
        (
            top_token_indexes,
            top_cumulative_scores,
            beam_indexes,
            current_completions,
            top_tokens,
        ) = ([], [], [], [], [])
        for beam_ix in range(n_beams):
            current_beam = beam_trees[beam_ix]
            # Get top cumulative scores for the current beam
            current_top_token_indexes = list(
                np.array(scores[step][beam_ix].argsort()[-n_beams:])[::-1]
            )
            top_token_indexes += current_top_token_indexes
            top_cumulative_scores += list(
                np.array(scores[step][beam_ix][current_top_token_indexes])
                + current_beam.cumulative_score
            )
            beam_indexes += [beam_ix] * n_beams
            current_completions += [beam_trees[beam_ix].current_sentence] * n_beams
            top_tokens += [tokenizer.decode([el]) for el in current_top_token_indexes]

        top_df = pd.DataFrame.from_dict(
            {
                "token_index": top_token_indexes,
                "cumulative_score": top_cumulative_scores,
                "beam_index": beam_indexes,
                "current_completions": current_completions,
                "token": top_tokens,
            }
        )
        maxes = top_df.groupby(["token_index", "current_completions"])[
            "cumulative_score"
        ].idxmax()

        top_df = top_df.loc[maxes]

        # Sort all top probabilities and keep top n_beams
        top_df_selected = top_df.sort_values("cumulative_score", ascending=False).iloc[
            :n_beams
        ]

        # Write the scores table - one per beam source?
        # Edge case: if several beam indexes are actually on the same beam, the selected tokens by beam_index for the second one will be empty. So we reverse
        for beam_ix in reversed(list(range(n_beams))):
            current_beam = beam_trees[beam_ix]
            selected_tokens = top_df_selected.loc[
                top_df_selected["beam_index"] == beam_ix
            ]
            markdown_table = generate_markdown_table(
                step_scores[beam_ix, :],
                current_beam.cumulative_score,
                current_beam.children_score_divider,
                chosen_tokens=list(selected_tokens["token"].values),
            )
            beam_trees[beam_ix].table = markdown_table

        # Add new children for each beam
        cumulative_scores = [beam.cumulative_score for beam in beam_trees]
        for beam_ix in range(n_beams):
            current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"]
            current_token_choice = tokenizer.decode([current_token_choice_ix])

            # Update the source tree
            source_beam_ix = int(top_df_selected.iloc[beam_ix]["beam_index"])

            cumulative_score = (
                cumulative_scores[source_beam_ix]
                + scores[step][source_beam_ix][current_token_choice_ix].numpy()
            )
            beam_trees[source_beam_ix].children[current_token_choice_ix] = BeamNode(
                table=None,
                children={},
                current_sentence=beam_trees[source_beam_ix].current_sentence
                + current_token_choice,
                cumulative_score=cumulative_score,
                children_score_divider=((input_length + step + 1) ** length_penalty),
            )

        # Reassign all beams at once
        beam_trees = [
            beam_trees[int(top_df_selected.iloc[beam_ix]["beam_index"])]
            for beam_ix in range(n_beams)
        ]

        # Advance all beams by one token
        for beam_ix in range(n_beams):
            current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"]
            beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice_ix]

    return original_tree

@spaces.GPU
def get_beam_search_html(input_text, number_steps, number_beams, length_penalty):
    inputs = tokenizer([input_text], return_tensors="pt")

    outputs = model.generate(
        **inputs,
        max_new_tokens=number_steps,
        num_beams=number_beams,
        num_return_sequences=number_beams,
        return_dict_in_generate=True,
        length_penalty=-10.0,
        output_scores=True,
        do_sample=False,
    )
    markdown = "Output sequences:"
    decoded_sequences = tokenizer.batch_decode(outputs.sequences)
    for i, sequence in enumerate(decoded_sequences):
        markdown += f"\n- {sequence.replace('<s> ', '')} (score {outputs.sequences_scores[i]:.2f})"

    original_tree = generate_beams(
        input_text,
        outputs.scores[:],
        outputs.sequences[:, :],
        length_penalty,
    )
    html = generate_html(input_text, original_tree)
    return html, markdown


with gr.Blocks(
    theme=gr.themes.Soft(
        text_size="lg", font=["monospace"], primary_hue=gr.themes.colors.yellow
    ),
    css=STYLE,
) as demo:
    gr.Markdown("""# Beam search visualizer
                
Play with the parameters below to understand how beam search decoding works!
                
#### Parameters: 
- **Sentence to decode from**: the input sequence to your decoder.
- **Number of steps**: the number of tokens to generate
- **Number of beams**: the number of beams to use
- **Length penalty**: the length penalty to apply to outputs. `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences.
""")
    text = gr.Textbox(label="Sentence to decode from", value="Conclusion: thanks a lot. This article was originally published on")
    with gr.Row():
        steps = gr.Slider(label="Number of steps", minimum=1, maximum=8, step=1, value=4)
        beams = gr.Slider(label="Number of beams", minimum=2, maximum=4, step=1, value=3)
        length_penalty = gr.Slider(label="Length penalty", minimum=-4, maximum=4, step=0.5, value=1)
    button = gr.Button()
    out_html = gr.Markdown()
    out_markdown = gr.Markdown()
    button.click(get_beam_search_html, inputs=[text, steps, beams, length_penalty], outputs=[out_html, out_markdown])

demo.launch()