VISOR-GPT / train /inference /run_classifier_multi_label_infer.py
szukevin's picture
upload
7900c16
raw
history blame
3.56 kB
"""
This script provides an example to wrap TencentPretrain for multi-label classification inference.
"""
import sys
import os
import torch
import argparse
import collections
import torch.nn as nn
tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(tencentpretrain_dir)
from tencentpretrain.utils.constants import *
from tencentpretrain.utils import *
from tencentpretrain.utils.config import load_hyperparam
from tencentpretrain.utils.seed import set_seed
from tencentpretrain.model_loader import load_model
from tencentpretrain.opts import infer_opts, tokenizer_opts
from finetune.run_classifier_multi_label import MultilabelClassifier
from inference.run_classifier_infer import read_dataset
from inference.run_classifier_infer import batch_loader
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
infer_opts(parser)
parser.add_argument("--labels_num", type=int, required=True,
help="Number of prediction labels.")
tokenizer_opts(parser)
parser.add_argument("--output_logits", action="store_true", help="Write logits to output file.")
parser.add_argument("--output_prob", action="store_true", help="Write probabilities to output file.")
args = parser.parse_args()
# Load the hyperparameters from the config file.
args = load_hyperparam(args)
# Build tokenizer.
args.tokenizer = str2tokenizer[args.tokenizer](args)
# Build classification model and load parameters.
args.soft_targets, args.soft_alpha = False, False
model = MultilabelClassifier(args)
model = load_model(model, args.load_model_path)
# For simplicity, we use DataParallel wrapper to use multiple GPUs.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
if torch.cuda.device_count() > 1:
print("{} GPUs are available. Let's use them.".format(torch.cuda.device_count()))
model = torch.nn.DataParallel(model)
dataset = read_dataset(args, args.test_path)
src = torch.LongTensor([sample[0] for sample in dataset])
seg = torch.LongTensor([sample[1] for sample in dataset])
batch_size = args.batch_size
instances_num = src.size()[0]
print("The number of prediction instances: ", instances_num)
model.eval()
with open(args.prediction_path, mode="w", encoding="utf-8") as f:
f.write("label")
if args.output_logits:
f.write("\t" + "logits")
if args.output_prob:
f.write("\t" + "prob")
f.write("\n")
for i, (src_batch, seg_batch) in enumerate(batch_loader(batch_size, src, seg)):
src_batch = src_batch.to(device)
seg_batch = seg_batch.to(device)
with torch.no_grad():
_, logits = model(src_batch, None, seg_batch)
prob = nn.Sigmoid()(logits)
prob = prob.cpu().numpy().tolist()
logits = logits.cpu().numpy().tolist()
for i, p in enumerate(prob):
label = list()
for j in range(len(p)):
if p[j] > 0.5:
label.append(str(j))
f.write(",".join(label))
if args.output_logits:
f.write("\t" + " ".join([str(v) for v in logits[i]]))
if args.output_prob:
f.write("\t" + " ".join([str(v) for v in p]))
f.write("\n")
if __name__ == "__main__":
main()