In [None]:
!pip install torch


from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import os
from argparse import Namespace
from collections import Counter
import json
import re
import string
import datetime

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.optim as optima
from torch.utils.data import Dataset, DataLoader






class Vocabulary(object):
 """Class to process text and extract vocabulary for mapping"""

 def __init__(self, token_to_idx=None):
 """
 Args:
 token_to_idx (dict): a pre-existing map of tokens to indices
 """

 if token_to_idx is None:
 token_to_idx = {}
 self._token_to_idx = token_to_idx

 self._idx_to_token = {idx: token 
 for token, idx in self._token_to_idx.items()}
 
 def to_serializable(self):
 """ returns a dictionary that can be serialized """
 return {'token_to_idx': self._token_to_idx}

 @classmethod
 def from_serializable(cls, contents):
 """ instantiates the Vocabulary from a serialized dictionary """
 return cls(**contents)

 def add_token(self, token):
 """Update mapping dicts based on the token.

 Args:
 token (str): the item to add into the Vocabulary
 Returns:
 index (int): the integer corresponding to the token
 """
 if token in self._token_to_idx:
 index = self._token_to_idx[token]
 else:
 index = len(self._token_to_idx)
 self._token_to_idx[token] = index
 self._idx_to_token[index] = token
 return index
 
 def add_many(self, tokens):
 """Add a list of tokens into the Vocabulary
 
 Args:
 tokens (list): a list of string tokens
 Returns:
 indices (list): a list of indices corresponding to the tokens
 """
 return [self.add_token(token) for token in tokens]

 def lookup_token(self, token):
 """Retrieve the index associated with the token 
 
 Args:
 token (str): the token to look up 
 Returns:
 index (int): the index corresponding to the token
 """
 return self._token_to_idx[token]

 def lookup_index(self, index):
 """Return the token associated with the index
 
 Args: 
 index (int): the index to look up
 Returns:
 token (str): the token corresponding to the index
 Raises:
 KeyError: if the index is not in the Vocabulary
 """
 if index not in self._idx_to_token:
 raise KeyError("the index (%d) is not in the Vocabulary" % index)
 return self._idx_to_token[index]

 def __str__(self):
 return "" % len(self)

 def __len__(self):
 return len(self._token_to_idx)
 




class SequenceVocabulary(Vocabulary):
 def __init__(self, token_to_idx=None, unk_token="",
 mask_token="", begin_seq_token="",
 end_seq_token=""):

 super(SequenceVocabulary, self).__init__(token_to_idx)

 self._mask_token = mask_token
 self._unk_token = unk_token
 self._begin_seq_token = begin_seq_token
 self._end_seq_token = end_seq_token

 self.mask_index = self.add_token(self._mask_token)
 self.unk_index = self.add_token(self._unk_token)
 self.begin_seq_index = self.add_token(self._begin_seq_token)
 self.end_seq_index = self.add_token(self._end_seq_token)

 def to_serializable(self):
 contents = super(SequenceVocabulary, self).to_serializable()
 contents.update({'unk_token': self._unk_token,
 'mask_token': self._mask_token,
 'begin_seq_token': self._begin_seq_token,
 'end_seq_token': self._end_seq_token})
 return contents

 def lookup_token(self, token):
 """Retrieve the index associated with the token 
 or the UNK index if token isn't present.
 
 Args:
 token (str): the token to look up 
 Returns:
 index (int): the index corresponding to the token
 Notes:
 `unk_index` needs to be >=0 (having been added into the Vocabulary) 
 for the UNK functionality 
 """
 if self.unk_index >= 0:
 return self._token_to_idx.get(token, self.unk_index)
 else:
 return self._token_to_idx[token]
 



