File size: 1,932 Bytes
cf76446
5afd42b
 
73ab36f
5afd42b
 
 
73ab36f
 
cf76446
5afd42b
 
cf76446
 
 
 
 
 
10a5d6b
5afd42b
cf76446
 
 
 
 
5afd42b
 
73ab36f
cf76446
73ab36f
 
5afd42b
cf76446
73ab36f
cf76446
 
 
 
 
 
 
 
 
 
 
 
96aee55
 
 
73ab36f
 
cf76446
73ab36f
cf76446
 
 
96aee55
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import sys, os
import ctranslate2
import sentencepiece as spm
from typing import Union



def indexOf(arr: list, value):
    try: return arr.index(value)
    except: return -1


class SugoiTranslator:
    def __init__(self, modelDir= "./model") -> None:
        self.modelDir = modelDir
        self.sp_source_model = os.path.join(modelDir, "spm.ja.nopretok.model")
        self.sp_target_model = os.path.join(modelDir, "spm.en.nopretok.model")
        # inter_threads: quantas operações independentes podem ser executadas simultaneamente
        self.translator = ctranslate2.Translator(modelDir, device="cpu", intra_threads=4, inter_threads=2)

    def tokenizeBatch(self, text):
        sp = spm.SentencePieceProcessor(self.sp_source_model)
        if isinstance(text, list): return sp.encode(text, out_type=str)
        elif isinstance(text, str):
            return [sp.encode(text, out_type=str)]


    def detokenizeBatch(self, token):
        sp = spm.SentencePieceProcessor(self.sp_target_model)
        text = sp.decode(token)
        return text


    def translate(self, text: Union[str, list]):
        translated = self.translator.translate_batch(
            source= self.tokenizeBatch(text), 
            num_hypotheses= 1, 
            return_alternatives= False, 
            replace_unknowns= False, 
            no_repeat_ngram_size= 3, # repetition_penalty
            disable_unk= True, 
            beam_size= 5, 
            sampling_temperature= 0, 
        )

        return [''.join( self.detokenizeBatch(result.hypotheses[0]) ) for result in translated]


if __name__ == "__main__":
    global modelDir
    modelDir = "./model"
    index = indexOf(sys.argv, "-modelDir")
    if index != -1: modelDir = sys.argv[index+1]
    
    sugoiTranslator = SugoiTranslator(modelDir)
    translated = sugoiTranslator.translate("ダンガンロンパ 希望の学園と絶望の高校生")
    print(translated)