Spaces:
Sleeping
Sleeping
from ..diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation | |
from ..diffusionmodules.openaimodel import Timestep | |
import torch | |
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation): | |
def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs): | |
super().__init__(*args, **kwargs) | |
if clip_stats_path is None: | |
clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim) | |
else: | |
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu") | |
self.register_buffer("data_mean", clip_mean[None, :], persistent=False) | |
self.register_buffer("data_std", clip_std[None, :], persistent=False) | |
self.time_embed = Timestep(timestep_dim) | |
def scale(self, x): | |
# re-normalize to centered mean and unit variance | |
x = (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device) | |
return x | |
def unscale(self, x): | |
# back to original data stats | |
x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device) | |
return x | |
def forward(self, x, noise_level=None, seed=None): | |
if noise_level is None: | |
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() | |
else: | |
assert isinstance(noise_level, torch.Tensor) | |
x = self.scale(x) | |
z = self.q_sample(x, noise_level, seed=seed) | |
z = self.unscale(z) | |
noise_level = self.time_embed(noise_level) | |
return z, noise_level | |