|
import torch |
|
|
|
|
|
def load_poser(model: str, device: torch.device): |
|
print("Using the %s model." % model) |
|
if model == "standard_float": |
|
from tha3.poser.modes.standard_float import create_poser |
|
return create_poser(device) |
|
elif model == "standard_half": |
|
from tha3.poser.modes.standard_half import create_poser |
|
return create_poser(device) |
|
elif model == "separable_float": |
|
from tha3.poser.modes.separable_float import create_poser |
|
return create_poser(device) |
|
elif model == "separable_half": |
|
from tha3.poser.modes.separable_half import create_poser |
|
return create_poser(device) |
|
else: |
|
raise RuntimeError("Invalid model: '%s'" % model) |