|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("./checkpoint-12500/") |
|
|
|
def text_processing(text): |
|
text = text + ' ' if text[-2:] != ' ' else text |
|
inputs = [text] |
|
|
|
|
|
input_ids = tokenizer(inputs, return_tensors="pt", max_length=512, truncation=True, padding="max_length").input_ids |
|
attention_mask = tokenizer(inputs, return_tensors="pt", max_length=512, truncation=True, padding="max_length").attention_mask |
|
|
|
|
|
output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=512) |
|
|
|
|
|
decoded_output = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output] |
|
|
|
return decoded_output[0] |
|
|
|
examples = [ |
|
["我们的价值观是 富强 民主 文明 和谐"], |
|
["都什么年代了 还在抽传统香烟"], |
|
["今夕是何年"], |
|
[" 三国演义 全名为 三國志通俗演义 又稱作 三國志演義 三國志傳 三國傳 三國全傳 三國英雄志傳 "], |
|
] |
|
|
|
inputs=[gr.inputs.Textbox(default=examples[0][0], label="输入文本")] |
|
|
|
|
|
iface = gr.Interface( |
|
fn=text_processing, |
|
inputs=[gr.inputs.Textbox(default=examples[0][0], label="输入文本")], |
|
outputs='text', |
|
title='Punctuation Mark Prediction', |
|
description='本模型主要用于语音识别模型输出的后处理。\n输入无符号句子,需要打标点处用空格隔开,返回带标点句子。\n仅支持中文,因为训练数据中只有中文。', |
|
examples=examples |
|
) |
|
|
|
iface.launch(inline=False) |