File size: 10,067 Bytes
2d63e52
88356be
 
fb66618
7a37cfb
88356be
e4010d8
 
 
 
 
 
2d63e52
 
46edcf6
 
2d63e52
ad36776
e4010d8
2eb8692
ad36776
 
5da9cef
e4010d8
ad36776
9a4471c
 
e4010d8
9a4471c
e4010d8
 
 
 
 
 
 
ad36776
a7a4e14
da19af6
a7a4e14
e4010d8
ad36776
e4010d8
ad36776
 
 
 
 
 
 
da19af6
e4010d8
 
 
 
ad36776
 
a7a4e14
ad36776
 
 
e4010d8
ad36776
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dccff1
ad36776
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4010d8
ad36776
 
 
 
 
 
 
 
e4010d8
ad36776
 
 
 
 
 
7a70688
8666c72
ad36776
 
 
8666c72
ad36776
e4010d8
8666c72
e4010d8
ad36776
bc46ee1
7a37cfb
e4010d8
 
ad36776
e4010d8
 
ad36776
e4010d8
a7a4e14
e4010d8
 
a7a4e14
 
ad36776
 
 
7a37cfb
ad36776
 
 
 
7a37cfb
e4010d8
ad36776
e4010d8
ad36776
 
 
 
 
 
 
c9b1b1f
7a37cfb
ad36776
 
 
 
 
 
e4010d8
 
 
a7a4e14
 
e4010d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad36776
e4010d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad36776
e4010d8
 
 
 
ad36776
e4010d8
 
 
 
 
ad36776
e4010d8
 
 
ad36776
e4010d8
ad36776
e4010d8
 
 
 
 
ad36776
e4010d8
 
7a37cfb
e4010d8
 
 
 
 
 
 
 
 
 
 
 
ad36776
e4010d8
 
 
 
 
 
 
 
 
 
 
 
ad36776
e774464
e4010d8
ad36776
 
 
 
 
 
 
 
 
 
e4010d8
bc46ee1
ad36776
e4010d8
66e4ecd
e4010d8
 
 
bc46ee1
e4010d8
 
 
7a37cfb
 
ad36776
 
d121639
ad36776
 
 
88356be
ad36776
8cba28a
ad36776
 
e4010d8
ad36776
bc46ee1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import gradio as gr
import spaces

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

tokenizer.pad_token_id = tokenizer.eos_token_id
print("Loading finished.")

print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")

STYLE = """
.custom-container {
	width: 100%;
	display: grid;
	align-items: center;
    margin: 0!important;
    overflow: scroll;
}
.prose ul ul {
    margin: 0!important;
    font-size: 10px!important;
}
.prose td, th {
    padding-left: 2px;
    padding-right: 2px;
    padding-top: 0;
    padding-bottom: 0;
}

.tree {
	padding: 0px;
	margin: 0!important;
	box-sizing: border-box;
    font-size: 10px;
	width: 100%;
    min-width: 2000px;
	height: auto;
	text-align: center;
}
.tree ul {
	padding-top: 20px;
	position: relative;
	transition: .5s;
    margin: 0!important;
    display: flex;
    flex-direction: row;
    justify-content: center;
    gap:10px;
}
.tree li {
	display: inline-table;
	text-align: center;
	list-style-type: none;
	position: relative;
	padding-top: 10px;
	transition: .5s;
}
.tree li::before, .tree li::after {
	content: '';
	position: absolute;
	top: 0;
	right: 50%;
	border-top: 1px solid #ccc;
	width: 51%;
	height: 10px;
}
.tree li::after {
	right: auto;
	left: 50%;
	border-left: 1px solid #ccc;
}
.tree li:only-child::after, .tree li:only-child::before {
	display: none;
}

.tree li:first-child::before, .tree li:last-child::after {
	border: 0 none;
}
.tree li:last-child::before {
	border-right: 1px solid #ccc;
	border-radius: 0 5px 0 0;
	-webkit-border-radius: 0 5px 0 0;
	-moz-border-radius: 0 5px 0 0;
}
.tree li:first-child::after {
	border-radius: 5px 0 0 0;
	-webkit-border-radius: 5px 0 0 0;
	-moz-border-radius: 5px 0 0 0;
}
.tree ul ul::before {
	content: '';
	position: absolute;
	top: 0;
	left: 50%;
	border-left: 1px solid #ccc;
	width: 0;
	height: 20px;
}
.tree li a {
	border: 1px solid #ccc;
	padding: 5px;
	display: inline-grid;
	border-radius: 5px;
	text-decoration-line: none;
	border-radius: 5px;
	transition: .5s;
}
.tree li a span {
	color: #666;
	padding: 5px;
	font-size: 12px;
	text-transform: uppercase;
	letter-spacing: 1px;
	font-weight: 500;
}
/*Hover-Section*/
.tree li a:hover, .tree li a:hover+ul li a {
	background: #ffedd5;
	color: #000;
}
.tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before {
	border-color: #7c2d12;
}
.chosen {
    background-color: #ea580c;
}
"""


