File size: 475 Bytes
03a8026
 
 
9c2c875
 
 
03a8026
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

from transformers import FSMTForConditionalGeneration, FSMTTokenizer
mname = "facebook/wmt19-ru-en"
tokenizer = FSMTTokenizer.from_pretrained(mname, cache_dir='./cache')
model = FSMTForConditionalGeneration.from_pretrained(
    mname, cache_dir='./cache')


def translate_ru_en(text):
    input_ids = tokenizer.encode(text, return_tensors="pt")
    outputs = model.generate(input_ids)
    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return decoded