import os import logging from dataclasses import dataclass, field from typing import Dict, List, Optional import torch import nlp from transformers import T5Tokenizer, BartTokenizer, HfArgumentParser logger = logging.getLogger(__name__) @dataclass class DataTrainingArguments: """ Arguments pertaining to what data we are going to input our model for training and eval. """ task: str = field( metadata={"help": "Which task 'qa', 'qg', 'e2e_qg', 'ans_ext', 'multi'. 'multi' means 'qa', 'qg', 'ans_ext' tasks"}, ) model_type: str = field(metadata={"help": "One of 't5', 'bart'"}) dataset_path: Optional[str] = field( default="data/squad_multitask", metadata={"help": "Path for dataset directory"}, ) train_file_name: Optional[str] = field( default=None, metadata={"help": "name for cached train dataset"}, ) valid_file_name: Optional[str] = field( default=None, metadata={"help": "name for cached valid dataset"}, ) valid_for_qg_only: bool = field( default=False, metadata={"help": "For multitask dataset valid split should contain only qg task or all tasks."} ) qg_format: Optional[str] = field( default='highlight_qg_format', metadata={"help": "How to format inputs for que generation, 'highlight_qg_format' or 'prepend_qg_format'"}, ) max_source_length: Optional[int] = field( default=512, metadata={"help": "Max input length for the source text"}, ) max_target_length: Optional[int] = field( default=32, metadata={"help": "Max input length for the target text"}, ) class DataProcessor: def __init__(self, tokenizer, model_type="t5", max_source_length=512, max_target_length=32): self.tokenizer = tokenizer self.max_source_length = max_source_length self.max_target_length = max_target_length self.model_type = model_type self.hl_token = "" if model_type == "t5": self.sep_token = "" elif model_type == "bart": self.sep_token = "" else: self.sep_token = "[SEP]" def process(self, dataset): if self.model_type == "t5": dataset = dataset.map(self._add_eos_examples) dataset = dataset.map(self._add_special_tokens) dataset = dataset.map(self._convert_to_features, batched=True) return dataset def _add_eos_examples(self, example): example['source_text'] = example['source_text'] + " " example['target_text'] = example['target_text'] + " " return example def _add_special_tokens(self, example): example['source_text'] = example['source_text'].replace("{hl_token}", self.hl_token) example['target_text'] = example['target_text'].replace("{sep_token}", self.sep_token) return example # tokenize the examples def _convert_to_features(self, example_batch): source_encoding = self.tokenizer.batch_encode_plus( example_batch['source_text'], max_length=self.max_source_length, padding='max_length', pad_to_max_length=True, truncation=True, ) target_encoding = self.tokenizer.batch_encode_plus( example_batch['target_text'], max_length=self.max_target_length, padding='max_length', pad_to_max_length=True, truncation=True, ) encodings = { 'source_ids': source_encoding['input_ids'], 'target_ids': target_encoding['input_ids'], 'attention_mask': source_encoding['attention_mask'], } return encodings def filter_qa(example): return example['task'] == 'qa' def filter_qg(example): return example['task'] == 'qg' def filter_e2e_qg(example): return example['task'] == 'e2e_qg' def filter_ans_ext(example): return example['task'] == 'ans_ext' def filter_multi(example): return example['task'] != 'e2e_qg' TASK_TO_FILTER_FN = { 'qa': filter_qa, 'qg': filter_qg, 'e2e_qg': filter_e2e_qg, 'ans_ext': filter_ans_ext, 'multi': filter_multi } def main(): parser = HfArgumentParser((DataTrainingArguments,)) data_args = parser.parse_args_into_dataclasses()[0] logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO ) if data_args.model_type == 't5': tokenizer = T5Tokenizer.from_pretrained("t5-base") else: tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") tokenizer.add_tokens(['', '']) train_dataset = nlp.load_dataset(data_args.dataset_path, name=data_args.qg_format, split=nlp.Split.TRAIN) valid_dataset = nlp.load_dataset(data_args.dataset_path, name=data_args.qg_format, split=nlp.Split.VALIDATION) processor = DataProcessor( tokenizer, model_type=data_args.model_type, max_source_length=data_args.max_source_length, max_target_length=data_args.max_target_length ) train_dataset = train_dataset.filter(TASK_TO_FILTER_FN[data_args.task]) if data_args.task == 'multi' and data_args.valid_for_qg_only: logger.info("processing valid data only for qg task") valid_dataset = valid_dataset.filter(filter_qg) else: valid_dataset = valid_dataset.filter(TASK_TO_FILTER_FN[data_args.task]) train_dataset = processor.process(train_dataset) valid_dataset = processor.process(valid_dataset) columns = ["source_ids", "target_ids", "attention_mask"] train_dataset.set_format(type='torch', columns=columns) valid_dataset.set_format(type='torch', columns=columns) if data_args.train_file_name is None: train_file_name = f"train_data_{data_args.task}_{data_args.qg_format}_{data_args.model_type}.pt" train_path = os.path.join("data", train_file_name) valid_file_name = f"valid_data_{data_args.task}_{data_args.qg_format}_{data_args.model_type}.pt" valid_path = os.path.join("data", valid_file_name) else: train_path = os.path.join("data", data_args.train_file_name) valid_path = os.path.join("data", data_args.valid_file_name) torch.save(train_dataset, train_path) logger.info(f"saved train dataset at {train_path}") torch.save(valid_dataset, valid_path) logger.info(f"saved validation dataset at {valid_path}") tokenizer_path = f"{data_args.model_type}_qg_tokenizer" if not os.path.exists(tokenizer_path): os.mkdir(tokenizer_path) tokenizer.save_pretrained(tokenizer_path) logger.info(f"saved tokenizer at {tokenizer_path}") if __name__ == "__main__": main()