File size: 1,208 Bytes
97fe9c2
 
 
 
549d15e
97fe9c2
 
c813716
97fe9c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c813716
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from transformer import Transformer
import tensorflow_text as tf_text
import tensorflow as tf
from config import config
import h5py


def load_transformer(en_emb_matrix, de_emb_matrix, model_path, config):
    # Initialize and rebuild your Transformer model
    # (Make sure to replace '...' with actual parameters)
    model = Transformer(
        num_layers=config.num_layers,
        d_model=config.embed_dim,
        num_heads=config.num_heads,
        en_embedding_matrix=en_emb_matrix,
        de_embedding_matrix=de_emb_matrix,
        dff=config.latent_dim,
        input_vocab_size=config.vocab_size,
        target_vocab_size=config.vocab_size,
        dropout_rate=0.2
    )
    model.load_weights(model_path)
    return model

def load_sp_model(path_en,path_ur):
    sp_model_en = tf_text.SentencepieceTokenizer(model=tf.io.gfile.GFile(path_en, 'rb').read(),add_bos=True,add_eos=True)
    sp_model_ur = tf_text.SentencepieceTokenizer(model=tf.io.gfile.GFile(path_ur, 'rb').read(),reverse=True,add_bos=True,add_eos=True)
    return sp_model_en, sp_model_ur

def load_emb(emb_path):
    with h5py.File(emb_path, 'r') as hf:
        embedding_matrix = hf['embeddings'][:]
    return embedding_matrix