ariwaranosai commited on
Commit
496a65a
1 Parent(s): 0c7f845

make sp_model init before __init__ to compat with transformer>=4.34

Browse files

AttributeError: 'BaichuanTokenizer' object has no attribute 'sp_model' will be raised if sp_model would be set properly, since PreTrainedTokenizer' __init__ func will call _add_tokens which need `get_vocab`.

Files changed (1) hide show
  1. tokenization_baichuan.py +7 -6
tokenization_baichuan.py CHANGED
@@ -72,6 +72,12 @@ class BaichuanTokenizer(PreTrainedTokenizer):
72
  eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
73
  unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
74
  pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
 
 
 
 
 
 
75
  super().__init__(
76
  bos_token=bos_token,
77
  eos_token=eos_token,
@@ -82,12 +88,7 @@ class BaichuanTokenizer(PreTrainedTokenizer):
82
  sp_model_kwargs=self.sp_model_kwargs,
83
  clean_up_tokenization_spaces=clean_up_tokenization_spaces,
84
  **kwargs,
85
- )
86
- self.vocab_file = vocab_file
87
- self.add_bos_token = add_bos_token
88
- self.add_eos_token = add_eos_token
89
- self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
90
- self.sp_model.Load(vocab_file)
91
 
92
  def __getstate__(self):
93
  state = self.__dict__.copy()
 
72
  eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
73
  unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
74
  pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
75
+ self.vocab_file = vocab_file
76
+ self.add_bos_token = add_bos_token
77
+ self.add_eos_token = add_eos_token
78
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
79
+ self.sp_model.Load(vocab_file)
80
+
81
  super().__init__(
82
  bos_token=bos_token,
83
  eos_token=eos_token,
 
88
  sp_model_kwargs=self.sp_model_kwargs,
89
  clean_up_tokenization_spaces=clean_up_tokenization_spaces,
90
  **kwargs,
91
+ )
 
 
 
 
 
92
 
93
  def __getstate__(self):
94
  state = self.__dict__.copy()