Spaces:
Running
Running
File size: 4,819 Bytes
4d4dd90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import string
import h5py
import torch
from ..datasets.base_dataset import collate
from ..settings import DATA_PATH
from ..utils.tensor import batch_to_device
from .base_model import BaseModel
from .utils.misc import pad_to_length
def pad_local_features(pred: dict, seq_l: int):
pred["keypoints"] = pad_to_length(
pred["keypoints"],
seq_l,
-2,
mode="random_c",
)
if "keypoint_scores" in pred.keys():
pred["keypoint_scores"] = pad_to_length(
pred["keypoint_scores"], seq_l, -1, mode="zeros"
)
if "descriptors" in pred.keys():
pred["descriptors"] = pad_to_length(
pred["descriptors"], seq_l, -2, mode="random"
)
if "scales" in pred.keys():
pred["scales"] = pad_to_length(pred["scales"], seq_l, -1, mode="zeros")
if "oris" in pred.keys():
pred["oris"] = pad_to_length(pred["oris"], seq_l, -1, mode="zeros")
if "depth_keypoints" in pred.keys():
pred["depth_keypoints"] = pad_to_length(
pred["depth_keypoints"], seq_l, -1, mode="zeros"
)
if "valid_depth_keypoints" in pred.keys():
pred["valid_depth_keypoints"] = pad_to_length(
pred["valid_depth_keypoints"], seq_l, -1, mode="zeros"
)
return pred
def pad_line_features(pred, seq_l: int = None):
raise NotImplementedError
def recursive_load(grp, pkeys):
return {
k: torch.from_numpy(grp[k].__array__())
if isinstance(grp[k], h5py.Dataset)
else recursive_load(grp[k], list(grp.keys()))
for k in pkeys
}
class CacheLoader(BaseModel):
default_conf = {
"path": "???", # can be a format string like exports/{scene}/
"data_keys": None, # load all keys
"device": None, # load to same device as data
"trainable": False,
"add_data_path": True,
"collate": True,
"scale": ["keypoints", "lines", "orig_lines"],
"padding_fn": None,
"padding_length": None, # required for batching!
"numeric_type": "float32", # [None, "float16", "float32", "float64"]
}
required_data_keys = ["name"] # we need an identifier
def _init(self, conf):
self.hfiles = {}
self.padding_fn = conf.padding_fn
if self.padding_fn is not None:
self.padding_fn = eval(self.padding_fn)
self.numeric_dtype = {
None: None,
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
}[conf.numeric_type]
def _forward(self, data):
preds = []
device = self.conf.device
if not device:
devices = set(
[v.device for v in data.values() if isinstance(v, torch.Tensor)]
)
if len(devices) == 0:
device = "cpu"
else:
assert len(devices) == 1
device = devices.pop()
var_names = [x[1] for x in string.Formatter().parse(self.conf.path) if x[1]]
for i, name in enumerate(data["name"]):
fpath = self.conf.path.format(**{k: data[k][i] for k in var_names})
if self.conf.add_data_path:
fpath = DATA_PATH / fpath
hfile = h5py.File(str(fpath), "r")
grp = hfile[name]
pkeys = (
self.conf.data_keys if self.conf.data_keys is not None else grp.keys()
)
pred = recursive_load(grp, pkeys)
if self.numeric_dtype is not None:
pred = {
k: v
if not isinstance(v, torch.Tensor) or not torch.is_floating_point(v)
else v.to(dtype=self.numeric_dtype)
for k, v in pred.items()
}
pred = batch_to_device(pred, device)
for k, v in pred.items():
for pattern in self.conf.scale:
if k.startswith(pattern):
view_idx = k.replace(pattern, "")
scales = (
data["scales"]
if len(view_idx) == 0
else data[f"view{view_idx}"]["scales"]
)
pred[k] = pred[k] * scales[i]
# use this function to fix number of keypoints etc.
if self.padding_fn is not None:
pred = self.padding_fn(pred, self.conf.padding_length)
preds.append(pred)
hfile.close()
if self.conf.collate:
return batch_to_device(collate(preds), device)
else:
assert len(preds) == 1
return batch_to_device(preds[0], device)
def loss(self, pred, data):
raise NotImplementedError
|