File size: 3,288 Bytes
39fbaa4
 
 
390173a
 
 
 
39fbaa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390173a
 
 
39fbaa4
 
390173a
 
 
 
 
 
 
39fbaa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390173a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39fbaa4
 
 
 
 
 
390173a
39fbaa4
 
 
 
 
 
 
390173a
 
 
39fbaa4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from .singleton import Singleton
from transformers import (
    EncoderDecoderModel,
    AutoTokenizer
)

device = "cuda" if torch.cuda.is_available() else "cpu"


@Singleton
class Models(object):

    def __getattr__(self, item):
        if item in self.__dict__:
            return getattr(self, item)

        if item in ('zh2en_model', 'zh2en_tokenizer',):
            self.zh2en_model, self.zh2en_tokenizer = self.load_zh2en_model()

        if item in ('en2zh_model', 'en2zh_tokenizer',):
            self.en2zh_model, self.en2zh_tokenizer = self.load_en2zh_model()

        if item in ('wenyanwen2modern_tokenizer', 'wenyanwen2modern_model',):
            self.wenyanwen2modern_tokenizer, self.wenyanwen2modern_model = self.load_wenyanwen2modern_model()

        return getattr(self, item)

    @classmethod
    def load_wenyanwen2modern_model(cls):
        PRETRAINED = "raynardj/wenyanwen-ancient-translate-to-modern"
        tokenizer = AutoTokenizer.from_pretrained(PRETRAINED)
        model = EncoderDecoderModel.from_pretrained(PRETRAINED)
        return tokenizer, model

    @classmethod
    def load_en2zh_model(cls):
        en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval()
        en2zh_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh")
        return en2zh_model, en2zh_tokenizer

    @classmethod
    def load_zh2en_model(cls):
        zh2en_model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval()
        zh2en_tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en')

        return zh2en_model, zh2en_tokenizer,


models = Models.instance()


def wenyanwen2modern(text: str) -> str:
    tk_kwargs = dict(
        truncation=True,
        max_length=128,
        padding="max_length",
        return_tensors='pt')

    inputs = models.wenyanwen2modern_tokenizer([text, ], **tk_kwargs)
    with torch.no_grad():
        return models.wenyanwen2modern_tokenizer.batch_decode(
            models.wenyanwen2modern_model.generate(
                inputs.input_ids,
                attention_mask=inputs.attention_mask,
                num_beams=3,
                max_length=256,
                bos_token_id=101,
                eos_token_id=models.wenyanwen2modern_tokenizer.sep_token_id,
                pad_token_id=models.wenyanwen2modern_tokenizer.pad_token_id,
            ), skip_special_tokens=True)[0].replace(" ", "")


def zh2en(text: str) -> str:
    with torch.no_grad():
        encoded = models.zh2en_tokenizer([text], return_tensors="pt")
        sequences = models.zh2en_model.generate(**encoded)
        return models.zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]


def en2zh(text: str) -> str:
    with torch.no_grad():
        encoded = models.en2zh_tokenizer([text], return_tensors="pt")
        sequences = models.en2zh_model.generate(**encoded)
        return models.en2zh_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]


if __name__ == "__main__":
    input = "飞流直下三千尺,疑是银河落九天"
    input_m = wenyanwen2modern(input)
    en = zh2en(input_m)
    print(input, en)
    zh = en2zh(en)
    print(en, zh)