Spaces:
Running
Running
File size: 5,265 Bytes
5e4b316 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import json
from pathlib import Path
from typing import Optional, Union
import torch
class Tokenizer:
def __init__(self, checkpoint_dir: Union[Path, str]) -> None:
checkpoint_dir = Path(checkpoint_dir)
if not checkpoint_dir.exists():
raise NotADirectoryError(
f"The checkpoint directory does not exist: {str(checkpoint_dir)}"
)
self.use_bos = self.check_if_bos_token_used(checkpoint_dir)
self.bos_id = None
self.eos_id = None
# some checkpoints have both files, `.json` takes precedence
if (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file():
from tokenizers import Tokenizer as HFTokenizer
self.processor = HFTokenizer.from_file(str(vocabulary_path))
self.backend = "huggingface"
if (
special_tokens_path := checkpoint_dir / "tokenizer_config.json"
).is_file():
with open(special_tokens_path, encoding="utf-8") as fp:
config = json.load(fp)
bos_token = config.get("bos_token")
eos_token = config.get("eos_token")
if bos_token is not None and isinstance(bos_token, dict):
bos_token = bos_token.get("content")
if eos_token is not None and isinstance(eos_token, dict):
eos_token = eos_token.get("content")
self.bos_id = (
self.token_to_id(bos_token) if bos_token is not None else None
)
self.eos_id = (
self.token_to_id(eos_token) if eos_token is not None else None
)
if (
special_tokens_path := checkpoint_dir / "generation_config.json"
).is_file():
with open(special_tokens_path, encoding="utf-8") as fp:
config = json.load(fp)
if self.bos_id is None:
self.bos_id = config.get("bos_token_id")
if self.eos_id is None:
self.eos_id = config.get("eos_token_id")
elif (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file():
from sentencepiece import SentencePieceProcessor
self.processor = SentencePieceProcessor(model_file=str(vocabulary_path))
self.backend = "sentencepiece"
self.bos_id = self.processor.bos_id()
self.eos_id = self.processor.eos_id()
else:
raise NotImplementedError
@property
def vocab_size(self) -> int:
if self.backend == "huggingface":
return self.processor.get_vocab_size(with_added_tokens=False)
if self.backend == "sentencepiece":
return self.processor.vocab_size()
raise RuntimeError
def token_to_id(self, token: str) -> int:
if self.backend == "huggingface":
id_ = self.processor.token_to_id(token)
elif self.backend == "sentencepiece":
id_ = self.processor.piece_to_id(token)
else:
raise RuntimeError
if id_ is None:
raise ValueError(f"token {token!r} not found in the collection.")
return id_
def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
if not (
tokenizer_config_path := checkpoint_dir / "tokenizer_config.json"
).is_file():
return False
with open(tokenizer_config_path, encoding="utf-8") as fp:
config = json.load(fp)
if "add_bos_token" in config:
return config["add_bos_token"]
# if `add_bos_token` isn't in the config file, but LLaMA tokenizer is used - return True.
# ex: https://huggingface.co/stabilityai/StableBeluga2/blob/main/tokenizer_config.json#L2
return config.get("tokenizer_class") == "LlamaTokenizer"
def encode(
self,
string: str,
device: Optional[torch.device] = None,
bos: Optional[bool] = None,
eos: bool = False,
max_length: int = -1,
) -> torch.Tensor:
if self.backend == "huggingface":
tokens = self.processor.encode(string).ids
elif self.backend == "sentencepiece":
tokens = self.processor.encode(string)
else:
raise RuntimeError
if bos or (bos is None and self.use_bos):
bos_id = self.bos_id
if bos_id is None:
raise NotImplementedError(
"This tokenizer does not have a defined a bos token"
)
if tokens[0] != bos_id:
tokens = [bos_id] + tokens
if tokens is None:
raise ValueError("`tokens` is None")
if eos and (not tokens or tokens[-1] != self.eos_id):
tokens = tokens + [self.eos_id]
if max_length > 0:
tokens = tokens[:max_length]
return torch.tensor(tokens, dtype=torch.int, device=device)
def decode(self, tensor: torch.Tensor) -> str:
tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()
return self.processor.decode(tokens)
|