Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
from tqdm import tqdm | |
from math import ceil | |
def load_vsec_dataset(base_path, corr_file, incorr_file): | |
# load files | |
if base_path: | |
assert os.path.exists(base_path) == True | |
incorr_data = [] | |
opfile1 = open(os.path.join(base_path, incorr_file), "r", encoding="utf-8") | |
for line in opfile1: | |
if line.strip() != "": | |
incorr_data.append(line.strip()) | |
opfile1.close() | |
corr_data = [] | |
opfile2 = open(os.path.join(base_path, corr_file), "r", encoding="utf-8") | |
for line in opfile2: | |
if line.strip() != "": | |
corr_data.append(line.strip()) | |
opfile2.close() | |
assert len(incorr_data) == len(corr_data) | |
data = [] | |
for x, y in zip(corr_data, incorr_data): | |
data.append((x, y)) | |
print(f"loaded tuples of (incorr, corr) examples from {base_path}") | |
return data | |
def load_dataset(base_path, corr_file, incorr_file, length_file = None): | |
# load files | |
if base_path: | |
assert os.path.exists(base_path) == True | |
data = [] | |
opfile2 = open(os.path.join(base_path, corr_file), "r", encoding="utf-8") | |
for line in tqdm(opfile2): | |
if line.strip() != "": | |
data.append([line.strip()]) | |
data.append([line.strip()]) | |
opfile2.close() | |
opfile1 = open(os.path.join(base_path, incorr_file), "r", encoding="utf-8") | |
for i, line in tqdm(enumerate(opfile1)): | |
if line.strip() != "": | |
data[i].append(line.strip()) | |
opfile1.close() | |
opfile4 = open(os.path.join(base_path, length_file), "r", encoding="utf-8") | |
for i, line in tqdm(enumerate(opfile4)): | |
if line.strip() != "": | |
data[i].append(int(line)) | |
opfile4.close() | |
print(f"loaded tuples of (incorr, corr, length) examples from {base_path}") | |
return data | |
def load_epoch_dataset(base_path, corr_file, incorr_file, length_file, epoch: int, num_epoch: int): | |
# load files | |
if base_path: | |
assert os.path.exists(base_path) == True | |
assert num_epoch >= 1 | |
assert epoch >= 1 and epoch <= num_epoch | |
## Count number of data | |
opfile = open(os.path.join(base_path, length_file), "r", encoding="utf-8") | |
count = 0 | |
for i, line in tqdm(enumerate(opfile)): | |
count +=1 | |
opfile.close() | |
print(f"Number of training datas: {count} examples!") | |
epochdataset_examples = int(ceil(1 / num_epoch * count)) | |
start_index = epochdataset_examples * (epoch - 1) | |
end_index = start_index + epochdataset_examples | |
data = [] | |
opfile2 = open(os.path.join(base_path, corr_file), "r", encoding="utf-8") | |
traverse_count = 0 | |
for i, line in tqdm(enumerate(opfile2)): | |
if line.strip() != "": | |
if traverse_count >= start_index and traverse_count < end_index : | |
data.append([line.strip()]) | |
traverse_count += 1 | |
elif traverse_count >= end_index: | |
break | |
else: | |
traverse_count += 1 | |
if traverse_count >= start_index and traverse_count < end_index : | |
data.append([line.strip()]) | |
traverse_count += 1 | |
elif traverse_count >= end_index: | |
break | |
else: | |
traverse_count += 1 | |
opfile2.close() | |
opfile1 = open(os.path.join(base_path, incorr_file), "r", encoding="utf-8") | |
traverse_count = 0 | |
for i, line in tqdm(enumerate(opfile1)): | |
if line.strip() != "": | |
if traverse_count >= start_index and traverse_count < end_index : | |
data[i - start_index].append(line.strip()) | |
elif traverse_count >= end_index: | |
break | |
traverse_count += 1 | |
opfile1.close() | |
traverse_count = 0 | |
opfile4 = open(os.path.join(base_path, length_file), "r", encoding="utf-8") | |
for i, line in tqdm(enumerate(opfile4)): | |
if line.strip() != "": | |
if traverse_count >= start_index and traverse_count < end_index : | |
data[i - start_index].append(int(line)) | |
elif traverse_count >= end_index: | |
break | |
traverse_count += 1 | |
opfile4.close() | |
print(f"loaded tuples of (incorr, corr, length) examples from {base_path}") | |
return data | |