Leffa / leffa /transform.py
franciszzj's picture
init code
b213d84
raw
history blame
6.66 kB
import logging
from typing import Any, Dict
import numpy as np
import torch
from diffusers.image_processor import VaeImageProcessor
from PIL import Image
from torch import nn
logger: logging.Logger = logging.getLogger(__name__)
class LeffaTransform(nn.Module):
def __init__(
self,
height: int = 1024,
width: int = 768,
dataset: str = "virtual_tryon", # virtual_tryon or pose_transfer
):
super().__init__()
self.height = height
self.width = width
self.dataset = dataset
self.vae_processor = VaeImageProcessor(vae_scale_factor=8)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=8,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
)
def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
batch_size = len(batch["src_image"])
src_image_list = []
ref_image_list = []
mask_list = []
densepose_list = []
for i in range(batch_size):
# 1. get original data
src_image = batch["src_image"][i]
ref_image = batch["ref_image"][i]
mask = batch["mask"][i]
densepose = batch["densepose"][i]
# 3. process data
src_image = self.vae_processor.preprocess(
src_image, self.height, self.width)[0]
ref_image = self.vae_processor.preprocess(
ref_image, self.height, self.width)[0]
mask = self.mask_processor.preprocess(
mask, self.height, self.width)[0]
if self.dataset in ["pose_transfer"]:
densepose = densepose.resize(
(self.width, self.height), Image.NEAREST)
else:
densepose = self.vae_processor.preprocess(
densepose, self.height, self.width
)[0]
src_image = self.prepare_image(src_image)
ref_image = self.prepare_image(ref_image)
mask = self.prepare_mask(mask)
if self.dataset in ["pose_transfer"]:
densepose = self.prepare_densepose(densepose)
else:
densepose = self.prepare_image(densepose)
src_image_list.append(src_image)
ref_image_list.append(ref_image)
mask_list.append(mask)
densepose_list.append(densepose)
src_image = torch.cat(src_image_list, dim=0)
ref_image = torch.cat(ref_image_list, dim=0)
mask = torch.cat(mask_list, dim=0)
densepose = torch.cat(densepose_list, dim=0)
batch["src_image"] = src_image
batch["ref_image"] = ref_image
batch["mask"] = mask
batch["densepose"] = densepose
return batch
@staticmethod
def prepare_image(image):
if isinstance(image, torch.Tensor):
# Batch single image
if image.ndim == 3:
image = image.unsqueeze(0)
image = image.to(dtype=torch.float32)
else:
# preprocess image
if isinstance(image, (Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], Image.Image):
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(
dtype=torch.float32) / 127.5 - 1.0
return image
@staticmethod
def prepare_mask(mask):
if isinstance(mask, torch.Tensor):
if mask.ndim == 2:
# Batch and add channel dim for single mask
mask = mask.unsqueeze(0).unsqueeze(0)
elif mask.ndim == 3 and mask.shape[0] == 1:
# Single mask, the 0'th dimension is considered to be
# the existing batch size of 1
mask = mask.unsqueeze(0)
elif mask.ndim == 3 and mask.shape[0] != 1:
# Batch of mask, the 0'th dimension is considered to be
# the batching dimension
mask = mask.unsqueeze(1)
# Binarize mask
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
else:
# preprocess mask
if isinstance(mask, (Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], Image.Image):
mask = np.concatenate(
[np.array(m.convert("L"))[None, None, :] for m in mask],
axis=0,
)
mask = mask.astype(np.float32) / 255.0
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
return mask
@staticmethod
def prepare_densepose(densepose):
"""
For internal (meta) densepose, the first and second channel should be normalized to 0~1 by 255.0,
and the third channel should be normalized to 0~1 by 24.0
"""
if isinstance(densepose, torch.Tensor):
# Batch single densepose
if densepose.ndim == 3:
densepose = densepose.unsqueeze(0)
densepose = densepose.to(dtype=torch.float32)
else:
# preprocess densepose
if isinstance(densepose, (Image.Image, np.ndarray)):
densepose = [densepose]
if isinstance(densepose, list) and isinstance(
densepose[0], Image.Image
):
densepose = [np.array(i.convert("RGB"))[None, :]
for i in densepose]
densepose = np.concatenate(densepose, axis=0)
elif isinstance(densepose, list) and isinstance(densepose[0], np.ndarray):
densepose = np.concatenate(
[i[None, :] for i in densepose], axis=0)
densepose = densepose.transpose(0, 3, 1, 2)
densepose = densepose.astype(np.float32)
densepose[:, 0:2, :, :] /= 255.0
densepose[:, 2:3, :, :] /= 24.0
densepose = torch.from_numpy(densepose).to(
dtype=torch.float32) * 2.0 - 1.0
return densepose