Spaces:
Runtime error
Runtime error
import os | |
import random | |
import pickle | |
import torch | |
from multiprocessing import Pool | |
from tencentpretrain.utils.constants import * | |
from tencentpretrain.utils.tokenizers import * | |
from tencentpretrain.utils.misc import count_lines | |
from tencentpretrain.utils.seed import set_seed | |
from tencentpretrain.utils.mask import mask_seq | |
def merge_dataset(dataset_path, workers_num): | |
# Merge datasets. | |
dataset_writer = open(dataset_path, "wb") | |
for i in range(workers_num): | |
tmp_dataset_reader = open("dataset-tmp-" + str(i) + ".pt", "rb") | |
while True: | |
tmp_data = tmp_dataset_reader.read(2**20) | |
if tmp_data: | |
dataset_writer.write(tmp_data) | |
else: | |
break | |
tmp_dataset_reader.close() | |
os.remove("dataset-tmp-" + str(i) + ".pt") | |
dataset_writer.close() | |
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens): | |
""" truncate sequence pair to specific length """ | |
while True: | |
total_length = len(tokens_a) + len(tokens_b) | |
if total_length <= max_num_tokens: | |
break | |
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b | |
if random.random() < 0.5: | |
del trunc_tokens[0] | |
else: | |
trunc_tokens.pop() | |
class Dataset(object): | |
def __init__(self, args, vocab, tokenizer): | |
self.vocab = vocab | |
self.tokenizer = tokenizer | |
self.corpus_path = args.corpus_path | |
self.dataset_path = args.dataset_path | |
self.seq_length = args.seq_length | |
self.seed = args.seed | |
self.dynamic_masking = args.dynamic_masking | |
self.whole_word_masking = args.whole_word_masking | |
self.span_masking = args.span_masking | |
self.span_geo_prob = args.span_geo_prob | |
self.span_max_length = args.span_max_length | |
self.docs_buffer_size = args.docs_buffer_size | |
self.dup_factor = args.dup_factor | |
def build_and_save(self, workers_num): | |
""" | |
Build dataset from the given corpus. | |
Start workers_num processes and each process deals with a part of data. | |
""" | |
lines_num = count_lines(self.corpus_path) | |
print("Starting %d workers for building datasets ... " % workers_num) | |
assert (workers_num >= 1) | |
if workers_num == 1: | |
self.worker(0, 0, lines_num) | |
else: | |
pool = Pool(workers_num) | |
for i in range(workers_num): | |
start = i * lines_num // workers_num | |
end = (i + 1) * lines_num // workers_num | |
pool.apply_async(func=self.worker, args=[i, start, end]) | |
pool.close() | |
pool.join() | |
# Merge datasets. | |
merge_dataset(self.dataset_path, workers_num) | |
def worker(self, proc_id, start, end): | |
raise NotImplementedError() | |
class BertDataset(Dataset): | |
""" | |
Construct dataset for MLM and NSP tasks from the given corpus. | |
Each document consists of multiple sentences, | |
and each sentence occupies a single line. | |
Documents in corpus must be separated by empty lines. | |
""" | |
def __init__(self, args, vocab, tokenizer): | |
super(BertDataset, self).__init__(args, vocab, tokenizer) | |
self.short_seq_prob = args.short_seq_prob | |
def worker(self, proc_id, start, end): | |
print("Worker %d is building dataset ... " % proc_id) | |
set_seed(self.seed) | |
docs_buffer = [] | |
document = [] | |
pos = 0 | |
dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb") | |
with open(self.corpus_path, mode="r", encoding="utf-8") as f: | |
while pos < start: | |
f.readline() | |
pos += 1 | |
while True: | |
line = f.readline() | |
pos += 1 | |
if pos >= end: | |
if len(docs_buffer) > 0: | |
instances = self.build_instances(docs_buffer) | |
for instance in instances: | |
pickle.dump(instance, dataset_writer) | |
break | |
if not line.strip(): | |
if len(document) >= 1: | |
docs_buffer.append(document) | |
document = [] | |
if len(docs_buffer) == self.docs_buffer_size: | |
# Build instances from documents. | |
instances = self.build_instances(docs_buffer) | |
# Save instances. | |
for instance in instances: | |
pickle.dump(instance, dataset_writer) | |
# Clear buffer. | |
docs_buffer = [] | |
continue | |
sentence = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(line)) | |
if len(sentence) > 0: | |
document.append(sentence) | |
dataset_writer.close() | |
def build_instances(self, all_documents): | |
instances = [] | |
for _ in range(self.dup_factor): | |
for doc_index in range(len(all_documents)): | |
instances.extend(self.create_ins_from_doc(all_documents, doc_index)) | |
return instances | |
def create_ins_from_doc(self, all_documents, document_index): | |
document = all_documents[document_index] | |
max_num_tokens = self.seq_length - 3 | |
target_seq_length = max_num_tokens | |
if random.random() < self.short_seq_prob: | |
target_seq_length = random.randint(2, max_num_tokens) | |
instances = [] | |
current_chunk = [] | |
current_length = 0 | |
i = 0 | |
while i < len(document): | |
segment = document[i] | |
current_chunk.append(segment) | |
current_length += len(segment) | |
if i == len(document) - 1 or current_length >= target_seq_length: | |
if current_chunk: | |
a_end = 1 | |
if len(current_chunk) >= 2: | |
a_end = random.randint(1, len(current_chunk) - 1) | |
tokens_a = [] | |
for j in range(a_end): | |
tokens_a.extend(current_chunk[j]) | |
tokens_b = [] | |
is_random_next = 0 | |
if len(current_chunk) == 1 or random.random() < 0.5: | |
is_random_next = 1 | |
target_b_length = target_seq_length - len(tokens_a) | |
for _ in range(10): | |
random_document_index = random.randint(0, len(all_documents) - 1) | |
if random_document_index != document_index: | |
break | |
random_document = all_documents[random_document_index] | |
random_start = random.randint(0, len(random_document) - 1) | |
for j in range(random_start, len(random_document)): | |
tokens_b.extend(random_document[j]) | |
if len(tokens_b) >= target_b_length: | |
break | |
num_unused_segments = len(current_chunk) - a_end | |
i -= num_unused_segments | |
else: | |
is_random_next = 0 | |
for j in range(a_end, len(current_chunk)): | |
tokens_b.extend(current_chunk[j]) | |
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens) | |
src = [] | |
src.append(self.vocab.get(CLS_TOKEN)) | |
src.extend(tokens_a) | |
src.append(self.vocab.get(SEP_TOKEN)) | |
seg_pos = [len(src)] | |
src.extend(tokens_b) | |
src.append(self.vocab.get(SEP_TOKEN)) | |
seg_pos.append(len(src)) | |
pad_num = 0 | |
if len(src) != self.seq_length: | |
pad_num = self.seq_length - len(src) | |
if not self.dynamic_masking: | |
src, tgt_mlm = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) | |
src = (src, pad_num) | |
instance = (src, tgt_mlm, is_random_next, seg_pos) | |
else: | |
src = (src, pad_num) | |
instance = (src, is_random_next, seg_pos) | |
instances.append(instance) | |
current_chunk = [] | |
current_length = 0 | |
i += 1 | |
return instances | |
class MlmDataset(Dataset): | |
def __init__(self, args, vocab, tokenizer): | |
super(MlmDataset, self).__init__(args, vocab, tokenizer) | |
self.full_sentences = args.full_sentences | |
def worker(self, proc_id, start, end): | |
print("Worker %d is building dataset ... " % proc_id) | |
set_seed(self.seed) | |
dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb") | |
docs_buffer = [] | |
for _ in range(self.dup_factor): | |
pos = 0 | |
with open(self.corpus_path, mode="r", encoding="utf-8") as f: | |
while pos < start: | |
f.readline() | |
pos += 1 | |
while True: | |
line = f.readline() | |
pos += 1 | |
document = [self.vocab.get(CLS_TOKEN)] + self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(line)) + [self.vocab.get(SEP_TOKEN)] | |
if self.full_sentences: | |
if len(document) > 0: | |
docs_buffer.append(document) | |
if len(docs_buffer) == self.docs_buffer_size: | |
# Build instances from documents. | |
all_documents = self.concatenate_docs(docs_buffer) | |
instances = self.build_instances(all_documents) | |
# Save instances. | |
for instance in instances: | |
pickle.dump(instance, dataset_writer) | |
# Clear buffer. | |
docs_buffer = [] | |
if pos >= end: | |
if len(docs_buffer) > 0: | |
all_documents = self.concatenate_docs(docs_buffer) | |
instances = self.build_instances(all_documents) | |
# Save instances. | |
for instance in instances: | |
pickle.dump(instance, dataset_writer) | |
break | |
else: | |
if len(document) > 0: | |
instances = self.build_instances(document) | |
# Save instances. | |
for instance in instances: | |
pickle.dump(instance, dataset_writer) | |
if pos >= end: | |
break | |
dataset_writer.close() | |
def concatenate_docs(self, docs_buffer): | |
all_documents = [] | |
for i in range(len(docs_buffer)): | |
all_documents += docs_buffer[i] | |
return all_documents | |
def build_instances(self, all_documents): | |
instances = [] | |
instances_num = len(all_documents) // self.seq_length | |
for i in range(instances_num): | |
src = all_documents[i * self.seq_length: (i + 1) * self.seq_length] | |
seg_pos = [len(src)] | |
if not self.dynamic_masking: | |
src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) | |
instance = ((src, 0), tgt, seg_pos) | |
else: | |
instance = ((src, 0), seg_pos) | |
instances.append(instance) | |
src = all_documents[instances_num * self.seq_length:] | |
if len(src) == 0: | |
return instances | |
seg_pos = [len(src)] | |
pad_num = self.seq_length - len(src) | |
if not self.dynamic_masking: | |
src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) | |
instance = ((src, pad_num), tgt, seg_pos) | |
else: | |
instance = ((src, pad_num), seg_pos) | |
instances.append(instance) | |
return instances | |
class AlbertDataset(Dataset): | |
""" | |
Construct dataset for MLM and SOP tasks from the given corpus. | |
Each document consists of multiple sentences, | |
and each sentence occupies a single line. | |
Documents in corpus must be separated by empty lines. | |
""" | |
def __init__(self, args, vocab, tokenizer): | |
super(AlbertDataset, self).__init__(args, vocab, tokenizer) | |
self.short_seq_prob = args.short_seq_prob | |
def worker(self, proc_id, start, end): | |
print("Worker %d is building dataset ... " % proc_id) | |
set_seed(self.seed) | |
document = [] | |
dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb") | |
for _ in range(self.dup_factor): | |
pos = 0 | |
with open(self.corpus_path, mode="r", encoding="utf-8") as f: | |
while pos < start: | |
f.readline() | |
pos += 1 | |
while True: | |
line = f.readline() | |
pos += 1 | |
if not line.strip(): | |
if len(document) >= 1: | |
instances = self.build_instances(document) | |
for instance in instances: | |
pickle.dump(instance, dataset_writer) | |
document = [] | |
sentence = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(line)) | |
if len(sentence) > 0: | |
document.append(sentence) | |
if pos >= end: | |
if len(document) >= 1: | |
instances = self.build_instances(document) | |
for instance in instances: | |
pickle.dump(instance, dataset_writer) | |
break | |
dataset_writer.close() | |
def build_instances(self, document): | |
instances = [] | |
instances.extend(self.create_ins_from_doc(document)) | |
return instances | |
def create_ins_from_doc(self, document): | |
max_num_tokens = self.seq_length - 3 | |
target_seq_length = max_num_tokens | |
if random.random() < self.short_seq_prob: | |
target_seq_length = random.randint(2, max_num_tokens) | |
instances = [] | |
current_chunk = [] | |
current_length = 0 | |
i = 0 | |
while i < len(document): | |
segment = document[i] | |
current_chunk.append(segment) | |
current_length += len(segment) | |
if i == len(document) - 1 or current_length >= target_seq_length: | |
if current_chunk: | |
a_end = 1 | |
if len(current_chunk) >= 2: | |
a_end = random.randint(1, len(current_chunk) - 1) | |
tokens_a = [] | |
for j in range(a_end): | |
tokens_a.extend(current_chunk[j]) | |
tokens_b = [] | |
is_wrong_order = 0 | |
for j in range(a_end, len(current_chunk)): | |
tokens_b.extend(current_chunk[j]) | |
if random.random() < 0.5: | |
is_wrong_order = 1 | |
tmp = tokens_a | |
tokens_a = tokens_b | |
tokens_b = tmp | |
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens) | |
src = [] | |
src.append(self.vocab.get(CLS_TOKEN)) | |
src.extend(tokens_a) | |
src.append(self.vocab.get(SEP_TOKEN)) | |
seg_pos = [len(src)] | |
src.extend(tokens_b) | |
src.append(self.vocab.get(SEP_TOKEN)) | |
seg_pos.append(len(src)) | |
pad_num = 0 | |
if len(src) != self.seq_length: | |
pad_num = self.seq_length - len(src) | |
if not self.dynamic_masking: | |
src, tgt_mlm = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) | |
src = (src, pad_num) | |
instance = (src, tgt_mlm, is_wrong_order, seg_pos) | |
else: | |
src = (src, pad_num) | |
instance = (src, is_wrong_order, seg_pos) | |
instances.append(instance) | |
current_chunk = [] | |
current_length = 0 | |
i += 1 | |
return instances | |
class LmDataset(Dataset): | |
def worker(self, proc_id, start, end): | |
print("Worker %d is building dataset ... " % proc_id) | |
set_seed(self.seed) | |
dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb") | |
pos = 0 | |
with open(self.corpus_path, mode="r", encoding="utf-8") as f: | |
while pos < start: | |
f.readline() | |
pos += 1 | |
while True: | |
line = f.readline() | |
pos += 1 | |
document = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(line)) | |
document = [self.vocab.get(CLS_TOKEN)] + document + [self.vocab.get(SEP_TOKEN)] | |
instances_num = len(document) // (self.seq_length + 1) | |
for i in range(instances_num): | |
src = document[i * (self.seq_length + 1): (i + 1) * (self.seq_length + 1)] | |
seg_pos = [self.seq_length] | |
src = (src, 0) | |
pickle.dump((src, seg_pos), dataset_writer) | |
src = document[instances_num * (self.seq_length + 1):] | |
if len(src) > 0: | |
seg_pos = [len(src)] | |
pad_num = self.seq_length + 1 - len(src) | |
src = (src, pad_num) | |
pickle.dump((src, seg_pos), dataset_writer) | |
if pos >= end: | |
break | |
dataset_writer.close() | |
class BilmDataset(Dataset): | |
def worker(self, proc_id, start, end): | |
print("Worker %d is building dataset ... " % proc_id) | |
set_seed(self.seed) | |
dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb") | |
pos = 0 | |
with open(self.corpus_path, mode="r", encoding="utf-8") as f: | |
while pos < start: | |
f.readline() | |
pos += 1 | |
while True: | |
line = f.readline() | |
pos += 1 | |
document = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(line)) | |
instances_num = len(document) // self.seq_length | |
for i in range(instances_num): | |
src = document[i * self.seq_length: (i + 1) * self.seq_length] | |
tgt_forward = src[1:] + [self.vocab.get(SEP_TOKEN)] | |
tgt_backward = [self.vocab.get(CLS_TOKEN)] + src[:-1] | |
seg_pos = [self.seq_length] | |
src = (src, 0) | |
pickle.dump((src, tgt_forward, tgt_backward, seg_pos), dataset_writer) | |
src = document[instances_num * self.seq_length:] | |
if len(src) < 1: | |
continue | |
tgt_forward = src[1:] + [self.vocab.get(SEP_TOKEN)] | |
tgt_backward = [self.vocab.get(CLS_TOKEN)] + src[:-1] | |
seg_pos = [len(src)] | |
pad_num = self.seq_length - len(src) | |
src = (src, pad_num) | |
pickle.dump((src, tgt_forward, tgt_backward, seg_pos), dataset_writer) | |
if pos >= end: | |
break | |
dataset_writer.close() | |
class MtDataset(Dataset): | |
def __init__(self, args, vocab, tokenizer): | |
super(MtDataset, self).__init__(args, vocab, tokenizer) | |
self.tgt_seq_length = args.tgt_seq_length | |
self.src_vocab, self.src_tokenizer = vocab, tokenizer | |
self.tgt_tokenizer = args.tgt_tokenizer | |
self.tgt_vocab = self.tgt_tokenizer.vocab | |
def worker(self, proc_id, start, end): | |
print("Worker %d is building dataset ... " % proc_id) | |
set_seed(self.seed) | |
dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb") | |
pos = 0 | |
with open(self.corpus_path, mode="r", encoding="utf-8") as f: | |
while pos < start: | |
f.readline() | |
pos += 1 | |
while True: | |
line = f.readline() | |
pos += 1 | |
if len(line.strip().split("\t")) != 2: | |
if pos >= end: | |
break | |
continue | |
document_src, document_tgt = line.strip().split("\t") | |
src = self.src_tokenizer.convert_tokens_to_ids(self.src_tokenizer.tokenize(document_src)) | |
tgt = self.tgt_tokenizer.convert_tokens_to_ids(self.tgt_tokenizer.tokenize(document_tgt)) | |
src = [self.src_vocab.get(CLS_TOKEN)] + src + [self.src_vocab.get(SEP_TOKEN)] | |
tgt = [self.tgt_vocab.get(CLS_TOKEN)] + tgt + [self.tgt_vocab.get(SEP_TOKEN)] | |
src, tgt = src[:self.seq_length], tgt[:self.tgt_seq_length + 1] | |
seg_pos = [len(src)] | |
pad_num = self.seq_length - len(src) | |
src = (src, pad_num) | |
pad_num = self.tgt_seq_length + 1 - len(tgt) | |
tgt = (tgt, pad_num) | |
pickle.dump((src, tgt, seg_pos), dataset_writer) | |
if pos >= end: | |
break | |
dataset_writer.close() | |
class T5Dataset(MlmDataset): | |
''' | |
T5 can reuse the code of MlmDataset. | |
''' | |
pass | |
class GsgDataset(BertDataset): | |
def __init__(self, args, vocab, tokenizer): | |
super(GsgDataset, self).__init__(args, vocab, tokenizer) | |
self.sentence_selection_strategy = args.sentence_selection_strategy | |
self.tgt_seq_length = args.tgt_seq_length | |
def create_single_instance(self, src, tgt): | |
src = [self.vocab.get(CLS_TOKEN)] + src + [self.vocab.get(SEP_TOKEN)] | |
tgt = [self.vocab.get(CLS_TOKEN)] + tgt + [self.vocab.get(SEP_TOKEN)] | |
seg_pos = [len(src)] | |
pad_num = self.seq_length - len(src) | |
src = (src, pad_num) | |
pad_num = self.tgt_seq_length - len(tgt) | |
tgt = (tgt, pad_num) | |
instance = (src, tgt, seg_pos) | |
return instance | |
def create_ins_from_doc(self, all_documents, document_index): | |
sentence_selection_strategy = self.sentence_selection_strategy | |
instances = [] | |
mask_seq_list = [] | |
tmp_document = [] | |
src = [] | |
tgt = [] | |
i = 0 | |
document = all_documents[document_index] | |
target_seq_length, target_tgt_seq_length = self.seq_length - 2, self.tgt_seq_length - 2 | |
for segment in document: | |
if len(segment) < target_seq_length and len(segment) < target_tgt_seq_length: | |
tmp_document.append(segment) | |
document = tmp_document | |
mask_seq_num = int(round(len(document) * 0.3, 0)) | |
if sentence_selection_strategy == "random": | |
mask_seq_list = random.sample(range(0, len(document) - 1), mask_seq_num) | |
else: | |
mask_seq_list = list(range(0, mask_seq_num)) | |
while i < len(document): | |
segment = document[i] | |
if i in mask_seq_list and len(tgt) + len(segment) < target_tgt_seq_length and len(src) + 1 < target_seq_length: | |
tgt = tgt + segment | |
src = src + [self.vocab.get(MASK_TOKEN)] | |
elif i not in mask_seq_list and len(src) + len(segment) < target_seq_length: | |
src = src + segment | |
else: | |
if len(tgt) > 0 and len(src) > 0: | |
instance = self.create_single_instance(src, tgt) | |
instances.append(instance) | |
if i in mask_seq_list: | |
tgt = segment | |
src = [self.vocab.get(MASK_TOKEN)] | |
else: | |
src = segment | |
tgt = [] | |
i += 1 | |
if len(tgt) > 0 and len(src) > 0: | |
instance = self.create_single_instance(src, tgt) | |
instances.append(instance) | |
return instances | |
class BartDataset(BertDataset): | |
def create_single_instance(self, src, tgt): | |
src = [self.vocab.get(CLS_TOKEN)] + src + [self.vocab.get(SEP_TOKEN)] | |
tgt = [self.vocab.get(CLS_TOKEN)] + tgt + [self.vocab.get(SEP_TOKEN)] | |
seg_pos = [len(src)] | |
pad_num = self.seq_length - len(src) | |
src = (src, pad_num) | |
tgt = (tgt, pad_num) | |
instance = (src, tgt, seg_pos) | |
return instance | |
def create_ins_from_doc(self, all_documents, document_index): | |
document = all_documents[document_index] | |
target_seq_length = self.seq_length - 2 | |
src = [] | |
tgt = [] | |
instances = [] | |
current_chunk = [] | |
current_length = 0 | |
i = 0 | |
while i < len(document): | |
segment = document[i] | |
if len(segment) > target_seq_length: | |
i += 1 | |
continue | |
if current_length + len(segment) < target_seq_length: | |
current_chunk.append(segment) | |
current_length += len(segment) | |
else: | |
shuf_chunk = current_chunk.copy() | |
random.shuffle(shuf_chunk) | |
for k in range(len(current_chunk)): | |
src = src + shuf_chunk[k] | |
tgt = tgt + current_chunk[k] | |
instance = self.create_single_instance(src, tgt) | |
instances.append(instance) | |
current_length = len(segment) | |
current_chunk = [segment] | |
src = [] | |
tgt = [] | |
i += 1 | |
if len(current_chunk) > 0: | |
shuf_chunk = current_chunk.copy() | |
random.shuffle(shuf_chunk) | |
for k in range(len(current_chunk)): | |
src = src + shuf_chunk[k] | |
tgt = tgt + current_chunk[k] | |
instance = self.create_single_instance(src, tgt) | |
instances.append(instance) | |
return instances | |
class ClsDataset(Dataset): | |
def worker(self, proc_id, start, end): | |
print("Worker %d is building dataset ... " % proc_id) | |
set_seed(self.seed) | |
dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb") | |
pos = 0 | |
with open(self.corpus_path, mode="r", encoding="utf-8") as f: | |
while pos < start: | |
f.readline() | |
pos += 1 | |
while True: | |
line = f.readline() | |
pos += 1 | |
line = line.strip().split('\t') | |
if len(line) == 2: | |
label = int(line[0]) | |
text = line[1] | |
src = [self.vocab.get(t) for t in self.tokenizer.tokenize(text)] | |
src = [self.vocab.get(CLS_TOKEN)] + src | |
tgt = label | |
seg_pos = [len(src)] | |
if len(src) >= self.seq_length: | |
pad_num = 0 | |
src = (src[:self.seq_length], pad_num) | |
seg_pos = [self.seq_length] | |
else: | |
pad_num = self.seq_length - len(src) | |
src = (src, pad_num) | |
pickle.dump((src, tgt, seg_pos), dataset_writer) | |
elif len(line) == 3: # For sentence pair input. | |
label = int(line[0]) | |
text_a, text_b = line[1], line[2] | |
src_a = [self.vocab.get(t) for t in self.tokenizer.tokenize(text_a)] | |
src_a = [self.vocab.get(CLS_TOKEN)] + src_a + [self.vocab.get(SEP_TOKEN)] | |
src_b = [self.vocab.get(t) for t in self.tokenizer.tokenize(text_b)] | |
src_b = src_b + [self.vocab.get(SEP_TOKEN)] | |
src = src_a + src_b | |
tgt = label | |
seg_pos = [len(src_a)] + [len(src_b)] | |
if len(src) >= self.seq_length: | |
pad_num = 0 | |
src = (src[:self.seq_length], pad_num) | |
if len(src_a) >= self.seq_length: | |
seg_pos = [self.seq_length] | |
else: | |
seg_pos = [len(src_a)] + [self.seq_length - len(src_a)] | |
else: | |
pad_num = self.seq_length - len(src) | |
src = (src, pad_num) | |
pickle.dump((src, tgt, seg_pos), dataset_writer) | |
else: | |
pass | |
if pos >= end: | |
break | |
dataset_writer.close() | |
class PrefixlmDataset(Dataset): | |
def worker(self, proc_id, start, end): | |
print("Worker %d is building dataset ... " % proc_id) | |
set_seed(self.seed) | |
dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb") | |
pos = 0 | |
with open(self.corpus_path, mode="r", encoding="utf-8") as f: | |
while pos < start: | |
f.readline() | |
pos += 1 | |
while True: | |
line = f.readline() | |
pos += 1 | |
if len(line.strip().split("\t")) != 2: | |
if pos >= end: | |
break | |
continue | |
document_src, document_tgt = line.strip().split("\t") | |
src = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(document_src)) | |
tgt = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(document_tgt)) | |
src = [self.vocab.get(CLS_TOKEN)] + src + [self.vocab.get(SEP_TOKEN)] | |
tgt = tgt + [self.vocab.get(SEP_TOKEN)] | |
seg_pos = [len(src)] | |
if seg_pos[0] >= self.seq_length: | |
continue | |
src = src + tgt | |
tgt = [0] * (seg_pos[0] - 1) + tgt + [self.vocab.get(PAD_TOKEN)] | |
seg_pos.append(len(src)) | |
src, tgt = src[:self.seq_length], tgt[:self.seq_length] | |
pad_num = self.seq_length - len(src) | |
src = (src, pad_num) | |
if seg_pos[1] > self.seq_length: | |
seg_pos[1] = self.seq_length | |
pickle.dump((src, tgt, seg_pos), dataset_writer) | |
if pos >= end: | |
break | |
dataset_writer.close() | |
class ClsMlmDataset(Dataset): | |
def worker(self, proc_id, start, end): | |
print("Worker %d is building dataset ... " % proc_id) | |
set_seed(self.seed) | |
dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb") | |
pos = 0 | |
with open(self.corpus_path, mode="r", encoding="utf-8") as f: | |
while pos < start: | |
f.readline() | |
pos += 1 | |
while True: | |
line = f.readline() | |
pos += 1 | |
line = line.strip().split('\t') | |
if len(line) == 2: | |
label = int(line[0]) | |
text = line[1] | |
src = [self.vocab.get(CLS_TOKEN)] + self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text)) + [self.vocab.get(SEP_TOKEN)] | |
tgt_cls = label | |
seg_pos = [len(src)] | |
elif len(line) == 3: # For sentence pair input. | |
label = int(line[0]) | |
text_a, text_b = line[1], line[2] | |
src_a = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text_a)) | |
src_a = [self.vocab.get(CLS_TOKEN)] + src_a + [self.vocab.get(SEP_TOKEN)] | |
src_b = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text_b)) | |
src_b = src_b + [self.vocab.get(SEP_TOKEN)] | |
src = src_a + src_b | |
tgt_cls = label | |
seg_pos = [len(src_a)] + [len(src_b)] | |
else: | |
if pos >= end: | |
break | |
continue | |
if len(src) >= self.seq_length: | |
pad_num = 0 | |
src = (src[:self.seq_length], pad_num) | |
if len(seg_pos) == 1: | |
seg_pos = [self.seq_length] | |
else: | |
if len(src_a) >= self.seq_length: | |
seg_pos = [self.seq_length] | |
else: | |
seg_pos = [len(src_a)] + [self.seq_length - len(src_a)] | |
else: | |
pad_num = self.seq_length - len(src) | |
src = (src, pad_num) | |
if not self.dynamic_masking: | |
src_single, pad_num = src | |
src_single, tgt_mlm = mask_seq(src_single, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) | |
src = (src_single, pad_num) | |
instance = (src, tgt_mlm, tgt_cls, seg_pos) | |
else: | |
instance = (src, tgt_cls, seg_pos) | |
pickle.dump(instance, dataset_writer) | |
if pos >= end: | |
break | |
dataset_writer.close() | |
class FileWithTextDataset(Dataset): | |
def worker(self, proc_id, start, end): | |
print("Worker %d is building dataset ... " % proc_id) | |
set_seed(self.seed) | |
dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb") | |
pos = 0 | |
with open(self.corpus_path, mode="r", encoding="utf-8") as f: | |
while pos < start: | |
f.readline() | |
pos += 1 | |
while True: | |
line = f.readline() | |
pos += 1 | |
line = line.strip().split('\t') | |
text = line[0] | |
path = line[1] | |
if pos == 1 and text == "text": | |
continue | |
src = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text)) | |
src = src[:self.seq_length - 2] | |
src = [self.vocab.get(CLS_TOKEN)] + src + [self.vocab.get(SEP_TOKEN)] | |
seg_pos = [len(src)] | |
pad_num = self.seq_length - len(src) | |
src = (src, pad_num) | |
pickle.dump((src, seg_pos, path), dataset_writer) | |
if pos >= end: | |
break | |
dataset_writer.close() | |
class FileWithLabelDataset(Dataset): | |
def worker(self, proc_id, start, end): | |
print("Worker %d is building dataset ... " % proc_id) | |
set_seed(self.seed) | |
dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb") | |
pos = 0 | |
with open(self.corpus_path, mode="r", encoding="utf-8") as f: | |
while pos < start: | |
f.readline() | |
pos += 1 | |
while True: | |
line = f.readline() | |
pos += 1 | |
line = line.strip().split('\t') | |
label = int(line[0]) | |
path = line[1] | |
pickle.dump((label, path), dataset_writer) | |
if pos >= end: | |
break | |
dataset_writer.close() | |
class FileDataset(Dataset): | |
def worker(self, proc_id, start, end): | |
print("Worker %d is building dataset ... " % proc_id) | |
set_seed(self.seed) | |
dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb") | |
pos = 0 | |
with open(self.corpus_path, mode="r", encoding="utf-8") as f: | |
while pos < start: | |
f.readline() | |
pos += 1 | |
while True: | |
line = f.readline() | |
pos += 1 | |
path = line.strip() | |
pickle.dump((path), dataset_writer) | |
if pos >= end: | |
break | |
dataset_writer.close() | |
class VitDataset(FileWithLabelDataset): | |
pass | |
class ViltDataset(FileWithTextDataset): | |
pass | |
class ClipDataset(FileWithTextDataset): | |
pass | |
class S2tDataset(FileWithTextDataset): | |
pass | |
class BeitDataset(FileDataset): | |
pass | |
class DalleDataset(FileWithTextDataset): | |
pass | |