|
import torch |
|
from tensorflow.core.protobuf.saved_model_pb2 import SavedModel |
|
from tensorflow.python.saved_model.loader_impl import parse_saved_model |
|
from tensorflow_text.python.ops.sentencepiece_tokenizer import SentencepieceTokenizer |
|
|
|
|
|
def _get_tokenizer_from_saved_model(saved_model: SavedModel) -> SentencepieceTokenizer: |
|
""" |
|
Get tokenizer from tf SavedModel. |
|
:param SavedModel saved_model: tf SavedModel. |
|
:return: tokenizer. |
|
:rtype: SentencepieceTokenizer |
|
""" |
|
|
|
|
|
functions_with_sp = [ |
|
f |
|
for f in saved_model.meta_graphs[0].graph_def.library.function |
|
if "tokenizer" in str(f).lower() |
|
] |
|
|
|
assert ( |
|
len(functions_with_sp) == 1 |
|
), f"len(functions_with_sp) = {len(functions_with_sp)}" |
|
|
|
|
|
nodes_with_sp = [ |
|
n for n in functions_with_sp[0].node_def if n.op == "SentencepieceOp" |
|
] |
|
|
|
assert len(nodes_with_sp) == 1, f"len(nodes_with_sp) = {len(nodes_with_sp)}" |
|
|
|
|
|
model = nodes_with_sp[0].attr["model"].s |
|
|
|
|
|
tokenizer = SentencepieceTokenizer(model) |
|
|
|
return tokenizer |
|
|
|
|
|
def get_tokenizer(model_path: str) -> SentencepieceTokenizer: |
|
tokenizer = _get_tokenizer_from_saved_model(parse_saved_model(model_path)) |
|
return tokenizer |
|
|
|
|
|
def tokenize( |
|
sentence: str, |
|
tokenizer: SentencepieceTokenizer, |
|
) -> torch.Tensor: |
|
return torch.LongTensor([1] + tokenizer.tokenize([sentence]).to_list()[0] + [2]) |
|
|