""" This script provides an example to wrap TencentPretrain for ChID (a multiple choice dataset) inference. """ import sys import os import torch import json import argparse import collections import numpy as np tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(tencentpretrain_dir) from tencentpretrain.utils.constants import * from tencentpretrain.utils.tokenizers import * from tencentpretrain.utils.config import load_hyperparam from tencentpretrain.model_loader import load_model from tencentpretrain.opts import infer_opts from finetune.run_classifier import batch_loader from finetune.run_c3 import MultipleChoice from finetune.run_chid import read_dataset def postprocess_chid_predictions(results): index2tag = {index: tag for index, (tag, logits) in enumerate(results)} logits_matrix = [logits for _, logits in results] logits_matrix = np.transpose(np.array(logits_matrix)) logits_matrix_list = [] for i, row in enumerate(logits_matrix): for j, value in enumerate(row): logits_matrix_list.append((i, j, value)) else: choices = set(range(i + 1)) blanks = set(range(j + 1)) logits_matrix_list = sorted(logits_matrix_list, key=lambda x: x[2], reverse=True) results = [] for i, j, v in logits_matrix_list: if (j in blanks) and (i in choices): results.append((i, j)) blanks.remove(j) choices.remove(i) results = sorted(results, key=lambda x: x[1], reverse=False) results = [[index2tag[j], i] for i, j in results] return results def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) infer_opts(parser) parser.add_argument("--vocab_path", default=None, type=str, help="Path of the vocabulary file.") parser.add_argument("--spm_model_path", default=None, type=str, help="Path of the sentence piece model.") parser.add_argument("--max_choices_num", default=10, type=int, help="The maximum number of cadicate answer, shorter than this will be padded.") args = parser.parse_args() # Load the hyperparameters from the config file. args = load_hyperparam(args) # Build tokenizer. args.tokenizer = CharTokenizer(args) # Build classification model and load parameters. model = MultipleChoice(args) model = load_model(model, args.load_model_path) # For simplicity, we use DataParallel wrapper to use multiple GPUs. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) if torch.cuda.device_count() > 1: print("{} GPUs are available. Let's use them.".format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) dataset = read_dataset(args, args.test_path, None) model.eval() batch_size = args.batch_size results_final = [] dataset_by_group = {} print("The number of prediction instances: ", len(dataset)) for example in dataset: if example[-1] not in dataset_by_group: dataset_by_group[example[-1]] = [example] else: dataset_by_group[example[-1]].append(example) for group_index, examples in dataset_by_group.items(): src = torch.LongTensor([example[0] for example in examples]) tgt = torch.LongTensor([example[1] for example in examples]) seg = torch.LongTensor([example[2] for example in examples]) index = 0 results = [] for i, (src_batch, _, seg_batch, _) in enumerate(batch_loader(batch_size, src, tgt, seg)): src_batch = src_batch.to(device) seg_batch = seg_batch.to(device) with torch.no_grad(): _, logits = model(src_batch, None, seg_batch) pred = torch.argmax(logits, dim=1) pred = pred.cpu().numpy().tolist() for j in range(len(pred)): results.append((examples[index][-2], logits[index].cpu().numpy())) index += 1 results_final.extend(postprocess_chid_predictions(results)) with open(args.prediction_path, "w") as f: json.dump({tag: pred for tag, pred in results_final}, f, indent=2) if __name__ == "__main__": main()