Spaces:
Running
Running
import argparse | |
from collections import defaultdict | |
import datetime | |
import json | |
import os, sys | |
import time | |
import concurrent | |
import math | |
import gradio as gr | |
import requests | |
import logging | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import fairseq | |
fairseq_path = os.path.dirname(os.path.dirname(fairseq.__file__)) | |
sys.path.insert(1, f"{fairseq_path}") | |
from fs_plugins.models.glat_decomposed_with_link import GlatDecomposedLink | |
sys.path.insert(1, f"{fairseq_path}/examples") | |
from mass.s2s_model import TransformerMASSModel | |
from transformer.hub_interface import TransformerHubInterface | |
logger = logging.getLogger(__name__) | |
notice_markdown = (""" | |
# Directed Acyclic Transformer: A Non-Autoregressive Sequence-to-Sequence Model designed for Parallel Text Generation. | |
- **Fast Generation**: DA-Transformer offers faster inference compared to autoregressive Transformers (with fairseq implementation), with a reduction in latency by 7~14x and an increase in throughput by ~20x. | |
- **High Quality**: DA-Transformer performs competitively with autoregressive Transformers, even with pre-trained models like BART, in a variety of text generation tasks. | |
- **Easy Training**: DA-Transformer can be trained end-to-end without requiring knowledge distillation, making it simple and straightforward to train. | |
## Resources | |
- [[Github]](https://github.com/thu-coai/DA-Transformer) | |
- Papers: [[Machine Translation]](https://proceedings.mlr.press/v162/huang22m/huang22m.pdf) [[Pre-training]](https://arxiv.org/pdf/2304.11791.pdf) | |
## Terms of use | |
By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It does not gaurantee the correctness of the output text. The service may collect user data for future research. | |
## This demo contains models | |
- [En-De Translation]() | |
- [Zh-En Translation]() | |
- [Question Generation]() | |
""") | |
learn_more_markdown = (""" | |
""") | |
css = """ | |
pre { | |
white-space: pre-wrap; /* Since CSS 2.1 */ | |
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ | |
white-space: -pre-wrap; /* Opera 4-6 */ | |
white-space: -o-pre-wrap; /* Opera 7 */ | |
word-wrap: break-word; /* Internet Explorer 5.5+ */ | |
} | |
""" | |
available_models = { | |
"dat_base_translation_ende": { | |
"class": GlatDecomposedLink, | |
"args":{ | |
"model_name_or_path": "hfhub://thu-coai/dat_base_translation_ende", | |
"decode_strategy": "beamsearch", | |
"decode_max_workers": 1, | |
"decode_threads_per_worker": 4, | |
"decode_dedup": True, | |
"decode_alpha": 1.1, | |
"decode_gamma": 0, | |
"decode_beam_size": 200, | |
"decode_batch_size": 1, | |
"decode_top_cand": 5, | |
"decode_max_beam_per_length": 10, | |
"max_decoder_batch_tokens": 2048 | |
}, | |
"examples": ["I am a fast translation model."], | |
"expected_load_time": 17 | |
}, | |
"dat_base_translation_zhen": { | |
"class": GlatDecomposedLink, | |
"args":{ | |
"model_name_or_path": "hfhub://thu-coai/dat_base_translation_zhen", | |
"decode_strategy": "beamsearch", | |
"decode_max_workers": 1, | |
"decode_threads_per_worker": 4, | |
"decode_dedup": True, | |
"decode_alpha": 1.1, | |
"decode_gamma": 0, | |
"decode_beam_size": 200, | |
"decode_batch_size": 1, | |
"decode_top_cand": 5, | |
"decode_max_beam_per_length": 10, | |
"max_decoder_batch_tokens": 2048 | |
}, | |
"examples": ["我是一个高速的机器翻译模型。"], | |
"expected_load_time": 17 | |
}, | |
"dat_uncased_squad": { | |
"class": GlatDecomposedLink, | |
"args":{ | |
"model_name_or_path": "hfhub://thu-coai/dat_uncased_squad", | |
"decode_strategy": "beamsearch", | |
"decode_max_workers": 1, | |
"decode_threads_per_worker": 4, | |
"decode_gamma": 0, | |
"decode_beam_size": 200, | |
"decode_batch_size": 1, | |
"decode_top_cand": 5, | |
"decode_no_consecutive_repeated_tokens": 3, | |
"decode_no_repeated_tokens": 2, | |
"decode_max_beam_per_length": 10, | |
"max_decoder_batch_tokens": 2048 | |
}, | |
"examples": ["Two [SEP] Two additional teams of 40 attendants each will accompany the flame on its mainland China route."], | |
"expected_load_time": 20 | |
}, | |
"mass_uncased_squad": { | |
"class": TransformerMASSModel, | |
"args":{ | |
"model_name_or_path": "hfhub://thu-coai/mass_uncased_squad" | |
}, | |
"examples": ["Two [SEP] Two additional teams of 40 attendants each will accompany the flame on its mainland China route."], | |
"expected_load_time": 10 | |
}, | |
"transformer_base_translation_ende": { | |
"class": TransformerHubInterface, | |
"args":{ | |
"model_name_or_path": "hfhub://thu-coai/transformer_base_translation_ende" | |
}, | |
"examples": ["I am a fast translation model."], | |
"expected_load_time": 10 | |
}, | |
"transformer_base_translation_zhen": { | |
"class": TransformerHubInterface, | |
"args":{ | |
"model_name_or_path": "hfhub://thu-coai/transformer_base_translation_zhen" | |
}, | |
"examples": ["我是一个高速的机器翻译模型。"], | |
"expected_load_time": 10 | |
} | |
} | |
compare_available_types = { | |
"Translation Zh-En: DA-Transformer v.s. Autoregressive Transformer": { | |
"models": ['dat_base_translation_zhen', 'transformer_base_translation_zhen'], | |
"examples": ["我是一个高速的机器翻译模型。", "非自回归模型可以用来加速自然语言生成。", | |
"使用本服务前,用户必须同意以下条款:该服务是仅供非商业用途的研究预览。它不保证输出文本的正确性。本服务可能会收集用户数据以供将来研究。"], | |
"placeholder": "请输入一个中文句子。 (The model will translate the input into English.)" | |
}, | |
"Question Generation: DA-Transformer v.s. MASS": { | |
"models": ['dat_uncased_squad', "mass_uncased_squad"], | |
"examples": ["Two [SEP] Two additional teams of 40 attendants each will accompany the flame on its mainland China route.", "DA-Transformer [SEP] Directed Acyclic Transformer (DA-Transformer) is a non-autoregressive sequence-to-sequence model designed for parallel text generation."], | |
"placeholder": "Answer [SEP] Your Passage Here (the answer should be appearred in the passage)." | |
}, | |
"Translation En-De: DA-Transformer v.s. Autoregressive Transformer": { | |
"models": ['dat_base_translation_ende', 'transformer_base_translation_ende'], | |
"examples": ["I am a fast translation model.", "Non-autoregressive models are designed for fast natural language generation.", | |
"By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only."], | |
"placeholder": "Any English sentence here. (The model will translate the input into German.)" | |
}, | |
} | |
detail_available_types = { | |
"Translation Zh-En": { | |
"model": 'dat_base_translation_zhen', | |
"examples": compare_available_types['Translation Zh-En: DA-Transformer v.s. Autoregressive Transformer']["examples"], | |
"placeholder": compare_available_types['Translation Zh-En: DA-Transformer v.s. Autoregressive Transformer']["placeholder"] | |
}, | |
"Question Generation": { | |
"model": 'dat_uncased_squad', | |
"examples": compare_available_types['Question Generation: DA-Transformer v.s. MASS']["examples"], | |
"placeholder": compare_available_types['Question Generation: DA-Transformer v.s. MASS']["placeholder"] | |
}, | |
"Translation En-De": { | |
"model": 'dat_base_translation_ende', | |
"examples": compare_available_types['Translation En-De: DA-Transformer v.s. Autoregressive Transformer']["examples"], | |
"placeholder": compare_available_types['Translation En-De: DA-Transformer v.s. Autoregressive Transformer']["placeholder"], | |
}, | |
} | |
models = {} | |
workers = None | |
def softplus(x, beta=1): | |
return math.log1p(math.exp(-abs(x * beta))) / beta + max(x, 0) | |
def get_fake_progress(min_progress, max_progress, used_time, expected_time): | |
percentage = max(1 - softplus(expected_time - used_time) / expected_time, 0) | |
return min_progress + (max_progress - min_progress) * percentage | |
def generate(model, model_input): | |
return {"output": model.translate(model_input)} | |
def generate_detail(model, model_input): | |
output, graph_info = model.generate_graph(model_input) | |
return {"output": output, "graph_info": graph_info} | |
def load_model(model_name): | |
assert model_name in available_models | |
model = available_models[model_name]['class'].from_pretrained(**available_models[model_name]['args']) | |
return model | |
def warmup_model(model, model_name): | |
model.translate(available_models[model_name]['examples'][0]) | |
def submit(model_name, model_input, generate_fn, request: gr.Request, progress=gr.Progress()): | |
assert workers is not None, "No workers" | |
current_progress = 0 | |
progress(0, desc="Downloading Checkpoints and Loading Models") | |
if model_name not in models: | |
load_start = time.time() | |
future = workers.submit(load_model, model_name) | |
while True: | |
try: | |
model = future.result(timeout=1) | |
break | |
except concurrent.futures._base.TimeoutError as _: | |
progress(get_fake_progress(min_progress=current_progress, max_progress=0.8, used_time=time.time() - load_start, expected_time=available_models[model_name]['expected_load_time']), | |
desc="Downloading Checkpoints and Loading Models") | |
logger.info(f"Model Loaded: {model_name} Load Time: {time.time() - load_start}") | |
current_progress = 0.8 | |
models[model_name] = model | |
else: | |
model = models[model_name] | |
# warmup for better inference time | |
progress(current_progress, desc="Downloading Checkpoints and Loading Models") | |
if current_progress == 0.8: | |
target_progress = 0.9 | |
else: | |
target_progress = 0.5 | |
warmup_start = time.time() | |
future = workers.submit(warmup_model, model, model_name) | |
while True: | |
try: | |
result = future.result(timeout=1) | |
break | |
except concurrent.futures._base.TimeoutError as _: | |
progress(get_fake_progress(min_progress=current_progress, max_progress=target_progress, used_time=time.time() - warmup_start, expected_time=1), | |
desc="Downloading Checkpoints and Loading Models") | |
current_progress = target_progress | |
# running | |
progress(current_progress, desc="Running") | |
try: | |
generate_start = time.time() | |
future = workers.submit(generate_fn, model, model_input) | |
while True: | |
try: | |
result = future.result(timeout=1) | |
break | |
except concurrent.futures._base.TimeoutError as _: | |
progress(get_fake_progress(min_progress=current_progress, max_progress=1, used_time=time.time() - generate_start, expected_time=1), | |
desc="Running") | |
inference_time = time.time() - generate_start | |
result_abbrev = {} | |
for key, value in result.items(): | |
log_str = str(value) | |
if len(log_str) > 1024: | |
log_str = log_str[:1024] + "..." | |
result_abbrev[key] = log_str | |
logger.info(f"Input: [{model_input}] Output: [{result_abbrev}] Inference Time: {inference_time}") | |
return result, inference_time | |
except RuntimeError as err: | |
return f"Runtime Error: {str(err)}", 0 | |
def compare_init_state(model_selector): | |
model1 = compare_available_types[model_selector]['models'][0] | |
model2 = compare_available_types[model_selector]['models'][1] | |
state = [{"model_name": model1}, {"model_name": model2}] | |
return state | |
def compare_refresh(model_selector, samples): | |
model1 = compare_available_types[model_selector]['models'][0] | |
model2 = compare_available_types[model_selector]['models'][1] | |
model_output1 = gr.Textbox.update(visible=True, label=model1) | |
model_output2 = gr.Textbox.update(visible=True, label=model2) | |
model_input = gr.Textbox.update(value="", placeholder=compare_available_types[model_selector]['placeholder']) | |
samples.clear() | |
samples += [[x]for x in compare_available_types[model_selector]['examples']] | |
examples = gr.Dataset.update(samples=samples) | |
model_speed = gr.Plot.update(visible=False) | |
return model_input, model_output1, model_output2, examples, samples, model_speed | |
def compare_submit(model_input, idx, state, request: gr.Request, progress=gr.Progress()): | |
model_name = state[idx]['model_name'] | |
model_output, inference_time = submit(model_name, model_input, generate, request, progress) | |
state[idx]['inference_time'] = inference_time | |
return model_output['output'], state | |
def compare_dataset_click(examples, samples): | |
return samples[examples][0] | |
def compare_show_plot(state): | |
x = [state[0]['model_name'], state[1]['model_name']] | |
y = [state[0]['inference_time'], state[1]['inference_time']] | |
fig = plt.figure(figsize=(12, 2.5)) | |
ax = plt.subplot(111) | |
bars = ax.barh(x, y, 0.75) | |
ax.bar_label(bars, fmt="%.2f") | |
ax.set_yticks(np.arange(len(x)), labels=x) | |
ax.set_xlabel('Inference Time on CPU (s)') | |
plt.tight_layout() | |
# plt.subplots_adjust(left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0, hspace=0) | |
return gr.Row.update(visible=True), gr.Plot.update(value=fig, visible=True) | |
def compare_clear(): | |
return "", "", "", gr.Row.update(visible=False) | |
example_list = [] | |
def build_tab_compare(): | |
state = gr.State() | |
samples = gr.State(example_list) | |
available_type_names = list(compare_available_types.keys()) | |
with gr.Row(elem_id="compare_model_selector_row"): | |
model_selector = gr.Dropdown( | |
choices=available_type_names, | |
value=available_type_names[0] if len(available_type_names) > 0 else "", | |
interactive=True, | |
show_label=False).style(container=False) | |
with gr.Row(elem_id="compare_model_input"): | |
model_input = gr.Textbox(lines=5, label="input") | |
# examples = gr.Dataset(examples=[], inputs=[model_input], elem_id="compare_examples") | |
examples = gr.Dataset(components=[model_input], | |
label="Examples", | |
type='index', | |
samples=example_list, | |
visible=True | |
) | |
# with gr.Row(elem_id="compare_examples"): | |
with gr.Row(): | |
clear_btn = gr.Button(value="Clear") | |
submit_btn = gr.Button(value="Submit", variant="primary") | |
# with gr.Accordion("Parameters", open=False, visible=False) as parameter_row: | |
# temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Temperature",) | |
# max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",) | |
with gr.Row(elem_id="compare_model_output"): | |
model_output1 = gr.Textbox(lines=5, label="output", visible=False) | |
model_output2 = gr.Textbox(lines=5, label="output", visible=False) | |
with gr.Row(elem_id="compare_model_speed", visible=False) as row: | |
with gr.Column(): | |
model_speed = gr.Plot(value=None, label="Speed") | |
compare_hints = gr.Markdown("**Note the above time is measured on a free cloud server, which does not use GPU and is thus different from the setting in the papers.**") | |
model_selector.change(compare_refresh, [model_selector, samples], [model_input, model_output1, model_output2, examples, samples, model_speed]) | |
clear_btn.click(compare_clear, None, [model_input, model_output1, model_output2, row]) | |
submit_btn.click(compare_init_state, [model_selector], [state]).\ | |
then(compare_submit, [model_input, gr.Number(value=0, visible=False, precision=0), state], [model_output1, state]).\ | |
then(compare_submit, [model_input, gr.Number(value=1, visible=False, precision=0), state], [model_output2, state]).\ | |
then(compare_show_plot, [state], [row, model_speed]) | |
# submit_btn.click(compare_show_plot, [state], [model_speed]) | |
examples.click(compare_dataset_click, [examples, samples], [model_input]) | |
def load(fn): | |
fn(compare_refresh, [model_selector, samples], [model_input, model_output1, model_output2, examples, samples]) | |
return load | |
def detail_init_state(model_selector): | |
model = detail_available_types[model_selector]['model'] | |
state = {"model_name": model, "cnt": 0} | |
return state | |
def detail_refresh(model_selector, samples): | |
model = detail_available_types[model_selector]['model'] | |
model_output = gr.Textbox.update(visible=True, label=model) | |
model_input = gr.Textbox.update(value="", placeholder=detail_available_types[model_selector]['placeholder']) | |
samples.clear() | |
samples += [[x]for x in detail_available_types[model_selector]['examples']] | |
examples = gr.Dataset.update(samples=samples) | |
model_speed = gr.Plot.update(visible=False) | |
return model_input, model_output, examples, samples, model_speed | |
def detail_submit(model_input, state, request: gr.Request, progress=gr.Progress()): | |
model_name = state['model_name'] | |
model_output, inference_time = submit(model_name, model_input, generate_detail, request, progress) | |
state['inference_time'] = inference_time | |
state["graph_info"] = model_output['graph_info'] | |
# html_code = open("graph.html").read() | |
# state["cnt"] += 1 | |
# if state["cnt"] > 2: | |
# html_code += r"""<script type="text/javascript">addNode();</script>\n""" | |
# print(html_code) | |
return model_output['output'], state, gr.Row.update(visible=True), json.dumps(state) | |
def detail_dataset_click(examples, samples): | |
return samples[examples][0] | |
def detail_clear(): | |
return "", "", gr.Row.update(visible=False) | |
def build_tab_detail(): | |
state = gr.State() | |
samples = gr.State(example_list) | |
available_type_names = list(detail_available_types.keys()) | |
with gr.Row(elem_id="detail_model_selector_row"): | |
model_selector = gr.Dropdown( | |
choices=available_type_names, | |
value=available_type_names[0] if len(available_type_names) > 0 else "", | |
interactive=True, | |
show_label=False).style(container=False) | |
with gr.Row(elem_id="detail_model_input"): | |
model_input = gr.Textbox(lines=5, label="input") | |
# examples = gr.Dataset(examples=[], inputs=[model_input], elem_id="compare_examples") | |
examples = gr.Dataset(components=[model_input], | |
label="Examples", | |
type='index', | |
samples=example_list, | |
visible=True | |
) | |
# with gr.Row(elem_id="compare_examples"): | |
with gr.Row(): | |
clear_btn = gr.Button(value="Clear") | |
submit_btn = gr.Button(value="Submit", variant="primary") | |
# with gr.Accordion("Parameters", open=False, visible=False) as parameter_row: | |
# temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Temperature",) | |
# max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",) | |
with gr.Row(elem_id="detail_model_output"): | |
model_output = gr.Textbox(lines=5, label="output", visible=False) | |
with gr.Row(visible=False) as dag_graph: | |
with gr.Column(scale=1.8): | |
html = gr.HTML(open("graph.html").read()) | |
with gr.Column(scale=1): | |
minimum_node_pass_prob = gr.Slider(0, 1, value=0.2, label="Show nodes with passing probability greater than", info="Nodes that predict the output sequence are always visible") | |
minimum_edge_prob = gr.Slider(0, 1, value=0.1, label="Show edges with transition probability greater than") | |
max_out_edge_num = gr.Slider(1, 10, value=5, step=1, label="Show top-k outgoing edges with k") | |
max_out_edge_prob = gr.Slider(0, 1, value=0.9, label="Show top-p outgoing edges with p") | |
force_in_edge = gr.Checkbox(True, label="Show at least one incoming edge for each node") | |
show_node_detail = gr.Checkbox(False, label="Show verbose node information") | |
show_edge_label = gr.Checkbox(False, label="Show transition probability") | |
network_refresh = gr.Button(value="Reinitialize DAG Visualization") | |
graph_parameters = [minimum_node_pass_prob, minimum_edge_prob, max_out_edge_num, max_out_edge_prob, force_in_edge, show_node_detail, show_edge_label] | |
js_state = gr.Textbox(visible=False) | |
model_selector.change(detail_refresh, [model_selector, samples], [model_input, model_output, examples, samples]) | |
clear_btn.click(detail_clear, None, [model_input, model_output, dag_graph]) | |
graph_create_js = """(state_str, minimum_node_pass_prob, minimum_edge_prob, max_out_edge_num, max_out_edge_prob, force_in_edge, show_node_detail, show_edge_label) => { | |
var state = JSON.parse(state_str); | |
var options = { | |
minimum_node_pass_prob: minimum_node_pass_prob, | |
minimum_edge_prob: minimum_edge_prob, | |
max_out_edge_num: max_out_edge_num, | |
max_out_edge_prob: max_out_edge_prob, | |
force_in_edge: force_in_edge, | |
show_node_detail: show_node_detail, | |
show_edge_label: show_edge_label, | |
} | |
startNetwork(state.graph_info, options); | |
}""" | |
graph_update_js = """(minimum_node_pass_prob, minimum_edge_prob, max_out_edge_num, max_out_edge_prob, force_in_edge, show_node_detail, show_edge_label) => { | |
var options = { | |
minimum_node_pass_prob: minimum_node_pass_prob, | |
minimum_edge_prob: minimum_edge_prob, | |
max_out_edge_num: max_out_edge_num, | |
max_out_edge_prob: max_out_edge_prob, | |
force_in_edge: force_in_edge, | |
show_node_detail: show_node_detail, | |
show_edge_label: show_edge_label, | |
} | |
updateNetwork(options); | |
}""" | |
submit_btn.click(detail_init_state, [model_selector], [state]).\ | |
then(detail_submit, [model_input, state], [model_output, state, dag_graph, js_state]).\ | |
then(None, [js_state] + graph_parameters, None, _js=graph_create_js) | |
network_refresh.click(None, [js_state] + graph_parameters, None, _js=graph_create_js) | |
minimum_node_pass_prob.change(None, graph_parameters, None, _js=graph_update_js) | |
minimum_edge_prob.change(None, graph_parameters, None, _js=graph_update_js) | |
max_out_edge_num.change(None, graph_parameters, None, _js=graph_update_js) | |
max_out_edge_prob.change(None, graph_parameters, None, _js=graph_update_js) | |
force_in_edge.select(None, graph_parameters, None, _js=graph_update_js) | |
show_node_detail.select(None, graph_parameters, None, _js=graph_update_js) | |
show_edge_label.select(None, graph_parameters, None, _js=graph_update_js) | |
examples.click(detail_dataset_click, [examples, samples], [model_input]) | |
def load(fn): | |
fn(detail_refresh, [model_selector, samples], [model_input, model_output, examples, samples]) | |
return load | |
def build_demo(): | |
with gr.Blocks(title="DA-Transformer Demo", theme=gr.themes.Base(), css=css) as demo: | |
gr.Markdown(notice_markdown) | |
with gr.Tab("Speed Comparison") as compare_tab: | |
compare_load = build_tab_compare() | |
compare_load(compare_tab.select) | |
with gr.Tab("DA-Transformer Inspection") as detail_tab: | |
detail_load = build_tab_detail() | |
detail_load(detail_tab.select) | |
gr.Markdown(learn_more_markdown) | |
compare_load(demo.load) | |
demo.load(None,None,None,_js=open("global.js").read()) | |
return demo | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--host", type=str, default="0.0.0.0") | |
parser.add_argument("--port", type=int) | |
parser.add_argument("--concurrency-count", type=int, default=1) | |
parser.add_argument("--share", action="store_true") | |
args = parser.parse_args() | |
logger.info(f"args: {args}") | |
workers = concurrent.futures.ThreadPoolExecutor(max_workers=1) | |
demo = build_demo() | |
demo.queue(concurrency_count=args.concurrency_count, status_update_rate=10, | |
api_open=False).launch(server_name=args.host, server_port=args.port, | |
share=args.share, max_threads=5) | |