|
from transformers import ( |
|
EncoderDecoderModel, |
|
AutoTokenizer |
|
) |
|
import torch |
|
import streamlit as st |
|
|
|
PRETRAINED = "raynardj/wenyanwen-chinese-translate-to-ancient" |
|
|
|
def inference(text): |
|
tk_kwargs = dict( |
|
truncation=True, |
|
max_length=128, |
|
padding="max_length", |
|
return_tensors='pt') |
|
|
|
inputs = tokenizer([text,],**tk_kwargs) |
|
with torch.no_grad(): |
|
return tokenizer.batch_decode( |
|
model.generate( |
|
inputs.input_ids, |
|
attention_mask=inputs.attention_mask, |
|
num_beams=3, |
|
bos_token_id=101, |
|
eos_token_id=tokenizer.sep_token_id, |
|
pad_token_id=tokenizer.pad_token_id, |
|
), skip_special_tokens=True)[0].replace(" ","") |
|
|
|
st.title("🪕古朴 ❄️清雅 🌊壮丽") |
|
st.markdown(""" |
|
> Translate from Chinese to Ancient Chinese / 还你古朴清雅壮丽的文言文, 这[github](https://github.com/raynardj/yuan) |
|
> 最多100个中文字符 |
|
""") |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_model(): |
|
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED) |
|
model = EncoderDecoderModel.from_pretrained(PRETRAINED) |
|
return tokenizer, model |
|
|
|
tokenizer, model = load_model() |
|
|
|
text = st.text_area(value="轻轻地我走了,正如我轻轻地来。我挥一挥衣袖,不带走一片云彩。", label="输入文本") |
|
|
|
if st.button("曰"): |
|
if len(text) > 100: |
|
st.error("无过百字,若过则当答此言。") |
|
else: |
|
st.write(inference(text)) |
|
|
|
|