Eng-Urdu-Translation / load_model.py
huzaifanafees's picture
Update load_model.py
549d15e
raw
history blame
1.21 kB
from transformer import Transformer
import tensorflow_text as tf_text
import tensorflow as tf
from config import config
import h5py
def load_transformer(en_emb_matrix, de_emb_matrix, model_path, config):
# Initialize and rebuild your Transformer model
# (Make sure to replace '...' with actual parameters)
model = Transformer(
num_layers=config.num_layers,
d_model=config.embed_dim,
num_heads=config.num_heads,
en_embedding_matrix=en_emb_matrix,
de_embedding_matrix=de_emb_matrix,
dff=config.latent_dim,
input_vocab_size=config.vocab_size,
target_vocab_size=config.vocab_size,
dropout_rate=0.2
)
model.load_weights(model_path)
return model
def load_sp_model(path_en,path_ur):
sp_model_en = tf_text.SentencepieceTokenizer(model=tf.io.gfile.GFile(path_en, 'rb').read(),add_bos=True,add_eos=True)
sp_model_ur = tf_text.SentencepieceTokenizer(model=tf.io.gfile.GFile(path_ur, 'rb').read(),reverse=True,add_bos=True,add_eos=True)
return sp_model_en, sp_model_ur
def load_emb(emb_path):
with h5py.File(emb_path, 'r') as hf:
embedding_matrix = hf['embeddings'][:]
return embedding_matrix