dayyass's picture
Upload 3 files
ef7bae2 verified
raw
history blame
No virus
1.67 kB
import torch
from tensorflow.core.protobuf.saved_model_pb2 import SavedModel
from tensorflow.python.saved_model.loader_impl import parse_saved_model
from tensorflow_text.python.ops.sentencepiece_tokenizer import SentencepieceTokenizer
def _get_tokenizer_from_saved_model(saved_model: SavedModel) -> SentencepieceTokenizer:
"""
Get tokenizer from tf SavedModel.
:param SavedModel saved_model: tf SavedModel.
:return: tokenizer.
:rtype: SentencepieceTokenizer
"""
# extract functions that contain SentencePiece somewhere in there
functions_with_sp = [
f
for f in saved_model.meta_graphs[0].graph_def.library.function
if "tokenizer" in str(f).lower()
]
assert (
len(functions_with_sp) == 1
), f"len(functions_with_sp) = {len(functions_with_sp)}"
# find SentencePieceOp (contains the model) in the found function
nodes_with_sp = [
n for n in functions_with_sp[0].node_def if n.op == "SentencepieceOp"
]
assert len(nodes_with_sp) == 1, f"len(nodes_with_sp) = {len(nodes_with_sp)}"
# we can pretty much save the model into a file since it does not change
model = nodes_with_sp[0].attr["model"].s
# instantiate the model
tokenizer = SentencepieceTokenizer(model)
return tokenizer
def get_tokenizer(model_path: str) -> SentencepieceTokenizer:
tokenizer = _get_tokenizer_from_saved_model(parse_saved_model(model_path))
return tokenizer
def tokenize(
sentence: str, # TODO: add batch processing
tokenizer: SentencepieceTokenizer,
) -> torch.Tensor:
return torch.LongTensor([1] + tokenizer.tokenize([sentence]).to_list()[0] + [2])