davanstrien HF staff commited on
Commit
5cb6981
1 Parent(s): 71ae380

use src lang

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -29,14 +29,20 @@ def translate(
29
  window_size: int = 800,
30
  overlap_size: int = 200,
31
  ):
32
- input_tokens = tokenizer.encode(text, return_tensors="pt")[0].cpu().numpy().tolist()
 
 
 
 
 
 
33
  translated_chunks = []
34
 
35
  for i in range(0, len(input_tokens), window_size - overlap_size):
36
  window = input_tokens[i : i + window_size]
37
  translated_chunk = model.generate(
38
  input_ids=torch.tensor([window]).to(device),
39
- forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang],
40
  max_length=window_size,
41
  num_return_sequences=1,
42
  )
 
29
  window_size: int = 800,
30
  overlap_size: int = 200,
31
  ):
32
+ input_tokens = (
33
+ tokenizer(text, return_tensors="pt", src_lang=code_mapping[src_lang])
34
+ .input_ids[0]
35
+ .cpu()
36
+ .numpy()
37
+ .tolist()
38
+ )
39
  translated_chunks = []
40
 
41
  for i in range(0, len(input_tokens), window_size - overlap_size):
42
  window = input_tokens[i : i + window_size]
43
  translated_chunk = model.generate(
44
  input_ids=torch.tensor([window]).to(device),
45
+ forced_bos_token_id=tokenizer.lang_code_to_id[code_mapping[tgt_lang]],
46
  max_length=window_size,
47
  num_return_sequences=1,
48
  )