def generate_nodes(token, node):
    """Recursively generate HTML for the tree nodes."""

    html_content = f" <li> <a href='#' class={('chosen' if node.table is None else '')}> <span> <b>{token}</b> </span> "
    html_content += node.table if node.table is not None else ""
    html_content += "</a>"
    if len(node.children.keys()) > 0:
        html_content += "<ul> "
        for token, subnode in node.children.items():
            html_content += generate_nodes(token, subnode)
        html_content += "</ul>"
    html_content += "</li>"
    return html_content


def generate_markdown_table(scores, sequence_prob, top_k=4, chosen_tokens=None):
    markdown_table = """
    <table>
        <tr>
            <th><b>Token</b></th>
            <th><b>Step score</b></th>
            <th><b>Total score</b></th>
        </tr>"""
    for token_idx in np.array(np.argsort(scores)[-top_k:])[::-1]:
        token = tokenizer.decode([token_idx])
        style = ""
        if chosen_tokens and token in chosen_tokens:
            style = "background-color:red"
        markdown_table += f"""
        <tr style={style}>
            <td>{token}</td>
            <td>{scores[token_idx]:.4f}</td>
            <td>{scores[token_idx] + sequence_prob:.4f}</td>
        </tr>"""
    markdown_table += """
    </table>"""
    return markdown_table


def generate_html(start_sentence, original_tree):

    html_output = """<div class="custom-container">
				<div class="tree">
                <ul>"""
    html_output += generate_nodes(start_sentence, original_tree)

    html_output += """
        </ul>
        </div>
    </body>
    """
    return html_output


import pandas as pd
from typing import Dict
from dataclasses import dataclass


@dataclass
class BeamNode:
    cumulative_score: float
    table: str
    current_sentence: str
    children: Dict[str, "BeamNode"]


