|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base") |
|
|
|
models = { |
|
"Lục Bát": AutoModelForCausalLM.from_pretrained( |
|
"Libosa2707/vietnamese-poem-luc-bat-gpt2" |
|
), |
|
"Bảy Chữ": AutoModelForCausalLM.from_pretrained( |
|
"Libosa2707/vietnamese-poem-bay-chu-gpt2" |
|
), |
|
"Tám Chữ": AutoModelForCausalLM.from_pretrained( |
|
"Libosa2707/vietnamese-poem-tam-chu-gpt2" |
|
), |
|
"Năm Chữ": AutoModelForCausalLM.from_pretrained( |
|
"Libosa2707/vietnamese-poem-nam-chu-gpt2" |
|
), |
|
} |
|
|
|
|
|
def complete_poem(text, style): |
|
|
|
text = text.strip() |
|
text = text.lower() |
|
|
|
|
|
model = models[style] |
|
|
|
|
|
input_ids = tokenizer.encode(text, return_tensors="pt")[:, :-1] |
|
|
|
|
|
output = model.generate(input_ids, max_length=100, do_sample=True, temperature=0.7) |
|
|
|
|
|
generated_text = tokenizer.decode( |
|
output[:, input_ids.shape[-1] :][0], skip_special_tokens=True |
|
) |
|
|
|
text = text + " " + generated_text |
|
|
|
|
|
text = text.replace("<unk>", "\n") |
|
pretty_text = "" |
|
for idx, line in enumerate(text.split("\n")): |
|
line = line.strip() |
|
if not line: |
|
continue |
|
line = line[0].upper() + line[1:] |
|
pretty_text += line + "\n" |
|
|
|
return pretty_text |
|
|
|
|
|
complete_poem_interface = gr.Interface( |
|
title="Viết tiếp áng thơ hay...", |
|
fn=complete_poem, |
|
inputs=[ |
|
gr.components.Textbox( |
|
lines=1, |
|
placeholder="Tôi đâu có biết làm thơ", |
|
label="Những áng thơ đầu tiên", |
|
), |
|
gr.components.Dropdown( |
|
choices=["Lục Bát", "Bảy Chữ", "Tám Chữ", "Năm Chữ"], |
|
label="Kiểu thơ", |
|
value="Lục Bát", |
|
), |
|
], |
|
outputs="text", |
|
) |
|
|