phamson02 commited on
Commit
9087a07
1 Parent(s): 59e5fd8
Files changed (1) hide show
  1. generate_poem.py +65 -3
generate_poem.py CHANGED
@@ -1,14 +1,76 @@
 
1
  import gradio as gr
 
2
 
 
 
3
 
4
- def generate_poem(text):
5
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  generate_poem_interface = gr.Interface(
 
9
  fn=generate_poem,
10
  inputs=[
11
- gr.components.Textbox(lines=1, placeholder="First words of the poem"),
 
 
 
 
12
  ],
13
  outputs="text",
14
  )
 
1
+ import torch
2
  import gradio as gr
3
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
4
 
5
+ tokenizer = T5Tokenizer.from_pretrained("VietAI/vit5-base")
6
+ model = T5ForConditionalGeneration.from_pretrained("Libosa2707/vietnamese-poem-t5")
7
 
8
+
9
+ def generate_poem(input_text):
10
+ # Define the parameters for the generate function
11
+ min_length = 50
12
+ max_length = 100
13
+ rep_penalty = 1.2
14
+ temp = 0.7
15
+ top_k = 50
16
+ top_p = 0.92
17
+ no_repeat_ngram_size = 2
18
+
19
+ # Tokenize the input
20
+ input_ids = tokenizer(
21
+ input_text,
22
+ return_tensors="pt",
23
+ padding="max_length",
24
+ truncation=True,
25
+ max_length=42,
26
+ ).input_ids.to(model.device)
27
+
28
+ # Generate text
29
+ model.eval()
30
+ with torch.no_grad():
31
+ output = model.generate(
32
+ do_sample=True,
33
+ input_ids=input_ids,
34
+ min_length=min_length,
35
+ max_length=max_length,
36
+ top_p=top_p,
37
+ top_k=top_k,
38
+ temperature=temp,
39
+ repetition_penalty=rep_penalty,
40
+ no_repeat_ngram_size=no_repeat_ngram_size,
41
+ num_return_sequences=1,
42
+ )
43
+
44
+ # Process the generated text
45
+ gen = tokenizer.decode(
46
+ output[0], skip_special_tokens=False, clean_up_tokenization_spaces=False
47
+ )
48
+ sentences = gen.split("<unk>")
49
+ gen_poem = "\n".join(sentences).replace("<pad>", "").replace("</s>", "")
50
+ gen_poem = gen_poem.strip()
51
+
52
+ # Post-process the poem text
53
+ pretty_text = ""
54
+ for line in gen_poem.split("\n"):
55
+ line = line.strip()
56
+ if not line:
57
+ continue
58
+ line = line[0].upper() + line[1:]
59
+ pretty_text += line + "\n"
60
+
61
+ # Return the generated poem
62
+ return pretty_text
63
 
64
 
65
  generate_poem_interface = gr.Interface(
66
+ title="Làm thơ theo yêu cầu",
67
  fn=generate_poem,
68
  inputs=[
69
+ gr.components.Textbox(
70
+ lines=1,
71
+ placeholder="Làm thơ với thể thơ tám chữ và tiêu đề mùa xuân nho nhỏ",
72
+ label="Yêu cầu về thể thơ và tiêu đề",
73
+ ),
74
  ],
75
  outputs="text",
76
  )