class NMTVectorizer(object):
 """ The Vectorizer which coordinates the Vocabularies and puts them to use""" 
 def __init__(self, source_vocab, target_vocab, max_source_length, max_target_length):
 """
 Args:
 source_vocab (SequenceVocabulary): maps source words to integers
 target_vocab (SequenceVocabulary): maps target words to integers
 max_source_length (int): the longest sequence in the source dataset
 max_target_length (int): the longest sequence in the target dataset
 """
 self.source_vocab = source_vocab
 self.target_vocab = target_vocab
 
 self.max_source_length = max_source_length
 self.max_target_length = max_target_length
 

 def _vectorize(self, indices, vector_length=-1, mask_index=0):
 """Vectorize the provided indices
 
 Args:
 indices (list): a list of integers that represent a sequence
 vector_length (int): an argument for forcing the length of index vector
 mask_index (int): the mask_index to use; almost always 0
 """
 if vector_length < 0:
 vector_length = len(indices)
 
 vector = np.zeros(vector_length, dtype=np.int64)
 vector[:len(indices)] = indices
 vector[len(indices):] = mask_index

 return vector
 
 def _get_source_indices(self, text):
 """Return the vectorized source text
 
 Args:
 text (str): the source text; tokens should be separated by spaces
 Returns:
 indices (list): list of integers representing the text
 """
 indices = [self.source_vocab.begin_seq_index]
 indices.extend(self.source_vocab.lookup_token(token) for token in text.split(" "))
 indices.append(self.source_vocab.end_seq_index)
 return indices
 
 def _get_target_indices(self, text):
 """Return the vectorized source text
 
 Args:
 text (str): the source text; tokens should be separated by spaces
 Returns:
 a tuple: (x_indices, y_indices)
 x_indices (list): list of integers representing the observations in target decoder 
 y_indices (list): list of integers representing predictions in target decoder
 """
 indices = [self.target_vocab.lookup_token(token) for token in text.split(" ")]
 x_indices = [self.target_vocab.begin_seq_index] + indices
 y_indices = indices + [self.target_vocab.end_seq_index]
 return x_indices, y_indices
 
 def vectorize(self, source_text, target_text, use_dataset_max_lengths=True):
 """Return the vectorized source and target text
 
 The vetorized source text is just the a single vector.
 The vectorized target text is split into two vectors in a similar style to 
 the surname modeling in Chapter 7.
 At each timestep, the first vector is the observation and the second vector is the target. 
 
 
 Args:
 source_text (str): text from the source language
 target_text (str): text from the target language
 use_dataset_max_lengths (bool): whether to use the global max vector lengths
 Returns:
 The vectorized data point as a dictionary with the keys: 
 source_vector, target_x_vector, target_y_vector, source_length
 """
 source_vector_length = -1
 target_vector_length = -1
 
 if use_dataset_max_lengths:
 source_vector_length = self.max_source_length + 2
 target_vector_length = self.max_target_length + 1
 
 source_indices = self._get_source_indices(source_text)
 source_vector = self._vectorize(source_indices, 
 vector_length=source_vector_length, 
 mask_index=self.source_vocab.mask_index)
 
 target_x_indices, target_y_indices = self._get_target_indices(target_text)
 target_x_vector = self._vectorize(target_x_indices,
 vector_length=target_vector_length,
 mask_index=self.target_vocab.mask_index)
 target_y_vector = self._vectorize(target_y_indices,
 vector_length=target_vector_length,
 mask_index=self.target_vocab.mask_index)
 return {"source_vector": source_vector, 
 "target_x_vector": target_x_vector, 
 "target_y_vector": target_y_vector, 
 "source_length": len(source_indices)}
 
 @classmethod
 def from_dataframe(cls, bitext_df):
 """Instantiate the vectorizer from the dataset dataframe
 
 Args:
 bitext_df (pandas.DataFrame): the parallel text dataset
 Returns:
 an instance of the NMTVectorizer
 """
 source_vocab = SequenceVocabulary()
 target_vocab = SequenceVocabulary()
 
 max_source_length = 50
 max_target_length = 25

 for _, row in bitext_df.iterrows():
 source_tokens = row["source_language"].split(" ")
 if len(source_tokens) > max_source_length:
 max_source_length = len(source_tokens)
 for token in source_tokens:
 source_vocab.add_token(token)
 
 target_tokens = row["target_language"].split(" ")
 if len(target_tokens) > max_target_length:
 max_target_length = len(target_tokens)
 for token in target_tokens:
 target_vocab.add_token(token)
 
 return cls(source_vocab, target_vocab, max_source_length, max_target_length)

 @classmethod
 def from_serializable(cls, contents):
 source_vocab = SequenceVocabulary.from_serializable(contents["source_vocab"])
 target_vocab = SequenceVocabulary.from_serializable(contents["target_vocab"])
 
 return cls(source_vocab=source_vocab, 
 target_vocab=target_vocab, 
 max_source_length=contents["max_source_length"], 
 max_target_length=contents["max_target_length"])

 def to_serializable(self):
 return {"source_vocab": self.source_vocab.to_serializable(), 
 "target_vocab": self.target_vocab.to_serializable(), 
 "max_source_length": self.max_source_length,
 "max_target_length": self.max_target_length}
 




