|
from typing import List, Dict, Optional, Tuple |
|
from transformers import PreTrainedTokenizer |
|
|
|
class PhyloGPNTokenizer(PreTrainedTokenizer): |
|
model_input_names = ["input_ids"] |
|
|
|
def __init__(self, model_max_length: int = None, unk_token="N", pad_token="-", bos_token=None, eos_token=None, sep_token=None, cls_token=None, mask_token=None, **kwargs): |
|
self.model_max_length = model_max_length |
|
self._vocab = {k: v for v, k in enumerate("ACGTN-")} |
|
|
|
add_prefix_space = kwargs.pop("add_prefix_space", False) |
|
padding_side = kwargs.pop("padding_side", "right") |
|
super().__init__( |
|
model_max_length=model_max_length, |
|
unk_token=unk_token, |
|
pad_token=pad_token, |
|
bos_token=bos_token, |
|
eos_token=eos_token, |
|
sep_token=sep_token, |
|
cls_token=cls_token, |
|
mask_token=mask_token, |
|
add_prefix_space=add_prefix_space, |
|
padding_side=padding_side, |
|
**kwargs, |
|
) |
|
|
|
def _tokenize(self, seq: str) -> List[str]: |
|
assert len(seq) >= 481, "Input must be at least 481 bp long" |
|
return list(seq) |
|
|
|
def _convert_token_to_id(self, token: str) -> int: |
|
return self._vocab.get(token, self._vocab["N"]) |
|
|
|
def _convert_id_to_token(self, idx: int) -> str: |
|
return self._vocab[idx] |
|
|
|
@property |
|
def vocab_size(self) -> int: |
|
return len(self._vocab) |
|
|
|
def get_vocab(self) -> Dict[str, int]: |
|
return self._vocab |
|
|
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple: |
|
return () |