VISOR-GPT / train /inference /run_regression_infer.py
szukevin's picture
upload
7900c16
raw
history blame
2.22 kB
"""
This script provides an example to wrap TencentPretrain for regression inference.
"""
import sys
import os
import torch
import argparse
import collections
import torch.nn as nn
tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(tencentpretrain_dir)
from finetune.run_regression import Regression
from inference.run_classifier_infer import *
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
infer_opts(parser)
tokenizer_opts(parser)
args = parser.parse_args()
# Load the hyperparameters from the config file.
args = load_hyperparam(args)
# Build tokenizer.
args.tokenizer = str2tokenizer[args.tokenizer](args)
# Build classification model and load parameters.
args.soft_targets, args.soft_alpha = False, False
model = Regression(args)
model = load_model(model, args.load_model_path)
# For simplicity, we use DataParallel wrapper to use multiple GPUs.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
if torch.cuda.device_count() > 1:
print("{} GPUs are available. Let's use them.".format(torch.cuda.device_count()))
model = torch.nn.DataParallel(model)
dataset = read_dataset(args, args.test_path)
src = torch.LongTensor([sample[0] for sample in dataset])
seg = torch.LongTensor([sample[1] for sample in dataset])
batch_size = args.batch_size
instances_num = src.size()[0]
print("The number of prediction instances: ", instances_num)
model.eval()
with open(args.prediction_path, mode="w", encoding="utf-8") as f:
f.write("label")
f.write("\n")
for i, (src_batch, seg_batch) in enumerate(batch_loader(batch_size, src, seg)):
src_batch = src_batch.to(device)
seg_batch = seg_batch.to(device)
with torch.no_grad():
_, logits = model(src_batch, None, seg_batch)
logits = logits.cpu().numpy().tolist()
for j in range(len(logits)):
f.write(str(logits[j][0]))
f.write("\n")
if __name__ == "__main__":
main()