File size: 9,127 Bytes
a164e13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

import json
import logging
import tempfile
from typing import Dict, List, Optional, Tuple

import tokenizers
import tokenizers.models
import torch
import tqdm
import transformers
from pydantic import BaseModel

from mergekit.common import ModelPath, ModelReference
from mergekit.graph import Task


def get_vocab_size(model_path: ModelPath, trust_remote_code: bool) -> Optional[int]:
    try:
        cfg = transformers.AutoConfig.from_pretrained(
            model_path.path,
            revision=model_path.revision,
            trust_remote_code=trust_remote_code,
        )
        return cfg.vocab_size
    except Exception as e:
        logging.warning(f"Unable to get vocab size for {model_path}", exc_info=e)

    return None


def get_stripped_tokenizer(
    path: ModelPath, trust_remote_code: bool = False
) -> transformers.PreTrainedTokenizerFast:
    """
    Return a tokenizer for a model that only contains used tokens.

    Strips any tokens with indices >= model.vocab_size.
    """
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        path.path,
        revision=path.revision,
        trust_remote_code=trust_remote_code,
        use_fast=True,
    )
    vocab_size = get_vocab_size(path, trust_remote_code=trust_remote_code) or len(
        tokenizer.get_vocab()
    )

    unused_toks = [
        tok for tok, idx in tokenizer.get_vocab().items() if idx >= vocab_size
    ]
    if not unused_toks:
        # we're good, ship it
        return tokenizer

    if not tokenizer.is_fast:
        raise RuntimeError(
            f"Model {path} has unused tokens and does not support fast "
            "tokenizer - can not be used in tokenizer merge"
        )

    tok_dict = json.loads(tokenizer._tokenizer.to_str())
    if tok_dict["model"]["type"] != "BPE":
        raise RuntimeError(
            f"Tokenizer for {path} has type {tok_dict['model']['type']}, "
            "but only BPE is currently supported for tokenizer merge"
        )

    tok_dict["added_tokens"] = [
        e for e in tok_dict["added_tokens"] if e["id"] < vocab_size
    ]

    for tok in unused_toks:
        if tok in tok_dict["model"]["vocab"]:
            del tok_dict["model"]["vocab"][tok]

    def _keep_merge(m):
        toks = m.split(" ")
        for tok in toks:
            if tok in unused_toks:
                return False
        return True

    tok_dict["model"]["merges"] = [
        e for e in tok_dict["model"]["merges"] if _keep_merge(e)
    ]
    tokenizer._tokenizer = tokenizers.Tokenizer.from_str(json.dumps(tok_dict))
    return tokenizer


def build_union_tokenizer(
    base_tok: transformers.PreTrainedTokenizerBase,
    tokenizers: Dict[ModelReference, transformers.PreTrainedTokenizerBase],
    trust_remote_code: bool = False,
) -> transformers.PreTrainedTokenizerBase:
    out_added_tokens = {}
    out_vocab = {}

    warned_added_tokens = set()

    for model, tokenizer in tokenizers.items():
        vocab_size = (
            get_vocab_size(model.model, trust_remote_code=trust_remote_code)
            or tokenizer.vocab_size
        )
        added_tokens = tokenizer.added_tokens_decoder

        vocab = tokenizer.get_vocab()
        for tok, idx in vocab.items():
            if idx >= vocab_size:
                logging.warning(
                    f"Token {repr(tok)} present in {str(model)} tokenizer but >= vocab_size"
                )
                continue
            if tok in added_tokens:
                # deal with later
                continue

            if tok not in out_vocab:
                out_vocab[tok] = len(out_vocab)

        for tok_idx, info in tokenizer.added_tokens_decoder.items():
            tok = info.content
            if tok_idx >= vocab_size:
                continue

            if tok in out_added_tokens:
                if (out_added_tokens[tok] != info) and tok not in warned_added_tokens:
                    logging.warning(
                        f"Token '{tok}' added with multiple different settings, using first"
                    )
                    warned_added_tokens.add(tok)

                continue
            out_added_tokens[tok] = info

    # HACK: save base tokenizer to temp dir and reload to avoid mutating base_tok
    with tempfile.TemporaryDirectory() as p:
        base_tok.save_pretrained(p, legacy_format=False, safe_serialization=True)
        res = transformers.AutoTokenizer.from_pretrained(
            p, use_fast=True, trust_remote_code=trust_remote_code
        )

    orig_base_vocab = base_tok.get_vocab()
    for tok in out_vocab:
        if tok in out_added_tokens:
            continue

        if tok not in orig_base_vocab:
            res.add_tokens(tok)

    for info in out_added_tokens.values():
        res.add_tokens(info)
    return res


