gandhi-gpt / code /data_preprocessing.py
ritwikm's picture
added code files
b7c468b
import re, glob, string
import math
from tqdm import tqdm
from transformers import AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained('gpt2', bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>')
from nltk.tokenize import sent_tokenize
# ----------------------------- Cleaning process 1/2 -----------------------------
def sanitize(line):
# print('before', line)
line2 = re.sub(r'\[.+\]','',line)
# print('after',line2)
for a in ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"]:
line2 = line2.replace(a,'')
line2 = re.sub(r'\b[A-Z]+\b','',line2.strip())
line2 = re.sub(r'\d','',line2)
line2 = line2.translate(str.maketrans('','',"โ€Ÿโ€œโ€™โโžโ€šโ€˜โ€›โ›โœโŸรขย€ย™")) #just removed the quotes
line2 = line2.translate(str.maketrans('','',string.punctuation))
line2 = re.sub(r'\s+',' ',line2).strip()
return line2
def remove_footnotes_and_clean(sents):
sents = [x.replace("'",'').replace('*','').replace('โ€™ยฎ','').replace('โ€™','') for x in sents]
s = ''
for line in sents:
try:
if line.strip()[-1] != '-':
s = s + line.strip() + ' '
else:
s = s + line.strip()
except:
print(sents)
input()
s = re.sub(r'\s+',' ',s)
return s
path = 'text_files/'
ml = sorted(glob.glob(path+'*.txt'))
show = False
path = 'clean_text_files/'
for k,m in enumerate(tqdm(ml, total=len(ml), ncols=100)):
# m = ml[-1]
# if k < 67:
# continue
file = open(m,'r')
content = file.readlines()
file.close()
if show:
print(m)
paras = []
sents = []
mean_spaces = []
footnote_found = False
for line in content:
line2 = sanitize(line)
if re.search(r'^\W\s\w',line.strip()):
footnote_found = True
if re.search(r'^VOL.*\d\d\d\d.*\d$',line.strip()) or 'THE COLLECTED WORKS OF MAHATMA GANDHI' in line.strip():
# new page
footnote_found = False
if len(line2) > 5 and len(line2.split()) > 4 and footnote_found==False:
if show:
print(line.rstrip(),end='')
li_spaces = len(line) - len(line.strip())
if show:
print(li_spaces)
mean_spaces.append(li_spaces)
# input()
mean_spaces = math.floor(sum(mean_spaces)/len(mean_spaces))
if show:
print('ms',mean_spaces)
print(' '*mean_spaces+'^')
footnote_found = False
last_spaces = -1
i = 0
while i < len(content)-1:
# line2 = re.sub(r'[A-Z]','',line.strip())
# line2 = re.sub(r'\[\w+\]','',line2)
line = content[i]
li_spaces = len(line) - len(line.strip())
if re.search(r'^\W\s\w',line.strip()):
footnote_found = True
if re.search(r'^VOL.*\d\d\d\d.*\d$',line.strip()) or 'THE COLLECTED WORKS OF MAHATMA GANDHI' in line.strip():
# new page
footnote_found = False
i+=1
# print('--',line.rstrip())
continue
if footnote_found == False:
if not (li_spaces > mean_spaces):
# when the spaces in current line is equal or one tab shy from the mean spaces
line2 = sanitize(line)
if len(line2) > 5 and len(line2.split()) > 4:
if show:
print('++',line.rstrip())
sents.append(line)
last_spaces = li_spaces
elif last_spaces == li_spaces:
if show:
print('++',line.rstrip())
sents.append(line)
else:
last_spaces = -1
if show:
print('--',line.rstrip())
else:
# the current line has more or less spaces as compared to the mean
next_line = content[i+1]
lj_spaces = len(next_line) - len(next_line.strip())
if not (lj_spaces > mean_spaces):
# print('b4', line)
line1 = sanitize(content[i])
line2 = sanitize(next_line)
# print('now',line2)
if len(line1) > 5 and len(line1.split()) > 4 and len(line2) > 5 and len(line2.split()) > 4:
sent_text = remove_footnotes_and_clean(sents)
paras.append(sent_text)
if show:
print('++',line.rstrip(),'<------NEW PARA')
sents = [line]
# print('$$',paras[-1])
else:
last_spaces = -1
if show:
print('--',line.rstrip())
else:
last_spaces = -1
if show:
print('--',line.rstrip())
else:
last_spaces = -1
if show:
print('--',line.rstrip())
if show:
input('wait')
i+=1
file = open(path+m.split('/')[-1],'w')
file.write('\n'.join(paras[1:]))
file.close()
# input('here wait')
# ----------------------------- Cleaning process 2/2 -----------------------------
path = 'clean_text_files/'
ml = sorted(glob.glob(path+'*.txt'))
text = []
for m in tqdm(range(1,99)):
file = open(path+str(m)+'.txt','r')
text += file.readlines()
file.close()
file = open('all_paras.txt','w')
file.write(''.join(text))
file.close()
sents = []
tcsents = [] # transformer compatible sents
para_stack = []
for para in tqdm(text):
para = para.strip()
sents += sent_tokenize(para)
para_stack = [para] + para_stack
while len(para_stack)!=0:
top_para = para_stack.pop(0)
if len(tokenizer('<|startoftext|>'+ top_para + '<|endoftext|>')['input_ids']) > 200: # <-------------
ts = sent_tokenize(top_para)
if len(ts) > 1:
para_stack = [' '.join(ts[int(len(ts)/2):])] + para_stack # second half
para_stack = [' '.join(ts[:int(len(ts)/2)])] + para_stack # first half
else:
tcsents.append(top_para)
else:
tcsents.append(top_para)
file = open('all_sents.txt','w')
file.write('\n'.join(sents))
file.close()
file = open('all_tc_sents_200.txt','w')
file.write('\n'.join(tcsents))
file.close()