Spaces:
Running
Running
import os | |
import orjson | |
import torch | |
import numpy as np | |
from model import TMR_textencoder | |
EMBS = "data/unit_motion_embs" | |
def load_json(path): | |
with open(path, "rb") as ff: | |
return orjson.loads(ff.read()) | |
def load_keyids(split): | |
path = os.path.join(EMBS, f"{split}.keyids") | |
with open(path) as ff: | |
keyids = np.array([x.strip() for x in ff.readlines()]) | |
return keyids | |
def load_keyids_splits(splits): | |
return { | |
split: load_keyids(split) | |
for split in splits | |
} | |
def load_unit_motion_embs(split, device): | |
path = os.path.join(EMBS, f"{split}_motion_embs_unit.npy") | |
tensor = torch.from_numpy(np.load(path)).to(device) | |
return tensor | |
def load_unit_motion_embs_splits(splits, device): | |
return { | |
split: load_unit_motion_embs(split, device) | |
for split in splits | |
} | |
def load_model(device): | |
text_params = { | |
'latent_dim': 256, 'ff_size': 1024, 'num_layers': 6, 'num_heads': 4, | |
'activation': 'gelu', 'modelpath': 'distilbert-base-uncased' | |
} | |
"unit_motion_embs" | |
model = TMR_textencoder(**text_params) | |
state_dict = torch.load("data/textencoder.pt", map_location=device) | |
# load values for the transformer only | |
model.load_state_dict(state_dict, strict=False) | |
model = model.eval() | |
return model | |