class NMTDataset(Dataset):
 def __init__(self, text_df, vectorizer):
 """
 Args:
 surname_df (pandas.DataFrame): the dataset
 vectorizer (SurnameVectorizer): vectorizer instatiated from dataset
 """
 self.text_df = text_df
 self._vectorizer = vectorizer

 self.train_df = self.text_df[self.text_df.split=='train']
 self.train_size = len(self.train_df)

 self.val_df = self.text_df[self.text_df.split=='val']
 self.validation_size = len(self.val_df)

 self.test_df = self.text_df[self.text_df.split=='test']
 self.test_size = len(self.test_df)

 self._lookup_dict = {'train': (self.train_df, self.train_size),
 'val': (self.val_df, self.validation_size),
 'test': (self.test_df, self.test_size)}

 self.set_split('train')

 @classmethod
 def load_dataset_and_make_vectorizer(cls, dataset_csv):
 """Load dataset and make a new vectorizer from scratch
 
 Args:
 surname_csv (str): location of the dataset
 Returns:
 an instance of SurnameDataset
 """
 text_df = pd.read_csv(dataset_csv).fillna(' ')
 train_subset = text_df[text_df.split=='train']
 return cls(text_df, NMTVectorizer.from_dataframe(train_subset))

 @classmethod
 def load_dataset_and_load_vectorizer(cls, dataset_csv, vectorizer_filepath):
 """Load dataset and the corresponding vectorizer. 
 Used in the case in the vectorizer has been cached for re-use
 
 Args:
 surname_csv (str): location of the dataset
 vectorizer_filepath (str): location of the saved vectorizer
 Returns:
 an instance of SurnameDataset
 """
 text_df = pd.read_csv(dataset_csv).fillna(' ')
 vectorizer = cls.load_vectorizer_only(vectorizer_filepath)
 return cls(text_df, vectorizer)

 @staticmethod
 def load_vectorizer_only(vectorizer_filepath):
 """a static method for loading the vectorizer from file
 
 Args:
 vectorizer_filepath (str): the location of the serialized vectorizer
 Returns:
 an instance of SurnameVectorizer
 """
 with open(vectorizer_filepath) as fp:
 return NMTVectorizer.from_serializable(json.load(fp))

 def save_vectorizer(self, vectorizer_filepath):
 """saves the vectorizer to disk using json
 
 Args:
 vectorizer_filepath (str): the location to save the vectorizer
 """
 with open(vectorizer_filepath, "w") as fp:
 json.dump(self._vectorizer.to_serializable(), fp)

 def get_vectorizer(self):
 """ returns the vectorizer """
 return self._vectorizer

 def set_split(self, split="train"):
 self._target_split = split
 self._target_df, self._target_size = self._lookup_dict[split]

 def __len__(self):
 return self._target_size

 def __getitem__(self, index):
 """the primary entry point method for PyTorch datasets
 
 Args:
 index (int): the index to the data point 
 Returns:
 a dictionary holding the data point: (x_data, y_target, class_index)
 """
 row = self._target_df.iloc[index]

 vector_dict = self._vectorizer.vectorize(row.source_language, row.target_language)

 return {"x_source": vector_dict["source_vector"], 
 "x_target": vector_dict["target_x_vector"],
 "y_target": vector_dict["target_y_vector"], 
 "x_source_length": vector_dict["source_length"]}
 
 def get_num_batches(self, batch_size):
 """Given a batch size, return the number of batches in the dataset
 
 Args:
 batch_size (int)
 Returns:
 number of batches in the dataset
 """
 return len(self) // batch_size
 



def generate_nmt_batches(dataset, batch_size, shuffle=True, 
 drop_last=True, device="cpu"):
 """A generator function which wraps the PyTorch DataLoader. The NMT Version """
 dataloader = DataLoader(dataset=dataset, batch_size=batch_size,
 shuffle=shuffle, drop_last=drop_last)

 for data_dict in dataloader:
 lengths = data_dict['x_source_length'].numpy()
 # Get the indices according to sorted length
 sorted_length_indices = lengths.argsort()[::-1].tolist()
 
 # Sort the minibatch
 out_data_dict = {}
 for name, tensor in data_dict.items():
 out_data_dict[name] = data_dict[name][sorted_length_indices].to(device)
 yield out_data_dict




