|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import logging |
|
import tempfile |
|
from typing import Dict, List, Optional, Tuple |
|
|
|
import tokenizers |
|
import tokenizers.models |
|
import torch |
|
import tqdm |
|
import transformers |
|
from pydantic import BaseModel |
|
|
|
from mergekit.common import ModelPath, ModelReference |
|
from mergekit.graph import Task |
|
|
|
|
|
def get_vocab_size(model_path: ModelPath, trust_remote_code: bool) -> Optional[int]: |
|
try: |
|
cfg = transformers.AutoConfig.from_pretrained( |
|
model_path.path, |
|
revision=model_path.revision, |
|
trust_remote_code=trust_remote_code, |
|
) |
|
return cfg.vocab_size |
|
except Exception as e: |
|
logging.warning(f"Unable to get vocab size for {model_path}", exc_info=e) |
|
|
|
return None |
|
|
|
|
|
def get_stripped_tokenizer( |
|
path: ModelPath, trust_remote_code: bool = False |
|
) -> transformers.PreTrainedTokenizerFast: |
|
""" |
|
Return a tokenizer for a model that only contains used tokens. |
|
|
|
Strips any tokens with indices >= model.vocab_size. |
|
""" |
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
path.path, |
|
revision=path.revision, |
|
trust_remote_code=trust_remote_code, |
|
use_fast=True, |
|
) |
|
vocab_size = get_vocab_size(path, trust_remote_code=trust_remote_code) or len( |
|
tokenizer.get_vocab() |
|
) |
|
|
|
unused_toks = [ |
|
tok for tok, idx in tokenizer.get_vocab().items() if idx >= vocab_size |
|
] |
|
if not unused_toks: |
|
|
|
return tokenizer |
|
|
|
if not tokenizer.is_fast: |
|
raise RuntimeError( |
|
f"Model {path} has unused tokens and does not support fast " |
|
"tokenizer - can not be used in tokenizer merge" |
|
) |
|
|
|
tok_dict = json.loads(tokenizer._tokenizer.to_str()) |
|
if tok_dict["model"]["type"] != "BPE": |
|
raise RuntimeError( |
|
f"Tokenizer for {path} has type {tok_dict['model']['type']}, " |
|
"but only BPE is currently supported for tokenizer merge" |
|
) |
|
|
|
tok_dict["added_tokens"] = [ |
|
e for e in tok_dict["added_tokens"] if e["id"] < vocab_size |
|
] |
|
|
|
for tok in unused_toks: |
|
if tok in tok_dict["model"]["vocab"]: |
|
del tok_dict["model"]["vocab"][tok] |
|
|
|
def _keep_merge(m): |
|
toks = m.split(" ") |
|
for tok in toks: |
|
if tok in unused_toks: |
|
return False |
|
return True |
|
|
|
tok_dict["model"]["merges"] = [ |
|
e for e in tok_dict["model"]["merges"] if _keep_merge(e) |
|
] |
|
tokenizer._tokenizer = tokenizers.Tokenizer.from_str(json.dumps(tok_dict)) |
|
return tokenizer |
|
|
|
|
|
def build_union_tokenizer( |
|
base_tok: transformers.PreTrainedTokenizerBase, |
|
tokenizers: Dict[ModelReference, transformers.PreTrainedTokenizerBase], |
|
trust_remote_code: bool = False, |
|
) -> transformers.PreTrainedTokenizerBase: |
|
out_added_tokens = {} |
|
out_vocab = {} |
|
|
|
warned_added_tokens = set() |
|
|
|
for model, tokenizer in tokenizers.items(): |
|
vocab_size = ( |
|
get_vocab_size(model.model, trust_remote_code=trust_remote_code) |
|
or tokenizer.vocab_size |
|
) |
|
added_tokens = tokenizer.added_tokens_decoder |
|
|
|
vocab = tokenizer.get_vocab() |
|
for tok, idx in vocab.items(): |
|
if idx >= vocab_size: |
|
logging.warning( |
|
f"Token {repr(tok)} present in {str(model)} tokenizer but >= vocab_size" |
|
) |
|
continue |
|
if tok in added_tokens: |
|
|
|
continue |
|
|
|
if tok not in out_vocab: |
|
out_vocab[tok] = len(out_vocab) |
|
|
|
for tok_idx, info in tokenizer.added_tokens_decoder.items(): |
|
tok = info.content |
|
if tok_idx >= vocab_size: |
|
continue |
|
|
|
if tok in out_added_tokens: |
|
if (out_added_tokens[tok] != info) and tok not in warned_added_tokens: |
|
logging.warning( |
|
f"Token '{tok}' added with multiple different settings, using first" |
|
) |
|
warned_added_tokens.add(tok) |
|
|
|
continue |
|
out_added_tokens[tok] = info |
|
|
|
|
|
with tempfile.TemporaryDirectory() as p: |
|
base_tok.save_pretrained(p, legacy_format=False, safe_serialization=True) |
|
res = transformers.AutoTokenizer.from_pretrained( |
|
p, use_fast=True, trust_remote_code=trust_remote_code |
|
) |
|
|
|
orig_base_vocab = base_tok.get_vocab() |
|
for tok in out_vocab: |
|
if tok in out_added_tokens: |
|
continue |
|
|
|
if tok not in orig_base_vocab: |
|
res.add_tokens(tok) |
|
|
|
for info in out_added_tokens.values(): |
|
res.add_tokens(info) |
|
return res |
|
|
|
|
|
def build_tokenizer( |
|
base_model: Optional[ModelReference], |
|
referenced_models: List[ModelReference], |
|
tokenizer_source: str, |
|
trust_remote_code: bool, |
|
) -> Tuple[transformers.PreTrainedTokenizer, Dict[ModelReference, torch.IntTensor]]: |
|
if base_model is None: |
|
base_model = referenced_models[0] |
|
if base_model is None: |
|
raise RuntimeError("No models referenced") |
|
|
|
|
|
tokenizer_base = get_stripped_tokenizer( |
|
base_model.model, trust_remote_code=trust_remote_code |
|
) |
|
|
|
|
|
logging.info("Loading tokenizers") |
|
tokenizers = {base_model: tokenizer_base} |
|
for model in referenced_models: |
|
if model == base_model: |
|
continue |
|
|
|
try: |
|
model_tok = transformers.AutoTokenizer.from_pretrained( |
|
model.model.path, |
|
revision=model.model.revision, |
|
trust_remote_code=trust_remote_code, |
|
) |
|
except Exception as e: |
|
logging.error(e) |
|
logging.warning( |
|
f"Unable to load tokenizer for {model}. Assuming same as {base_model}." |
|
) |
|
continue |
|
tokenizers[model] = model_tok |
|
|
|
logging.info("Building output tokenizer") |
|
|
|
if tokenizer_source == "base": |
|
|
|
tokenizer_out = tokenizer_base |
|
elif tokenizer_source == "union": |
|
tokenizer_out = build_union_tokenizer( |
|
tokenizer_base, tokenizers, trust_remote_code=trust_remote_code |
|
) |
|
elif tokenizer_source.startswith("model:"): |
|
tokenizer_out = transformers.AutoTokenizer.from_pretrained( |
|
tokenizer_source[len("model:") :], |
|
trust_remote_code=trust_remote_code, |
|
) |
|
else: |
|
raise RuntimeError(f"Unimplemented tokenizer source: {tokenizer_source}") |
|
|
|
vocab_out = tokenizer_out.get_vocab() |
|
|
|
logging.info("Building permutations") |
|
permutations = {} |
|
for model in tqdm.tqdm(referenced_models): |
|
if model in tokenizers: |
|
model_vocab = tokenizers[model].get_vocab() |
|
else: |
|
model_vocab = tokenizers[base_model].get_vocab() |
|
|
|
vocab_size = get_vocab_size(model.model, trust_remote_code=trust_remote_code) |
|
if vocab_size is None: |
|
vocab_size = len(model_vocab) |
|
|
|
p = {} |
|
for tok in vocab_out: |
|
new_idx = vocab_out[tok] |
|
if tok not in model_vocab: |
|
p[new_idx] = -1 |
|
continue |
|
|
|
orig_idx = model_vocab[tok] |
|
if orig_idx >= vocab_size: |
|
logging.warning( |
|
f"{model} token {repr(tok)} has index {orig_idx}>{vocab_size-1} (padding?)" |
|
) |
|
continue |
|
|
|
p[new_idx] = orig_idx |
|
|
|
permutations[model] = p |
|
|
|
return tokenizer_out, permutations |
|
|
|
|
|
class TokenizerInfo(BaseModel, arbitrary_types_allowed=True): |
|
tokenizer: transformers.PreTrainedTokenizerBase |
|
permutations: Optional[Dict[ModelReference, Dict[int, int]]] |
|
|
|
|
|
class BuildTokenizer(Task[TokenizerInfo]): |
|
base_model: Optional[ModelReference] |
|
referenced_models: Tuple[ModelReference, ...] |
|
tokenizer_source: str |
|
trust_remote_code: bool = False |
|
|
|
def arguments(self) -> Dict[str, Task]: |
|
return {} |
|
|
|
def execute(self, **_kwargs) -> TokenizerInfo: |
|
tokenizer, permutations = build_tokenizer( |
|
self.base_model, |
|
self.referenced_models, |
|
self.tokenizer_source, |
|
self.trust_remote_code, |
|
) |
|
return TokenizerInfo(tokenizer=tokenizer, permutations=permutations) |
|
|