Spaces:
Runtime error
Runtime error
File size: 7,569 Bytes
913d3e3 f265950 913d3e3 e8f6bdd 913d3e3 e8f6bdd 913d3e3 35c104c f265950 913d3e3 f265950 913d3e3 f265950 913d3e3 f265950 913d3e3 f265950 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import torch.nn as nn
import torchvision
from scipy.spatial import Delaunay
import torch
import numpy as np
from torch.nn import functional as nnf
from easydict import EasyDict
from shapely.geometry import Point
from shapely.geometry.polygon import Polygon
from torchvision import transforms
from PIL import Image
class SDSLoss(nn.Module):
def __init__(self, cfg, device, model):
super(SDSLoss, self).__init__()
self.cfg = cfg
self.device = device
self.pipe = model
self.alphas = self.pipe.scheduler.alphas_cumprod.to(self.device)
self.sigmas = (1 - self.pipe.scheduler.alphas_cumprod).to(self.device)
self.text_embeddings = None
self.embed_text()
def embed_text(self):
# tokenizer and embed text
text_input = self.pipe.tokenizer(self.cfg.caption, padding="max_length",
max_length=self.pipe.tokenizer.model_max_length,
truncation=True, return_tensors="pt")
uncond_input = self.pipe.tokenizer([""], padding="max_length",
max_length=text_input.input_ids.shape[-1],
return_tensors="pt")
with torch.no_grad():
text_embeddings = self.pipe.text_encoder(text_input.input_ids.to(self.device))[0]
uncond_embeddings = self.pipe.text_encoder(uncond_input.input_ids.to(self.device))[0]
self.text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
self.text_embeddings = self.text_embeddings.repeat_interleave(self.cfg.batch_size, 0)
def forward(self, x_aug):
sds_loss = 0
# encode rendered image
x = x_aug * 2. - 1.
with torch.cuda.amp.autocast():
init_latent_z = (self.pipe.vae.encode(x).latent_dist.sample())
latent_z = 0.18215 * init_latent_z # scaling_factor * init_latents
with torch.inference_mode():
# sample timesteps
timestep = torch.randint(
low=50,
high=min(950, self.cfg.diffusion.timesteps) - 1, # avoid highest timestep | diffusion.timesteps=1000
size=(latent_z.shape[0],),
device=self.device, dtype=torch.long)
# add noise
eps = torch.randn_like(latent_z)
# zt = alpha_t * latent_z + sigma_t * eps
noised_latent_zt = self.pipe.scheduler.add_noise(latent_z, eps, timestep)
# denoise
z_in = torch.cat([noised_latent_zt] * 2) # expand latents for classifier free guidance
timestep_in = torch.cat([timestep] * 2)
with torch.autocast(device_type="cuda", dtype=torch.float16):
eps_t_uncond, eps_t = self.pipe.unet(z_in, timestep, encoder_hidden_states=self.text_embeddings).sample.float().chunk(2)
eps_t = eps_t_uncond + self.cfg.diffusion.guidance_scale * (eps_t - eps_t_uncond)
# w = alphas[timestep]^0.5 * (1 - alphas[timestep]) = alphas[timestep]^0.5 * sigmas[timestep]
grad_z = self.alphas[timestep]**0.5 * self.sigmas[timestep] * (eps_t - eps)
assert torch.isfinite(grad_z).all()
grad_z = torch.nan_to_num(grad_z.detach().float(), 0.0, 0.0, 0.0)
sds_loss = grad_z.clone() * latent_z
del grad_z
sds_loss = sds_loss.sum(1).mean()
return sds_loss
class ToneLoss(nn.Module):
def __init__(self, cfg):
super(ToneLoss, self).__init__()
self.dist_loss_weight = cfg.loss.tone.dist_loss_weight
self.im_init = None
self.cfg = cfg
self.mse_loss = nn.MSELoss()
self.blurrer = torchvision.transforms.GaussianBlur(kernel_size=(cfg.loss.tone.pixel_dist_kernel_blur,
cfg.loss.tone.pixel_dist_kernel_blur), sigma=(cfg.loss.tone.pixel_dist_sigma))
def set_image_init(self, im_init):
self.im_init = im_init.permute(2, 0, 1).unsqueeze(0)
self.init_blurred = self.blurrer(self.im_init)
def get_scheduler(self, step=None):
if step is not None:
return self.dist_loss_weight * np.exp(-(1/5)*((step-300)/(20)) ** 2)
else:
return self.dist_loss_weight
def forward(self, cur_raster, step=None):
blurred_cur = self.blurrer(cur_raster)
return self.mse_loss(self.init_blurred.detach(), blurred_cur) * self.get_scheduler(step)
class ConformalLoss:
def __init__(self, parameters: EasyDict, device: torch.device, target_letter: str, shape_groups):
self.parameters = parameters
self.target_letter = target_letter
self.shape_groups = shape_groups
self.faces = self.init_faces(device)
self.faces_roll_a = [torch.roll(self.faces[i], 1, 1) for i in range(len(self.faces))]
with torch.no_grad():
self.angles = []
self.reset()
def get_angles(self, points: torch.Tensor) -> torch.Tensor:
angles_ = []
for i in range(len(self.faces)):
triangles = points[self.faces[i]]
triangles_roll_a = points[self.faces_roll_a[i]]
edges = triangles_roll_a - triangles
length = edges.norm(dim=-1)
edges = edges / (length + 1e-1)[:, :, None]
edges_roll = torch.roll(edges, 1, 1)
cosine = torch.einsum('ned,ned->ne', edges, edges_roll)
angles = torch.arccos(cosine)
angles_.append(angles)
return angles_
def get_letter_inds(self, letter_to_insert):
for group, l in zip(self.shape_groups, self.target_letter):
if l == letter_to_insert:
letter_inds = group.shape_ids
return letter_inds[0], letter_inds[-1], len(letter_inds)
def reset(self):
points = torch.cat([point.clone().detach() for point in self.parameters.point])
self.angles = self.get_angles(points)
def init_faces(self, device: torch.device) -> torch.tensor:
faces_ = []
num_shapes = 0
for j, c in enumerate(self.target_letter):
points_np = [self.parameters.point[i].clone().detach().cpu().numpy() for i in range(len(self.parameters.point))]
start_ind, end_ind, shapes_per_letter = self.get_letter_inds(c)
print(c, start_ind, end_ind, shapes_per_letter)
holes = []
if shapes_per_letter > 1:
holes = points_np[start_ind+1:end_ind]
poly = Polygon(points_np[start_ind], holes=holes)
poly = poly.buffer(0)
points_np = np.concatenate(points_np)
faces = Delaunay(points_np).simplices
is_intersect = np.array([poly.contains(Point(points_np[face].mean(0))) for face in faces], dtype=np.bool_)
faces_.append(torch.from_numpy(faces[is_intersect]).to(device, dtype=torch.int64))
num_shapes += shapes_per_letter
if num_shapes >= len(self.target_letter):
break
return faces_
def __call__(self) -> torch.Tensor:
loss_angles = 0
points = torch.cat(self.parameters.point)
angles = self.get_angles(points)
for i in range(len(self.faces)):
loss_angles += (nnf.mse_loss(angles[i], self.angles[i]))
return loss_angles |