class PositionalEncoding(nn.Module):
 def __init__(self, emb_size, drop_out, max_len:int = 200):
 super(PositionalEncoding, self).__init__()
 den = torch.exp(-torch.arange(0, emb_size,2)*math.log(10000)/emb_size)
 pos = torch.arange(0,max_len).reshape(max_len,1)
 pos_embedding = torch.zeros((max_len, emb_size))
 pos_embedding[:,0::2]= torch.sin(pos*den)
 pos_embedding[:,1::2] = torch.cos(pos*den)
 pos_embedding = pos_embedding.unsqueeze(-2)
 self.dropout = nn.Dropout(drop_out)
 self.register_buffer('pos_embedding', pos_embedding)

 def forward(self, token_embedding:Tensor):
 return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0),:])

class TokenEmbedding(nn.Module):
 def __init__(self, vocab_size:int, emb_size):
 super(TokenEmbedding, self).__init__()
 self.embedding = nn.Embedding(vocab_size, emb_size)
 self.emb_size = emb_size

 def forward(self, tokens:Tensor):
 return self.embedding(tokens.long())*math.sqrt(self.emb_size)


class Seq2SeqTransformer(nn.Module):
 def __init__(self, num_encoder_layers,num_decoder_layers, emb_size, nhead,src_vocab_size,tgt_vocab_size, dim_feedforward = 512, dropout = 0.1):
 super(Seq2SeqTransformer,self).__init__()
 self.transformer = Transformer(d_model = emb_size, nhead = nhead, num_encoder_layers = num_encoder_layers, num_decoder_layers = num_decoder_layers, dim_feedforward = dim_feedforward, dropout = dropout, norm_first = True)
 self.generator = nn.Linear(emb_size, tgt_vocab_size)
 self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
 self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
 self.positional_encoding = PositionalEncoding(emb_size, drop_out = dropout)

 def forward(self, src:Tensor, trg:Tensor, src_mask:Tensor, tgt_mask: Tensor, src_padding_mask: Tensor, tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor):
 src_emb = self.positional_encoding(self.src_tok_emb(src))
 tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
 outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
 return self.generator(outs)

 def encode(self, src, src_mask):
 return self.transformer.encoder(self.positional_encoding(self.src_tok_emb(src)),src_mask)

 def decode(self, tgt:Tensor, memory:Tensor, tgt_mask:Tensor):
 return self.transformer.decoder(self.positional_encoding(self.tgt_tok_emb(tgt)), memory, tgt_mask)






def set_seed_everywhere(seed, cuda):
 #seed = self.seed
 #cuda = self.cuda
 np.random.seed(seed)
 torch.manual_seed(seed)
 print(seed)
 if cuda:
 torch.cuda.manual_seed_all(seed)


def generate_square_subsequent_mask(sz):
 mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
 mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
 return mask



def handle_dirs(save_dirs):
 dirpath = save_dir
 if not os.path.exists(dirpath):
 os.makedirs(dirpath)



def create_mask(src, tgt,PAD_IDX):
 src_seq_len = src.shape[0]
 tgt_seq_len = tgt.shape[0]
 
 tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
 src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)
 
 src_padding_mask = (src == PAD_IDX).transpose(0, 1)
 tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
 return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask



