File size: 1,738 Bytes
404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import collections.abc as collections
from pathlib import Path
import torch
GLUESTICK_ROOT = Path(__file__).parent.parent
def get_class(mod_name, base_path, BaseClass):
"""Get the class object which inherits from BaseClass and is defined in
the module named mod_name, child of base_path.
"""
import inspect
mod_path = "{}.{}".format(base_path, mod_name)
mod = __import__(mod_path, fromlist=[""])
classes = inspect.getmembers(mod, inspect.isclass)
# Filter classes defined in the module
classes = [c for c in classes if c[1].__module__ == mod_path]
# Filter classes inherited from BaseModel
classes = [c for c in classes if issubclass(c[1], BaseClass)]
assert len(classes) == 1, classes
return classes[0][1]
def get_model(name):
from .models.base_model import BaseModel
return get_class("models." + name, __name__, BaseModel)
def numpy_image_to_torch(image):
"""Normalize the image tensor and reorder the dimensions."""
if image.ndim == 3:
image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
elif image.ndim == 2:
image = image[None] # add channel axis
else:
raise ValueError(f"Not an image: {image.shape}")
return torch.from_numpy(image / 255.0).float()
def map_tensor(input_, func):
if isinstance(input_, (str, bytes)):
return input_
elif isinstance(input_, collections.Mapping):
return {k: map_tensor(sample, func) for k, sample in input_.items()}
elif isinstance(input_, collections.Sequence):
return [map_tensor(sample, func) for sample in input_]
else:
return func(input_)
def batch_to_np(batch):
return map_tensor(batch, lambda t: t.detach().cpu().numpy()[0])
|