File size: 2,167 Bytes
1b46418 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import os
from transformers import ReformerTokenizerFast
from transformers.models.bert_japanese.tokenization_bert_japanese import MecabTokenizer
try:
from transformers.utils import cached_file
except:
from transformers.file_utils import cached_path,hf_bucket_url
cached_file=lambda x,y:os.path.join(x,y) if os.path.isdir(x) else cached_path(hf_bucket_url(x,y))
class MecabPreTokenizer(MecabTokenizer):
def mecab_split(self,i,normalized_string):
t=str(normalized_string)
z=[]
e=0
for c in self.tokenize(t):
s=t.find(c,e)
e=e if s<0 else s+len(c)
z.append((0,0) if s<0 else (s,e))
return [normalized_string[s:e] for s,e in z if e>0]
def pre_tokenize(self,pretok):
pretok.split(self.mecab_split)
class JumanReformerTokenizerFast(ReformerTokenizerFast):
def __init__(self,**kwargs):
from tokenizers.pre_tokenizers import PreTokenizer,Metaspace,Sequence
super().__init__(**kwargs)
d,r="/var/lib/mecab/dic/juman-utf8","/etc/mecabrc"
if not (os.path.isdir(d) and os.path.isfile(r)):
import zipfile
import tempfile
self.dicdir=tempfile.TemporaryDirectory()
d=self.dicdir.name
with zipfile.ZipFile(cached_file(self.name_or_path,"mecab-jumandic-utf8.zip")) as z:
z.extractall(d)
r=os.path.join(d,"mecabrc")
with open(r,"w",encoding="utf-8") as w:
print("dicdir =",d,file=w)
self.custom_pre_tokenizer=Sequence([PreTokenizer.custom(MecabPreTokenizer(mecab_dic=None,mecab_option="-d "+d+" -r "+r)),Metaspace()])
self._tokenizer.pre_tokenizer=self.custom_pre_tokenizer
def save_pretrained(self,save_directory,**kwargs):
import shutil
from tokenizers.pre_tokenizers import Metaspace
self._auto_map={"AutoTokenizer":[None,"juman.JumanReformerTokenizerFast"]}
self._tokenizer.pre_tokenizer=Metaspace()
super().save_pretrained(save_directory,**kwargs)
self._tokenizer.pre_tokenizer=self.custom_pre_tokenizer
shutil.copy(os.path.abspath(__file__),os.path.join(save_directory,"juman.py"))
shutil.copy(cached_file(self.name_or_path,"mecab-jumandic-utf8.zip"),os.path.join(save_directory,"mecab-jumandic-utf8.zip"))
|