tho_ai / generate_poem.py
phamson02
update
06071ca
raw
history blame contribute delete
No virus
2.2 kB
import torch
import gradio as gr
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("VietAI/vit5-base")
model = T5ForConditionalGeneration.from_pretrained("Libosa2707/vietnamese-poem-t5")
def generate_poem(input_text):
# Preprocess the input text
input_text = input_text.strip()
input_text = input_text.lower()
# Define the parameters for the generate function
min_length = 50
max_length = 512
rep_penalty = 1.2
temp = 0.7
top_k = 50
top_p = 0.92
no_repeat_ngram_size = 2
# Tokenize the input
input_ids = tokenizer(
input_text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=42,
).input_ids.to(model.device)
# Generate text
model.eval()
with torch.no_grad():
output = model.generate(
do_sample=True,
input_ids=input_ids,
min_length=min_length,
max_length=max_length,
top_p=top_p,
top_k=top_k,
temperature=temp,
repetition_penalty=rep_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
num_return_sequences=1,
)
# Process the generated text
gen = tokenizer.decode(
output[0], skip_special_tokens=False, clean_up_tokenization_spaces=False
)
sentences = gen.split("<unk>")
gen_poem = "\n".join(sentences).replace("<pad>", "").replace("</s>", "")
gen_poem = gen_poem.strip()
# Post-process the poem text
pretty_text = ""
for line in gen_poem.split("\n"):
line = line.strip()
if not line:
continue
line = line[0].upper() + line[1:]
pretty_text += line + "\n"
# Return the generated poem
return pretty_text
generate_poem_interface = gr.Interface(
title="Làm thơ theo yêu cầu",
fn=generate_poem,
inputs=[
gr.components.Textbox(
lines=1,
placeholder="Làm thơ với thể thơ tám chữ và tiêu đề mùa xuân nho nhỏ",
label="Yêu cầu về thể thơ và tiêu đề",
),
],
outputs="text",
)