minnehwg commited on
Commit
a7b7e94
1 Parent(s): f0b497f

Create util.py

Browse files
Files changed (1) hide show
  1. util.py +141 -0
util.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from datasets import Dataset
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, TrainingArguments
4
+ from youtube_transcript_api import YouTubeTranscriptApi
5
+ from deepmultilingualpunctuation import PunctuationModel
6
+ from googletrans import Translator
7
+ import time
8
+ import torch
9
+ import re
10
+
11
+
12
+ cp_aug = 'minnehwg/finetune-newwiki-summarization-ver-augmented2'
13
+
14
+ def load_model(cp):
15
+ tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")
16
+ model = AutoModelForSeq2SeqLM.from_pretrained(cp)
17
+ return tokenizer, model
18
+
19
+
20
+ def summarize(text, model, tokenizer, num_beams=4, device='cpu'):
21
+ model.to(device)
22
+ inputs = tokenizer.encode(text, return_tensors="pt", max_length=1024, truncation=True, padding = True).to(device)
23
+
24
+ with torch.no_grad():
25
+ summary_ids = model.generate(inputs, max_length=256, num_beams=num_beams)
26
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
27
+
28
+ return summary
29
+
30
+
31
+ def processed(text):
32
+ processed_text = text.replace('\n', ' ')
33
+ processed_text = processed_text.lower()
34
+ return processed_text
35
+
36
+
37
+ def get_subtitles(video_url):
38
+ try:
39
+ video_id = video_url.split("v=")[1]
40
+ transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en'])
41
+ subs = " ".join(entry['text'] for entry in transcript)
42
+
43
+ return transcript, subs
44
+
45
+ except Exception as e:
46
+ return [], f"An error occurred: {e}"
47
+
48
+
49
+ def restore_punctuation(text):
50
+ model = PunctuationModel()
51
+ result = model.restore_punctuation(text)
52
+ return result
53
+
54
+
55
+ def translate_long(text, language='vi'):
56
+ translator = Translator()
57
+ limit = 4700
58
+ chunks = []
59
+ current_chunk = ''
60
+
61
+ sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
62
+
63
+ for sentence in sentences:
64
+ if len(current_chunk) + len(sentence) <= limit:
65
+ current_chunk += sentence.strip() + ' '
66
+ else:
67
+ chunks.append(current_chunk.strip())
68
+ current_chunk = sentence.strip() + ' '
69
+
70
+ if current_chunk:
71
+ chunks.append(current_chunk.strip())
72
+
73
+ translated_text = ''
74
+
75
+ for chunk in chunks:
76
+ try:
77
+ time.sleep(1)
78
+ translation = translator.translate(chunk, dest=language)
79
+ translated_text += translation.text + ' '
80
+ except Exception as e:
81
+ translated_text += chunk + ' '
82
+
83
+ return translated_text.strip()
84
+
85
+ def split_into_chunks(text, max_words=800, overlap_sentences=2):
86
+ sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
87
+
88
+ chunks = []
89
+ current_chunk = []
90
+ current_word_count = 0
91
+
92
+ for sentence in sentences:
93
+ word_count = len(sentence.split())
94
+ if current_word_count + word_count <= max_words:
95
+ current_chunk.append(sentence)
96
+ current_word_count += word_count
97
+ else:
98
+ if len(current_chunk) >= overlap_sentences:
99
+ overlap = current_chunk[-overlap_sentences:]
100
+ print(f"Overlapping sentences: {' '.join(overlap)}")
101
+ chunks.append(' '.join(current_chunk))
102
+ current_chunk = current_chunk[-overlap_sentences:] + [sentence]
103
+ current_word_count = sum(len(sent.split()) for sent in current_chunk)
104
+ if current_chunk:
105
+ if len(current_chunk) >= overlap_sentences:
106
+ overlap = current_chunk[-overlap_sentences:]
107
+ print(f"Overlapping sentences: {' '.join(overlap)}")
108
+ chunks.append(' '.join(current_chunk))
109
+
110
+ return chunks
111
+
112
+
113
+ def post_processing(text):
114
+ sentences = re.split(r'(?<=[.!?])\s*', text)
115
+ for i in range(len(sentences)):
116
+ if sentences[i]:
117
+ sentences[i] = sentences[i][0].upper() + sentences[i][1:]
118
+ text = " ".join(sentences)
119
+ return text
120
+
121
+ def display(text):
122
+ sentences = re.split(r'(?<=[.!?])\s*', text)
123
+ uni = list(dict.fromkeys(sentences[:-1]))
124
+ for sentence in uni:
125
+ print(f"• {sentence}")
126
+
127
+ def pipeline(url):
128
+ trans, sub = get_subtitles(url)
129
+ sub = restore_punctuation(sub)
130
+ vie_sub = translate_long(sub)
131
+ vie_sub = processed(vie_sub)
132
+ chunks = split_into_chunks(vie_sub, 700, 3)
133
+ sum_para = []
134
+ for i in chunks:
135
+ tmp = summarize(i, model_aug, tokenizer, num_beams=4)
136
+ sum_para.append(tmp)
137
+ sum = ''.join(sum_para)
138
+ del sub, vie_sub, sum_para, chunks
139
+ sum = post_processing(sum)
140
+ display(sum)
141
+