davanstrien HF staff cointegrated commited on
Commit
a50a704
1 Parent(s): 6634f63

Optimize the preprocessing and generation (#11)

Browse files

- harmonize the language codes list with NLLB (d0a2f64cdae2fae119a127dba13609cb1d0b7542)
- raise errors when the source or target language is not chosen (5c565ab3ea2711194390b6c1b06a499b7da4534e)
- adjust the generation parameters to avoid repetitions (d0ffdbfb40076436f5f40e7deffb7440f5c35e07)
- add punctuation normalization and load the tokenizer only once (2a62da0ac954875090a26ab5dacfef37e9000aec)
- use sentence splitters from stopes (3740b63b75a6c13c1e25911113565bbb51a584a6)


Co-authored-by: David Dale <cointegrated@users.noreply.huggingface.co>

Files changed (3) hide show
  1. app.py +33 -9
  2. flores.py +3 -3
  3. requirements.txt +3 -1
app.py CHANGED
@@ -1,5 +1,7 @@
1
  import spaces
2
  import gradio as gr
 
 
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  from flores import code_mapping
5
  import platform
@@ -28,28 +30,47 @@ def load_model():
28
  model = load_model()
29
 
30
 
31
- def load_tokenizer(src_lang, tgt_lang):
32
- tokenizer = AutoTokenizer.from_pretrained(
33
- MODEL_NAME, src_lang=code_mapping[src_lang], tgt_lang=code_mapping[tgt_lang]
34
- )
35
- return tokenizer
 
 
 
 
 
 
 
36
 
37
 
38
  # cache function
39
  @lru_cache(maxsize=100)
40
  def translate(text: str, src_lang: str, tgt_lang: str):
41
- return _translate(text, src_lang,tgt_lang )
 
 
 
 
 
42
 
43
  # Only assign GPU if cache not used
44
  @spaces.GPU
45
  def _translate(text: str, src_lang: str, tgt_lang: str):
46
- tokenizer = load_tokenizer(src_lang, tgt_lang)
 
 
 
 
 
 
47
 
48
  paragraphs = text.split("\n")
49
  translated_paragraphs = []
50
 
51
  for paragraph in paragraphs:
52
- sentences = nltk.sent_tokenize(paragraph)
 
53
  translated_sentences = []
54
 
55
  for sentence in sentences:
@@ -62,9 +83,12 @@ def _translate(text: str, src_lang: str, tgt_lang: str):
62
  )
63
  translated_chunk = model.generate(
64
  input_ids=torch.tensor([input_tokens]).to(device),
65
- forced_bos_token_id=tokenizer.convert_tokens_to_ids(code_mapping[tgt_lang]),
66
  max_length=len(input_tokens) + 50,
67
  num_return_sequences=1,
 
 
 
68
  )
69
  translated_chunk = tokenizer.decode(
70
  translated_chunk[0], skip_special_tokens=True
 
1
  import spaces
2
  import gradio as gr
3
+ from sacremoses import MosesPunctNormalizer
4
+ from stopes.pipelines.monolingual.utils.sentence_split import get_split_algo
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
  from flores import code_mapping
7
  import platform
 
30
  model = load_model()
31
 
32
 
33
+ # Loading the tokenizer once, because re-loading it takes about 1.5 seconds each time
34
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
35
+
36
+
37
+ punct_normalizer = MosesPunctNormalizer(lang="en")
38
+
39
+
40
+ @lru_cache(maxsize=202)
41
+ def get_language_specific_sentence_splitter(language_code):
42
+ short_code = language_code[:3]
43
+ splitter = get_split_algo(short_code, "default")
44
+ return splitter
45
 
46
 
47
  # cache function
48
  @lru_cache(maxsize=100)
49
  def translate(text: str, src_lang: str, tgt_lang: str):
50
+ if not src_lang:
51
+ raise gr.Error("The source language is empty! Please choose it in the dropdown list.")
52
+ if not tgt_lang:
53
+ raise gr.Error("The target language is empty! Please choose it in the dropdown list.")
54
+ return _translate(text, src_lang, tgt_lang)
55
+
56
 
57
  # Only assign GPU if cache not used
58
  @spaces.GPU
59
  def _translate(text: str, src_lang: str, tgt_lang: str):
60
+ src_code = code_mapping[src_lang]
61
+ tgt_code = code_mapping[tgt_lang]
62
+ tokenizer.src_lang = src_code
63
+ tokenizer.tgt_lang = tgt_code
64
+
65
+ # normalizing the punctuation first
66
+ text = punct_normalizer.normalize(text)
67
 
68
  paragraphs = text.split("\n")
69
  translated_paragraphs = []
70
 
71
  for paragraph in paragraphs:
72
+ splitter = get_language_specific_sentence_splitter(src_code)
73
+ sentences = list(splitter(paragraph))
74
  translated_sentences = []
75
 
76
  for sentence in sentences:
 
83
  )
