import hashlib
import os
import uuid
from typing import List, Tuple, Union, Dict
import regex as re
import sentencepiece as spm
from indicnlp.normalize import indic_normalize
from indicnlp.tokenize import indic_detokenize, indic_tokenize
from indicnlp.tokenize.sentence_tokenize import DELIM_PAT_NO_DANDA, sentence_split
from indicnlp.transliterate import unicode_transliterate
from mosestokenizer import MosesSentenceSplitter
from nltk.tokenize import sent_tokenize
from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer
from tqdm import tqdm
from .flores_codes_map_indic import flores_codes, iso_to_flores
from .normalize_punctuation import punc_norm
from .normalize_regex_inference import EMAIL_PATTERN, normalize
def split_sentences(paragraph: str, lang: str) -> List[str]:
Splits the input text paragraph into sentences. It uses `moses` for English and
`indic-nlp` for Indic languages.
paragraph (str): input text paragraph.
lang (str): flores language code.
List[str] -> list of sentences.
if lang == "eng_Latn":
with MosesSentenceSplitter(flores_codes[lang]) as splitter:
sents_moses = splitter([paragraph])
sents_nltk = sent_tokenize(paragraph)
if len(sents_nltk) < len(sents_moses):
sents = sents_nltk
sents = sents_moses
return [sent.replace("\xad", "") for sent in sents]
return sentence_split(paragraph, lang=flores_codes[lang], delim_pat=DELIM_PAT_NO_DANDA)
def add_token(sent: str, src_lang: str, tgt_lang: str, delimiter: str = " ") -> str:
Add special tokens indicating source and target language to the start of the input sentence.
The resulting string will have the format: "`{src_lang} {tgt_lang} {input_sentence}`".
sent (str): input sentence to be translated.
src_lang (str): flores lang code of the input sentence.
tgt_lang (str): flores lang code in which the input sentence will be translated.
delimiter (str): separator to add between language tags and input sentence (default: " ").
str: input sentence with the special tokens added to the start.
return src_lang + delimiter + tgt_lang + delimiter + sent
def apply_lang_tags(sents: List[str], src_lang: str, tgt_lang: str) -> List[str]:
Add special tokens indicating source and target language to the start of the each input sentence.
Each resulting input sentence will have the format: "`{src_lang} {tgt_lang} {input_sentence}`".
sent (str): input sentence to be translated.
src_lang (str): flores lang code of the input sentence.
tgt_lang (str): flores lang code in which the input sentence will be translated.
List[str]: list of input sentences with the special tokens added to the start.
tagged_sents = []
for sent in sents:
tagged_sent = add_token(sent.strip(), src_lang, tgt_lang)
return tagged_sents
def truncate_long_sentences(
sents: List[str], placeholder_entity_map_sents: List[Dict]
) -> Tuple[List[str], List[Dict]]:
Truncates the sentences that exceed the maximum sequence length.
The maximum sequence for the IndicTrans2 model is limited to 256 tokens.
sents (List[str]): list of input sentences to truncate.
Tuple[List[str], List[Dict]]: tuple containing the list of sentences with truncation applied and the updated placeholder entity maps.
new_sents = []
placeholders = []
for j, sent in enumerate(sents):
words = sent.split()
num_words = len(words)
if num_words > MAX_SEQ_LEN:
sents = []
i = 0
while i <= len(words):
sents.append(" ".join(words[i : i + MAX_SEQ_LEN]))
placeholders.extend([placeholder_entity_map_sents[j]] * (len(sents)))
return new_sents, placeholders
class Model:
Model class to run the IndicTransv2 models using python interface.
def __init__(
ckpt_dir: str,
device: str = "cuda",
input_lang_code_format: str = "flores",
model_type: str = "ctranslate2",
Initialize the model class.
ckpt_dir (str): path of the model checkpoint directory.
device (str, optional): where to load the model (defaults: cuda).
self.ckpt_dir = ckpt_dir
self.en_tok = MosesTokenizer(lang="en")
self.en_normalizer = MosesPunctNormalizer()
self.en_detok = MosesDetokenizer(lang="en")
self.xliterator = unicode_transliterate.UnicodeIndicTransliterator()
print("Initializing sentencepiece model for SRC and TGT")
self.sp_src = spm.SentencePieceProcessor(
model_file=os.path.join(ckpt_dir, "vocab", "model.SRC")
self.sp_tgt = spm.SentencePieceProcessor(
model_file=os.path.join(ckpt_dir, "vocab", "model.TGT")
self.input_lang_code_format = input_lang_code_format
print("Initializing model for translation")
# initialize the model
if model_type == "ctranslate2":
import ctranslate2
self.translator = ctranslate2.Translator(
self.ckpt_dir, device=device
) # , compute_type="auto")
self.translate_lines = self.ctranslate2_translate_lines
elif model_type == "fairseq":
from .custom_interactive import Translator
self.translator = Translator(
data_dir=os.path.join(self.ckpt_dir, "final_bin"),
checkpoint_path=os.path.join(self.ckpt_dir, "model", ""),
self.translate_lines = self.fairseq_translate_lines
raise NotImplementedError(f"Unknown model_type: {model_type}")
def ctranslate2_translate_lines(self, lines: List[str]) -> List[str]:
tokenized_sents = [x.strip().split(" ") for x in lines]
translations = self.translator.translate_batch(
translations = [" ".join(x.hypotheses[0]) for x in translations]
return translations
def fairseq_translate_lines(self, lines: List[str]) -> List[str]:
return self.translator.translate(lines)
def paragraphs_batch_translate__multilingual(self, batch_payloads: List[tuple]) -> List[str]:
Translates a batch of input paragraphs (including pre/post processing)
from any language to any language.
batch_payloads (List[tuple]): batch of long input-texts to be translated, each in format: (paragraph, src_lang, tgt_lang)
List[str]: batch of paragraph-translations in the respective languages.
paragraph_id_to_sentence_range = []
global__sents = []
global__preprocessed_sents = []
global__preprocessed_sents_placeholder_entity_map = []
for i in range(len(batch_payloads)):
paragraph, src_lang, tgt_lang = batch_payloads[i]
if self.input_lang_code_format == "iso":
src_lang, tgt_lang = iso_to_flores[src_lang], iso_to_flores[tgt_lang]
batch = split_sentences(paragraph, src_lang)
preprocessed_sents, placeholder_entity_map_sents = self.preprocess_batch(
batch, src_lang, tgt_lang
global_sentence_start_index = len(global__preprocessed_sents)
(global_sentence_start_index, len(global__preprocessed_sents))
translations = self.translate_lines(global__preprocessed_sents)
translated_paragraphs = []
for paragraph_id, sentence_range in enumerate(paragraph_id_to_sentence_range):
tgt_lang = batch_payloads[paragraph_id][2]
if self.input_lang_code_format == "iso":
tgt_lang = iso_to_flores[tgt_lang]
postprocessed_sents = self.postprocess(
translations[sentence_range[0] : sentence_range[1]],
sentence_range[0] : sentence_range[1]
translated_paragraph = " ".join(postprocessed_sents)
return translated_paragraphs
# translate a batch of sentences from src_lang to tgt_lang
def batch_translate(self, batch: List[str], src_lang: str, tgt_lang: str) -> List[str]:
Translates a batch of input sentences (including pre/post processing)
from source language to target language.
batch (List[str]): batch of input sentences to be translated.
src_lang (str): flores source language code.
tgt_lang (str): flores target language code.
List[str]: batch of translated-sentences generated by the model.
assert isinstance(batch, list)
if self.input_lang_code_format == "iso":
src_lang, tgt_lang = iso_to_flores[src_lang], iso_to_flores[tgt_lang]
preprocessed_sents, placeholder_entity_map_sents = self.preprocess_batch(
batch, src_lang, tgt_lang
translations = self.translate_lines(preprocessed_sents)
return self.postprocess(translations, placeholder_entity_map_sents, tgt_lang)
# translate a paragraph from src_lang to tgt_lang
def translate_paragraph(self, paragraph: str, src_lang: str, tgt_lang: str) -> str:
Translates an input text paragraph (including pre/post processing)
from source language to target language.
paragraph (str): input text paragraph to be translated.
src_lang (str): flores source language code.
tgt_lang (str): flores target language code.
str: paragraph translation generated by the model.
assert isinstance(paragraph, str)
if self.input_lang_code_format == "iso":
flores_src_lang = iso_to_flores[src_lang]
flores_src_lang = src_lang
sents = split_sentences(paragraph, flores_src_lang)
postprocessed_sents = self.batch_translate(sents, src_lang, tgt_lang)
translated_paragraph = " ".join(postprocessed_sents)
return translated_paragraph
def preprocess_batch(self, batch: List[str], src_lang: str, tgt_lang: str) -> List[str]:
Preprocess an array of sentences by normalizing, tokenization, and possibly transliterating it. It also tokenizes the
normalized text sequences using sentence piece tokenizer and also adds language tags.
batch (List[str]): input list of sentences to preprocess.
src_lang (str): flores language code of the input text sentences.
tgt_lang (str): flores language code of the output text sentences.
Tuple[List[str], List[Dict]]: a tuple of list of preprocessed input text sentences and also a corresponding list of dictionary
mapping placeholders to their original values.
preprocessed_sents, placeholder_entity_map_sents = self.preprocess(batch, lang=src_lang)
tokenized_sents = self.apply_spm(preprocessed_sents)
tokenized_sents, placeholder_entity_map_sents = truncate_long_sentences(
tokenized_sents, placeholder_entity_map_sents
tagged_sents = apply_lang_tags(tokenized_sents, src_lang, tgt_lang)
return tagged_sents, placeholder_entity_map_sents
def apply_spm(self, sents: List[str]) -> List[str]:
Applies sentence piece encoding to the batch of input sentences.
sents (List[str]): batch of the input sentences.
List[str]: batch of encoded sentences with sentence piece model
return [" ".join(self.sp_src.encode(sent, out_type=str)) for sent in sents]
def preprocess_sent(
sent: str,
normalizer: Union[MosesPunctNormalizer, indic_normalize.IndicNormalizerFactory],
lang: str,
) -> Tuple[str, Dict]:
Preprocess an input text sentence by normalizing, tokenization, and possibly transliterating it.
sent (str): input text sentence to preprocess.
normalizer (Union[MosesPunctNormalizer, indic_normalize.IndicNormalizerFactory]): an object that performs normalization on the text.
lang (str): flores language code of the input text sentence.
Tuple[str, Dict]: A tuple containing the preprocessed input text sentence and a corresponding dictionary
mapping placeholders to their original values.
iso_lang = flores_codes[lang]
sent = punc_norm(sent, iso_lang)
sent, placeholder_entity_map = normalize(sent)
transliterate = True
if lang.split("_")[1] in ["Arab", "Aran", "Olck", "Mtei", "Latn"]:
transliterate = False
if iso_lang == "en":
processed_sent = " ".join(
self.en_tok.tokenize(self.en_normalizer.normalize(sent.strip()), escape=False)
elif transliterate:
# transliterates from the any specific language to devanagari
# which is why we specify lang2_code as "hi".
processed_sent = self.xliterator.transliterate(
" ".join(
indic_tokenize.trivial_tokenize(normalizer.normalize(sent.strip()), iso_lang)
).replace(" ् ", "्")
# we only need to transliterate for joint training
processed_sent = " ".join(
indic_tokenize.trivial_tokenize(normalizer.normalize(sent.strip()), iso_lang)
return processed_sent, placeholder_entity_map
def preprocess(self, sents: List[str], lang: str):
Preprocess an array of sentences by normalizing, tokenization, and possibly transliterating it.
batch (List[str]): input list of sentences to preprocess.
lang (str): flores language code of the input text sentences.
Tuple[List[str], List[Dict]]: a tuple of list of preprocessed input text sentences and also a corresponding list of dictionary
mapping placeholders to their original values.
processed_sents, placeholder_entity_map_sents = [], []
if lang == "eng_Latn":
normalizer = None
normfactory = indic_normalize.IndicNormalizerFactory()
normalizer = normfactory.get_normalizer(flores_codes[lang])
for sent in sents:
sent, placeholder_entity_map = self.preprocess_sent(sent, normalizer, lang)
return processed_sents, placeholder_entity_map_sents
def postprocess(
sents: List[str],
placeholder_entity_map: List[Dict],
lang: str,
common_lang: str = "hin_Deva",
) -> List[str]:
Postprocesses a batch of input sentences after the translation generations.
sents (List[str]): batch of translated sentences to postprocess.
placeholder_entity_map (List[Dict]): dictionary mapping placeholders to the original entity values.
lang (str): flores language code of the input sentences.
common_lang (str, optional): flores language code of the transliterated language (defaults: hin_Deva).
List[str]: postprocessed batch of input sentences.
lang_code, script_code = lang.split("_")
# SPM decode
for i in range(len(sents)):
# sent_tokens = sents[i].split(" ")
# sents[i] = self.sp_tgt.decode(sent_tokens)
sents[i] = sents[i].replace(" ", "").replace("▁", " ").strip()
# Fixes for Perso-Arabic scripts
# TODO: Move these normalizations inside indic-nlp-library
if script_code in {"Arab", "Aran"}:
# UrduHack adds space before punctuations. Since the model was trained without fixing this issue, let's fix it now
sents[i] = sents[i].replace(" ؟", "؟").replace(" ۔", "۔").replace(" ،", "،")
# Kashmiri bugfix for palatalization:
sents[i] = sents[i].replace("ٮ۪", "ؠ")
assert len(sents) == len(placeholder_entity_map)
for i in range(0, len(sents)):
for key in placeholder_entity_map[i].keys():
sents[i] = sents[i].replace(key, placeholder_entity_map[i][key])
# Detokenize and transliterate to native scripts if applicable
postprocessed_sents = []
if lang == "eng_Latn":
for sent in sents:
postprocessed_sents.append(self.en_detok.detokenize(sent.split(" ")))
for sent in sents:
outstr = indic_detokenize.trivial_detokenize(
sent, flores_codes[common_lang], flores_codes[lang]
# Oriya bug: indic-nlp-library produces ଯ଼ instead of ୟ when converting from Devanagari to Odia
# TODO: Find out what's the issue with unicode transliterator for Oriya and fix it
if lang_code == "ory":
outstr = outstr.replace("ଯ଼", 'ୟ')
return postprocessed_sents