Spaces:
Build error
Build error
import torch | |
from enum import Enum | |
import gc | |
import numpy as np | |
import jax.numpy as jnp | |
import tomesd | |
import jax | |
from flax.training.common_utils import shard | |
from flax.jax_utils import replicate | |
from flax import jax_utils | |
import einops | |
from transformers import CLIPTokenizer, CLIPFeatureExtractor, FlaxCLIPTextModel | |
from diffusers import ( | |
FlaxDDIMScheduler, | |
FlaxAutoencoderKL, | |
FlaxStableDiffusionControlNetPipeline, | |
StableDiffusionPipeline, | |
) | |
from text_to_animation.models.unet_2d_condition_flax import FlaxUNet2DConditionModel | |
from text_to_animation.models.controlnet_flax import FlaxControlNetModel | |
from text_to_animation.pipelines.text_to_video_pipeline_flax import ( | |
FlaxTextToVideoPipeline, | |
) | |
import utils.utils as utils | |
import utils.gradio_utils as gradio_utils | |
import os | |
on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR" | |
unshard = lambda x: einops.rearrange(x, "d b ... -> (d b) ...") | |
class ModelType(Enum): | |
Text2Video = 1 | |
ControlNetPose = 2 | |
StableDiffusion = 3 | |
def replicate_devices(array): | |
return jnp.expand_dims(array, 0).repeat(jax.device_count(), 0) | |
class ControlAnimationModel: | |
def __init__(self, device, dtype, **kwargs): | |
self.device = device | |
self.dtype = dtype | |
self.rng = jax.random.PRNGKey(0) | |
self.pipe_dict = { | |
ModelType.Text2Video: FlaxTextToVideoPipeline, # TODO: Replace with our TextToVideo JAX Pipeline | |
ModelType.ControlNetPose: FlaxStableDiffusionControlNetPipeline, | |
} | |
self.pipe = None | |
self.model_type = None | |
self.states = {} | |
self.model_name = "" | |
self.from_local = True # if the attn model is available in local (after adaptation by adapt_attn.py) | |
def set_model( | |
self, | |
model_type: ModelType, | |
model_id: str, | |
controlnet, | |
controlnet_params, | |
tokenizer, | |
scheduler, | |
scheduler_state, | |
**kwargs, | |
): | |
if hasattr(self, "pipe") and self.pipe is not None: | |
del self.pipe | |
self.pipe = None | |
gc.collect() | |
scheduler, scheduler_state = FlaxDDIMScheduler.from_pretrained( | |
model_id, subfolder="scheduler", from_pt=True | |
) | |
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") | |
feature_extractor = CLIPFeatureExtractor.from_pretrained( | |
model_id, subfolder="feature_extractor" | |
) | |
if self.from_local: | |
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( | |
f'./{model_id.split("/")[-1]}', | |
subfolder="unet", | |
from_pt=True, | |
dtype=self.dtype, | |
) | |
else: | |
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( | |
model_id, subfolder="unet", from_pt=True, dtype=self.dtype | |
) | |
vae, vae_params = FlaxAutoencoderKL.from_pretrained( | |
model_id, subfolder="vae", from_pt=True, dtype=self.dtype | |
) | |
text_encoder = FlaxCLIPTextModel.from_pretrained( | |
model_id, subfolder="text_encoder", from_pt=True, dtype=self.dtype | |
) | |
self.pipe = FlaxTextToVideoPipeline( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
controlnet=controlnet, | |
scheduler=scheduler, | |
safety_checker=None, | |
feature_extractor=feature_extractor, | |
) | |
self.params = { | |
"unet": unet_params, | |
"vae": vae_params, | |
"scheduler": scheduler_state, | |
"controlnet": controlnet_params, | |
"text_encoder": text_encoder.params, | |
} | |
self.p_params = jax_utils.replicate(self.params) | |
self.model_type = model_type | |
self.model_name = model_id | |
# def inference_chunk(self, image, frame_ids, prompt, negative_prompt, **kwargs): | |
# prompt_ids = self.pipe.prepare_text_inputs(prompt) | |
# n_prompt_ids = self.pipe.prepare_text_inputs(negative_prompt) | |
# latents = kwargs.pop('latents') | |
# # rng = jax.random.split(self.rng, jax.device_count()) | |
# prng, self.rng = jax.random.split(self.rng) | |
# #prng = jax.numpy.stack([prng] * jax.device_count())#same prng seed on every device | |
# prng_seed = jax.random.split(prng, jax.device_count()) | |
# image = replicate_devices(image[frame_ids]) | |
# latents = replicate_devices(latents) | |
# prompt_ids = replicate_devices(prompt_ids) | |
# n_prompt_ids = replicate_devices(n_prompt_ids) | |
# return (self.pipe(image=image, | |
# latents=latents, | |
# prompt_ids=prompt_ids, | |
# neg_prompt_ids=n_prompt_ids, | |
# params=self.p_params, | |
# prng_seed=prng_seed, jit = True, | |
# ).images)[0] | |
def inference(self, image, split_to_chunks=False, chunk_size=8, **kwargs): | |
if not hasattr(self, "pipe") or self.pipe is None: | |
return | |
if "merging_ratio" in kwargs: | |
merging_ratio = kwargs.pop("merging_ratio") | |
# if merging_ratio > 0: | |
tomesd.apply_patch(self.pipe, ratio=merging_ratio) | |
# f = image.shape[0] | |
assert "prompt" in kwargs | |
prompt = [kwargs.pop("prompt")] | |
negative_prompt = [kwargs.pop("negative_prompt", "")] | |
frames_counter = 0 | |
# Processing chunk-by-chunk | |
if split_to_chunks: | |
pass | |
# # not tested | |
# f = image.shape[0] | |
# chunk_ids = np.arange(0, f, chunk_size - 1) | |
# result = [] | |
# for i in range(len(chunk_ids)): | |
# ch_start = chunk_ids[i] | |
# ch_end = f if i == len(chunk_ids) - 1 else chunk_ids[i + 1] | |
# frame_ids = [0] + list(range(ch_start, ch_end)) | |
# print(f'Processing chunk {i + 1} / {len(chunk_ids)}') | |
# result.append(self.inference_chunk(image=image, | |
# frame_ids=frame_ids, | |
# prompt=prompt, | |
# negative_prompt=negative_prompt, | |
# **kwargs).images[1:]) | |
# frames_counter += len(chunk_ids)-1 | |
# if on_huggingspace and frames_counter >= 80: | |
# break | |
# result = np.concatenate(result) | |
# return result | |
else: | |
if "jit" in kwargs and kwargs.pop("jit"): | |
prompt_ids = self.pipe.prepare_text_inputs(prompt) | |
n_prompt_ids = self.pipe.prepare_text_inputs(negative_prompt) | |
latents = kwargs.pop("latents") | |
prng, self.rng = jax.random.split(self.rng) | |
prng_seed = jax.random.split(prng, jax.device_count()) | |
image = replicate_devices(image) | |
latents = replicate_devices(latents) | |
prompt_ids = replicate_devices(prompt_ids) | |
n_prompt_ids = replicate_devices(n_prompt_ids) | |
return ( | |
self.pipe( | |
image=image, | |
latents=latents, | |
prompt_ids=prompt_ids, | |
neg_prompt_ids=n_prompt_ids, | |
params=self.p_params, | |
prng_seed=prng_seed, | |
jit=True, | |
).images | |
)[0] | |
else: | |
prompt_ids = self.pipe.prepare_text_inputs(prompt) | |
n_prompt_ids = self.pipe.prepare_text_inputs(negative_prompt) | |
latents = kwargs.pop("latents") | |
prng_seed, self.rng = jax.random.split(self.rng) | |
return self.pipe( | |
image=image, | |
latents=latents, | |
prompt_ids=prompt_ids, | |
neg_prompt_ids=n_prompt_ids, | |
params=self.params, | |
prng_seed=prng_seed, | |
jit=False, | |
).images | |
def process_controlnet_pose( | |
self, | |
video_path, | |
prompt, | |
chunk_size=8, | |
watermark="Picsart AI Research", | |
merging_ratio=0.0, | |
num_inference_steps=20, | |
controlnet_conditioning_scale=1.0, | |
guidance_scale=9.0, | |
seed=42, | |
eta=0.0, | |
resolution=512, | |
use_cf_attn=True, | |
save_path=None, | |
): | |
print("Module Pose") | |
video_path = gradio_utils.motion_to_video_path(video_path) | |
if self.model_type != ModelType.ControlNetPose: | |
controlnet = FlaxControlNetModel.from_pretrained( | |
"fusing/stable-diffusion-v1-5-controlnet-openpose" | |
) | |
self.set_model( | |
ModelType.ControlNetPose, | |
model_id="runwayml/stable-diffusion-v1-5", | |
controlnet=controlnet, | |
) | |
self.pipe.scheduler = FlaxDDIMScheduler.from_config( | |
self.pipe.scheduler.config | |
) | |
if use_cf_attn: | |
self.pipe.unet.set_attn_processor(processor=self.controlnet_attn_proc) | |
self.pipe.controlnet.set_attn_processor( | |
processor=self.controlnet_attn_proc | |
) | |
video_path = ( | |
gradio_utils.motion_to_video_path(video_path) | |
if "Motion" in video_path | |
else video_path | |
) | |
added_prompt = "best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth" | |
negative_prompts = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic" | |
video, fps = utils.prepare_video( | |
video_path, resolution, self.device, self.dtype, False, output_fps=4 | |
) | |
control = ( | |
utils.pre_process_pose(video, apply_pose_detect=False) | |
.to(self.device) | |
.to(self.dtype) | |
) | |
f, _, h, w = video.shape | |
self.generator.manual_seed(seed) | |
latents = torch.randn( | |
(1, 4, h // 8, w // 8), | |
dtype=self.dtype, | |
device=self.device, | |
generator=self.generator, | |
) | |
latents = latents.repeat(f, 1, 1, 1) | |
result = self.inference( | |
image=control, | |
prompt=prompt + ", " + added_prompt, | |
height=h, | |
width=w, | |
negative_prompt=negative_prompts, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
controlnet_conditioning_scale=controlnet_conditioning_scale, | |
eta=eta, | |
latents=latents, | |
seed=seed, | |
output_type="numpy", | |
split_to_chunks=True, | |
chunk_size=chunk_size, | |
merging_ratio=merging_ratio, | |
) | |
return utils.create_gif( | |
result, | |
fps, | |
path=save_path, | |
watermark=gradio_utils.logo_name_to_path(watermark), | |
) | |
def process_text2video( | |
self, | |
prompt, | |
model_name="dreamlike-art/dreamlike-photoreal-2.0", | |
motion_field_strength_x=12, | |
motion_field_strength_y=12, | |
t0=44, | |
t1=47, | |
n_prompt="", | |
chunk_size=8, | |
video_length=8, | |
watermark="Picsart AI Research", | |
merging_ratio=0.0, | |
seed=0, | |
resolution=512, | |
fps=2, | |
use_cf_attn=True, | |
use_motion_field=True, | |
smooth_bg=False, | |
smooth_bg_strength=0.4, | |
path=None, | |
): | |
print("Module Text2Video") | |
if self.model_type != ModelType.Text2Video or model_name != self.model_name: | |
print("Model update") | |
unet = FlaxUNet2DConditionModel.from_pretrained( | |
model_name, subfolder="unet" | |
) | |
self.set_model(ModelType.Text2Video, model_id=model_name, unet=unet) | |
self.pipe.scheduler = FlaxDDIMScheduler.from_config( | |
self.pipe.scheduler.config | |
) | |
if use_cf_attn: | |
self.pipe.unet.set_attn_processor(processor=self.text2video_attn_proc) | |
self.generator.manual_seed(seed) | |
added_prompt = "high quality, HD, 8K, trending on artstation, high focus, dramatic lighting" | |
negative_prompts = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic" | |
prompt = prompt.rstrip() | |
if len(prompt) > 0 and (prompt[-1] == "," or prompt[-1] == "."): | |
prompt = prompt.rstrip()[:-1] | |
prompt = prompt.rstrip() | |
prompt = prompt + ", " + added_prompt | |
if len(n_prompt) > 0: | |
negative_prompt = n_prompt | |
else: | |
negative_prompt = None | |
result = self.inference( | |
prompt=prompt, | |
video_length=video_length, | |
height=resolution, | |
width=resolution, | |
num_inference_steps=50, | |
guidance_scale=7.5, | |
guidance_stop_step=1.0, | |
t0=t0, | |
t1=t1, | |
motion_field_strength_x=motion_field_strength_x, | |
motion_field_strength_y=motion_field_strength_y, | |
use_motion_field=use_motion_field, | |
smooth_bg=smooth_bg, | |
smooth_bg_strength=smooth_bg_strength, | |
seed=seed, | |
output_type="numpy", | |
negative_prompt=negative_prompt, | |
merging_ratio=merging_ratio, | |
split_to_chunks=True, | |
chunk_size=chunk_size, | |
) | |
return utils.create_video( | |
result, fps, path=path, watermark=gradio_utils.logo_name_to_path(watermark) | |
) | |
def generate_animation( | |
self, | |
prompt: str, | |
model_link: str = "dreamlike-art/dreamlike-photoreal-2.0", | |
is_safetensor: bool = False, | |
motion_field_strength_x: int = 12, | |
motion_field_strength_y: int = 12, | |
t0: int = 44, | |
t1: int = 47, | |
n_prompt: str = "", | |
chunk_size: int = 8, | |
video_length: int = 8, | |
merging_ratio: float = 0.0, | |
seed: int = 0, | |
resolution: int = 512, | |
fps: int = 2, | |
use_cf_attn: bool = True, | |
use_motion_field: bool = True, | |
smooth_bg: bool = False, | |
smooth_bg_strength: float = 0.4, | |
path: str = None, | |
): | |
if is_safetensor and model_link[-len(".safetensors") :] == ".safetensors": | |
pipe = utils.load_safetensors_model(model_link) | |
return | |
def generate_initial_frames( | |
self, | |
prompt: str, | |
model_link: str = "dreamlike-art/dreamlike-photoreal-2.0", | |
is_safetensor: bool = False, | |
n_prompt: str = "", | |
width: int = 512, | |
height: int = 512, | |
# batch_count: int = 4, | |
# batch_size: int = 1, | |
cfg_scale: float = 7.0, | |
seed: int = 0, | |
): | |
print(f">>> prompt: {prompt}, model_link: {model_link}") | |
pipe = StableDiffusionPipeline.from_pretrained(model_link) | |
batch_size = 4 | |
prompt = [prompt] * batch_size | |
negative_prompt = [n_prompt] * batch_size | |
images = pipe( | |
prompt, | |
negative_prompt=negative_prompt, | |
width=width, | |
height=height, | |
guidance_scale=cfg_scale, | |
).images | |
return images | |