davanstrien HF staff commited on
Commit
5e1003d
1 Parent(s): 3f23d73

Add nltk dependency and update translate function to handle multiple sentences

Browse files
Files changed (2) hide show
  1. app.py +15 -17
  2. requirements.txt +2 -1
app.py CHANGED
@@ -4,6 +4,9 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  from flores import code_mapping
5
  import platform
6
  import torch
 
 
 
7
 
8
  device = "cpu" if platform.system() == "Darwin" else "cuda"
9
  MODEL_NAME = "facebook/nllb-200-3.3B"
@@ -28,34 +31,29 @@ def load_tokenizer(src_lang, tgt_lang):
28
 
29
 
30
  @spaces.GPU
31
- def translate(
32
- text: str,
33
- src_lang: str,
34
- tgt_lang: str,
35
- window_size: int = 800,
36
- overlap_size: int = 200,
37
- ):
38
  tokenizer = load_tokenizer(src_lang, tgt_lang)
39
 
40
- input_tokens = (
41
- tokenizer(text, return_tensors="pt").input_ids[0].cpu().numpy().tolist()
42
- )
43
- translated_chunks = []
44
 
45
- for i in range(0, len(input_tokens), window_size - overlap_size):
46
- window = input_tokens[i : i + window_size]
 
 
47
  translated_chunk = model.generate(
48
- input_ids=torch.tensor([window]).to(device),
49
  forced_bos_token_id=tokenizer.lang_code_to_id[code_mapping[tgt_lang]],
50
- max_length=window_size,
51
  num_return_sequences=1,
52
  )
53
  translated_chunk = tokenizer.decode(
54
  translated_chunk[0], skip_special_tokens=True
55
  )
56
- translated_chunks.append(translated_chunk)
57
 
58
- return " ".join(translated_chunks)
 
59
 
60
 
61
  description = """
 
4
  from flores import code_mapping
5
  import platform
6
  import torch
7
+ import nltk
8
+
9
+ nltk.download("punkt")
10
 
11
  device = "cpu" if platform.system() == "Darwin" else "cuda"
12
  MODEL_NAME = "facebook/nllb-200-3.3B"
 
31
 
32
 
33
  @spaces.GPU
34
+ def translate(text: str, src_lang: str, tgt_lang: str):
 
 
 
 
 
 
35
  tokenizer = load_tokenizer(src_lang, tgt_lang)
36
 
37
+ sentences = nltk.sent_tokenize(text)
38
+ translated_sentences = []
 
 
39
 
40
+ for sentence in sentences:
41
+ input_tokens = (
42
+ tokenizer(sentence, return_tensors="pt").input_ids[0].cpu().numpy().tolist()
43
+ )
44
  translated_chunk = model.generate(
45
+ input_ids=torch.tensor([input_tokens]).to(device),
46
  forced_bos_token_id=tokenizer.lang_code_to_id[code_mapping[tgt_lang]],
47
+ max_length=len(input_tokens) + 50,
48
  num_return_sequences=1,
49
  )
50
  translated_chunk = tokenizer.decode(
51
  translated_chunk[0], skip_special_tokens=True
52
  )
53
+ translated_sentences.append(translated_chunk)
54
 
55
+ translated_text = " ".join(translated_sentences)
56
+ return translated_text
57
 
58
 
59
  description = """
requirements.txt CHANGED
@@ -2,4 +2,5 @@
2
  transformers
3
  torch
4
  gradio
5
- spaces
 
 
2
  transformers
3
  torch
4
  gradio
5
+ spaces
6
+ nltk