edaiofficial commited on
Commit
b13cebd
1 Parent(s): 6c20b5b

Upload mmtafrica.py

Browse files
Files changed (1) hide show
  1. mmtafrica.py +961 -0
mmtafrica.py ADDED
@@ -0,0 +1,961 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from locale import strcoll
2
+ from datasets import load_dataset
3
+ import numpy as np
4
+ import torch
5
+ from torch import optim
6
+ from torch.nn import functional as F
7
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
8
+ from transformers.optimization import Adafactor
9
+ from transformers import get_linear_schedule_with_warmup
10
+ from tqdm.notebook import tqdm
11
+ import random
12
+ import sacrebleu
13
+ import os
14
+ import pandas as pd
15
+ from sklearn.model_selection import train_test_split
16
+ import torch.multiprocessing as mp
17
+ from torch.multiprocessing import Process, Queue
18
+ from joblib import Parallel, delayed,parallel_backend
19
+ import sys
20
+ from functools import partial
21
+ import json
22
+ import time
23
+ import numpy as np
24
+ from datetime import datetime
25
+
26
+
27
+ class Config():
28
+ def __init__(self,args) -> None:
29
+
30
+ self.homepath = args.homepath
31
+ self.prediction_path = os.path.join(args.homepath,args.prediction_path)
32
+ # Use 'google/mt5-small' for non-pro cloab users
33
+ self.model_repo = 'google/mt5-base'
34
+ self.model_path_dir = args.homepath
35
+ self.model_name = f'{args.model_name}.pt'
36
+ self.bt_data_dir = os.path.join(args.homepath,args.bt_data_dir)
37
+
38
+ #Data part
39
+ self.parallel_dir= os.path.join(args.homepath,args.parallel_dir)
40
+ self.mono_dir= os.path.join(args.homepath,args.mono_dir)
41
+
42
+ self.log = os.path.join(args.homepath,args.log)
43
+ self.mono_data_limit = args.mono_data_limit
44
+ self.mono_data_for_noise_limit=args.mono_data_for_noise_limit
45
+ #Training params
46
+ self.n_epochs = args.n_epochs
47
+ self.n_bt_epochs=args.n_bt_epochs
48
+
49
+ self.batch_size = args.batch_size
50
+ self.max_seq_len = args.max_seq_len
51
+ self.min_seq_len = args.min_seq_len
52
+ self.checkpoint_freq = args.checkpoint_freq
53
+ self.lr = 1e-4
54
+ self.print_freq = args.print_freq
55
+ self.use_multiprocessing = args.use_multiprocessing
56
+
57
+ self.num_cores = mp.cpu_count()
58
+ self.NUM_PRETRAIN = args.num_pretrain_steps
59
+ self.NUM_BACKTRANSLATION_TIMES =args.num_backtranslation_steps
60
+ self.do_backtranslation=args.do_backtranslation
61
+ self.now_on_bt=False
62
+ self.bt_time=0
63
+ self.using_reconstruction= args.use_reconstruction
64
+ self.num_return_sequences_bt=2
65
+ self.use_torch_data_parallel = args.use_torch_data_parallel
66
+
67
+ self.gradient_accumulation_batch = args.gradient_accumulation_batch
68
+ self.num_beams = args.num_beams
69
+
70
+ self.best_loss = 1000
71
+ self.best_loss_delta = 0.00000001
72
+ self.patience=args.patience
73
+ self.L2=0.0000001
74
+ self.dropout=args.dropout
75
+
76
+ self.drop_prob=args.drop_probability
77
+ self.num_swaps=args.num_swaps
78
+
79
+ self.verbose=args.verbose
80
+
81
+ self.now_on_test=False
82
+
83
+ #Initialization of state dict which will be saved during training
84
+ self.state_dict = {'batch_idx': 0,'epoch':0,'bt_time':self.bt_time,'best_loss':self.best_loss}
85
+ self.state_dict_check = {'batch_idx': 0,'epoch':0,'bt_time':self.bt_time,'best_loss':self.best_loss} #this is for tracing training after abrupt end!
86
+
87
+
88
+
89
+ self.device = torch.device('cuda' if True and torch.cuda.is_available() else 'cpu')
90
+
91
+ #We will be leveraging parallel and monolingual data for each of these languages.
92
+ #parallel data will be saved in a central 'parallel_data 'folder as 'src'_'tg'_parallel.tsv
93
+ #monolingual data will be saved in another folder called 'monolingual_data' as 'lg'_mono.tsv
94
+
95
+ #Each tsv file is of the form "input", "output"
96
+ self.LANG_TOKEN_MAPPING = {
97
+ 'ig': '<ig>',
98
+ 'fon': '<fon>',
99
+ 'en': '<en>',
100
+ 'fr': '<fr>',
101
+ 'rw':'<rw>',
102
+ 'yo':'<yo>',
103
+ 'xh':'<xh>',
104
+ 'sw':'<sw>'
105
+ }
106
+
107
+
108
+ self.truncation=True
109
+
110
+
111
+
112
+
113
+ def beautify_time(time):
114
+ hr = time//(3600)
115
+ mins = (time-(hr*3600))//60
116
+ rest = time -(hr*3600) - (mins*60)
117
+ #DARIA's implementation!
118
+ sp = ""
119
+ if hr >=1:
120
+ sp += '{} hours'.format(hr)
121
+ if mins >=1:
122
+ sp += ' {} mins'.format(mins)
123
+ if rest >=1:
124
+ sp += ' {} seconds'.format(rest)
125
+ return sp
126
+
127
+
128
+
129
+ def word_delete(x,config):
130
+ noise=[]
131
+ words = x.split(' ')
132
+ if len(words) == 1:
133
+ return x
134
+ for w in words:
135
+ a= np.random.choice([0,1], 1, p=[config.drop_prob, 1-config.drop_prob])
136
+ if a[0]==1: #It means don't delete
137
+ noise.append(w)
138
+ #if you end up deleting all words, just return a random word
139
+ if len(noise) == 0:
140
+ rand_int = random.randint(0, len(words)-1)
141
+ return [words[rand_int]]
142
+
143
+ return ' '.join(noise)
144
+
145
+ def swap_word(new_words):
146
+
147
+ random_idx_1 = random.randint(0, len(new_words)-1)
148
+ random_idx_2 = random_idx_1
149
+ counter = 0
150
+
151
+ while random_idx_2 == random_idx_1:
152
+ random_idx_2 = random.randint(0, len(new_words)-1)
153
+ counter += 1
154
+
155
+ if counter > 3:
156
+ return new_words
157
+
158
+ new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1]
159
+ return new_words
160
+
161
+ def random_swap(words, n):
162
+
163
+ words = words.split()
164
+ new_words = words.copy()
165
+
166
+ for _ in range(n):
167
+ new_words = swap_word(new_words)
168
+
169
+ sentence = ' '.join(new_words)
170
+
171
+ return sentence
172
+
173
+
174
+
175
+ def get_dict(input,target,src,tgt):
176
+ inp = [i for i in input]
177
+ target_ = [ i for i in target]
178
+ s= [src for i in range(len(inp))]
179
+ t = [tgt for i in range(len(target_))]
180
+ return [{'inputs':inp_,'targets':target__,'src':s_,'tgt':t_} for inp_,target__,s_,t_ in zip(inp,target_,s,t)]
181
+
182
+ def get_dict_mono(input,src,config):
183
+ index = [i for i in range(len(input))]
184
+ ids = random.sample(index,config.mono_data_limit)
185
+ inp = [input[i] for i in ids]
186
+ s= [src for i in range(len(inp))]
187
+ data=[]
188
+ for lang in config.LANG_TOKEN_MAPPING.keys():
189
+ if lang!=src and lang not in ['en','fr']:
190
+ data.extend([{'inputs':inp_,'src':s_,'tgt':lang} for inp_,s_ in zip(inp,s)])
191
+ return data
192
+
193
+ def get_dict_mono_noise(input,src,config):
194
+ index = [i for i in range(len(input))]
195
+ ids = random.sample(index,config.mono_data_for_noise_limit)
196
+ inp = [input[i] for i in ids]
197
+ noised = [word_delete(random_swap(str(x),config.num_swaps),config) for x in inp]
198
+ s= [src for i in range(len(inp))]
199
+ data=[]
200
+ data.extend([{'inputs':noise_,'targets':inp_,'src':s_,'tgt':s_} for inp_,s_,noise_ in zip(inp,s,noised)])
201
+ return data
202
+
203
+
204
+ def compress(input,target,src,tgt):
205
+ return {'inputs':input,'targets':target,'src':src,'tgt':tgt}
206
+
207
+
208
+ def make_dataset(config,mode):
209
+ if mode!='eval' and mode!='train' and mode!='test':
210
+ raise Exception('mode is either train or eval or test!')
211
+ else:
212
+
213
+ files = [f.name for f in os.scandir(config.parallel_dir) ]
214
+ files = [f for f in files if f.split('.')[-1]=='tsv' and f.split('.tsv')[0].endswith(mode) and len(f.split('_'))>2 ]
215
+ data = [(f_.split('_')[0],f_.split('_')[1],pd.read_csv(os.path.join(config.parallel_dir,f_), sep="\t")) for f_ in files]
216
+ dict_ = [get_dict(df['input'],df['target'],src,tgt) for src,tgt,df in data]
217
+ return [item for sublist in dict_ for item in sublist]
218
+
219
+
220
+
221
+ def get_model_translation(config,model,tokenizer,sentence,tgt):
222
+ if config.use_torch_data_parallel:
223
+ max_seq_len_ = model.module.config.max_length
224
+ else:
225
+ max_seq_len_ = model.config.max_length
226
+ input_ids = encode_input_str(config,text = sentence,target_lang = tgt,tokenizer = tokenizer,seq_len = max_seq_len_).unsqueeze(0).to(config.device)
227
+ if config.use_torch_data_parallel:
228
+ out = model.module.generate(input_ids,num_beams=3,do_sample=True, num_return_sequences=config.num_return_sequences_bt,max_length=config.max_seq_len,min_length=config.min_seq_len)
229
+ else:
230
+ out = model.generate(input_ids,num_beams=3, do_sample=True,num_return_sequences=config.num_return_sequences_bt,max_length=config.max_seq_len,min_length=config.min_seq_len)
231
+
232
+ out_id = [i for i in range(config.num_return_sequences_bt)]
233
+ id_ = random.sample(out_id,1)
234
+
235
+ return tokenizer.decode(out[id_][0], skip_special_tokens=True)
236
+
237
+
238
+ def do_job(t,id_,tokenizers):
239
+ tokenizer = tokenizers[id_ % len(tokenizers)]
240
+ #We flip the input as target and vice versa in order to have target-side backtranslation (where source side is synthetic).
241
+ return {'inputs':get_model_translation(config,model,tokenizer,t['inputs'],t['tgt']),'targets':t['inputs'],'src':t['tgt'],'tgt':t['src']}
242
+ #return {'inputs':t['inputs'],'targets':get_model_translation(config,model,tokenizer,t['inputs'],t['tgt']),'src':t['src'],'tgt':t['tgt']}
243
+
244
+
245
+ def do_job_pmap(t):
246
+ #tokenizer = tokenizers[id_ % len(tokenizers)]
247
+ return {'inputs':t['inputs'],'targets':get_model_translation(config,model,tokenizer,t['inputs'],t['tgt']),'src':t['src'],'tgt':t['tgt']}
248
+
249
+ def do_job_pool(bt_data,model,id_,tokenizers,config,mono_data):
250
+ tokenizer = tokenizers[id_]
251
+ if config.verbose:
252
+ print(f"Mono data inside job pool: {mono_data}")
253
+ sys.stdout.flush()
254
+ res = [{'inputs':t['inputs'],'targets':get_model_translation(config,model,tokenizer,t['inputs'],t['tgt']),'src':t['src'],'tgt':t['tgt']} for t in mono_data]
255
+ bt_data.put(res)
256
+ return None
257
+
258
+ def mono_data_(config):
259
+ #Find and prepare all the mono data in the directory
260
+ files_ = [f.name for f in os.scandir(config.mono_dir) ]
261
+ files = [f for f in files_ if f.endswith('tsv') and f.split('.tsv')[0].endswith('mono')]
262
+ if config.verbose:
263
+ print("Generating data for back translation")
264
+ print(f"Files found in mono dir: {files}")
265
+ data = [(f_.split('_')[0],pd.read_csv(os.path.join(config.mono_dir,f_), sep="\t")) for f_ in files]
266
+ dict_ = [get_dict_mono(df['input'],src,config) for src,df in data]
267
+ mono_data = [item for sublist in dict_ for item in sublist]
268
+ return mono_data
269
+
270
+ def mono_data_noise(config):
271
+ #Find and prepare all the mono data in the directory
272
+ files_ = [f.name for f in os.scandir(config.mono_dir) ]
273
+ files = [f for f in files_ if f.endswith('tsv') and f.split('.tsv')[0].endswith('mono')]
274
+ if config.verbose:
275
+ print("Generating data for back translation")
276
+ print(f"Files found in mono dir: {files}")
277
+ data = [(f_.split('_')[0],pd.read_csv(os.path.join(config.mono_dir,f_), sep="\t")) for f_ in files]
278
+ dict_ = [get_dict_mono_noise(df['input'],src,config) for src,df in data]
279
+ mono_data = [item for sublist in dict_ for item in sublist]
280
+ return mono_data
281
+
282
+
283
+
284
+ def get_mono_data(config,model):
285
+ mono_data = mono_data_(config)
286
+
287
+ if config.use_multiprocessing:
288
+ if config.verbose:
289
+ print(f"Using multiprocessing on {config.num_cores} processes")
290
+ if __name__ == "__main__":
291
+ ctx = mp.get_context('spawn')
292
+ #mp.set_start_method("spawn",force=True)
293
+ bt_data = ctx.Queue()
294
+ model.share_memory()
295
+ num_processes = config.num_cores
296
+ NUM_TO_USE = len(mono_data)//num_processes
297
+ mini_mono_data = [mono_data[i:i + NUM_TO_USE] for i in range(0, len(mono_data), NUM_TO_USE)]
298
+ #print(f"Length of mini mono data {len(mini_mono_data)}. Length of processes: {num_processes}")
299
+ assert len(mini_mono_data) == num_processes, "Length of mini mono data and number of processes do not match."
300
+
301
+ num_processes_range = [i for i in range(num_processes)]
302
+ processes = []
303
+ for rank,data_ in tqdm(zip(num_processes_range,mini_mono_data)):
304
+ p = ctx.Process(target=do_job_pool, args=(bt_data,model,rank,tokenizers_for_parallel,config,data_))
305
+ p.start()
306
+ if config.verbose:
307
+ print(f"Bt data: {bt_data.get()}")
308
+ sys.stdout.flush()
309
+ processes.append(p)
310
+
311
+ for p in processes:
312
+ p.join()
313
+
314
+ return bt_data
315
+
316
+
317
+
318
+ #output = multiprocessing.Queue()
319
+ #multiprocessing.set_start_method("spawn",force=True)
320
+ #pool = mp.Pool(processes=config.num_cores)
321
+ #bt_data = [pool.apply(do_job, args=(data_,i,tokenizers_for_parallel,)) for i,data_ in enumerate(mono_data)]
322
+
323
+ '''
324
+ # Setup a list of processes that we want to run
325
+ processes = [mp.Process(target=do_job, args=(5, output)) for x in range(config.num_cores)]
326
+ if __name__ == "__main__":
327
+ #pool = mp.Pool(processes=config.num_cores)
328
+ with parallel_backend('loky'):
329
+ bt_data = Parallel(n_jobs = config.num_cores, require='sharedmem')(delayed(do_job)(data_,i,tokenizers_for_parallel) for i,data_ in enumerate(mono_data))
330
+ '''
331
+ else:
332
+ bt_data = [{'inputs':t['inputs'],'targets':get_model_translation(config,model,tokenizer,t['inputs'],t['tgt']),'src':t['src'],'tgt':t['tgt']} for t in tqdm(mono_data)]
333
+ return bt_data
334
+
335
+
336
+
337
+ def encode_input_str(config,text, target_lang, tokenizer, seq_len):
338
+
339
+ target_lang_token = config.LANG_TOKEN_MAPPING[target_lang]
340
+
341
+ # Tokenize and add special tokens
342
+ input_ids = tokenizer.encode(
343
+ text = str(target_lang_token) + str(text),
344
+ return_tensors = 'pt',
345
+ padding = 'max_length',
346
+ truncation = config.truncation,
347
+ max_length = seq_len)
348
+
349
+ return input_ids[0]
350
+
351
+ def encode_target_str(config,text, tokenizer, seq_len):
352
+ token_ids = tokenizer.encode(
353
+ text = str(text),
354
+ return_tensors = 'pt',
355
+ padding = 'max_length',
356
+ truncation = config.truncation,
357
+ max_length = seq_len)
358
+
359
+ return token_ids[0]
360
+
361
+ def format_translation_data(config,sample,tokenizer,seq_len):
362
+
363
+ # sample is of the form {'inputs':input,'targets':target,'src':src,'tgt':tgt}
364
+
365
+ # Get the translations for the batch
366
+
367
+ input_lang = sample['src']
368
+ target_lang = sample['tgt']
369
+
370
+
371
+ input_text = sample['inputs']
372
+ target_text = sample['targets']
373
+
374
+ if input_text is None or target_text is None:
375
+ return None
376
+
377
+ input_token_ids = encode_input_str(config,input_text, target_lang, tokenizer, seq_len)
378
+
379
+ target_token_ids = encode_target_str(config,target_text, tokenizer, seq_len)
380
+
381
+ return input_token_ids, target_token_ids
382
+
383
+ def transform_batch(config,batch,tokenizer,max_seq_len):
384
+ inputs = []
385
+ targets = []
386
+ for sample in batch:
387
+ formatted_data = format_translation_data(config,sample,tokenizer,max_seq_len)
388
+
389
+ if formatted_data is None:
390
+ continue
391
+
392
+ input_ids, target_ids = formatted_data
393
+ inputs.append(input_ids.unsqueeze(0))
394
+ targets.append(target_ids.unsqueeze(0))
395
+
396
+ batch_input_ids = torch.cat(inputs)
397
+ batch_target_ids = torch.cat(targets)
398
+
399
+ return batch_input_ids, batch_target_ids
400
+
401
+ def get_data_generator(config,dataset,tokenizer,max_seq_len,batch_size):
402
+ random.shuffle(dataset)
403
+
404
+ for i in range(0, len(dataset), batch_size):
405
+ raw_batch = dataset[i:i+batch_size]
406
+ yield transform_batch(config,raw_batch, tokenizer,max_seq_len)
407
+
408
+ def eval_model(config,tokenizer,model, gdataset, max_iters=8):
409
+ test_generator = get_data_generator(config,gdataset,tokenizer,config.max_seq_len, config.batch_size)
410
+ eval_losses = []
411
+ for i, (input_batch, label_batch) in enumerate(test_generator):
412
+
413
+ input_batch, label_batch = input_batch.to(config.device), label_batch.to(config.device)
414
+ model_out = model.forward(
415
+ input_ids = input_batch,
416
+ labels = label_batch)
417
+
418
+ if config.use_torch_data_parallel:
419
+ loss = torch.mean(model_out.loss)
420
+ else:
421
+ loss = model_out.loss
422
+
423
+ eval_losses.append(loss.item())
424
+
425
+ return np.mean(eval_losses)
426
+
427
+
428
+
429
+ def evaluate(config,tokenizer,model,test_dataset,src_lang=None,tgt_lang=None):
430
+ if src_lang!=None and tgt_lang!=None:
431
+ if config.verbose:
432
+ with open(config.log,'a+') as fl:
433
+ print(f"Getting evaluation set for source language -> {src_lang} and target language -> {tgt_lang}",file=fl)
434
+ data = [t for t in test_dataset if t['src']==src_lang and t['tgt']==tgt_lang]
435
+
436
+ else:
437
+ data= [t for t in test_dataset]
438
+
439
+ inp = [t['inputs'] for t in data]
440
+ truth = [t['targets'] for t in data]
441
+ tgt_lang_ = [t['tgt'] for t in data]
442
+
443
+ seq_len__ = config.max_seq_len
444
+
445
+ input_tokens = [encode_input_str(config,text = inp[i],target_lang = tgt_lang_[i],tokenizer = tokenizer,seq_len =seq_len__).unsqueeze(0).to(config.device) for i in range(len(inp))]
446
+
447
+ if config.use_torch_data_parallel:
448
+ output = [model.module.generate(input_ids, num_beams=config.num_beams, num_return_sequences=1,max_length=config.max_seq_len,min_length=config.min_seq_len) for input_ids in tqdm(input_tokens)]
449
+ else:
450
+ output = [model.generate(input_ids, num_beams=config.num_beams, num_return_sequences=1,max_length=config.max_seq_len,min_length=config.min_seq_len) for input_ids in tqdm(input_tokens)]
451
+ output = [tokenizer.decode(out[0], skip_special_tokens=True) for out in tqdm(output)]
452
+
453
+ df= pd.DataFrame({'predictions':output,'truth':truth,'inputs':inp})
454
+ if config.now_on_bt and config.using_reconstruction:
455
+ filename = f'{src_lang}_{tgt_lang}_bt_{config.bt_time}_rec.tsv'
456
+ elif config.now_on_bt:
457
+ filename = f'{src_lang}_{tgt_lang}_bt_{config.bt_time}.tsv'
458
+ elif config.now_on_test:
459
+ filename = f'{src_lang}_{tgt_lang}_TEST.tsv'
460
+ else:
461
+ filename = f'{src_lang}_{tgt_lang}.tsv'
462
+ df.to_csv(os.path.join(config.prediction_path,filename),sep='\t',index=False)
463
+ try:
464
+ spbleu = sacrebleu.corpus_bleu(output, [truth])
465
+ except Exception:
466
+ raise Exception(f'There is a problem with {src_lang}_{tgt_lang}. Truth is {truth} \n Input is {inp} ')
467
+
468
+
469
+
470
+ return spbleu.score
471
+
472
+
473
+ def do_evaluation(config,tokenizer,model,test_dataset):
474
+ LANGS = list(config.LANG_TOKEN_MAPPING.keys())
475
+ if config.now_on_bt and config.using_reconstruction:
476
+ s=f'---------------------------AFTER BACKTRANSLATION {config.bt_time} with RECONSTRUCTION---------------------------'+'\n'
477
+ elif config.now_on_bt:
478
+ s=f'---------------------------AFTER BACKTRANSLATION {config.bt_time}---------------------------'+'\n'
479
+ elif config.now_on_test:
480
+ s=f'---------------------------TESTING EVALUATION---------------------------'+'\n'
481
+ else:
482
+ s=f'---------------------------EVALUATION ON DEV---------------------------'+'\n'
483
+ for i in range(len(LANGS)):
484
+ for j in range(len(LANGS)):
485
+ if LANGS[j]!=LANGS[i]:
486
+ eval_bleu = evaluate(config,tokenizer,model,test_dataset,src_lang=LANGS[i],tgt_lang=LANGS[j])
487
+ a = f'Bleu Score for {LANGS[i]} to {LANGS[j]} -> {eval_bleu} '+'\n'
488
+ s+=a
489
+
490
+
491
+ s+='------------------------------------------------------'
492
+ with open(os.path.join(config.homepath,'bleu_log.txt'), 'a+') as fl:
493
+ print(s,file=fl)
494
+
495
+
496
+ def train(config,n_epochs,optimizer,tokenizer,train_dataset,dev_dataset,n_batches,model,save_with_bt=False):
497
+ patience=0
498
+ losses = []
499
+ for epoch_idx in range(n_epochs):
500
+ if epoch_idx>=config.state_dict_check['epoch']+1:
501
+ st_time = time.time()
502
+ avg_loss=0
503
+ # Randomize data order
504
+ data_generator = get_data_generator(config,train_dataset,tokenizer,config.max_seq_len, config.batch_size)
505
+ optimizer.zero_grad()
506
+ for batch_idx, (input_batch, label_batch) in tqdm(enumerate(data_generator), total=n_batches):
507
+ if batch_idx >= config.state_dict_check['batch_idx']:
508
+
509
+ input_batch,label_batch = input_batch.to(config.device),label_batch.to(config.device)
510
+ # Forward pass
511
+ model_out = model.forward(input_ids = input_batch, labels = label_batch)
512
+
513
+ # Calculate loss and update weights
514
+ if config.use_torch_data_parallel:
515
+ loss = torch.mean(model_out.loss)
516
+ else:
517
+ loss = model_out.loss
518
+
519
+ losses.append(loss.item())
520
+ loss.backward()
521
+
522
+ #Gradient accumulation
523
+ if (batch_idx+1) % config.gradient_accumulation_batch == 0:
524
+ optimizer.step()
525
+ optimizer.zero_grad()
526
+ # Print training update info
527
+ if (batch_idx + 1) % config.print_freq == 0:
528
+ avg_loss = np.mean(losses)
529
+ losses=[]
530
+ if config.verbose:
531
+ with open(config.log,'a+') as fl:
532
+ print('Epoch: {} | Step: {} | Avg. loss: {:.3f}'.format(epoch_idx+1, batch_idx+1, avg_loss),file=fl)
533
+
534
+ if (batch_idx + 1) % config.checkpoint_freq == 0:
535
+ test_loss = eval_model(config,tokenizer,model, dev_dataset)
536
+ if config.best_loss-test_loss > config.best_loss_delta:
537
+ config.best_loss = test_loss
538
+ patience=0
539
+ if config.verbose:
540
+ with open(config.log,'a+') as fl:
541
+ print('Saving model with best test loss of {:.3f}'.format(test_loss),file=fl)
542
+
543
+ if save_with_bt:
544
+ model_name = config.model_name.split('.')[0]+'_bt.pt'
545
+ else:
546
+ model_name = config.model_name
547
+
548
+ config.state_dict.update({'batch_idx': batch_idx,'epoch':epoch_idx,'bt_time':config.bt_time-1,'best_loss':config.best_loss})
549
+ if config.use_torch_data_parallel:
550
+ config.state_dict['model_state_dict']=model.module.state_dict()
551
+ torch.save(config.state_dict, os.path.join(config.model_path_dir,model_name))
552
+ else:
553
+ config.state_dict['model_state_dict']=model.state_dict()
554
+ torch.save(config.state_dict, os.path.join(config.model_path_dir,model_name))
555
+ else:
556
+ if config.verbose:
557
+ with open(config.log,'a+') as fl:
558
+ print(f'No improvement in loss {test_loss} over best loss {config.best_loss}. Not saving model checkpoint',file=fl)
559
+ patience+=1
560
+ if patience >= config.patience:
561
+ with open(config.log,'a+') as fl:
562
+ print("Stopping model training due to early stopping",file=fl)
563
+ break
564
+ with open(config.log,'a+') as fl:
565
+ print('Epoch: {} | Step: {} | Avg. loss: {:.3f} | Time taken: {} | Time: {}'.format(epoch_idx+1, batch_idx+1, avg_loss, beautify_time(time.time()-st_time),datetime.now()),file=fl)
566
+
567
+ # Do this after epochs to get status of model at end of training----
568
+ test_loss = eval_model(config,tokenizer,model, dev_dataset)
569
+ if config.best_loss-test_loss > config.best_loss_delta:
570
+ config.best_loss = test_loss
571
+ patience=0
572
+ if config.verbose:
573
+ with open(config.log,'a+') as fl:
574
+ print('Saving model with best test loss of {:.3f}'.format(test_loss),file=fl)
575
+
576
+ if save_with_bt:
577
+ model_name = config.model_name.split('.')[0]+'_bt.pt'
578
+ else:
579
+ model_name = config.model_name
580
+
581
+ config.state_dict.update({'batch_idx': n_batches-1,'epoch':n_epochs-1,'bt_time':config.bt_time-1,'best_loss':config.best_loss})
582
+ if config.use_torch_data_parallel:
583
+ config.state_dict['model_state_dict']=model.module.state_dict()
584
+ torch.save(config.state_dict, os.path.join(config.model_path_dir,model_name))
585
+ else:
586
+ config.state_dict['model_state_dict']=model.state_dict()
587
+ torch.save(config.state_dict, os.path.join(config.model_path_dir,model_name))
588
+ else:
589
+ if config.verbose:
590
+ with open(config.log,'a+') as fl:
591
+ print(f'No improvement in loss {test_loss} over best loss {config.best_loss}. Not saving model checkpoint',file=fl)
592
+ patience+=1
593
+ #---------------------------------------------
594
+
595
+
596
+
597
+ def main(args):
598
+ if not os.path.exists(args.homepath):
599
+ raise Exception(f'HOMEPATH {args.homepath} does not exist!')
600
+ config = Config(args)
601
+ if not os.path.exists(config.prediction_path):
602
+ os.makedirs(config.prediction_path)
603
+ if not os.path.exists(config.bt_data_dir):
604
+ os.makedirs(config.bt_data_dir)
605
+ """# Load Tokenizer & Model"""
606
+
607
+ tokenizer = AutoTokenizer.from_pretrained(config.model_repo)
608
+ if config.use_multiprocessing:
609
+ tokenizers_for_parallel = [AutoTokenizer.from_pretrained(config.model_repo) for i in range(config.num_cores)]
610
+
611
+ model = AutoModelForSeq2SeqLM.from_pretrained(config.model_repo)
612
+
613
+ if not os.path.exists(config.parallel_dir):
614
+ raise Exception(f'Directory `{config.parallel_dir}` cannot be empty! It must contain the parallel files')
615
+
616
+ train_dataset = make_dataset(config,'train')
617
+ with open(config.log,'a+') as fl:
618
+ print(f"Length of train dataset: {len(train_dataset)}",file=fl)
619
+
620
+ dev_dataset = make_dataset(config,'eval')
621
+ with open(config.log,'a+') as fl:
622
+ print(f"Length of dev dataset: {len(dev_dataset)}",file=fl)
623
+
624
+ """## Update tokenizer"""
625
+ special_tokens_dict = {'additional_special_tokens': list(config.LANG_TOKEN_MAPPING.values())}
626
+ tokenizer.add_special_tokens(special_tokens_dict)
627
+ if config.use_multiprocessing:
628
+ for tk in tokenizers_for_parallel:
629
+ tk.add_special_tokens(special_tokens_dict)
630
+ model.resize_token_embeddings(len(tokenizer))
631
+
632
+
633
+ """# Train/Finetune MT5"""
634
+ if os.path.exists(os.path.join(config.model_path_dir,config.model_name)):
635
+ if config.verbose:
636
+ with open(config.log,'a+') as fl:
637
+ print("-----------Using model checkpoint-----------",file=fl)
638
+
639
+ try:
640
+ state_dict = torch.load(os.path.join(config.model_path_dir,config.model_name.split('.')[0]+'_bt.pt'))
641
+ except Exception:
642
+ with open(config.log,'a+') as fl:
643
+ print('No mmt_translation_bt.pt present. Default to original mmt_translation.pt',file=fl)
644
+ state_dict = torch.load(os.path.join(config.model_path_dir,config.model_name))
645
+
646
+
647
+ # Note to self: Make this beter.
648
+ config.state_dict_check['epoch']=state_dict['epoch']
649
+ config.state_dict_check['bt_time']=state_dict['bt_time']
650
+ config.state_dict_check['best_loss']=state_dict['best_loss']
651
+ config.best_loss = config.state_dict_check['best_loss']
652
+ config.state_dict_check['batch_idx']=state_dict['batch_idx']
653
+ model.load_state_dict(state_dict['model_state_dict'])
654
+
655
+ #Temp change
656
+ config.state_dict_check['epoch']=-1
657
+ config.state_dict_check['batch_idx']=0
658
+ config.state_dict_check['bt_time']=-1
659
+
660
+
661
+ #Using DataParallel
662
+ if config.use_torch_data_parallel:
663
+ model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
664
+ model = model.to(config.device)
665
+ #-----
666
+
667
+ # Optimizer
668
+ optimizer = Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=config.lr)
669
+
670
+ #Normal training
671
+ n_batches = int(np.ceil(len(train_dataset) / config.batch_size))
672
+ total_steps = config.n_epochs * n_batches
673
+ n_warmup_steps = int(total_steps * 0.01)
674
+
675
+ #scheduler = get_linear_schedule_with_warmup(optimizer, n_warmup_steps, total_steps)
676
+ #scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=config.lr, max_lr=0.001,cycle_momentum=False)
677
+
678
+ train(config,config.n_epochs,optimizer,tokenizer,train_dataset,dev_dataset,n_batches,model)
679
+ if config.verbose:
680
+ with open(config.log,'a+') as fl:
681
+ print('Evaluaton...',file=fl)
682
+ do_evaluation(config,tokenizer,model,dev_dataset)
683
+ config.state_dict_check['epoch']=-1
684
+ config.state_dict_check['batch_idx']=0
685
+
686
+ if config.do_backtranslation:
687
+ #Backtranslation time
688
+ config.now_on_bt=True
689
+ with open(config.log,'a+') as fl:
690
+ print('---------------Start of Backtranslation---------------',file=fl)
691
+ for n_bt in range(config.NUM_BACKTRANSLATION_TIMES):
692
+ if n_bt>=config.state_dict_check['bt_time']+1:
693
+ with open(config.log,'a+') as fl:
694
+ print(f"Backtranslation {n_bt+1} of {config.NUM_BACKTRANSLATION_TIMES}--------------",file=fl)
695
+ config.bt_time = n_bt+1
696
+ save_bt_file_path = os.path.join(config.bt_data_dir,'bt'+str(n_bt+1)+'.json')
697
+ if not os.path.exists(save_bt_file_path):
698
+ mono_data = mono_data_(config)
699
+ start_time = time.time()
700
+ if config.use_multiprocessing:
701
+ if config.verbose:
702
+ with open(config.log,'a+') as fl:
703
+ print(f"Using multiprocessing on {config.num_cores} processes",file=fl)
704
+ if __name__ == "__main__":
705
+ model.share_memory()
706
+ with parallel_backend('loky'):
707
+ bt_data = Parallel(n_jobs = config.num_cores, require='sharedmem')(delayed(do_job)(data_,i,tokenizers_for_parallel) for i,data_ in tqdm(enumerate(mono_data)))
708
+ else:
709
+ bt_data = [{'inputs':get_model_translation(config,model,tokenizer,t['inputs'],t['tgt']),'targets':t['inputs'],'src':t['tgt'],'tgt':t['src']} for t in tqdm(mono_data)]
710
+ with open(config.log,'a+') as fl:
711
+ print(f'Time taken for backtranslation of data: {beautify_time(time.time()-start_time)}',file=fl)
712
+ with open(save_bt_file_path,'w') as fp:
713
+ json.dump(bt_data,fp)
714
+
715
+ else:
716
+ with open(save_bt_file_path,'r') as f:
717
+ bt_data = json.load(f)
718
+ with open(config.log,'a+') as fl:
719
+ print('-'*15+'Printing 5 random BT Data'+'-'*15,file=fl)
720
+ ids_print = random.sample([i for i in range(len(bt_data))],5)
721
+ with open(config.log,'a+') as fl:
722
+ for ids_print_ in ids_print:
723
+
724
+ print(bt_data[ids_print_],file=fl)
725
+
726
+ augmented_dataset = train_dataset + bt_data + mono_data_noise(config) #mono_data_noise adds denoising objective
727
+ random.shuffle(augmented_dataset)
728
+
729
+ with open(config.log,'a+') as fl:
730
+ print(f'New length of dataset: {len(augmented_dataset)}',file=fl)
731
+
732
+ n_batches = int(np.ceil(len(augmented_dataset) / config.batch_size))
733
+ total_steps = config.n_bt_epochs * n_batches
734
+ n_warmup_steps = int(total_steps * 0.01)
735
+
736
+ #scheduler = get_linear_schedule_with_warmup(optimizer, n_warmup_steps, total_steps)
737
+ #scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=config.lr, max_lr=0.001,cycle_momentum=False)
738
+
739
+ train(config,config.n_bt_epochs,optimizer,tokenizer,augmented_dataset,dev_dataset,n_batches,model,save_with_bt=True)
740
+
741
+ if config.verbose:
742
+ with open(config.log,'a+') as fl:
743
+ print('Evaluaton...',file=fl)
744
+ do_evaluation(config,tokenizer,model,dev_dataset)
745
+
746
+ config.state_dict_check['epoch']=-1
747
+ config.state_dict_check['batch_idx']=0
748
+ with open(config.log,'a+') as fl:
749
+ print('---------------End of Backtranslation---------------',file=fl)
750
+
751
+ with open(config.log,'a+') as fl:
752
+ print('---------------End of Training---------------',file=fl)
753
+ config.now_on_bt=False
754
+ config.now_on_test=True
755
+ with open(config.log,'a+') as fl:
756
+ print('Evaluating on test set',file=fl)
757
+ test_dataset = make_dataset(config,'test')
758
+ with open(config.log,'a+') as fl:
759
+ print(f"Length of test dataset: {len(test_dataset)}",file=fl)
760
+ do_evaluation(config,tokenizer,model,test_dataset)
761
+
762
+ with open(config.log,'a+') as fl:
763
+ print("ALL DONE",file=fl)
764
+
765
+
766
+ def load_params(args: dict) -> dict:
767
+ """
768
+ Load the parameters passed to `translate`
769
+ """
770
+ #if not os.path.exists(args['checkpoint']):
771
+ # raise Exception(f'Checkpoint file does not exist')
772
+
773
+ params = {}
774
+ model_repo = 'google/mt5-base'
775
+ LANG_TOKEN_MAPPING = {
776
+ 'ig': '<ig>',
777
+ 'fon': '<fon>',
778
+ 'en': '<en>',
779
+ 'fr': '<fr>',
780
+ 'rw':'<rw>',
781
+ 'yo':'<yo>',
782
+ 'xh':'<xh>',
783
+ 'sw':'<sw>'
784
+ }
785
+ tokenizer = AutoTokenizer.from_pretrained(model_repo)
786
+
787
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_repo)
788
+
789
+
790
+ """## Update tokenizer"""
791
+ special_tokens_dict = {'additional_special_tokens': list(LANG_TOKEN_MAPPING.values())}
792
+ tokenizer.add_special_tokens(special_tokens_dict)
793
+
794
+ model.resize_token_embeddings(len(tokenizer))
795
+
796
+ state_dict = torch.load(args['checkpoint'],map_location=args['device'])
797
+
798
+ model.load_state_dict(state_dict['model_state_dict'])
799
+
800
+ model = model.to(args['device'])
801
+
802
+ #Load the model, load the tokenizer, max and min seq len
803
+ params['model'] = model
804
+ params['device'] = args['device']
805
+ params['max_seq_len'] = args['max_seq_len'] if 'max_seq_len' in args else 50
806
+ params['min_seq_len'] = args['min_seq_len'] if 'min_seq_len' in args else 2
807
+ params['tokenizer'] = tokenizer
808
+ params['num_beams'] = args['num_beams'] if 'num_beams' in args else 4
809
+ params['lang_token'] = LANG_TOKEN_MAPPING
810
+ params['truncation'] = args['truncation'] if 'truncation' in args else True
811
+
812
+ return params
813
+
814
+ def encode_input_str_translate(params,text, target_lang, tokenizer, seq_len):
815
+
816
+ target_lang_token = params['lang_token'][target_lang]
817
+
818
+ # Tokenize and add special tokens
819
+ input_ids = tokenizer.encode(
820
+ text = str(target_lang_token) + str(text),
821
+ return_tensors = 'pt',
822
+ padding = 'max_length',
823
+ truncation = params['truncation'] ,
824
+ max_length = seq_len)
825
+
826
+ return input_ids[0]
827
+
828
+ def translate(
829
+ params: dict,
830
+ sentence: str,
831
+ source_lang: str,
832
+ target_lang: str
833
+ ) -> str:
834
+ """
835
+ Given a sentence and its source and target sentences, this translates the sentence
836
+ to the given target sentence.
837
+ """
838
+
839
+
840
+ if source_lang!='' and target_lang!='':
841
+ inp = [sentence]
842
+
843
+ input_tokens = [encode_input_str_translate(params,text = inp[i],target_lang = target_lang,tokenizer = params['tokenizer'],seq_len =params['max_seq_len']).unsqueeze(0).to(params['device']) for i in range(len(inp))]
844
+ output = [params['model'].generate(input_ids, num_beams=params['num_beams'], num_return_sequences=1,max_length=params['max_seq_len'],min_length=params['min_seq_len']) for input_ids in input_tokens]
845
+ output = [params['tokenizer'].decode(out[0], skip_special_tokens=True) for out in tqdm(output)]
846
+
847
+ return output[0]
848
+
849
+ else:
850
+ return ''
851
+
852
+
853
+
854
+
855
+
856
+ if __name__=="__main__":
857
+ from argparse import ArgumentParser
858
+ import json
859
+ import os
860
+
861
+
862
+ parser = ArgumentParser('MMTArica Experiments')
863
+
864
+ parser.add_argument('-homepath', type=str, default=os.getcwd(),
865
+ help="Homepath directory. Where all experiments are saved and all \
866
+ necessary files/folders are saved. (default: current working directory)")
867
+
868
+ parser.add_argument('--prediction_path', type=str, default='./predictions',
869
+ help='directory path to save predictions (default: %(default)s)')
870
+
871
+ parser.add_argument('--model_name', type=str, default='mmt_translation',
872
+ help='Name of model (default: %(default)s)')
873
+
874
+ parser.add_argument('--bt_data_dir', type=str, default='btData',
875
+ help='Directory to save back-translation files (default: %(default)s)')
876
+
877
+ parser.add_argument('--parallel_dir', type=str, default='parallel',
878
+ help='name of directory where parallel corpora is saved')
879
+
880
+ parser.add_argument('--mono_dir', type=str, default='mono',
881
+ help='name of directory where monolingual files are saved (default: %(default)s)')
882
+
883
+ parser.add_argument('--log', type=str, default='train.log',
884
+ help='name of file to log experiments (default: %(default)s)')
885
+
886
+ parser.add_argument('--mono_data_limit', type=int, default=300,
887
+ help='limit of monolingual sentences to use for training (default: %(default)s)')
888
+
889
+ parser.add_argument('--mono_data_for_noise_limit', type=int, default=50,
890
+ help='limit of monolingual sentences to use for noise (default: %(default)s)')
891
+
892
+ parser.add_argument('--n_epochs', type=int, default=10,
893
+ help='number of training epochs (default: %(default)s)')
894
+
895
+ parser.add_argument('--n_bt_epochs', type=int, default=3,
896
+ help='number of backtranslation epochs (default: %(default)s)')
897
+
898
+ parser.add_argument('--batch_size', type=int, default=64,
899
+ help='batch size (default: %(default)s)')
900
+
901
+ parser.add_argument('--max_seq_len', type=int, default=50,
902
+ help='maximum length of sentence. All sentences beyond this length will be skipped. (default: %(default)s)')
903
+
904
+ parser.add_argument('--min_seq_len', type=int, default=2,
905
+ help='mnimum length of sentence. All sentences beyond this length will be skipped. (default: %(default)s)')
906
+
907
+ parser.add_argument('--checkpoint_freq', type=int, default=10_000,
908
+ help='maximum length of sentence. All sentences beyond this length will be skipped. (default: %(default)s)')
909
+
910
+ parser.add_argument('--lr', type=int, default=1e-4,
911
+ help='learning rate. (default: %(default)s)')
912
+
913
+ parser.add_argument('--print_freq', type=int, default=5_000,
914
+ help='frequency at which to print to log. (default: %(default)s)')
915
+
916
+ parser.add_argument('--use_multiprocessing', type=bool, default=False,
917
+ help='whether or not to use multiprocessing. (default: %(default)s)')
918
+
919
+ parser.add_argument('--num_pretrain_steps', type=int, default=20,
920
+ help='number of pretrain steps. (default: %(default)s)')
921
+
922
+ parser.add_argument('--num_backtranslation_steps', type=int, default=5,
923
+ help='number of pretrain steps. (default: %(default)s)')
924
+
925
+ parser.add_argument('--do_backtranslation', type=bool, default=True,
926
+ help='whether or not to do backtranslation during training. (default: %(default)s)')
927
+
928
+ parser.add_argument('--use_reconstruction', type=bool, default=True,
929
+ help='whether or not to use reconstruction during training. (default: %(default)s)')
930
+
931
+ parser.add_argument('--use_torch_data_parallel', type=bool, default=False,
932
+ help='whether or not to use torch data parallelism. (default: %(default)s)')
933
+
934
+ parser.add_argument('--gradient_accumulation_batch', type=int, default=4096//64,
935
+ help='batch size for gradient accumulation. (default: %(default)s)')
936
+
937
+ parser.add_argument('--num_beams', type=int, default=4,
938
+ help='number of beams to use for inference. (default: %(default)s)')
939
+
940
+ parser.add_argument('--patience', type=int, default=15_000_000,
941
+ help='patience for early stopping. (default: %(default)s)')
942
+
943
+ parser.add_argument('--drop_probability', type=float, default=0.2,
944
+ help='drop probability for reconstruction. (default: %(default)s)')
945
+
946
+ parser.add_argument('--dropout', type=float, default=0.1,
947
+ help='dropout probability. (default: %(default)s)')
948
+
949
+ parser.add_argument('--num_swaps', type=int, default=3,
950
+ help='number of word swaps to perform during reconstruction. (default: %(default)s)')
951
+
952
+ parser.add_argument('--verbose', type=bool, default=True,
953
+ help='whether or not to print information during experiments. (default: %(default)s)')
954
+
955
+ args = parser.parse_args()
956
+
957
+
958
+ main(args)
959
+
960
+
961
+