svjack's picture
Upload SPIGA with huggingface_hub
9390e2c
raw
history blame
825 Bytes
from torchvision import transforms
import numpy as np
from PIL import Image
import cv2
from spiga.data.loaders.transforms import TargetCrop, ToOpencv, AddModel3D
def get_transformers(data_config):
transformer_seq = [
Opencv2Pil(),
TargetCrop(data_config.image_size, data_config.target_dist),
ToOpencv(),
NormalizeAndPermute()]
return transforms.Compose(transformer_seq)
class NormalizeAndPermute:
def __call__(self, sample):
image = np.array(sample['image'], dtype=float)
image = np.transpose(image, (2, 0, 1))
sample['image'] = image / 255
return sample
class Opencv2Pil:
def __call__(self, sample):
image = cv2.cvtColor(sample['image'], cv2.COLOR_BGR2RGB)
sample['image'] = Image.fromarray(image)
return sample