0x7o commited on
Commit
2d930f5
1 Parent(s): db12f76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -15,7 +15,7 @@ def load_models():
15
 
16
  for call_name, real_name in model_name_dict.items():
17
  print('\tLoading model: %s' % call_name)
18
- model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
19
  tokenizer = AutoTokenizer.from_pretrained(real_name)
20
  model_dict[call_name+'_model'] = model
21
  model_dict[call_name+'_tokenizer'] = tokenizer
@@ -35,7 +35,7 @@ def translation(source, target, text):
35
  model = model_dict[model_name + '_model']
36
  tokenizer = model_dict[model_name + '_tokenizer']
37
 
38
- translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
39
  output = translator(text, max_length=400)
40
 
41
  end_time = time.time()
 
15
 
16
  for call_name, real_name in model_name_dict.items():
17
  print('\tLoading model: %s' % call_name)
18
+ model = AutoModelForSeq2SeqLM.from_pretrained(real_name).cuda()
19
  tokenizer = AutoTokenizer.from_pretrained(real_name)
20
  model_dict[call_name+'_model'] = model
21
  model_dict[call_name+'_tokenizer'] = tokenizer
 
35
  model = model_dict[model_name + '_model']
36
  tokenizer = model_dict[model_name + '_tokenizer']
37
 
38
+ translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target, device=0)
39
  output = translator(text, max_length=400)
40
 
41
  end_time = time.time()