File size: 1,759 Bytes
1a8c724 4d3d295 dd465f4 4d3d295 1a8c724 97ca765 4d3d295 9820b04 4d3d295 1a8c724 dd97139 97ca765 77dbb7f dd97139 a5284e6 dd97139 a5284e6 77dbb7f a5284e6 1a8c724 e9e5a55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("./checkpoint-25000/")
def text_processing(text):
text = text + ' ' if text[-2:] != ' ' else text # 在末尾加上空格有利于模型预测
inputs = [text]
# Tokenize and prepare the inputs for model
input_ids = tokenizer(inputs, return_tensors="pt", max_length=512, truncation=True, padding="max_length").input_ids
attention_mask = tokenizer(inputs, return_tensors="pt", max_length=512, truncation=True, padding="max_length").attention_mask
# Generate prediction
output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=512)
# Decode the prediction
decoded_output = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output]
return decoded_output[0]
examples = [
["我们的价值观是 富强 民主 文明 和谐"],
["都什么年代了 还在抽传统香烟"],
["今夕是何年"],
[" 三国演义 全名为 三國志通俗演义 又稱作 三國志演義 三國志傳 三國傳 三國全傳 三國英雄志傳 "],
]
inputs=[gr.inputs.Textbox(default=examples[0][0], label="输入文本")]
iface = gr.Interface(
fn=text_processing,
inputs=[gr.inputs.Textbox(default=examples[0][0], label="输入文本")],
outputs='text',
title='Punctuation Mark Prediction',
description='本模型主要用于语音识别模型输出的后处理。\n输入无符号句子,需要打标点处用空格隔开,返回带标点句子。\n仅支持中文,因为训练数据中只有中文。',
examples=examples
)
iface.launch(inline=False) |