|
import json |
|
import os |
|
import tempfile |
|
from typing import List, Optional, Union |
|
|
|
import pytest |
|
import tokenizers |
|
from common import make_picollama, run_and_check_merge |
|
from transformers import LlamaTokenizerFast, PreTrainedTokenizerBase |
|
|
|
from mergekit.config import InputModelDefinition, MergeConfiguration, ParameterSetting |
|
|
|
|
|
@pytest.fixture(scope="session") |
|
def model_base(tmp_path_factory): |
|
model_path = make_picollama(tmp_path_factory.mktemp("model_base"), vocab_size=64) |
|
make_tokenizer(vocab_size=64, added_tokens=[]).save_pretrained(model_path) |
|
return model_path |
|
|
|
|
|
@pytest.fixture(scope="session") |
|
def model_chatml(tmp_path_factory): |
|
model_path = make_picollama(tmp_path_factory.mktemp("model_chatml"), vocab_size=66) |
|
make_tokenizer( |
|
vocab_size=64, added_tokens=["<|im_start|>", "<|im_end|>"] |
|
).save_pretrained(model_path) |
|
return model_path |
|
|
|
|
|
@pytest.fixture(scope="session") |
|
def model_padded(tmp_path_factory): |
|
model_path = make_picollama(tmp_path_factory.mktemp("model_padded"), vocab_size=64) |
|
make_tokenizer( |
|
vocab_size=64, |
|
added_tokens=["<UNUSED_0>", "<UNUSED_1>", "<UNUSED_2>", "<UNUSED_3>"], |
|
).save_pretrained(model_path) |
|
return model_path |
|
|
|
|
|
def make_tokenizer( |
|
vocab_size: int, added_tokens: List[Union[str, tokenizers.AddedToken]] |
|
) -> PreTrainedTokenizerBase: |
|
tokens = ["<unk>", "<s>", "</s>"] + [f"_tok_{idx}" for idx in range(3, vocab_size)] |
|
tokens = tokens[:vocab_size] |
|
tok_data = { |
|
"version": "1.0", |
|
"model": { |
|
"type": "BPE", |
|
"vocab": dict(zip(tokens, range(vocab_size))), |
|
"merges": [], |
|
}, |
|
"added_tokens": [], |
|
} |
|
tok = tokenizers.Tokenizer.from_str(json.dumps(tok_data)) |
|
with tempfile.TemporaryDirectory() as p: |
|
tok_path = os.path.join(p, "tokenizer.json") |
|
tok.save(tok_path) |
|
res = LlamaTokenizerFast(tokenizer_file=tok_path) |
|
|
|
res.add_tokens(added_tokens) |
|
return res |
|
|
|
|
|
def check_tokenizer( |
|
expected_size: int, |
|
expected_added_ct: Optional[int] = None, |
|
must_contain: Optional[List[str]] = None, |
|
must_not_contain: Optional[List[str]] = None, |
|
): |
|
def _cb(model_path: str): |
|
tok: LlamaTokenizerFast = LlamaTokenizerFast.from_pretrained(model_path) |
|
vocab = tok.get_vocab() |
|
print(vocab) |
|
assert len(vocab) == expected_size |
|
|
|
if expected_added_ct is not None: |
|
assert len(tok.added_tokens_decoder) == expected_added_ct |
|
|
|
if must_contain: |
|
for tok in must_contain: |
|
assert tok in vocab |
|
|
|
if must_not_contain: |
|
for tok in must_not_contain: |
|
assert tok not in vocab |
|
|
|
return _cb |
|
|
|
|
|
class TestTokenizerMerges: |
|
def test_legacy_mode(self, model_base: str, model_padded: str, model_chatml: str): |
|
config = self.make_config( |
|
[model_base, model_padded, model_chatml], base_model=model_base |
|
) |
|
|
|
run_and_check_merge( |
|
config, validate=check_tokenizer(expected_size=64, expected_added_ct=3) |
|
) |
|
|
|
def test_source_base(self, model_base: str, model_padded: str, model_chatml: str): |
|
config = self.make_config( |
|
[model_base, model_padded, model_chatml], |
|
base_model=model_base, |
|
tokenizer_source="base", |
|
) |
|
|
|
run_and_check_merge( |
|
config, validate=check_tokenizer(expected_size=64, expected_added_ct=3) |
|
) |
|
|
|
def test_source_union(self, model_base: str, model_padded: str, model_chatml: str): |
|
config = self.make_config( |
|
[model_base, model_padded, model_chatml], |
|
base_model=model_base, |
|
tokenizer_source="union", |
|
) |
|
|
|
|
|
|
|
run_and_check_merge( |
|
config, |
|
validate=check_tokenizer( |
|
expected_size=66, |
|
expected_added_ct=5, |
|
must_contain=["<|im_start|>", "<|im_end|>"], |
|
must_not_contain=[f"<UNUSED_{idx}>" for idx in range(4)], |
|
), |
|
) |
|
|
|
def test_source_model(self, model_base: str, model_padded: str, model_chatml: str): |
|
config = self.make_config( |
|
[model_base, model_padded, model_chatml], |
|
base_model=model_base, |
|
tokenizer_source="model:" + model_chatml, |
|
) |
|
|
|
run_and_check_merge( |
|
config, |
|
validate=check_tokenizer( |
|
expected_size=66, must_contain=["<|im_start|>", "<|im_end|>"] |
|
), |
|
) |
|
|
|
def test_slerp_union(self, model_base: str, model_chatml: str): |
|
config = self.make_config( |
|
[model_base, model_chatml], |
|
base_model=model_base, |
|
tokenizer_source="union", |
|
merge_method="slerp", |
|
embed_slerp=True, |
|
t="0.5", |
|
) |
|
|
|
run_and_check_merge( |
|
config, |
|
validate=check_tokenizer( |
|
expected_size=66, |
|
must_contain=["<|im_start|>", "<|im_end|>"], |
|
), |
|
) |
|
|
|
def make_config( |
|
self, |
|
models: List[str], |
|
base_model: Optional[str] = None, |
|
merge_method: str = "linear", |
|
tokenizer_source: Optional[str] = None, |
|
embed_slerp: bool = False, |
|
t: Optional[ParameterSetting] = None, |
|
): |
|
parameters = {"embed_slerp": embed_slerp} |
|
if t is not None: |
|
parameters["t"] = t |
|
|
|
config = MergeConfiguration( |
|
merge_method=merge_method, |
|
base_model=base_model, |
|
models=[ |
|
InputModelDefinition( |
|
model=m, |
|
parameters={"weight": 1.0}, |
|
) |
|
for m in models |
|
], |
|
dtype="bfloat16", |
|
tokenizer_source=tokenizer_source, |
|
parameters=parameters, |
|
) |
|
return config |
|
|