def train_epoch(batch_size, device, model, dataset, split_value, optimizer, PAD_IDX, loss_fn):
 BATCH_SIZE = batch_size
 model.train()
 losses = 0
 print(dataset.__len__())
 train_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE)
 #print(BATCH_SIZE,len(list(train_dataloader)))
 dataset.set_split(split_value)
 batch_generator = generate_nmt_batches(dataset, batch_size=BATCH_SIZE, device = device)
 print("printing batch generator",batch_generator)
 ctr = 0
 for batch_index, batch_dict in enumerate(batch_generator):
 ctr = ctr+1
 #optimizer.zero_grad()
 #print(torch.cat((torch.transpose(batch_dict['x_source'],0,1),torch.transpose(batch_dict['x_target'],0,1),torch.transpose(batch_dict['y_target'],0,1)),1).numpy().shape)
 #print(torch.transpose(batch_dict['x_target'],0,1))
 #print(torch.transpose(batch_dict['y_target'],0,1))
 src=torch.transpose(batch_dict['x_source'],0,1)
 tgt=torch.transpose(batch_dict['y_target'],0,1)
 tgt_input = tgt[:-1,:]
 src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src,tgt_input, PAD_IDX)
 logits = model(src,tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)
 optimizer.zero_grad()
 tgt_out = tgt[1:,:]
 loss = loss_fn(logits.reshape(-1, logits.shape[-1]),tgt_out.reshape(-1))
 loss.backward()
 optimizer.step()
 losses += loss.item()
 if ctr%50==0:
 #print('source_shape',src.shape, 'target_shape',tgt.shape)
 print("ctr: ",ctr," losses: ",losses/ctr,'time',datetime.datetime.now())#," len_train_dataloader: ",len(list(train_dataloader)))
 return losses/len(list(train_dataloader))


def evaluate(batch_size,device,model, dataset,split_value,PAD_IDX,loss_fn):
 model.eval()
 losses = 0
 dataset.set_split(split_value)
 val_dataloader=DataLoader(dataset, batch_size=batch_size)
 batch_generator=generate_nmt_batches(dataset, batch_size=batch_size, device=device)
 ctr = 0
 for batch_index, batch_dict in enumerate(batch_generator):
 src = torch.transpose(batch_dict['x_source'],0,1)
 tgt = torch.transpose(batch_dict['y_target'],0,1)
 tgt_input = tgt[:-1,:]
 src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src,tgt_input, PAD_IDX)
 logits = model(src,tgt_input,src_mask,tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)
 tgt_out=tgt[1:,:]
 loss = loss_fn(logits.reshape(-1, logits.shape[-1]),tgt_out.reshape(-1))#loss_fn(logits.reshape[-1],tgt_out.reshape[-1])
 losses += loss.item()
 ctr = ctr+1
 print(ctr,"validation",losses/ctr)

 """for src, tgt in val_dataloader:
 src = src.to(DEVICE)
 tgt = tgt.to(DEVICE)

 tgt_input = tgt[:-1, :]

 src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

 logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

 tgt_out = tgt[1:, :]
 loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
 losses += loss.item()"""
 return losses/len(list(val_dataloader))



def greedy_decode(DEVICE, model, src, src_mask, max_len, start_symbol, EOS_IDX):
 src = src.to(DEVICE)
 src_mask=src_mask.to(DEVICE)
 memory = model.encode(src, src_mask)
 ys = torch.ones(1,1).fill_(start_symbol).type(torch.long).to(DEVICE)
 for i in range(max_len):
 #print(i,'ys',ys)
 memory = memory.to(DEVICE)
 tgt_mask = (generate_square_subsequent_mask(ys.size(0)).type(torch.bool)).to(DEVICE)
 #print('tgt_mask',tgt_mask)
 out = model.decode(ys,memory, tgt_mask)#.squeeze()
 #print("out",out,'out_shape',out.shape)
 out = out.transpose(0,1)
 #print("out transpose",out,'out_transpose_shape',out.shape)
 prob = model.generator(out)[:,-1]
 _, next_word = torch.max(prob, dim=1)
 next_word = next_word.item()
 #print('next_word = ',next_word)
 ys = torch.cat([ys, torch.ones(1,1).type_as(src.data).fill_(next_word)], dim = 0)
 #print('ys',ys)
 if next_word == EOS_IDX:
 break
 return ys



def translate( device,model:torch.nn.Module, src_sentence:str, BOS_IDX, EOS_IDX):
 model.eval()
 src= src_sentence
 #print('src',src)
 num_tokens = src.shape[0]
 #print(num_tokens)
 src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
 #print('src_mask',src_mask)
 tgt_tokens = greedy_decode(device,model, src, src_mask, max_len = num_tokens, start_symbol=BOS_IDX, EOS_IDX=EOS_IDX).flatten()
 return tgt_tokens
















input_df = 'dataset_for_APE_hinglish_to_english2.csv'
fpath = "nmt_IITB_APE2"


#dataset = NMTDataset.load_dataset_and_make_vectorizer('IITB_dataset_1.csv')
#dataset.save_vectorizer("vectorizer_transformer_3layer_IITB1mill.json")



