|
import gradio as gr |
|
import string |
|
import re |
|
import numpy as np |
|
import tensorflow as tf |
|
from tensorflow import keras |
|
from tensorflow.keras import layers |
|
from tensorflow.keras.layers import TextVectorization |
|
import pickle |
|
|
|
strip_chars = pickle.load(open('strip_chars.pkl', 'rb')) |
|
train_pairs = pickle.load(open('train_pairs.pkl', 'rb')) |
|
|
|
vocab_size = 15000 |
|
sequence_length = 20 |
|
batch_size = 64 |
|
|
|
class TransformerEncoder(layers.Layer): |
|
def __init__(self, embed_dim, dense_dim, num_heads, **kwargs): |
|
super(TransformerEncoder, self).__init__(**kwargs) |
|
self.embed_dim = embed_dim |
|
self.dense_dim = dense_dim |
|
self.num_heads = num_heads |
|
self.attention = layers.MultiHeadAttention( |
|
num_heads=num_heads, key_dim=embed_dim |
|
) |
|
self.dense_proj = keras.Sequential( |
|
[layers.Dense(dense_dim, activation="relu"), layers.Dense(embed_dim),] |
|
) |
|
self.layernorm_1 = layers.LayerNormalization() |
|
self.layernorm_2 = layers.LayerNormalization() |
|
self.supports_masking = True |
|
|
|
def call(self, inputs, mask=None): |
|
if mask is not None: |
|
padding_mask = tf.cast(mask[:, tf.newaxis, tf.newaxis, :], dtype="int32") |
|
attention_output = self.attention( |
|
query=inputs, value=inputs, key=inputs, attention_mask=padding_mask |
|
) |
|
proj_input = self.layernorm_1(inputs + attention_output) |
|
proj_output = self.dense_proj(proj_input) |
|
return self.layernorm_2(proj_input + proj_output) |
|
|
|
|
|
class PositionalEmbedding(layers.Layer): |
|
def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs): |
|
super(PositionalEmbedding, self).__init__(**kwargs) |
|
self.token_embeddings = layers.Embedding( |
|
input_dim=vocab_size, output_dim=embed_dim |
|
) |
|
self.position_embeddings = layers.Embedding( |
|
input_dim=sequence_length, output_dim=embed_dim |
|
) |
|
self.sequence_length = sequence_length |
|
self.vocab_size = vocab_size |
|
self.embed_dim = embed_dim |
|
|
|
def call(self, inputs): |
|
length = tf.shape(inputs)[-1] |
|
positions = tf.range(start=0, limit=length, delta=1) |
|
embedded_tokens = self.token_embeddings(inputs) |
|
embedded_positions = self.position_embeddings(positions) |
|
return embedded_tokens + embedded_positions |
|
|
|
def compute_mask(self, inputs, mask=None): |
|
return tf.math.not_equal(inputs, 0) |
|
|
|
|
|
class TransformerDecoder(layers.Layer): |
|
def __init__(self, embed_dim, latent_dim, num_heads, **kwargs): |
|
super(TransformerDecoder, self).__init__(**kwargs) |
|
self.embed_dim = embed_dim |
|
self.latent_dim = latent_dim |
|
self.num_heads = num_heads |
|
self.attention_1 = layers.MultiHeadAttention( |
|
num_heads=num_heads, key_dim=embed_dim |
|
) |
|
self.attention_2 = layers.MultiHeadAttention( |
|
num_heads=num_heads, key_dim=embed_dim |
|
) |
|
self.dense_proj = keras.Sequential( |
|
[layers.Dense(latent_dim, activation="relu"), layers.Dense(embed_dim),] |
|
) |
|
self.layernorm_1 = layers.LayerNormalization() |
|
self.layernorm_2 = layers.LayerNormalization() |
|
self.layernorm_3 = layers.LayerNormalization() |
|
self.supports_masking = True |
|
|
|
def call(self, inputs, encoder_outputs, mask=None): |
|
causal_mask = self.get_causal_attention_mask(inputs) |
|
if mask is not None: |
|
padding_mask = tf.cast(mask[:, tf.newaxis, :], dtype="int32") |
|
padding_mask = tf.minimum(padding_mask, causal_mask) |
|
|
|
attention_output_1 = self.attention_1( |
|
query=inputs, value=inputs, key=inputs, attention_mask=causal_mask |
|
) |
|
out_1 = self.layernorm_1(inputs + attention_output_1) |
|
|
|
attention_output_2 = self.attention_2( |
|
query=out_1, |
|
value=encoder_outputs, |
|
key=encoder_outputs, |
|
attention_mask=padding_mask, |
|
) |
|
out_2 = self.layernorm_2(out_1 + attention_output_2) |
|
|
|
proj_output = self.dense_proj(out_2) |
|
return self.layernorm_3(out_2 + proj_output) |
|
|
|
def get_causal_attention_mask(self, inputs): |
|
input_shape = tf.shape(inputs) |
|
batch_size, sequence_length = input_shape[0], input_shape[1] |
|
i = tf.range(sequence_length)[:, tf.newaxis] |
|
j = tf.range(sequence_length) |
|
mask = tf.cast(i >= j, dtype="int32") |
|
mask = tf.reshape(mask, (1, input_shape[1], input_shape[1])) |
|
mult = tf.concat( |
|
[tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], |
|
axis=0, |
|
) |
|
return tf.tile(mask, mult) |
|
|
|
custom_objects={'TransformerEncoder': TransformerEncoder, 'TransformerDecoder': TransformerDecoder, 'PositionalEmbedding':PositionalEmbedding} |
|
transformer = keras.models.load_model("model.h5", custom_objects=custom_objects) |
|
|
|
def custom_standardization(input_string): |
|
lowercase = tf.strings.lower(input_string) |
|
return tf.strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "") |
|
|
|
|
|
eng_vectorization = TextVectorization( |
|
max_tokens=vocab_size, output_mode="int", output_sequence_length=sequence_length, |
|
) |
|
spa_vectorization = TextVectorization( |
|
max_tokens=vocab_size, |
|
output_mode="int", |
|
output_sequence_length=sequence_length + 1, |
|
standardize=custom_standardization, |
|
) |
|
train_eng_texts = [pair[0] for pair in train_pairs] |
|
train_spa_texts = [pair[1] for pair in train_pairs] |
|
eng_vectorization.adapt(train_eng_texts) |
|
spa_vectorization.adapt(train_spa_texts) |
|
|
|
inputs = gr.inputs.Textbox(lines=1, label="Text in English") |
|
outputs = [gr.outputs.Textbox(label="Translated text in Spanish")] |
|
examples=["How are you"] |
|
|
|
def get_translate(input_sentence): |
|
spa_vocab = spa_vectorization.get_vocabulary() |
|
spa_index_lookup = dict(zip(range(len(spa_vocab)), spa_vocab)) |
|
max_decoded_sentence_length = 20 |
|
tokenized_input_sentence = eng_vectorization([input_sentence]) |
|
decoded_sentence = "[start]" |
|
for i in range(max_decoded_sentence_length): |
|
tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1] |
|
predictions = transformer([tokenized_input_sentence, tokenized_target_sentence]) |
|
|
|
sampled_token_index = np.argmax(predictions[0, i, :]) |
|
sampled_token = spa_index_lookup[sampled_token_index] |
|
decoded_sentence += " " + sampled_token |
|
|
|
if sampled_token == "[end]": |
|
break |
|
return decoded_sentence.replace("[start]", "").replace("[end]", "") |
|
iface=gr.Interface(fn=get_translate,inputs=inputs, outputs=outputs, title='EnglishToSpanish Translator', examples=examples) |
|
|
|
iface.launch(debug=True) |