from itertools import chain from typing import List, Union from transformers import ByT5Tokenizer import numpy as np import torch from surya.model.recognition.config import LANGUAGE_MAP, TOTAL_TOKENS, TOKEN_OFFSET def text_to_utf16_numbers(text): utf16_bytes = text.encode('utf-16le') # Little-endian to simplify byte order handling numbers = [] # Iterate through each pair of bytes and combine them into a single number for i in range(0, len(utf16_bytes), 2): # Combine two adjacent bytes into a single number number = utf16_bytes[i] + (utf16_bytes[i + 1] << 8) numbers.append(number) return numbers def utf16_numbers_to_text(numbers): byte_array = bytearray() for number in numbers: # Extract the two bytes from the number and add them to the byte array byte_array.append(number & 0xFF) # Lower byte byte_array.append((number >> 8) & 0xFF) # Upper byte text = byte_array.decode('utf-16le', errors="ignore") return text def _tokenize(text: str, langs: List[str], eos_token_id: int = 1, add_eos: bool = True, add_bos: bool = True): tokens = text_to_utf16_numbers(text) tokens = [t + TOKEN_OFFSET for t in tokens] # Account for special pad, etc, tokens lang_list = [] for lang in langs: code = LANGUAGE_MAP[lang] lang_list.append(code + TOKEN_OFFSET + TOTAL_TOKENS) tokens = lang_list + tokens if add_eos: tokens.append(eos_token_id) if add_bos: tokens.insert(0, eos_token_id) return tokens, lang_list class Byt5LangTokenizer(ByT5Tokenizer): def __init__(self, eos_token="", unk_token="", pad_token="", model_max_length=None, **kwargs, ): self.pad_token = pad_token self.eos_token = eos_token self.unk_token = unk_token self.bos_token = eos_token self.offset = TOKEN_OFFSET self.pad_id = 0 self.eos_id = 1 self.unk_id = 2 self.model_max_length = model_max_length self.special_token_start = TOKEN_OFFSET + TOTAL_TOKENS super().__init__() def __call__(self, texts: Union[List[str], str], langs: Union[List[List[str]], List[str]], pad_token_id: int = 0, **kwargs): tokenized = [] all_langs = [] is_list = True # Convert to list of lists format if isinstance(texts, str): texts = [texts] is_list = False if isinstance(langs[0], str): langs = [langs] # One language input per text input assert len(langs) == len(texts) for text, lang in zip(texts, langs): tokens, lang_list = _tokenize(text, lang) tokenized.append(tokens) all_langs.append(lang_list) # Convert back to flat format if not is_list: tokenized = tokenized[0] all_langs = all_langs[0] return {"input_ids": tokenized, "langs": all_langs} def decode( self, token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = None, **kwargs, ) -> str: if isinstance(token_ids, (np.ndarray, torch.Tensor)): token_ids = token_ids.tolist() token_ids = [t for t in token_ids if TOKEN_OFFSET <= t < self.special_token_start] token_ids = [t - TOKEN_OFFSET for t in token_ids] text = utf16_numbers_to_text(token_ids) return text