84
  translated_chunk = model.generate(
85
  input_ids=torch.tensor([input_tokens]).to(device),
86
+ forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code),
87
  max_length=len(input_tokens) + 50,
88
  num_return_sequences=1,
89
+ num_beams=5,
90
+ no_repeat_ngram_size=4, # repetition blocking works better if this number is below num_beams
91
+ renormalize_logits=True, # recompute token probabilities after banning the repetitions
92
  )
93
  translated_chunk = tokenizer.decode(
94
  translated_chunk[0], skip_special_tokens=True
flores.py CHANGED
@@ -10,7 +10,7 @@ code_mapping = {
10
  "Amharic": "amh_Ethi",
11
  "North Levantine Arabic": "apc_Arab",
12
  "Modern Standard Arabic": "arb_Arab",
13
- "Modern Standard Arabic (Romanized)": "arb_Latn",
14
  "Najdi Arabic": "ars_Arab",
15
  "Moroccan Arabic": "ary_Arab",
16
  "Egyptian Arabic": "arz_Arab",
@@ -115,7 +115,7 @@ code_mapping = {
115
  "Maithili": "mai_Deva",
116
  "Malayalam": "mal_Mlym",
117
  "Marathi": "mar_Deva",
118
- "Minangkabau (Arabic script)": "min_Arab",
119
  "Minangkabau (Latin script)": "min_Latn",
120
  "Macedonian": "mkd_Cyrl",
121
  "Plateau Malagasy": "plt_Latn",
@@ -149,7 +149,7 @@ code_mapping = {
149
  "Russian": "rus_Cyrl",
150
  "Sango": "sag_Latn",
151
  "Sanskrit": "san_Deva",
152
- "Santali": "sat_Olck",
153
  "Sicilian": "scn_Latn",
154
  "Shan": "shn_Mymr",
155
  "Sinhala": "sin_Sinh",
 
10
  "Amharic": "amh_Ethi",
11
  "North Levantine Arabic": "apc_Arab",
12
  "Modern Standard Arabic": "arb_Arab",
13
+ # "Modern Standard Arabic (Romanized)": "arb_Latn", # it is in FLORES, but not in NLLB
14
  "Najdi Arabic": "ars_Arab",
15
  "Moroccan Arabic": "ary_Arab",
16
  "Egyptian Arabic": "arz_Arab",
 
115
  "Maithili": "mai_Deva",
116
  "Malayalam": "mal_Mlym",
117
  "Marathi": "mar_Deva",
118
+ # "Minangkabau (Arabic script)": "min_Arab", # it is in FLORES, but not in NLLB
119
  "Minangkabau (Latin script)": "min_Latn",
120
  "Macedonian": "mkd_Cyrl",
121
  "Plateau Malagasy": "plt_Latn",
 
149
  "Russian": "rus_Cyrl",
150
  "Sango": "sag_Latn",
151
  "Sanskrit": "san_Deva",
152
+ "Santali": "sat_Beng", # It is called sat_Olck in FLORES, but (less correctly sat_Beng in NLLB)
153
  "Sicilian": "scn_Latn",
154
  "Shan": "shn_Mymr",
155
  "Sinhala": "sin_Sinh",
requirements.txt CHANGED
@@ -3,4 +3,6 @@ transformers
3
  torch
4
  gradio==4.32.2
5
  spaces
6
- nltk
 
 
 
3
  torch
4
  gradio==4.32.2
5
  spaces
6
+ nltk
7
+ sacremoses
8
+ stopes[mono] @ git+https://github.com/facebookresearch/stopes@better-sentence-splitters