Pie31415's picture
initial commit
2c924d3
raw
history blame
15.8 kB
import torch
from enum import Enum
import gc
import numpy as np
import jax.numpy as jnp
import tomesd
import jax
from flax.training.common_utils import shard
from flax.jax_utils import replicate
from flax import jax_utils
import einops
from transformers import CLIPTokenizer, CLIPFeatureExtractor, FlaxCLIPTextModel
from diffusers import (
FlaxDDIMScheduler,
FlaxAutoencoderKL,
FlaxStableDiffusionControlNetPipeline,
StableDiffusionPipeline,
)
from text_to_animation.models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from text_to_animation.models.controlnet_flax import FlaxControlNetModel
from text_to_animation.pipelines.text_to_video_pipeline_flax import (
FlaxTextToVideoPipeline,
)
import utils.utils as utils
import utils.gradio_utils as gradio_utils
import os
on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"
unshard = lambda x: einops.rearrange(x, "d b ... -> (d b) ...")
class ModelType(Enum):
Text2Video = 1
ControlNetPose = 2
StableDiffusion = 3
def replicate_devices(array):
return jnp.expand_dims(array, 0).repeat(jax.device_count(), 0)
class ControlAnimationModel:
def __init__(self, device, dtype, **kwargs):
self.device = device
self.dtype = dtype
self.rng = jax.random.PRNGKey(0)
self.pipe_dict = {
ModelType.Text2Video: FlaxTextToVideoPipeline, # TODO: Replace with our TextToVideo JAX Pipeline
ModelType.ControlNetPose: FlaxStableDiffusionControlNetPipeline,
}
self.pipe = None
self.model_type = None
self.states = {}
self.model_name = ""
self.from_local = True # if the attn model is available in local (after adaptation by adapt_attn.py)
def set_model(
self,
model_type: ModelType,
model_id: str,
controlnet,
controlnet_params,
tokenizer,
scheduler,
scheduler_state,
**kwargs,
):
if hasattr(self, "pipe") and self.pipe is not None:
del self.pipe
self.pipe = None
gc.collect()
scheduler, scheduler_state = FlaxDDIMScheduler.from_pretrained(
model_id, subfolder="scheduler", from_pt=True
)
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
feature_extractor = CLIPFeatureExtractor.from_pretrained(
model_id, subfolder="feature_extractor"
)
if self.from_local:
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
f'./{model_id.split("/")[-1]}',
subfolder="unet",
from_pt=True,
dtype=self.dtype,
)
else:
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
model_id, subfolder="unet", from_pt=True, dtype=self.dtype
)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
model_id, subfolder="vae", from_pt=True, dtype=self.dtype
)
text_encoder = FlaxCLIPTextModel.from_pretrained(
model_id, subfolder="text_encoder", from_pt=True, dtype=self.dtype
)
self.pipe = FlaxTextToVideoPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=feature_extractor,
)
self.params = {
"unet": unet_params,
"vae": vae_params,
"scheduler": scheduler_state,
"controlnet": controlnet_params,
"text_encoder": text_encoder.params,
}
self.p_params = jax_utils.replicate(self.params)
self.model_type = model_type
self.model_name = model_id
# def inference_chunk(self, image, frame_ids, prompt, negative_prompt, **kwargs):
# prompt_ids = self.pipe.prepare_text_inputs(prompt)
# n_prompt_ids = self.pipe.prepare_text_inputs(negative_prompt)
# latents = kwargs.pop('latents')
# # rng = jax.random.split(self.rng, jax.device_count())
# prng, self.rng = jax.random.split(self.rng)
# #prng = jax.numpy.stack([prng] * jax.device_count())#same prng seed on every device
# prng_seed = jax.random.split(prng, jax.device_count())
# image = replicate_devices(image[frame_ids])
# latents = replicate_devices(latents)
# prompt_ids = replicate_devices(prompt_ids)
# n_prompt_ids = replicate_devices(n_prompt_ids)
# return (self.pipe(image=image,
# latents=latents,
# prompt_ids=prompt_ids,
# neg_prompt_ids=n_prompt_ids,
# params=self.p_params,
# prng_seed=prng_seed, jit = True,
# ).images)[0]
def inference(self, image, split_to_chunks=False, chunk_size=8, **kwargs):
if not hasattr(self, "pipe") or self.pipe is None:
return
if "merging_ratio" in kwargs:
merging_ratio = kwargs.pop("merging_ratio")
# if merging_ratio > 0:
tomesd.apply_patch(self.pipe, ratio=merging_ratio)
# f = image.shape[0]
assert "prompt" in kwargs
prompt = [kwargs.pop("prompt")]
negative_prompt = [kwargs.pop("negative_prompt", "")]
frames_counter = 0
# Processing chunk-by-chunk
if split_to_chunks:
pass
# # not tested
# f = image.shape[0]
# chunk_ids = np.arange(0, f, chunk_size - 1)
# result = []
# for i in range(len(chunk_ids)):
# ch_start = chunk_ids[i]
# ch_end = f if i == len(chunk_ids) - 1 else chunk_ids[i + 1]
# frame_ids = [0] + list(range(ch_start, ch_end))
# print(f'Processing chunk {i + 1} / {len(chunk_ids)}')
# result.append(self.inference_chunk(image=image,
# frame_ids=frame_ids,
# prompt=prompt,
# negative_prompt=negative_prompt,
# **kwargs).images[1:])
# frames_counter += len(chunk_ids)-1
# if on_huggingspace and frames_counter >= 80:
# break
# result = np.concatenate(result)
# return result
else:
if "jit" in kwargs and kwargs.pop("jit"):
prompt_ids = self.pipe.prepare_text_inputs(prompt)
n_prompt_ids = self.pipe.prepare_text_inputs(negative_prompt)
latents = kwargs.pop("latents")
prng, self.rng = jax.random.split(self.rng)
prng_seed = jax.random.split(prng, jax.device_count())
image = replicate_devices(image)
latents = replicate_devices(latents)
prompt_ids = replicate_devices(prompt_ids)
n_prompt_ids = replicate_devices(n_prompt_ids)
return (
self.pipe(
image=image,
latents=latents,
prompt_ids=prompt_ids,
neg_prompt_ids=n_prompt_ids,
params=self.p_params,
prng_seed=prng_seed,
jit=True,
).images
)[0]
else:
prompt_ids = self.pipe.prepare_text_inputs(prompt)
n_prompt_ids = self.pipe.prepare_text_inputs(negative_prompt)
latents = kwargs.pop("latents")
prng_seed, self.rng = jax.random.split(self.rng)
return self.pipe(
image=image,
latents=latents,
prompt_ids=prompt_ids,
neg_prompt_ids=n_prompt_ids,
params=self.params,
prng_seed=prng_seed,
jit=False,
).images
def process_controlnet_pose(
self,
video_path,
prompt,
chunk_size=8,
watermark="Picsart AI Research",
merging_ratio=0.0,
num_inference_steps=20,
controlnet_conditioning_scale=1.0,
guidance_scale=9.0,
seed=42,
eta=0.0,
resolution=512,
use_cf_attn=True,
save_path=None,
):
print("Module Pose")
video_path = gradio_utils.motion_to_video_path(video_path)
if self.model_type != ModelType.ControlNetPose:
controlnet = FlaxControlNetModel.from_pretrained(
"fusing/stable-diffusion-v1-5-controlnet-openpose"
)
self.set_model(
ModelType.ControlNetPose,
model_id="runwayml/stable-diffusion-v1-5",
controlnet=controlnet,
)
self.pipe.scheduler = FlaxDDIMScheduler.from_config(
self.pipe.scheduler.config
)
if use_cf_attn:
self.pipe.unet.set_attn_processor(processor=self.controlnet_attn_proc)
self.pipe.controlnet.set_attn_processor(
processor=self.controlnet_attn_proc
)
video_path = (
gradio_utils.motion_to_video_path(video_path)
if "Motion" in video_path
else video_path
)
added_prompt = "best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth"
negative_prompts = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic"
video, fps = utils.prepare_video(
video_path, resolution, self.device, self.dtype, False, output_fps=4
)
control = (
utils.pre_process_pose(video, apply_pose_detect=False)
.to(self.device)
.to(self.dtype)
)
f, _, h, w = video.shape
self.generator.manual_seed(seed)
latents = torch.randn(
(1, 4, h // 8, w // 8),
dtype=self.dtype,
device=self.device,
generator=self.generator,
)
latents = latents.repeat(f, 1, 1, 1)
result = self.inference(
image=control,
prompt=prompt + ", " + added_prompt,
height=h,
width=w,
negative_prompt=negative_prompts,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
controlnet_conditioning_scale=controlnet_conditioning_scale,
eta=eta,
latents=latents,
seed=seed,
output_type="numpy",
split_to_chunks=True,
chunk_size=chunk_size,
merging_ratio=merging_ratio,
)
return utils.create_gif(
result,
fps,
path=save_path,
watermark=gradio_utils.logo_name_to_path(watermark),
)
def process_text2video(
self,
prompt,
model_name="dreamlike-art/dreamlike-photoreal-2.0",
motion_field_strength_x=12,
motion_field_strength_y=12,
t0=44,
t1=47,
n_prompt="",
chunk_size=8,
video_length=8,
watermark="Picsart AI Research",
merging_ratio=0.0,
seed=0,
resolution=512,
fps=2,
use_cf_attn=True,
use_motion_field=True,
smooth_bg=False,
smooth_bg_strength=0.4,
path=None,
):
print("Module Text2Video")
if self.model_type != ModelType.Text2Video or model_name != self.model_name:
print("Model update")
unet = FlaxUNet2DConditionModel.from_pretrained(
model_name, subfolder="unet"
)
self.set_model(ModelType.Text2Video, model_id=model_name, unet=unet)
self.pipe.scheduler = FlaxDDIMScheduler.from_config(
self.pipe.scheduler.config
)
if use_cf_attn:
self.pipe.unet.set_attn_processor(processor=self.text2video_attn_proc)
self.generator.manual_seed(seed)
added_prompt = "high quality, HD, 8K, trending on artstation, high focus, dramatic lighting"
negative_prompts = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic"
prompt = prompt.rstrip()
if len(prompt) > 0 and (prompt[-1] == "," or prompt[-1] == "."):
prompt = prompt.rstrip()[:-1]
prompt = prompt.rstrip()
prompt = prompt + ", " + added_prompt
if len(n_prompt) > 0:
negative_prompt = n_prompt
else:
negative_prompt = None
result = self.inference(
prompt=prompt,
video_length=video_length,
height=resolution,
width=resolution,
num_inference_steps=50,
guidance_scale=7.5,
guidance_stop_step=1.0,
t0=t0,
t1=t1,
motion_field_strength_x=motion_field_strength_x,
motion_field_strength_y=motion_field_strength_y,
use_motion_field=use_motion_field,
smooth_bg=smooth_bg,
smooth_bg_strength=smooth_bg_strength,
seed=seed,
output_type="numpy",
negative_prompt=negative_prompt,
merging_ratio=merging_ratio,
split_to_chunks=True,
chunk_size=chunk_size,
)
return utils.create_video(
result, fps, path=path, watermark=gradio_utils.logo_name_to_path(watermark)
)
def generate_animation(
self,
prompt: str,
model_link: str = "dreamlike-art/dreamlike-photoreal-2.0",
is_safetensor: bool = False,
motion_field_strength_x: int = 12,
motion_field_strength_y: int = 12,
t0: int = 44,
t1: int = 47,
n_prompt: str = "",
chunk_size: int = 8,
video_length: int = 8,
merging_ratio: float = 0.0,
seed: int = 0,
resolution: int = 512,
fps: int = 2,
use_cf_attn: bool = True,
use_motion_field: bool = True,
smooth_bg: bool = False,
smooth_bg_strength: float = 0.4,
path: str = None,
):
if is_safetensor and model_link[-len(".safetensors") :] == ".safetensors":
pipe = utils.load_safetensors_model(model_link)
return
def generate_initial_frames(
self,
prompt: str,
model_link: str = "dreamlike-art/dreamlike-photoreal-2.0",
is_safetensor: bool = False,
n_prompt: str = "",
width: int = 512,
height: int = 512,
# batch_count: int = 4,
# batch_size: int = 1,
cfg_scale: float = 7.0,
seed: int = 0,
):
print(f">>> prompt: {prompt}, model_link: {model_link}")
pipe = StableDiffusionPipeline.from_pretrained(model_link)
batch_size = 4
prompt = [prompt] * batch_size
negative_prompt = [n_prompt] * batch_size
images = pipe(
prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=cfg_scale,
).images
return images