File size: 1,674 Bytes
ef7bae2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from tensorflow.core.protobuf.saved_model_pb2 import SavedModel
from tensorflow.python.saved_model.loader_impl import parse_saved_model
from tensorflow_text.python.ops.sentencepiece_tokenizer import SentencepieceTokenizer


def _get_tokenizer_from_saved_model(saved_model: SavedModel) -> SentencepieceTokenizer:
    """
    Get tokenizer from tf SavedModel.
    :param SavedModel saved_model: tf SavedModel.
    :return: tokenizer.
    :rtype: SentencepieceTokenizer
    """

    # extract functions that contain SentencePiece somewhere in there
    functions_with_sp = [
        f
        for f in saved_model.meta_graphs[0].graph_def.library.function
        if "tokenizer" in str(f).lower()
    ]

    assert (
        len(functions_with_sp) == 1
    ), f"len(functions_with_sp) = {len(functions_with_sp)}"

    # find SentencePieceOp (contains the model) in the found function
    nodes_with_sp = [
        n for n in functions_with_sp[0].node_def if n.op == "SentencepieceOp"
    ]

    assert len(nodes_with_sp) == 1, f"len(nodes_with_sp) = {len(nodes_with_sp)}"

    # we can pretty much save the model into a file since it does not change
    model = nodes_with_sp[0].attr["model"].s

    # instantiate the model
    tokenizer = SentencepieceTokenizer(model)

    return tokenizer


def get_tokenizer(model_path: str) -> SentencepieceTokenizer:
    tokenizer = _get_tokenizer_from_saved_model(parse_saved_model(model_path))
    return tokenizer


def tokenize(
    sentence: str,  # TODO: add batch processing
    tokenizer: SentencepieceTokenizer,
) -> torch.Tensor:
    return torch.LongTensor([1] + tokenizer.tokenize([sentence]).to_list()[0] + [2])