File size: 719 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import torch
def load_poser(model: str, device: torch.device):
print("Using the %s model." % model)
if model == "standard_float":
from tha3.poser.modes.standard_float import create_poser
return create_poser(device)
elif model == "standard_half":
from tha3.poser.modes.standard_half import create_poser
return create_poser(device)
elif model == "separable_float":
from tha3.poser.modes.separable_float import create_poser
return create_poser(device)
elif model == "separable_half":
from tha3.poser.modes.separable_half import create_poser
return create_poser(device)
else:
raise RuntimeError("Invalid model: '%s'" % model) |