emotion_recognition / utils /compute_args.py
nouamanetazi's picture
nouamanetazi HF staff
linting
c731c61
import torch
def compute_args(args):
# DataLoader
if not hasattr(args, "dataset"): # fix for previous version
args.dataset = "MOSEI"
if args.dataset == "MOSEI":
args.dataloader = "Mosei_Dataset"
if args.dataset == "MELD":
args.dataloader = "Meld_Dataset"
# Loss function to use
if args.dataset == "MOSEI" and args.task == "sentiment":
args.loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
if args.dataset == "MOSEI" and args.task == "emotion":
args.loss_fn = torch.nn.BCEWithLogitsLoss(reduction="sum")
if args.dataset == "MELD":
args.loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
# Answer size
if args.dataset == "MOSEI" and args.task == "sentiment":
args.ans_size = 7
if args.dataset == "MOSEI" and args.task == "sentiment" and args.task_binary:
args.ans_size = 2
if args.dataset == "MOSEI" and args.task == "emotion":
args.ans_size = 6
if args.dataset == "MELD" and args.task == "emotion":
args.ans_size = 7
if args.dataset == "MELD" and args.task == "sentiment":
args.ans_size = 3
if args.dataset == "MOSEI":
args.pred_func = "amax"
if args.dataset == "MOSEI" and args.task == "emotion":
args.pred_func = "multi_label"
if args.dataset == "MELD":
args.pred_func = "amax"
return args