File size: 241 Bytes
9d21d47
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
from .sst2 import SST2ProbInferenceForMC


task_mapper = {"sst2": SST2ProbInferenceForMC}


def load_task(name):
    if name not in task_mapper.keys():
        raise ValueError(f"Unrecognized dataset `{name}`")

    return task_mapper[name]