File size: 2,098 Bytes
9087a07
ebf036f
9087a07
ebf036f
9087a07
 
ebf036f
9087a07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebf036f
 
 
9087a07
ebf036f
 
9087a07
 
 
 
 
ebf036f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import gradio as gr
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("VietAI/vit5-base")
model = T5ForConditionalGeneration.from_pretrained("Libosa2707/vietnamese-poem-t5")


def generate_poem(input_text):
    # Define the parameters for the generate function
    min_length = 50
    max_length = 100
    rep_penalty = 1.2
    temp = 0.7
    top_k = 50
    top_p = 0.92
    no_repeat_ngram_size = 2

    # Tokenize the input
    input_ids = tokenizer(
        input_text,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=42,
    ).input_ids.to(model.device)

    # Generate text
    model.eval()
    with torch.no_grad():
        output = model.generate(
            do_sample=True,
            input_ids=input_ids,
            min_length=min_length,
            max_length=max_length,
            top_p=top_p,
            top_k=top_k,
            temperature=temp,
            repetition_penalty=rep_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            num_return_sequences=1,
        )

    # Process the generated text
    gen = tokenizer.decode(
        output[0], skip_special_tokens=False, clean_up_tokenization_spaces=False
    )
    sentences = gen.split("<unk>")
    gen_poem = "\n".join(sentences).replace("<pad>", "").replace("</s>", "")
    gen_poem = gen_poem.strip()

    # Post-process the poem text
    pretty_text = ""
    for line in gen_poem.split("\n"):
        line = line.strip()
        if not line:
            continue
        line = line[0].upper() + line[1:]
        pretty_text += line + "\n"

    # Return the generated poem
    return pretty_text


generate_poem_interface = gr.Interface(
    title="Làm thơ theo yêu cầu",
    fn=generate_poem,
    inputs=[
        gr.components.Textbox(
            lines=1,
            placeholder="Làm thơ với thể thơ tám chữ và tiêu đề mùa xuân nho nhỏ",
            label="Yêu cầu về thể thơ và tiêu đề",
        ),
    ],
    outputs="text",
)