stazizov commited on
Commit
9f0cda1
1 Parent(s): aec0103
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -3,16 +3,15 @@ from transformers import AutoModelForSeq2SeqLM, NllbTokenizer
3
  import torch
4
  import spaces
5
 
6
- device = "cuda"
7
  model_id = "leks-forever/nllb-200-distilled-600M-v1"
8
 
9
  tokenizer = NllbTokenizer.from_pretrained(model_id)
10
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
11
 
12
  @spaces.GPU
13
  def translate(text, src_lang='lez_Cyrl', tgt_lang='rus_Cyrl', a=32, b=3, max_input_length=1024, num_beams=1, **kwargs):
14
-
15
- model = model.to(device)
16
 
17
  if src_lang in language_codes:
18
  src_lang = language_codes[src_lang]
 
3
  import torch
4
  import spaces
5
 
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
  model_id = "leks-forever/nllb-200-distilled-600M-v1"
8
 
9
  tokenizer = NllbTokenizer.from_pretrained(model_id)
10
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device)
11
 
12
  @spaces.GPU
13
  def translate(text, src_lang='lez_Cyrl', tgt_lang='rus_Cyrl', a=32, b=3, max_input_length=1024, num_beams=1, **kwargs):
14
+ global tokenizer
 
15
 
16
  if src_lang in language_codes:
17
  src_lang = language_codes[src_lang]