def generate_beams(start_sentence, scores, sequences, beam_indices):
    print(tokenizer.batch_decode(sequences))
    sequences = sequences.cpu().numpy()
    original_tree = BeamNode(
        cumulative_score=0, table=None, current_sentence=start_sentence, children={}
    )
    n_beams = len(scores[0])
    beam_trees = [original_tree] * n_beams
    for step, step_scores in enumerate(scores):
        (
            top_token_indexes,
            top_cumulative_scores,
            beam_indexes,
            current_completions,
            top_tokens,
        ) = ([], [], [], [], [])
        for beam_ix in range(n_beams):
            current_beam = beam_trees[beam_ix]
            # Get top cumulative scores for the current beam
            current_top_token_indexes = list(
                np.array(scores[step][beam_ix].argsort()[-n_beams:])[::-1]
            )
            top_token_indexes += current_top_token_indexes
            top_cumulative_scores += list(
                np.array(scores[step][beam_ix][current_top_token_indexes])
                + current_beam.cumulative_score
            )
            beam_indexes += [beam_ix] * n_beams
            current_completions += [beam_trees[beam_ix].current_sentence] * n_beams
            top_tokens += [
                tokenizer.decode([el]) for el in current_top_token_indexes
            ]

        top_df = pd.DataFrame.from_dict(
            {
                "token_index": top_token_indexes,
                "cumulative_score": top_cumulative_scores,
                "beam_index": beam_indexes,
                "current_completions": current_completions,
                "token": top_tokens,
            }
        )
        maxes = top_df.groupby(["token_index", "current_completions"])[
            "cumulative_score"
        ].idxmax()

        top_df = top_df.loc[maxes]

        # Sort all top probabilities and keep top n_beams
        top_df_selected = top_df.sort_values("cumulative_score", ascending=False).iloc[
            :n_beams
        ]

        # Write the scores table - one per beam source?
        # Edge case: if several beam indexes are actually on the same beam, the selected tokens by beam_index for the second one will be empty. So we reverse
        for beam_ix in reversed(list(range(n_beams))):
            current_beam = beam_trees[beam_ix]
            selected_tokens = top_df_selected.loc[top_df_selected["beam_index"] == beam_ix]
            markdown_table = generate_markdown_table(
                step_scores[beam_ix, :],
                current_beam.cumulative_score,
                chosen_tokens=list(selected_tokens["token"].values),
            )
            beam_trees[beam_ix].table = markdown_table

        # Add new children for each beam
        cumulative_scores = [beam.cumulative_score for beam in beam_trees]
        for beam_ix in range(n_beams):
            current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"]
            current_token_choice = tokenizer.decode([current_token_choice_ix])

            # Update the source tree
            source_beam_ix = int(top_df_selected.iloc[beam_ix]["beam_index"])

            previous_len = len(str(original_tree))
            beam_trees[source_beam_ix].children[current_token_choice] = BeamNode(
                table=None,
                children={},
                current_sentence=beam_trees[source_beam_ix].current_sentence
                + current_token_choice,
                cumulative_score=cumulative_scores[source_beam_ix]
                + scores[step][source_beam_ix][current_token_choice_ix].numpy(),
            )
            assert (
                len(str(original_tree)) > previous_len
            ), "Original tree has not increased size"

        # Reassign all beams at once
        beam_trees = [
            beam_trees[int(top_df_selected.iloc[beam_ix]["beam_index"])]
            for beam_ix in range(n_beams)
        ]

        # Advance all beams by one token
        for beam_ix in range(n_beams):
            current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"]
            current_token_choice = tokenizer.decode([current_token_choice_ix])
            beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice]
    return original_tree

@spaces.GPU
def get_beam_search_html(input_text, number_steps, number_beams):
    inputs = tokenizer([input_text], return_tensors="pt")

    outputs = model.generate(
        **inputs,
        max_new_tokens=number_steps,
        num_beams=number_beams,
        num_return_sequences=number_beams,
        return_dict_in_generate=True,
        output_scores=True,
        top_k=5,
        do_sample=False,
    )

    original_tree = generate_beams(
        input_text,
        outputs.scores[:],
        outputs.sequences[:, :],
        outputs.beam_indices[:, :],
    )
    html = generate_html(input_text, original_tree)
    print(html)
    return html


with gr.Blocks(
    theme=gr.themes.Soft(
        text_size="lg", font=["monospace"], primary_hue=gr.themes.colors.orange
    ),
    css=STYLE,
) as demo:
    text = gr.Textbox(label="Sentence to decode from", value="Today is")
    steps = gr.Slider(label="Number of steps", minimum=1, maximum=10, step=1, value=4)
    beams = gr.Slider(label="Number of beams", minimum=2, maximum=4, step=1, value=3)
    button = gr.Button()
    out = gr.Markdown(label="Output")
    button.click(get_beam_search_html, inputs=[text, steps, beams], outputs=out)

demo.launch()