#dataloader = DataLoader(dataset=dataset, batch_size=1024,shuffle=False, drop_last=True)

dataset_csv = 'dataset_for_APE_hinglish_to_english2.csv'
vectorizer_file = 'vectorizer_APE_2.json'
print(vectorizer_file)
model_state_file = 'APE_2.pth'
save_dir = "nmt_DG2_FFNN8192"#'GenV1_Transforemer_1',
print(save_dir)
reload_from_files = True
cuda = False
seed = 13
learning_rate = 8e-3
batch_size = 1024
batch_size_val = 1
num_epochs = 40
source_embedding_size = 256
target_embedding_size = 256
encoding_size = 256
use_glove = False
expand_filepaths_to_save_dir = True
early_stopping_criteria = 10
dataset_to_evaluate = 'dataset_for_APE_hinglish_to_english2.csv'
path_to_save = 'APE_1_new.csv'
saved_model_path = 'APE_1_new.pt'
file_exist = 0
existing_file_name = 'dataset_for_APE_hinglish_to_english2.csv'


dataset_path = fpath
existing_file_name = input_df
fname = existing_file_name
dataset_csv = fname






model_state_file = model_state_file
save_dir = save_dir
print(save_dir)
reload_from_files = reload_from_files
expand_filepaths_to_save_dir = expand_filepaths_to_save_dir
cuda = cuda
seed = seed
learning_rate = learning_rate
batch_size = batch_size
batch_size_val = batch_size_val
num_epochs = num_epochs
early_stopping_criteria = True#self.early_stopping_criteria
source_embedding_size = source_embedding_size
target_embedding_size = target_embedding_size
encoding_size = encoding_size
use_glove = False
catch_keyboard_interrupt = True
if expand_filepaths_to_save_dir:
 vectorizer_file = os.path.join(save_dir, vectorizer_file)
model_state_file = os.path.join(save_dir, model_state_file)
if not torch.cuda.is_available():
 cuda = False
device = torch.device("cuda" if cuda else "cpu")
set_seed_everywhere(seed,cuda)
handle_dirs(save_dir)
if reload_from_files and os.path.exists(vectorizer_file):
 dataset = NMTDataset.load_dataset_and_load_vectorizer(dataset_csv, vectorizer_file)
 print('load_dataset_and_load_vectorizer______')
else:
 dataset = NMTDataset.load_dataset_and_make_vectorizer(dataset_csv)
 dataset.save_vectorizer(vectorizer_file)
 print('_________load_dataset_and_make_vectorizer______')
vectorizer = dataset.get_vectorizer()
PAD_IDX = vectorizer.to_serializable()['target_vocab']['token_to_idx']['']
BOS_IDX = vectorizer.to_serializable()['target_vocab']['token_to_idx']['']
EOS_IDX = vectorizer.to_serializable()['target_vocab']['token_to_idx']['']
SRC_VOCAB_SIZE = len(vectorizer.to_serializable()['source_vocab']['token_to_idx'])
TGT_VOCAB_SiZE = len(vectorizer.to_serializable()['target_vocab']['token_to_idx'])
print('target vocab size',TGT_VOCAB_SiZE)
print('dataset_size 1: ', dataset.__len__(), dataset_path, dataset_csv)
print(' dataset csv length',len(pd.read_csv(dataset_csv)))
EMB_SIZE = 256
NHEAD = 16
FFN_HID_DIM =8192
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
batch_size = BATCH_SIZE
transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE, NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SiZE, FFN_HID_DIM)
transformer = transformer.to(DEVICE)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.004, betas = (0.99, 0.99), eps = 1e-9)
from timeit import default_timer as timer
NUM_EPOCHS = num_epochs
for epoch in range(1, NUM_EPOCHS+1):
 print("==================Training started==================",epoch)
 start_time = timer()
 split_value_train = 'train'
 split_value_validate = 'val'
 train_loss = train_epoch(batch_size,device,transformer, dataset, split_value_train, optimizer, PAD_IDX, loss_fn)
 end_time = timer()
 torch.save(transformer,'epoch'+str(epoch)+'_APE_2_new.pt')
#torch.save(transformer, save_dir+"/"+saved_model_path+"_epoch")
 #val_loss = evaluate(batch_size,device,transformer, dataset, split_value_validate, PAD_IDX, loss_fn)
torch.save(transformer, save_dir+"/"+saved_model_path)
