import hashlib import os import uuid from typing import List, Tuple, Union, Dict import regex as re import sentencepiece as spm from indicnlp.normalize import indic_normalize from indicnlp.tokenize import indic_detokenize, indic_tokenize from indicnlp.tokenize.sentence_tokenize import DELIM_PAT_NO_DANDA, sentence_split from indicnlp.transliterate import unicode_transliterate from mosestokenizer import MosesSentenceSplitter from nltk.tokenize import sent_tokenize from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer from tqdm import tqdm from .flores_codes_map_indic import flores_codes, iso_to_flores from .normalize_punctuation import punc_norm from .normalize_regex_inference import EMAIL_PATTERN, normalize def split_sentences(paragraph: str, lang: str) -> List[str]: """ Splits the input text paragraph into sentences. It uses `moses` for English and `indic-nlp` for Indic languages. Args: paragraph (str): input text paragraph. lang (str): flores language code. Returns: List[str] -> list of sentences. """ if lang == "eng_Latn": with MosesSentenceSplitter(flores_codes[lang]) as splitter: sents_moses = splitter([paragraph]) sents_nltk = sent_tokenize(paragraph) if len(sents_nltk) < len(sents_moses): sents = sents_nltk else: sents = sents_moses return [sent.replace("\xad", "") for sent in sents] else: return sentence_split(paragraph, lang=flores_codes[lang], delim_pat=DELIM_PAT_NO_DANDA) def add_token(sent: str, src_lang: str, tgt_lang: str, delimiter: str = " ") -> str: """ Add special tokens indicating source and target language to the start of the input sentence. The resulting string will have the format: "`{src_lang} {tgt_lang} {input_sentence}`". Args: sent (str): input sentence to be translated. src_lang (str): flores lang code of the input sentence. tgt_lang (str): flores lang code in which the input sentence will be translated. delimiter (str): separator to add between language tags and input sentence (default: " "). Returns: str: input sentence with the special tokens added to the start. """ return src_lang + delimiter + tgt_lang + delimiter + sent def apply_lang_tags(sents: List[str], src_lang: str, tgt_lang: str) -> List[str]: """ Add special tokens indicating source and target language to the start of the each input sentence. Each resulting input sentence will have the format: "`{src_lang} {tgt_lang} {input_sentence}`". Args: sent (str): input sentence to be translated. src_lang (str): flores lang code of the input sentence. tgt_lang (str): flores lang code in which the input sentence will be translated. Returns: List[str]: list of input sentences with the special tokens added to the start. """ tagged_sents = [] for sent in sents: tagged_sent = add_token(sent.strip(), src_lang, tgt_lang) tagged_sents.append(tagged_sent) return tagged_sents def truncate_long_sentences( sents: List[str], placeholder_entity_map_sents: List[Dict] ) -> Tuple[List[str], List[Dict]]: """ Truncates the sentences that exceed the maximum sequence length. The maximum sequence for the IndicTrans2 model is limited to 256 tokens. Args: sents (List[str]): list of input sentences to truncate. Returns: Tuple[List[str], List[Dict]]: tuple containing the list of sentences with truncation applied and the updated placeholder entity maps. """ MAX_SEQ_LEN = 256 new_sents = [] placeholders = [] for j, sent in enumerate(sents): words = sent.split() num_words = len(words) if num_words > MAX_SEQ_LEN: sents = [] i = 0 while i <= len(words): sents.append(" ".join(words[i : i + MAX_SEQ_LEN])) i += MAX_SEQ_LEN placeholders.extend([placeholder_entity_map_sents[j]] * (len(sents))) new_sents.extend(sents) else: placeholders.append(placeholder_entity_map_sents[j]) new_sents.append(sent) return new_sents, placeholders class Model: """ Model class to run the IndicTransv2 models using python interface. """ def __init__( self, ckpt_dir: str, device: str = "cuda", input_lang_code_format: str = "flores", model_type: str = "ctranslate2", ): """ Initialize the model class. Args: ckpt_dir (str): path of the model checkpoint directory. device (str, optional): where to load the model (defaults: cuda). """ self.ckpt_dir = ckpt_dir self.en_tok = MosesTokenizer(lang="en") self.en_normalizer = MosesPunctNormalizer() self.en_detok = MosesDetokenizer(lang="en") self.xliterator = unicode_transliterate.UnicodeIndicTransliterator() print("Initializing sentencepiece model for SRC and TGT") self.sp_src = spm.SentencePieceProcessor( model_file=os.path.join(ckpt_dir, "vocab", "model.SRC") ) self.sp_tgt = spm.SentencePieceProcessor( model_file=os.path.join(ckpt_dir, "vocab", "model.TGT") ) self.input_lang_code_format = input_lang_code_format print("Initializing model for translation") # initialize the model if model_type == "ctranslate2": import ctranslate2 self.translator = ctranslate2.Translator( self.ckpt_dir, device=device ) # , compute_type="auto") self.translate_lines = self.ctranslate2_translate_lines elif model_type == "fairseq": from .custom_interactive import Translator self.translator = Translator( data_dir=os.path.join(self.ckpt_dir, "final_bin"), checkpoint_path=os.path.join(self.ckpt_dir, "model", "checkpoint_best.pt"), batch_size=100, ) self.translate_lines = self.fairseq_translate_lines else: raise NotImplementedError(f"Unknown model_type: {model_type}") def ctranslate2_translate_lines(self, lines: List[str]) -> List[str]: tokenized_sents = [x.strip().split(" ") for x in lines] translations = self.translator.translate_batch( tokenized_sents, max_batch_size=9216, batch_type="tokens", max_input_length=160, max_decoding_length=256, beam_size=5, ) translations = [" ".join(x.hypotheses[0]) for x in translations] return translations def fairseq_translate_lines(self, lines: List[str]) -> List[str]: return self.translator.translate(lines) def paragraphs_batch_translate__multilingual(self, batch_payloads: List[tuple]) -> List[str]: """ Translates a batch of input paragraphs (including pre/post processing) from any language to any language. Args: batch_payloads (List[tuple]): batch of long input-texts to be translated, each in format: (paragraph, src_lang, tgt_lang) Returns: List[str]: batch of paragraph-translations in the respective languages. """ paragraph_id_to_sentence_range = [] global__sents = [] global__preprocessed_sents = [] global__preprocessed_sents_placeholder_entity_map = [] for i in range(len(batch_payloads)): paragraph, src_lang, tgt_lang = batch_payloads[i] if self.input_lang_code_format == "iso": src_lang, tgt_lang = iso_to_flores[src_lang], iso_to_flores[tgt_lang] batch = split_sentences(paragraph, src_lang) global__sents.extend(batch) preprocessed_sents, placeholder_entity_map_sents = self.preprocess_batch( batch, src_lang, tgt_lang ) global_sentence_start_index = len(global__preprocessed_sents) global__preprocessed_sents.extend(preprocessed_sents) global__preprocessed_sents_placeholder_entity_map.extend(placeholder_entity_map_sents) paragraph_id_to_sentence_range.append( (global_sentence_start_index, len(global__preprocessed_sents)) ) translations = self.translate_lines(global__preprocessed_sents) translated_paragraphs = [] for paragraph_id, sentence_range in enumerate(paragraph_id_to_sentence_range): tgt_lang = batch_payloads[paragraph_id][2] if self.input_lang_code_format == "iso": tgt_lang = iso_to_flores[tgt_lang] postprocessed_sents = self.postprocess( translations[sentence_range[0] : sentence_range[1]], global__preprocessed_sents_placeholder_entity_map[ sentence_range[0] : sentence_range[1] ], tgt_lang, ) translated_paragraph = " ".join(postprocessed_sents) translated_paragraphs.append(translated_paragraph) return translated_paragraphs # translate a batch of sentences from src_lang to tgt_lang def batch_translate(self, batch: List[str], src_lang: str, tgt_lang: str) -> List[str]: """ Translates a batch of input sentences (including pre/post processing) from source language to target language. Args: batch (List[str]): batch of input sentences to be translated. src_lang (str): flores source language code. tgt_lang (str): flores target language code. Returns: List[str]: batch of translated-sentences generated by the model. """ assert isinstance(batch, list) if self.input_lang_code_format == "iso": src_lang, tgt_lang = iso_to_flores[src_lang], iso_to_flores[tgt_lang] preprocessed_sents, placeholder_entity_map_sents = self.preprocess_batch( batch, src_lang, tgt_lang ) translations = self.translate_lines(preprocessed_sents) return self.postprocess(translations, placeholder_entity_map_sents, tgt_lang) # translate a paragraph from src_lang to tgt_lang def translate_paragraph(self, paragraph: str, src_lang: str, tgt_lang: str) -> str: """ Translates an input text paragraph (including pre/post processing) from source language to target language. Args: paragraph (str): input text paragraph to be translated. src_lang (str): flores source language code. tgt_lang (str): flores target language code. Returns: str: paragraph translation generated by the model. """ assert isinstance(paragraph, str) if self.input_lang_code_format == "iso": flores_src_lang = iso_to_flores[src_lang] else: flores_src_lang = src_lang sents = split_sentences(paragraph, flores_src_lang) postprocessed_sents = self.batch_translate(sents, src_lang, tgt_lang) translated_paragraph = " ".join(postprocessed_sents) return translated_paragraph def preprocess_batch(self, batch: List[str], src_lang: str, tgt_lang: str) -> List[str]: """ Preprocess an array of sentences by normalizing, tokenization, and possibly transliterating it. It also tokenizes the normalized text sequences using sentence piece tokenizer and also adds language tags. Args: batch (List[str]): input list of sentences to preprocess. src_lang (str): flores language code of the input text sentences. tgt_lang (str): flores language code of the output text sentences. Returns: Tuple[List[str], List[Dict]]: a tuple of list of preprocessed input text sentences and also a corresponding list of dictionary mapping placeholders to their original values. """ preprocessed_sents, placeholder_entity_map_sents = self.preprocess(batch, lang=src_lang) tokenized_sents = self.apply_spm(preprocessed_sents) tokenized_sents, placeholder_entity_map_sents = truncate_long_sentences( tokenized_sents, placeholder_entity_map_sents ) tagged_sents = apply_lang_tags(tokenized_sents, src_lang, tgt_lang) return tagged_sents, placeholder_entity_map_sents def apply_spm(self, sents: List[str]) -> List[str]: """ Applies sentence piece encoding to the batch of input sentences. Args: sents (List[str]): batch of the input sentences. Returns: List[str]: batch of encoded sentences with sentence piece model """ return [" ".join(self.sp_src.encode(sent, out_type=str)) for sent in sents] def preprocess_sent( self, sent: str, normalizer: Union[MosesPunctNormalizer, indic_normalize.IndicNormalizerFactory], lang: str, ) -> Tuple[str, Dict]: """ Preprocess an input text sentence by normalizing, tokenization, and possibly transliterating it. Args: sent (str): input text sentence to preprocess. normalizer (Union[MosesPunctNormalizer, indic_normalize.IndicNormalizerFactory]): an object that performs normalization on the text. lang (str): flores language code of the input text sentence. Returns: Tuple[str, Dict]: A tuple containing the preprocessed input text sentence and a corresponding dictionary mapping placeholders to their original values. """ iso_lang = flores_codes[lang] sent = punc_norm(sent, iso_lang) sent, placeholder_entity_map = normalize(sent) transliterate = True if lang.split("_")[1] in ["Arab", "Aran", "Olck", "Mtei", "Latn"]: transliterate = False if iso_lang == "en": processed_sent = " ".join( self.en_tok.tokenize(self.en_normalizer.normalize(sent.strip()), escape=False) ) elif transliterate: # transliterates from the any specific language to devanagari # which is why we specify lang2_code as "hi". processed_sent = self.xliterator.transliterate( " ".join( indic_tokenize.trivial_tokenize(normalizer.normalize(sent.strip()), iso_lang) ), iso_lang, "hi", ).replace(" ् ", "्") else: # we only need to transliterate for joint training processed_sent = " ".join( indic_tokenize.trivial_tokenize(normalizer.normalize(sent.strip()), iso_lang) ) return processed_sent, placeholder_entity_map def preprocess(self, sents: List[str], lang: str): """ Preprocess an array of sentences by normalizing, tokenization, and possibly transliterating it. Args: batch (List[str]): input list of sentences to preprocess. lang (str): flores language code of the input text sentences. Returns: Tuple[List[str], List[Dict]]: a tuple of list of preprocessed input text sentences and also a corresponding list of dictionary mapping placeholders to their original values. """ processed_sents, placeholder_entity_map_sents = [], [] if lang == "eng_Latn": normalizer = None else: normfactory = indic_normalize.IndicNormalizerFactory() normalizer = normfactory.get_normalizer(flores_codes[lang]) for sent in sents: sent, placeholder_entity_map = self.preprocess_sent(sent, normalizer, lang) processed_sents.append(sent) placeholder_entity_map_sents.append(placeholder_entity_map) return processed_sents, placeholder_entity_map_sents def postprocess( self, sents: List[str], placeholder_entity_map: List[Dict], lang: str, common_lang: str = "hin_Deva", ) -> List[str]: """ Postprocesses a batch of input sentences after the translation generations. Args: sents (List[str]): batch of translated sentences to postprocess. placeholder_entity_map (List[Dict]): dictionary mapping placeholders to the original entity values. lang (str): flores language code of the input sentences. common_lang (str, optional): flores language code of the transliterated language (defaults: hin_Deva). Returns: List[str]: postprocessed batch of input sentences. """ lang_code, script_code = lang.split("_") # SPM decode for i in range(len(sents)): # sent_tokens = sents[i].split(" ") # sents[i] = self.sp_tgt.decode(sent_tokens) sents[i] = sents[i].replace(" ", "").replace("▁", " ").strip() # Fixes for Perso-Arabic scripts # TODO: Move these normalizations inside indic-nlp-library if script_code in {"Arab", "Aran"}: # UrduHack adds space before punctuations. Since the model was trained without fixing this issue, let's fix it now sents[i] = sents[i].replace(" ؟", "؟").replace(" ۔", "۔").replace(" ،", "،") # Kashmiri bugfix for palatalization: https://github.com/AI4Bharat/IndicTrans2/issues/11 sents[i] = sents[i].replace("ٮ۪", "ؠ") assert len(sents) == len(placeholder_entity_map) for i in range(0, len(sents)): for key in placeholder_entity_map[i].keys(): sents[i] = sents[i].replace(key, placeholder_entity_map[i][key]) # Detokenize and transliterate to native scripts if applicable postprocessed_sents = [] if lang == "eng_Latn": for sent in sents: postprocessed_sents.append(self.en_detok.detokenize(sent.split(" "))) else: for sent in sents: outstr = indic_detokenize.trivial_detokenize( self.xliterator.transliterate( sent, flores_codes[common_lang], flores_codes[lang] ), flores_codes[lang], ) # Oriya bug: indic-nlp-library produces ଯ଼ instead of ୟ when converting from Devanagari to Odia # TODO: Find out what's the issue with unicode transliterator for Oriya and fix it if lang_code == "ory": outstr = outstr.replace("ଯ଼", 'ୟ') postprocessed_sents.append(outstr) return postprocessed_sents