""" This script provides an example to wrap TencentPretrain for classification inference. """ import sys import os import torch import argparse import collections import torch.nn as nn tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(tencentpretrain_dir) from tencentpretrain.utils.constants import * from tencentpretrain.utils import * from tencentpretrain.utils.config import load_hyperparam from tencentpretrain.utils.seed import set_seed from tencentpretrain.model_loader import load_model from tencentpretrain.opts import infer_opts, tokenizer_opts from finetune.run_classifier_siamese import SiameseClassifier def batch_loader(batch_size, src, seg): src_a, src_b = src seg_a, seg_b = seg instances_num = src_a.size()[0] for i in range(instances_num // batch_size): src_a_batch = src_a[i * batch_size : (i + 1) * batch_size, :] src_b_batch = src_b[i * batch_size : (i + 1) * batch_size, :] seg_a_batch = seg_a[i * batch_size : (i + 1) * batch_size, :] seg_b_batch = seg_b[i * batch_size : (i + 1) * batch_size, :] yield (src_a_batch, src_b_batch), (seg_a_batch, seg_b_batch) if instances_num > instances_num // batch_size * batch_size: src_a_batch = src_a[instances_num // batch_size * batch_size :, :] src_b_batch = src_b[instances_num // batch_size * batch_size :, :] seg_a_batch = seg_a[instances_num // batch_size * batch_size :, :] seg_b_batch = seg_b[instances_num // batch_size * batch_size :, :] yield (src_a_batch, src_b_batch), (seg_a_batch, seg_b_batch) def read_dataset(args, path): dataset, columns = [], {} with open(path, mode="r", encoding="utf-8") as f: for line_id, line in enumerate(f): if line_id == 0: line = line.rstrip("\r\n").split("\t") for i, column_name in enumerate(line): columns[column_name] = i continue line = line.rstrip("\r\n").split("\t") text_a, text_b = line[columns["text_a"]], line[columns["text_b"]] src_a = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(text_a) + [SEP_TOKEN]) src_b = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(text_b) + [SEP_TOKEN]) seg_a = [1] * len(src_a) seg_b = [1] * len(src_b) PAD_ID = args.tokenizer.convert_tokens_to_ids([PAD_TOKEN])[0] if len(src_a) >= args.seq_length: src_a = src_a[:args.seq_length] seg_a = seg_a[:args.seq_length] while len(src_a) < args.seq_length: src_a.append(PAD_ID) seg_a.append(0) if len(src_b) >= args.seq_length: src_b = src_b[:args.seq_length] seg_b = seg_b[:args.seq_length] while len(src_b) < args.seq_length: src_b.append(PAD_ID) seg_b.append(0) dataset.append(((src_a, src_b), (seg_a, seg_b))) return dataset def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) infer_opts(parser) parser.add_argument("--labels_num", type=int, required=True, help="Number of prediction labels.") tokenizer_opts(parser) parser.add_argument("--output_logits", action="store_true", help="Write logits to output file.") parser.add_argument("--output_prob", action="store_true", help="Write probabilities to output file.") args = parser.parse_args() # Load the hyperparameters from the config file. args = load_hyperparam(args) # Build tokenizer. args.tokenizer = str2tokenizer[args.tokenizer](args) # Build classification model and load parameters. args.soft_targets, args.soft_alpha = False, False model = SiameseClassifier(args) model = load_model(model, args.load_model_path) # For simplicity, we use DataParallel wrapper to use multiple GPUs. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) if torch.cuda.device_count() > 1: print("{} GPUs are available. Let's use them.".format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) dataset = read_dataset(args, args.test_path) src_a = torch.LongTensor([example[0][0] for example in dataset]) src_b = torch.LongTensor([example[0][1] for example in dataset]) seg_a = torch.LongTensor([example[1][0] for example in dataset]) seg_b = torch.LongTensor([example[1][1] for example in dataset]) batch_size = args.batch_size instances_num = src_a.size()[0] print("The number of prediction instances: ", instances_num) model.eval() with open(args.prediction_path, mode="w", encoding="utf-8") as f: f.write("label") if args.output_logits: f.write("\t" + "logits") if args.output_prob: f.write("\t" + "prob") f.write("\n") for i, (src_batch, seg_batch) in enumerate(batch_loader(batch_size, (src_a, src_b), (seg_a, seg_b))): src_a_batch, src_b_batch = src_batch seg_a_batch, seg_b_batch = seg_batch src_a_batch = src_a_batch.to(device) src_b_batch = src_b_batch.to(device) seg_a_batch = seg_a_batch.to(device) seg_b_batch = seg_b_batch.to(device) with torch.no_grad(): _, logits = model((src_a_batch, src_b_batch), None, (seg_a_batch, seg_b_batch)) pred = torch.argmax(logits, dim=1) pred = pred.cpu().numpy().tolist() prob = nn.Softmax(dim=1)(logits) logits = logits.cpu().numpy().tolist() prob = prob.cpu().numpy().tolist() for j in range(len(pred)): f.write(str(pred[j])) if args.output_logits: f.write("\t" + " ".join([str(v) for v in logits[j]])) if args.output_prob: f.write("\t" + " ".join([str(v) for v in prob[j]])) f.write("\n") if __name__ == "__main__": main()