Spaces:
Runtime error
Runtime error
import argparse | |
import six | |
from packaging import version | |
from tencentpretrain.utils import * | |
from tencentpretrain.opts import * | |
assert version.parse(six.__version__) >= version.parse("1.12.0") | |
def main(): | |
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
# Path options. | |
parser.add_argument("--corpus_path", type=str, required=True, | |
help="Path of the corpus for pretraining.") | |
parser.add_argument("--dataset_path", type=str, default="dataset.pt", | |
help="Path of the preprocessed dataset.") | |
# Preprocess options. | |
tokenizer_opts(parser) | |
tgt_tokenizer_opts(parser) | |
parser.add_argument("--processes_num", type=int, default=1, | |
help="Split the whole dataset into `processes_num` parts, " | |
"and process them with `processes_num` processes.") | |
parser.add_argument("--data_processor", | |
choices=["bert", "lm", "mlm", "bilm", "albert", "mt", "t5", "cls", "prefixlm", | |
"gsg", "bart", "cls_mlm", "vit", "vilt", "clip", "s2t", "beit", "dalle"], default="bert", | |
help="The data processor of the pretraining model.") | |
parser.add_argument("--docs_buffer_size", type=int, default=100000, | |
help="The buffer size of documents in memory, specific to targets that require negative sampling.") | |
parser.add_argument("--seq_length", type=int, default=128, help="Sequence length of instances.") | |
parser.add_argument("--tgt_seq_length", type=int, default=128, help="Target sequence length of instances.") | |
parser.add_argument("--dup_factor", type=int, default=5, | |
help="Duplicate instances multiple times.") | |
parser.add_argument("--short_seq_prob", type=float, default=0.1, | |
help="Probability of truncating sequence." | |
"The larger value, the higher probability of using short (truncated) sequence.") | |
parser.add_argument("--full_sentences", action="store_true", help="Full sentences.") | |
parser.add_argument("--seed", type=int, default=7, help="Random seed.") | |
# Masking options. | |
parser.add_argument("--dynamic_masking", action="store_true", help="Dynamic masking.") | |
parser.add_argument("--whole_word_masking", action="store_true", help="Whole word masking.") | |
parser.add_argument("--span_masking", action="store_true", help="Span masking.") | |
parser.add_argument("--span_geo_prob", type=float, default=0.2, | |
help="Hyperparameter of geometric distribution for span masking.") | |
parser.add_argument("--span_max_length", type=int, default=10, | |
help="Max length for span masking.") | |
# Sentence selection strategy options. | |
parser.add_argument("--sentence_selection_strategy", choices=["lead", "random"], default="lead", | |
help="Sentence selection strategy for gap-sentences generation task.") | |
args = parser.parse_args() | |
# Dynamic masking. | |
if args.dynamic_masking: | |
args.dup_factor = 1 | |
# Build tokenizer. | |
tokenizer = str2tokenizer[args.tokenizer](args) | |
if args.data_processor == "mt": | |
args.tgt_tokenizer = str2tokenizer[args.tgt_tokenizer](args, False) | |
print(tokenizer.convert_ids_to_tokens([ 1046, | |
30536, | |
25926, | |
1047, | |
30545, | |
28358, | |
1048, | |
26628, | |
21005, | |
1049, | |
21679, | |
25162, | |
1050, | |
24841, | |
24232, | |
1051, | |
1014, | |
1014, | |
1052, | |
21679, | |
16710, | |
1053, | |
1014, | |
1014, | |
1054, | |
20405, | |
5018, | |
1033, | |
102])) | |
# Build and save dataset. | |
dataset = str2dataset[args.data_processor](args, tokenizer.vocab, tokenizer) | |
dataset.build_and_save(args.processes_num) | |
if __name__ == "__main__": | |
main() | |