Spaces:
Runtime error
Runtime error
File size: 1,921 Bytes
39fbaa4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from .singleton import Singleton
device = "cuda" if torch.cuda.is_available() else "cpu"
@Singleton
class Models(object):
def __getattr__(self, item):
if item in self.__dict__:
return getattr(self, item)
if item in ('zh2en_model', 'zh2en_tokenizer',):
self.zh2en_model, self.zh2en_tokenizer = self.load_zh2en_model()
if item in ('en2zh_model', 'en2zh_tokenizer',):
self.en2zh_model, self.en2zh_tokenizer = self.load_en2zh_model()
return getattr(self, item)
@classmethod
def load_en2zh_model(cls):
en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval()
en2zh_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh")
return en2zh_model, en2zh_tokenizer
@classmethod
def load_zh2en_model(cls):
zh2en_model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval()
zh2en_tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en')
return zh2en_model, zh2en_tokenizer,
models = Models.instance()
def zh2en(text):
with torch.no_grad():
encoded = models.zh2en_tokenizer([text], return_tensors="pt")
sequences = models.zh2en_model.generate(**encoded)
return models.zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
def en2zh(text):
with torch.no_grad():
encoded = models.en2zh_tokenizer([text], return_tensors="pt")
sequences = models.en2zh_model.generate(**encoded)
return models.en2zh_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
if __name__ == "__main__":
input = "青春不能回头,所以青春没有终点。 ——《火影忍者》"
en = zh2en(input)
print(input, en)
zh = en2zh(en)
print(en, zh)
|