def build_tokenizer(
    base_model: Optional[ModelReference],
    referenced_models: List[ModelReference],
    tokenizer_source: str,
    trust_remote_code: bool,
) -> Tuple[transformers.PreTrainedTokenizer, Dict[ModelReference, torch.IntTensor]]:
    if base_model is None:
        base_model = referenced_models[0]
    if base_model is None:
        raise RuntimeError("No models referenced")

    #
    tokenizer_base = get_stripped_tokenizer(
        base_model.model, trust_remote_code=trust_remote_code
    )

    # load all tokenizers
    logging.info("Loading tokenizers")
    tokenizers = {base_model: tokenizer_base}
    for model in referenced_models:
        if model == base_model:
            continue

        try:
            model_tok = transformers.AutoTokenizer.from_pretrained(
                model.model.path,
                revision=model.model.revision,
                trust_remote_code=trust_remote_code,
            )
        except Exception as e:
            logging.error(e)
            logging.warning(
                f"Unable to load tokenizer for {model}. Assuming same as {base_model}."
            )
            continue
        tokenizers[model] = model_tok

    logging.info("Building output tokenizer")
    # build final vocabulary
    if tokenizer_source == "base":
        # it done
        tokenizer_out = tokenizer_base
    elif tokenizer_source == "union":
        tokenizer_out = build_union_tokenizer(
            tokenizer_base, tokenizers, trust_remote_code=trust_remote_code
        )
    elif tokenizer_source.startswith("model:"):
        tokenizer_out = transformers.AutoTokenizer.from_pretrained(
            tokenizer_source[len("model:") :],
            trust_remote_code=trust_remote_code,
        )
    else:
        raise RuntimeError(f"Unimplemented tokenizer source: {tokenizer_source}")

    vocab_out = tokenizer_out.get_vocab()

    logging.info("Building permutations")
    permutations = {}
    for model in tqdm.tqdm(referenced_models):
        if model in tokenizers:
            model_vocab = tokenizers[model].get_vocab()
        else:
            model_vocab = tokenizers[base_model].get_vocab()

        vocab_size = get_vocab_size(model.model, trust_remote_code=trust_remote_code)
        if vocab_size is None:
            vocab_size = len(model_vocab)

        p = {}
        for tok in vocab_out:
            new_idx = vocab_out[tok]
            if tok not in model_vocab:
                p[new_idx] = -1
                continue

            orig_idx = model_vocab[tok]
            if orig_idx >= vocab_size:
                logging.warning(
                    f"{model} token {repr(tok)} has index {orig_idx}>{vocab_size-1} (padding?)"
                )
                continue

            p[new_idx] = orig_idx

        permutations[model] = p

    return tokenizer_out, permutations


class TokenizerInfo(BaseModel, arbitrary_types_allowed=True):
    tokenizer: transformers.PreTrainedTokenizerBase
    permutations: Optional[Dict[ModelReference, Dict[int, int]]]


class BuildTokenizer(Task[TokenizerInfo]):
    base_model: Optional[ModelReference]
    referenced_models: Tuple[ModelReference, ...]
    tokenizer_source: str
    trust_remote_code: bool = False

    def arguments(self) -> Dict[str, Task]:
        return {}

    def execute(self, **_kwargs) -> TokenizerInfo:
        tokenizer, permutations = build_tokenizer(
            self.base_model,
            self.referenced_models,
            self.tokenizer_source,
            self.trust_remote_code,
        )
        return TokenizerInfo(tokenizer=tokenizer, permutations=permutations)