|
import string |
|
|
|
from functools import cached_property |
|
from typing import List, Optional, Tuple |
|
|
|
import tokenizers |
|
|
|
|
|
class Tokenizer: |
|
"""Simple wrapper around a tokenizers.Tokenizer.""" |
|
|
|
def __init__( |
|
self, |
|
tokenizer: tokenizers.Tokenizer, |
|
multilingual: bool, |
|
task: Optional[str] = None, |
|
language: Optional[str] = None, |
|
): |
|
self.tokenizer = tokenizer |
|
|
|
if multilingual: |
|
if task not in _TASKS: |
|
raise ValueError( |
|
"'%s' is not a valid task (accepted tasks: %s)" |
|
% (task, ", ".join(_TASKS)) |
|
) |
|
|
|
if language not in _LANGUAGE_CODES: |
|
raise ValueError( |
|
"'%s' is not a valid language code (accepted language codes: %s)" |
|
% (language, ", ".join(_LANGUAGE_CODES)) |
|
) |
|
|
|
self.task = self.tokenizer.token_to_id("<|%s|>" % task) |
|
self.language = self.tokenizer.token_to_id("<|%s|>" % language) |
|
self.language_code = language |
|
else: |
|
self.task = None |
|
self.language = None |
|
self.language_code = "en" |
|
|
|
@cached_property |
|
def transcribe(self) -> int: |
|
return self.tokenizer.token_to_id("<|transcribe|>") |
|
|
|
@cached_property |
|
def translate(self) -> int: |
|
return self.tokenizer.token_to_id("<|translate|>") |
|
|
|
@cached_property |
|
def sot(self) -> int: |
|
return self.tokenizer.token_to_id("<|startoftranscript|>") |
|
|
|
@cached_property |
|
def sot_lm(self) -> int: |
|
return self.tokenizer.token_to_id("<|startoflm|>") |
|
|
|
@cached_property |
|
def sot_prev(self) -> int: |
|
return self.tokenizer.token_to_id("<|startofprev|>") |
|
|
|
@cached_property |
|
def eot(self) -> int: |
|
return self.tokenizer.token_to_id("<|endoftext|>") |
|
|
|
@cached_property |
|
def no_timestamps(self) -> int: |
|
return self.tokenizer.token_to_id("<|notimestamps|>") |
|
|
|
@property |
|
def timestamp_begin(self) -> int: |
|
return self.no_timestamps + 1 |
|
|
|
@property |
|
def sot_sequence(self) -> List[int]: |
|
sequence = [self.sot] |
|
|
|
if self.language is not None: |
|
sequence.append(self.language) |
|
|
|
if self.task is not None: |
|
sequence.append(self.task) |
|
|
|
return sequence |
|
|
|
def encode(self, text: str) -> List[int]: |
|
return self.tokenizer.encode(text, add_special_tokens=False).ids |
|
|
|
def decode(self, tokens: List[int]) -> str: |
|
text_tokens = [token for token in tokens if token < self.eot] |
|
return self.tokenizer.decode(text_tokens) |
|
|
|
def decode_with_timestamps(self, tokens: List[int]) -> str: |
|
outputs = [[]] |
|
|
|
for token in tokens: |
|
if token >= self.timestamp_begin: |
|
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>" |
|
outputs.append(timestamp) |
|
outputs.append([]) |
|
else: |
|
outputs[-1].append(token) |
|
|
|
return "".join( |
|
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs] |
|
) |
|
|
|
def split_to_word_tokens( |
|
self, tokens: List[int] |
|
) -> Tuple[List[str], List[List[int]]]: |
|
if self.language_code in {"zh", "ja", "th", "lo", "my", "yue"}: |
|
|
|
|
|
|
|
return self.split_tokens_on_unicode(tokens) |
|
|
|
return self.split_tokens_on_spaces(tokens) |
|
|
|
def split_tokens_on_unicode( |
|
self, tokens: List[int] |
|
) -> Tuple[List[str], List[List[int]]]: |
|
decoded_full = self.decode_with_timestamps(tokens) |
|
replacement_char = "\ufffd" |
|
|
|
words = [] |
|
word_tokens = [] |
|
current_tokens = [] |
|
unicode_offset = 0 |
|
|
|
for token in tokens: |
|
current_tokens.append(token) |
|
decoded = self.decode_with_timestamps(current_tokens) |
|
|
|
try: |
|
replacement_char_index = decoded.index(replacement_char) |
|
replacement_char_index += unicode_offset |
|
except ValueError: |
|
replacement_char_index = None |
|
|
|
if replacement_char_index is None or ( |
|
replacement_char_index < len(decoded_full) |
|
and decoded_full[replacement_char_index] == replacement_char |
|
): |
|
words.append(decoded) |
|
word_tokens.append(current_tokens) |
|
current_tokens = [] |
|
unicode_offset += len(decoded) |
|
|
|
return words, word_tokens |
|
|
|
def split_tokens_on_spaces( |
|
self, tokens: List[int] |
|
) -> Tuple[List[str], List[List[int]]]: |
|
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens) |
|
words = [] |
|
word_tokens = [] |
|
|
|
for subword, subword_tokens in zip(subwords, subword_tokens_list): |
|
special = subword_tokens[0] >= self.eot |
|
with_space = subword.startswith(" ") |
|
punctuation = subword.strip() in string.punctuation |
|
if special or with_space or punctuation or len(words) == 0: |
|
words.append(subword) |
|
word_tokens.append(subword_tokens) |
|
else: |
|
words[-1] = words[-1] + subword |
|
word_tokens[-1].extend(subword_tokens) |
|
|
|
return words, word_tokens |
|
|
|
|
|
_TASKS = ( |
|
"transcribe", |
|
"translate", |
|
) |
|
|
|
_LANGUAGE_CODES = ( |
|
"af", |
|
"am", |
|
"ar", |
|
"as", |
|
"az", |
|
"ba", |
|
"be", |
|
"bg", |
|
"bn", |
|
"bo", |
|
"br", |
|
"bs", |
|
"ca", |
|
"cs", |
|
"cy", |
|
"da", |
|
"de", |
|
"el", |
|
"en", |
|
"es", |
|
"et", |
|
"eu", |
|
"fa", |
|
"fi", |
|
"fo", |
|
"fr", |
|
"gl", |
|
"gu", |
|
"ha", |
|
"haw", |
|
"he", |
|
"hi", |
|
"hr", |
|
"ht", |
|
"hu", |
|
"hy", |
|
"id", |
|
"is", |
|
"it", |
|
"ja", |
|
"jw", |
|
"ka", |
|
"kk", |
|
"km", |
|
"kn", |
|
"ko", |
|
"la", |
|
"lb", |
|
"ln", |
|
"lo", |
|
"lt", |
|
"lv", |
|
"mg", |
|
"mi", |
|
"mk", |
|
"ml", |
|
"mn", |
|
"mr", |
|
"ms", |
|
"mt", |
|
"my", |
|
"ne", |
|
"nl", |
|
"nn", |
|
"no", |
|
"oc", |
|
"pa", |
|
"pl", |
|
"ps", |
|
"pt", |
|
"ro", |
|
"ru", |
|
"sa", |
|
"sd", |
|
"si", |
|
"sk", |
|
"sl", |
|
"sn", |
|
"so", |
|
"sq", |
|
"sr", |
|
"su", |
|
"sv", |
|
"sw", |
|
"ta", |
|
"te", |
|
"tg", |
|
"th", |
|
"tk", |
|
"tl", |
|
"tr", |
|
"tt", |
|
"uk", |
|
"ur", |
|
"uz", |
|
"vi", |
|
"yi", |
|
"yo", |
|
"zh", |
|
"yue", |
|
) |
|
|