import string from functools import cached_property from typing import List, Optional, Tuple import tokenizers class Tokenizer: """Simple wrapper around a tokenizers.Tokenizer.""" def __init__( self, tokenizer: tokenizers.Tokenizer, multilingual: bool, task: Optional[str] = None, language: Optional[str] = None, ): self.tokenizer = tokenizer if multilingual: if task not in _TASKS: raise ValueError( "'%s' is not a valid task (accepted tasks: %s)" % (task, ", ".join(_TASKS)) ) if language not in _LANGUAGE_CODES: raise ValueError( "'%s' is not a valid language code (accepted language codes: %s)" % (language, ", ".join(_LANGUAGE_CODES)) ) self.task = self.tokenizer.token_to_id("<|%s|>" % task) self.language = self.tokenizer.token_to_id("<|%s|>" % language) self.language_code = language else: self.task = None self.language = None self.language_code = "en" @cached_property def transcribe(self) -> int: return self.tokenizer.token_to_id("<|transcribe|>") @cached_property def translate(self) -> int: return self.tokenizer.token_to_id("<|translate|>") @cached_property def sot(self) -> int: return self.tokenizer.token_to_id("<|startoftranscript|>") @cached_property def sot_lm(self) -> int: return self.tokenizer.token_to_id("<|startoflm|>") @cached_property def sot_prev(self) -> int: return self.tokenizer.token_to_id("<|startofprev|>") @cached_property def eot(self) -> int: return self.tokenizer.token_to_id("<|endoftext|>") @cached_property def no_timestamps(self) -> int: return self.tokenizer.token_to_id("<|notimestamps|>") @property def timestamp_begin(self) -> int: return self.no_timestamps + 1 @property def sot_sequence(self) -> List[int]: sequence = [self.sot] if self.language is not None: sequence.append(self.language) if self.task is not None: sequence.append(self.task) return sequence def encode(self, text: str) -> List[int]: return self.tokenizer.encode(text, add_special_tokens=False).ids def decode(self, tokens: List[int]) -> str: text_tokens = [token for token in tokens if token < self.eot] return self.tokenizer.decode(text_tokens) def decode_with_timestamps(self, tokens: List[int]) -> str: outputs = [[]] for token in tokens: if token >= self.timestamp_begin: timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>" outputs.append(timestamp) outputs.append([]) else: outputs[-1].append(token) return "".join( [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs] ) def split_to_word_tokens( self, tokens: List[int] ) -> Tuple[List[str], List[List[int]]]: if self.language_code in {"zh", "ja", "th", "lo", "my", "yue"}: # These languages don't typically use spaces, so it is difficult to split words # without morpheme analysis. Here, we instead split words at any # position where the tokens are decoded as valid unicode points return self.split_tokens_on_unicode(tokens) return self.split_tokens_on_spaces(tokens) def split_tokens_on_unicode( self, tokens: List[int] ) -> Tuple[List[str], List[List[int]]]: decoded_full = self.decode_with_timestamps(tokens) replacement_char = "\ufffd" words = [] word_tokens = [] current_tokens = [] unicode_offset = 0 for token in tokens: current_tokens.append(token) decoded = self.decode_with_timestamps(current_tokens) try: replacement_char_index = decoded.index(replacement_char) replacement_char_index += unicode_offset except ValueError: replacement_char_index = None if replacement_char_index is None or ( replacement_char_index < len(decoded_full) and decoded_full[replacement_char_index] == replacement_char ): words.append(decoded) word_tokens.append(current_tokens) current_tokens = [] unicode_offset += len(decoded) return words, word_tokens def split_tokens_on_spaces( self, tokens: List[int] ) -> Tuple[List[str], List[List[int]]]: subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens) words = [] word_tokens = [] for subword, subword_tokens in zip(subwords, subword_tokens_list): special = subword_tokens[0] >= self.eot with_space = subword.startswith(" ") punctuation = subword.strip() in string.punctuation if special or with_space or punctuation or len(words) == 0: words.append(subword) word_tokens.append(subword_tokens) else: words[-1] = words[-1] + subword word_tokens[-1].extend(subword_tokens) return words, word_tokens _TASKS = ( "transcribe", "translate", ) _LANGUAGE_CODES = ( "af", "am", "ar", "as", "az", "ba", "be", "bg", "bn", "bo", "br", "bs", "ca", "cs", "cy", "da", "de", "el", "en", "es", "et", "eu", "fa", "fi", "fo", "fr", "gl", "gu", "ha", "haw", "he", "hi", "hr", "ht", "hu", "hy", "id", "is", "it", "ja", "jw", "ka", "kk", "km", "kn", "ko", "la", "lb", "ln", "lo", "lt", "lv", "mg", "mi", "mk", "ml", "mn", "mr", "ms", "mt", "my", "ne", "nl", "nn", "no", "oc", "pa", "pl", "ps", "pt", "ro", "ru", "sa", "sd", "si", "sk", "sl", "sn", "so", "sq", "sr", "su", "sv", "sw", "ta", "te", "tg", "th", "tk", "tl", "tr", "tt", "uk", "ur", "uz", "vi", "yi", "yo", "zh", "yue", )