|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base") |
|
|
|
models = { |
|
"Luc Bat": AutoModelForCausalLM.from_pretrained( |
|
"Libosa2707/vietnamese-poem-luc-bat-gpt2" |
|
), |
|
"Bay Chu": AutoModelForCausalLM.from_pretrained( |
|
"Libosa2707/vietnamese-poem-bay-chu-gpt2" |
|
), |
|
"Tam Chu": AutoModelForCausalLM.from_pretrained( |
|
"Libosa2707/vietnamese-poem-tam-chu-gpt2" |
|
), |
|
"Nam Chu": AutoModelForCausalLM.from_pretrained( |
|
"Libosa2707/vietnamese-poem-nam-chu-gpt2" |
|
), |
|
} |
|
|
|
|
|
def generate_poem(text, style): |
|
|
|
model = models[style] |
|
|
|
|
|
input_ids = tokenizer.encode(text, return_tensors="pt") |
|
|
|
|
|
output = model.generate(input_ids, max_length=100, do_sample=True, temperature=0.7) |
|
|
|
|
|
generated_text = tokenizer.decode( |
|
output[:, input_ids.shape[-1] :][0], skip_special_tokens=True |
|
) |
|
|
|
text = text + generated_text |
|
|
|
|
|
text = text.replace("<unk>", "\n") |
|
pretty_text = "" |
|
for idx, line in enumerate(text.split("\n")): |
|
line = line.strip() |
|
if not line: |
|
continue |
|
line = line[0].upper() + line[1:] |
|
pretty_text += line + "\n" |
|
|
|
return pretty_text |
|
|
|
|
|
gradio_interface = gr.Interface( |
|
fn=generate_poem, |
|
inputs=[ |
|
gr.inputs.Textbox(lines=1, placeholder="First words of the poem"), |
|
gr.inputs.Dropdown( |
|
choices=["Luc Bat", "Bay Chu", "Tam Chu", "Nam Chu"], label="Style" |
|
), |
|
], |
|
outputs="text", |
|
) |
|
gradio_interface.launch() |
|
|