Spaces:
Paused
Paused
import os | |
from typing import Union, List, Optional | |
import math | |
import numpy as np | |
import torch | |
from PIL import Image | |
from einops import rearrange | |
from imwatermark import WatermarkEncoder | |
from omegaconf import ListConfig | |
from torch import autocast | |
from sgm.util import append_dims | |
class WatermarkEmbedder: | |
def __init__(self, watermark): | |
self.watermark = watermark | |
self.num_bits = len(WATERMARK_BITS) | |
self.encoder = WatermarkEncoder() | |
self.encoder.set_watermark("bits", self.watermark) | |
def __call__(self, image: torch.Tensor): | |
""" | |
Adds a predefined watermark to the input image | |
Args: | |
image: ([N,] B, C, H, W) in range [0, 1] | |
Returns: | |
same as input but watermarked | |
""" | |
# watermarking libary expects input as cv2 BGR format | |
squeeze = len(image.shape) == 4 | |
if squeeze: | |
image = image[None, ...] | |
n = image.shape[0] | |
image_np = rearrange( | |
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c" | |
).numpy()[:, :, :, ::-1] | |
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] | |
for k in range(image_np.shape[0]): | |
image_np[k] = self.encoder.encode(image_np[k], "dwtDct") | |
image = torch.from_numpy( | |
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n) | |
).to(image.device) | |
image = torch.clamp(image / 255, min=0.0, max=1.0) | |
if squeeze: | |
image = image[0] | |
return image | |
# A fixed 48-bit message that was choosen at random | |
# WATERMARK_MESSAGE = 0xB3EC907BB19E | |
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 | |
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 | |
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] | |
embed_watermark = WatermarkEmbedder(WATERMARK_BITS) | |
def get_unique_embedder_keys_from_conditioner(conditioner): | |
return list({x.input_key for x in conditioner.embedders}) | |
def perform_save_locally(save_path, samples): | |
os.makedirs(os.path.join(save_path), exist_ok=True) | |
base_count = len(os.listdir(os.path.join(save_path))) | |
samples = embed_watermark(samples) | |
for sample in samples: | |
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") | |
Image.fromarray(sample.astype(np.uint8)).save( | |
os.path.join(save_path, f"{base_count:09}.png") | |
) | |
base_count += 1 | |
class Img2ImgDiscretizationWrapper: | |
""" | |
wraps a discretizer, and prunes the sigmas | |
params: | |
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) | |
""" | |
def __init__(self, discretization, strength: float = 1.0): | |
self.discretization = discretization | |
self.strength = strength | |
assert 0.0 <= self.strength <= 1.0 | |
def __call__(self, *args, **kwargs): | |
# sigmas start large first, and decrease then | |
sigmas = self.discretization(*args, **kwargs) | |
print(f"sigmas after discretization, before pruning img2img: ", sigmas) | |
sigmas = torch.flip(sigmas, (0,)) | |
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] | |
print("prune index:", max(int(self.strength * len(sigmas)), 1)) | |
sigmas = torch.flip(sigmas, (0,)) | |
print(f"sigmas after pruning: ", sigmas) | |
return sigmas | |
def do_sample( | |
model, | |
sampler, | |
value_dict, | |
num_samples, | |
H, | |
W, | |
C, | |
F, | |
force_uc_zero_embeddings: Optional[List] = None, | |
batch2model_input: Optional[List] = None, | |
return_latents=False, | |
filter=None, | |
device="cuda", | |
): | |
if force_uc_zero_embeddings is None: | |
force_uc_zero_embeddings = [] | |
if batch2model_input is None: | |
batch2model_input = [] | |
with torch.no_grad(): | |
with autocast(device) as precision_scope: | |
with model.ema_scope(): | |
num_samples = [num_samples] | |
batch, batch_uc = get_batch( | |
get_unique_embedder_keys_from_conditioner(model.conditioner), | |
value_dict, | |
num_samples, | |
) | |
for key in batch: | |
if isinstance(batch[key], torch.Tensor): | |
print(key, batch[key].shape) | |
elif isinstance(batch[key], list): | |
print(key, [len(l) for l in batch[key]]) | |
else: | |
print(key, batch[key]) | |
c, uc = model.conditioner.get_unconditional_conditioning( | |
batch, | |
batch_uc=batch_uc, | |
force_uc_zero_embeddings=force_uc_zero_embeddings, | |
) | |
for k in c: | |
if not k == "crossattn": | |
c[k], uc[k] = map( | |
lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc) | |
) | |
additional_model_inputs = {} | |
for k in batch2model_input: | |
additional_model_inputs[k] = batch[k] | |
shape = (math.prod(num_samples), C, H // F, W // F) | |
randn = torch.randn(shape).to(device) | |
def denoiser(input, sigma, c): | |
return model.denoiser( | |
model.model, input, sigma, c, **additional_model_inputs | |
) | |
samples_z = sampler(denoiser, randn, cond=c, uc=uc) | |
samples_x = model.decode_first_stage(samples_z) | |
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) | |
if filter is not None: | |
samples = filter(samples) | |
if return_latents: | |
return samples, samples_z | |
return samples | |
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): | |
# Hardcoded demo setups; might undergo some changes in the future | |
batch = {} | |
batch_uc = {} | |
for key in keys: | |
if key == "txt": | |
batch["txt"] = ( | |
np.repeat([value_dict["prompt"]], repeats=math.prod(N)) | |
.reshape(N) | |
.tolist() | |
) | |
batch_uc["txt"] = ( | |
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) | |
.reshape(N) | |
.tolist() | |
) | |
elif key == "original_size_as_tuple": | |
batch["original_size_as_tuple"] = ( | |
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) | |
.to(device) | |
.repeat(*N, 1) | |
) | |
elif key == "crop_coords_top_left": | |
batch["crop_coords_top_left"] = ( | |
torch.tensor( | |
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]] | |
) | |
.to(device) | |
.repeat(*N, 1) | |
) | |
elif key == "aesthetic_score": | |
batch["aesthetic_score"] = ( | |
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) | |
) | |
batch_uc["aesthetic_score"] = ( | |
torch.tensor([value_dict["negative_aesthetic_score"]]) | |
.to(device) | |
.repeat(*N, 1) | |
) | |
elif key == "target_size_as_tuple": | |
batch["target_size_as_tuple"] = ( | |
torch.tensor([value_dict["target_height"], value_dict["target_width"]]) | |
.to(device) | |
.repeat(*N, 1) | |
) | |
else: | |
batch[key] = value_dict[key] | |
for key in batch.keys(): | |
if key not in batch_uc and isinstance(batch[key], torch.Tensor): | |
batch_uc[key] = torch.clone(batch[key]) | |
return batch, batch_uc | |
def get_input_image_tensor(image: Image.Image, device="cuda"): | |
w, h = image.size | |
print(f"loaded input image of size ({w}, {h})") | |
width, height = map( | |
lambda x: x - x % 64, (w, h) | |
) # resize to integer multiple of 64 | |
image = image.resize((width, height)) | |
image_array = np.array(image.convert("RGB")) | |
image_array = image_array[None].transpose(0, 3, 1, 2) | |
image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0 | |
return image_tensor.to(device) | |
def do_img2img( | |
img, | |
model, | |
sampler, | |
value_dict, | |
num_samples, | |
force_uc_zero_embeddings=[], | |
additional_kwargs={}, | |
offset_noise_level: float = 0.0, | |
return_latents=False, | |
skip_encode=False, | |
filter=None, | |
device="cuda", | |
): | |
with torch.no_grad(): | |
with autocast(device) as precision_scope: | |
with model.ema_scope(): | |
batch, batch_uc = get_batch( | |
get_unique_embedder_keys_from_conditioner(model.conditioner), | |
value_dict, | |
[num_samples], | |
) | |
c, uc = model.conditioner.get_unconditional_conditioning( | |
batch, | |
batch_uc=batch_uc, | |
force_uc_zero_embeddings=force_uc_zero_embeddings, | |
) | |
for k in c: | |
c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc)) | |
for k in additional_kwargs: | |
c[k] = uc[k] = additional_kwargs[k] | |
if skip_encode: | |
z = img | |
else: | |
z = model.encode_first_stage(img) | |
noise = torch.randn_like(z) | |
sigmas = sampler.discretization(sampler.num_steps) | |
sigma = sigmas[0].to(z.device) | |
if offset_noise_level > 0.0: | |
noise = noise + offset_noise_level * append_dims( | |
torch.randn(z.shape[0], device=z.device), z.ndim | |
) | |
noised_z = z + noise * append_dims(sigma, z.ndim) | |
noised_z = noised_z / torch.sqrt( | |
1.0 + sigmas[0] ** 2.0 | |
) # Note: hardcoded to DDPM-like scaling. need to generalize later. | |
def denoiser(x, sigma, c): | |
return model.denoiser(model.model, x, sigma, c) | |
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) | |
samples_x = model.decode_first_stage(samples_z) | |
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) | |
if filter is not None: | |
samples = filter(samples) | |
if return_latents: | |
return samples, samples_z | |
return samples | |