""" THis is the main file for the gradio web demo. It uses the CogVideoX-5B model to generate videos gradio web demo. set environment variable OPENAI_API_KEY to use the OpenAI API to enhance the prompt. Usage: OPENAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=your_base_url python app.py """ import spaces import math import os import random import threading import time import os import cv2 import tempfile import imageio_ffmpeg import gradio as gr import torch from PIL import Image from diffusers import ( CogVideoXPipeline, CogVideoXDPMScheduler, CogVideoXVideoToVideoPipeline, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, ) from diffusers.utils import load_video, load_image from datetime import datetime, timedelta from PIL import Image from transformers import AutoModelForCausalLM, LlamaTokenizer from diffusers.image_processor import VaeImageProcessor from openai import OpenAI import moviepy.editor as mp import utils from rife_model import load_rife_model, rife_inference_with_latents from huggingface_hub import hf_hub_download, snapshot_download device = "cuda" if torch.cuda.is_available() else "cpu" hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran") snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife") pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to(device) pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") pipe_image = CogVideoXImageToVideoPipeline.from_pretrained( "THUDM/CogVideoX-5b-I2V", transformer=CogVideoXTransformer3DModel.from_pretrained( "THUDM/CogVideoX-5b-I2V", subfolder="transformer", torch_dtype=torch.bfloat16 ), vae=pipe.vae, scheduler=pipe.scheduler, tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder, torch_dtype=torch.bfloat16, ) os.makedirs("checkpoints", exist_ok=True) # Download LoRA weights hf_hub_download( repo_id="wenqsun/DimensionX", filename="orbit_left_lora_weights.safetensors", local_dir="checkpoints" ) hf_hub_download( repo_id="wenqsun/DimensionX", filename="orbit_up_lora_weights.safetensors", local_dir="checkpoints" ) # pipe.transformer.to(memory_format=torch.channels_last) # pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) # pipe_image.transformer.to(memory_format=torch.channels_last) # pipe_image.transformer = torch.compile(pipe_image.transformer, mode="max-autotune", fullgraph=True) os.makedirs("./output", exist_ok=True) os.makedirs("./gradio_tmp", exist_ok=True) upscale_model = utils.load_sd_upscale("model_real_esran/RealESRGAN_x4.pth", device) frame_interpolation_model = load_rife_model("model_rife") sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets. For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive. There are a few rules to follow: You will only ever output a single video description per user request. When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions. Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user. Video descriptions must have the same num of words as examples below. Extra words will be ignored. """ def resize_if_unfit(input_video, progress=gr.Progress(track_tqdm=True)): width, height = get_video_dimensions(input_video) if width == 720 and height == 480: processed_video = input_video else: processed_video = center_crop_resize(input_video) return processed_video def get_video_dimensions(input_video_path): reader = imageio_ffmpeg.read_frames(input_video_path) metadata = next(reader) return metadata["size"] def center_crop_resize(input_video_path, target_width=720, target_height=480): cap = cv2.VideoCapture(input_video_path) orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) orig_fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) width_factor = target_width / orig_width height_factor = target_height / orig_height resize_factor = max(width_factor, height_factor) inter_width = int(orig_width * resize_factor) inter_height = int(orig_height * resize_factor) target_fps = 8 ideal_skip = max(0, math.ceil(orig_fps / target_fps) - 1) skip = min(5, ideal_skip) # Cap at 5 while (total_frames / (skip + 1)) < 49 and skip > 0: skip -= 1 processed_frames = [] frame_count = 0 total_read = 0 while frame_count < 49 and total_read < total_frames: ret, frame = cap.read() if not ret: break if total_read % (skip + 1) == 0: resized = cv2.resize(frame, (inter_width, inter_height), interpolation=cv2.INTER_AREA) start_x = (inter_width - target_width) // 2 start_y = (inter_height - target_height) // 2 cropped = resized[start_y : start_y + target_height, start_x : start_x + target_width] processed_frames.append(cropped) frame_count += 1 total_read += 1 cap.release() with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file: temp_video_path = temp_file.name fourcc = cv2.VideoWriter_fourcc(*"mp4v") out = cv2.VideoWriter(temp_video_path, fourcc, target_fps, (target_width, target_height)) for frame in processed_frames: out.write(frame) out.release() return temp_video_path def convert_prompt(prompt: str, image_path: str = None, retry_times: int = 3) -> str: # Define model and tokenizer paths MODEL_PATH = "THUDM/cogagent-chat-hf" TOKENIZER_PATH = "lmsys/vicuna-7b-v1.5" DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' torch_type = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # Initialize model and tokenizer tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH) model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, torch_dtype=torch_type, low_cpu_mem_usage=True, trust_remote_code=True ).to(DEVICE).eval() # Conversation template for text-only queries text_only_template = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:" # Check if image is available if image_path and os.path.isfile(image_path): image = Image.open(image_path).convert('RGB') else: image = None # Initialize history for conversation context history = [] query = prompt.strip() for _ in range(retry_times): if image is None: # Text-only query, format as required by CogAgent query = text_only_template.format(query) input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, template_version='base') inputs = { 'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE), 'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE), 'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE) } else: # Image-based input with initial query input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image]) inputs = { 'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE), 'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE), 'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE), 'images': [[input_by_model['images'][0].to(DEVICE).to(torch_type)]] } if 'cross_images' in input_by_model and input_by_model['cross_images']: inputs['cross_images'] = [[input_by_model['cross_images'][0].to(DEVICE).to(torch_type)]] # Generation settings gen_kwargs = {"max_length": 2048, "do_sample": False} with torch.no_grad(): outputs = model.generate(**inputs, **gen_kwargs) outputs = outputs[:, inputs['input_ids'].shape[1]:] response = tokenizer.decode(outputs[0], skip_special_tokens=True) response = response.split("")[0].strip() # Clean up response if response: return response # Return the response if generated successfully # Return original prompt if all retries fail return prompt @spaces.GPU def infer( prompt: str, orbit_type: str, image_input: str, num_inference_steps: int, guidance_scale: float, seed: int = -1, progress=gr.Progress(track_tqdm=True), ): if seed == -1: seed = random.randint(0, 2**8 - 1) # if video_input is not None: # video = load_video(video_input)[:49] # Limit to 49 frames # video_pt = pipe_video( # video=video, # prompt=prompt, # num_inference_steps=num_inference_steps, # num_videos_per_prompt=1, # strength=video_strenght, # use_dynamic_cfg=True, # output_type="pt", # guidance_scale=guidance_scale, # generator=torch.Generator(device="cpu").manual_seed(seed), # ).frames lora_path = "checkpoints/" weight_name = "orbit_left_lora_weights.safetensors" if orbit_type == "Left" else "orbit_up_lora_weights.safetensors" lora_rank = 256 adapter_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Load LoRA weights on CPU pipe_image.load_lora_weights(lora_path, weight_name=weight_name, adapter_name=f"adapter_{adapter_timestamp}") pipe_image.fuse_lora(lora_scale=1 / lora_rank) pipe_image = pipe_image.to(device) if image_input is not None: image_input = Image.fromarray(image_input).resize(size=(720, 480)) # Convert to PIL image = load_image(image_input) video_pt = pipe_image( image=image, prompt=prompt, num_inference_steps=num_inference_steps, num_videos_per_prompt=1, use_dynamic_cfg=True, output_type="pt", guidance_scale=guidance_scale, generator=torch.Generator(device="cpu").manual_seed(seed), ).frames else: video_pt = pipe( prompt=prompt, num_videos_per_prompt=1, num_inference_steps=num_inference_steps, num_frames=49, use_dynamic_cfg=True, output_type="pt", guidance_scale=guidance_scale, generator=torch.Generator(device="cpu").manual_seed(seed), ).frames return (video_pt, seed) def convert_to_gif(video_path): clip = mp.VideoFileClip(video_path) clip = clip.set_fps(8) clip = clip.resize(height=240) gif_path = video_path.replace(".mp4", ".gif") clip.write_gif(gif_path, fps=8) return gif_path def delete_old_files(): while True: now = datetime.now() cutoff = now - timedelta(minutes=10) directories = ["./output", "./gradio_tmp"] for directory in directories: for filename in os.listdir(directory): file_path = os.path.join(directory, filename) if os.path.isfile(file_path): file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path)) if file_mtime < cutoff: os.remove(file_path) time.sleep(600) threading.Thread(target=delete_old_files, daemon=True).start() examples_images = [["example_images/beef.png"], ["example_images/candle.png"], ["example_images/person.png"]] with gr.Blocks() as demo: gr.Markdown("""