import tensorflow as tf import numpy as np import re from config import config def clean_text(text): text = text.lower() # Lowercase the text text = re.sub(r'^[^\w\s(]+', '', text) # Remove any punctuation at the start of the sentence text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE) # Remove URLs text = re.sub(r"([*'_.,!?؟،۔()])\1+", r'\1', text) # Reduce multiple instances of the same punctuation to one text = re.sub(r'([^\w\s])\1+', r'\1', text) # Reduce sequences of the same non-alphanumeric character (excluding spaces and specific punctuation) to one text = re.sub(r"[^\w,'-.?!؟،۔\s]", '', text) # Remove special characters and symbols text = re.sub(r'\s+', ' ', text).strip() # Remove extra spaces return text class Translator(tf.Module): def __init__(self, sp_model_en, sp_model_ur, transformer): self.sp_model_en = sp_model_en self.sp_model_ur = sp_model_ur self.transformer = transformer def __call__(self, sentence, max_length=config.sequence_length): sentence = clean_text(sentence) sentence = tf.constant(sentence) if len(sentence.shape) == 0: sentence = sentence[tf.newaxis] # Tokenize the English sentence sentence = self.sp_model_en.tokenize(sentence).to_tensor() encoder_input = sentence # Initialize the output for Urdu with `[START]` token start = self.sp_model_ur.tokenize([''])[0][0][tf.newaxis] end = self.sp_model_ur.tokenize([''])[0][1][tf.newaxis] output_array = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True) output_array = output_array.write(0, start) for i in tf.range(max_length): output = tf.transpose(output_array.stack()) predictions = self.transformer([encoder_input, output], training=False) predictions = predictions[:, -1:, :] # Shape `(batch_size, 1, vocab_size)` predicted_id = tf.argmax(predictions, axis=-1) predicted_id = tf.cast(predicted_id, tf.int32) output_array = output_array.write(i+1, predicted_id[0]) if predicted_id == end: break output = tf.transpose(output_array.stack()) text = self.sp_model_ur.detokenize(output)[0] # Shape: `()` return text.numpy().decode('utf-8')