|
import math
|
|
from conjunctions import get_conjunctions, get_comma
|
|
from typing import TextIO
|
|
|
|
def normal_round(n):
|
|
if n - math.floor(n) < 0.5:
|
|
return math.floor(n)
|
|
return math.ceil(n)
|
|
|
|
|
|
def format_timestamp(seconds: float, is_vtt: bool = False):
|
|
|
|
assert seconds >= 0, "non-negative timestamp expected"
|
|
milliseconds = round(seconds * 1000.0)
|
|
|
|
hours = milliseconds // 3_600_000
|
|
milliseconds -= hours * 3_600_000
|
|
|
|
minutes = milliseconds // 60_000
|
|
milliseconds -= minutes * 60_000
|
|
|
|
seconds = milliseconds // 1_000
|
|
milliseconds -= seconds * 1_000
|
|
|
|
separator = '.' if is_vtt else ','
|
|
|
|
hours_marker = f"{hours:02d}:"
|
|
return (
|
|
f"{hours_marker}{minutes:02d}:{seconds:02d}{separator}{milliseconds:03d}"
|
|
)
|
|
|
|
|
|
|
|
class SubtitlesProcessor:
|
|
def __init__(self, segments, lang, max_line_length = 45, min_char_length_splitter = 30, is_vtt = False):
|
|
self.comma = get_comma(lang)
|
|
self.conjunctions = set(get_conjunctions(lang))
|
|
self.segments = segments
|
|
self.lang = lang
|
|
self.max_line_length = max_line_length
|
|
self.min_char_length_splitter = min_char_length_splitter
|
|
self.is_vtt = is_vtt
|
|
complex_script_languages = ['th', 'lo', 'my', 'km', 'am', 'ko', 'ja', 'zh', 'ti', 'ta', 'te', 'kn', 'ml', 'hi', 'ne', 'mr', 'ar', 'fa', 'ur', 'ka']
|
|
if self.lang in complex_script_languages:
|
|
self.max_line_length = 30
|
|
self.min_char_length_splitter = 20
|
|
|
|
def estimate_timestamp_for_word(self, words, i, next_segment_start_time=None):
|
|
k = 0.25
|
|
has_prev_end = i > 0 and 'end' in words[i - 1]
|
|
has_next_start = i < len(words) - 1 and 'start' in words[i + 1]
|
|
|
|
if has_prev_end:
|
|
words[i]['start'] = words[i - 1]['end']
|
|
if has_next_start:
|
|
words[i]['end'] = words[i + 1]['start']
|
|
else:
|
|
if next_segment_start_time:
|
|
words[i]['end'] = next_segment_start_time if next_segment_start_time - words[i - 1]['end'] <= 1 else next_segment_start_time - 0.5
|
|
else:
|
|
words[i]['end'] = words[i]['start'] + len(words[i]['word']) * k
|
|
|
|
elif has_next_start:
|
|
words[i]['start'] = words[i + 1]['start'] - len(words[i]['word']) * k
|
|
words[i]['end'] = words[i + 1]['start']
|
|
|
|
else:
|
|
if next_segment_start_time:
|
|
words[i]['start'] = next_segment_start_time - 1
|
|
words[i]['end'] = next_segment_start_time - 0.5
|
|
else:
|
|
words[i]['start'] = 0
|
|
words[i]['end'] = 0
|
|
|
|
|
|
|
|
def process_segments(self, advanced_splitting=True):
|
|
subtitles = []
|
|
for i, segment in enumerate(self.segments):
|
|
next_segment_start_time = self.segments[i + 1]['start'] if i + 1 < len(self.segments) else None
|
|
|
|
if advanced_splitting:
|
|
|
|
split_points = self.determine_advanced_split_points(segment, next_segment_start_time)
|
|
subtitles.extend(self.generate_subtitles_from_split_points(segment, split_points, next_segment_start_time))
|
|
else:
|
|
words = segment['words']
|
|
for i, word in enumerate(words):
|
|
if 'start' not in word or 'end' not in word:
|
|
self.estimate_timestamp_for_word(words, i, next_segment_start_time)
|
|
|
|
subtitles.append({
|
|
'start': segment['start'],
|
|
'end': segment['end'],
|
|
'text': segment['text']
|
|
})
|
|
|
|
return subtitles
|
|
|
|
def determine_advanced_split_points(self, segment, next_segment_start_time=None):
|
|
split_points = []
|
|
last_split_point = 0
|
|
char_count = 0
|
|
|
|
words = segment.get('words', segment['text'].split())
|
|
add_space = 0 if self.lang in ['zh', 'ja'] else 1
|
|
|
|
total_char_count = sum(len(word['word']) if isinstance(word, dict) else len(word) + add_space for word in words)
|
|
char_count_after = total_char_count
|
|
|
|
for i, word in enumerate(words):
|
|
word_text = word['word'] if isinstance(word, dict) else word
|
|
word_length = len(word_text) + add_space
|
|
char_count += word_length
|
|
char_count_after -= word_length
|
|
|
|
char_count_before = char_count - word_length
|
|
|
|
if isinstance(word, dict) and ('start' not in word or 'end' not in word):
|
|
self.estimate_timestamp_for_word(words, i, next_segment_start_time)
|
|
|
|
if char_count >= self.max_line_length:
|
|
midpoint = normal_round((last_split_point + i) / 2)
|
|
if char_count_before >= self.min_char_length_splitter:
|
|
split_points.append(midpoint)
|
|
last_split_point = midpoint + 1
|
|
char_count = sum(len(words[j]['word']) if isinstance(words[j], dict) else len(words[j]) + add_space for j in range(last_split_point, i + 1))
|
|
|
|
elif word_text.endswith(self.comma) and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter:
|
|
split_points.append(i)
|
|
last_split_point = i + 1
|
|
char_count = 0
|
|
|
|
elif word_text.lower() in self.conjunctions and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter:
|
|
split_points.append(i - 1)
|
|
last_split_point = i
|
|
char_count = word_length
|
|
|
|
return split_points
|
|
|
|
|
|
def generate_subtitles_from_split_points(self, segment, split_points, next_start_time=None):
|
|
subtitles = []
|
|
|
|
words = segment.get('words', segment['text'].split())
|
|
total_word_count = len(words)
|
|
total_time = segment['end'] - segment['start']
|
|
elapsed_time = segment['start']
|
|
prefix = ' ' if self.lang not in ['zh', 'ja'] else ''
|
|
start_idx = 0
|
|
for split_point in split_points:
|
|
|
|
fragment_words = words[start_idx:split_point + 1]
|
|
current_word_count = len(fragment_words)
|
|
|
|
|
|
if isinstance(fragment_words[0], dict):
|
|
start_time = fragment_words[0]['start']
|
|
end_time = fragment_words[-1]['end']
|
|
next_start_time_for_word = words[split_point + 1]['start'] if split_point + 1 < len(words) else None
|
|
if next_start_time_for_word and (next_start_time_for_word - end_time) <= 0.8:
|
|
end_time = next_start_time_for_word
|
|
else:
|
|
fragment = prefix.join(fragment_words).strip()
|
|
current_duration = (current_word_count / total_word_count) * total_time
|
|
start_time = elapsed_time
|
|
end_time = elapsed_time + current_duration
|
|
elapsed_time += current_duration
|
|
|
|
|
|
subtitles.append({
|
|
'start': start_time,
|
|
'end': end_time,
|
|
'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words)
|
|
})
|
|
|
|
start_idx = split_point + 1
|
|
|
|
|
|
if start_idx < len(words):
|
|
fragment_words = words[start_idx:]
|
|
current_word_count = len(fragment_words)
|
|
|
|
if isinstance(fragment_words[0], dict):
|
|
start_time = fragment_words[0]['start']
|
|
end_time = fragment_words[-1]['end']
|
|
else:
|
|
fragment = prefix.join(fragment_words).strip()
|
|
current_duration = (current_word_count / total_word_count) * total_time
|
|
start_time = elapsed_time
|
|
end_time = elapsed_time + current_duration
|
|
|
|
if next_start_time and (next_start_time - end_time) <= 0.8:
|
|
end_time = next_start_time
|
|
|
|
subtitles.append({
|
|
'start': start_time,
|
|
'end': end_time if end_time is not None else segment['end'],
|
|
'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words)
|
|
})
|
|
|
|
return subtitles
|
|
|
|
|
|
|
|
def save(self, filename="subtitles.srt", advanced_splitting=True):
|
|
|
|
subtitles = self.process_segments(advanced_splitting)
|
|
|
|
def write_subtitle(file, idx, start_time, end_time, text):
|
|
|
|
file.write(f"{idx}\n")
|
|
file.write(f"{start_time} --> {end_time}\n")
|
|
file.write(text + "\n\n")
|
|
|
|
with open(filename, 'w', encoding='utf-8') as file:
|
|
if self.is_vtt:
|
|
file.write("WEBVTT\n\n")
|
|
|
|
if advanced_splitting:
|
|
for idx, subtitle in enumerate(subtitles, 1):
|
|
start_time = format_timestamp(subtitle['start'], self.is_vtt)
|
|
end_time = format_timestamp(subtitle['end'], self.is_vtt)
|
|
text = subtitle['text'].strip()
|
|
write_subtitle(file, idx, start_time, end_time, text)
|
|
|
|
return len(subtitles) |