Spaces:
Runtime error
Runtime error
import json | |
import os | |
import runpod | |
import numpy as np | |
import torch | |
import requests | |
import uuid | |
from diffusers import (AutoencoderKL, CogVideoXDDIMScheduler, DDIMScheduler, | |
DPMSolverMultistepScheduler, | |
EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, | |
PNDMScheduler) | |
from transformers import T5EncoderModel, T5Tokenizer | |
from omegaconf import OmegaConf | |
from PIL import Image | |
from cogvideox.models.transformer3d import CogVideoXTransformer3DModel | |
from cogvideox.models.autoencoder_magvit import AutoencoderKLCogVideoX | |
from cogvideox.pipeline.pipeline_cogvideox import CogVideoX_Fun_Pipeline | |
from cogvideox.pipeline.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint | |
from cogvideox.utils.lora_utils import merge_lora, unmerge_lora | |
from cogvideox.utils.utils import get_image_to_video_latent, save_videos_grid | |
from cogvideox.data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio | |
from huggingface_hub import HfApi, HfFolder | |
tokenxf = os.getenv("HF_API_TOKEN") | |
# Low GPU memory mode | |
low_gpu_memory_mode = False | |
lora_path = "/content/shirtlift.safetensors" | |
weight_dtype = torch.bfloat16 | |
def to_pil(image): | |
if isinstance(image, Image.Image): | |
return image | |
if isinstance(image, torch.Tensor): | |
return tensor2pil(image) | |
if isinstance(image, np.ndarray): | |
return numpy2pil(image) | |
raise ValueError(f"Cannot convert {type(image)} to PIL.Image") | |
def download_image(url, download_dir="/content"): | |
# Ensure the download directory exists | |
if not os.path.exists(download_dir): | |
os.makedirs(download_dir, exist_ok=True) | |
# Send the request and check for successful response | |
response = requests.get(url, stream=True) | |
if response.status_code == 200: | |
# Determine file extension based on content type | |
content_type = response.headers.get("Content-Type") | |
if content_type == "image/png": | |
ext = "png" | |
elif content_type == "image/jpeg": | |
ext = "jpg" | |
else: | |
ext = "jpg" # default to .jpg if content type is unrecognized | |
# Generate a random filename with the correct extension | |
filename = f"{uuid.uuid4().hex}.{ext}" | |
file_path = os.path.join(download_dir, filename) | |
# Save the image | |
with open(file_path, "wb") as f: | |
for chunk in response.iter_content(1024): | |
f.write(chunk) | |
print(f"Image downloaded to {file_path}") | |
os.chmod(file_path, 0o777) | |
return file_path | |
else: | |
raise Exception(f"Failed to download image from {url}, status code: {response.status_code}") | |
# Usage | |
# validation_image_start = values.get("validation_image_start", "https://example.com/path/to/image.png") | |
# downloaded_image_path = download_image(validation_image_start) | |
with torch.inference_mode(): | |
model_id = "/runpod-volume/model" | |
transformer = CogVideoXTransformer3DModel.from_pretrained_2d( | |
model_id, subfolder="transformer" | |
).to(weight_dtype) | |
vae = AutoencoderKLCogVideoX.from_pretrained( | |
model_id, subfolder="vae" | |
).to(weight_dtype) | |
text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=weight_dtype) | |
sampler_dict = { | |
"Euler": EulerDiscreteScheduler, | |
"Euler A": EulerAncestralDiscreteScheduler, | |
"DPM++": DPMSolverMultistepScheduler, | |
"PNDM": PNDMScheduler, | |
"DDIM_Cog": CogVideoXDDIMScheduler, | |
"DDIM_Origin": DDIMScheduler, | |
} | |
scheduler = sampler_dict["DPM++"].from_pretrained(model_id, subfolder="scheduler") | |
if transformer.config.in_channels != vae.config.latent_channels: | |
pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained( | |
model_id, | |
vae=vae, | |
text_encoder=text_encoder, | |
transformer=transformer, | |
scheduler=scheduler, | |
torch_dtype=weight_dtype | |
) | |
else: | |
pipeline = CogVideoX_Fun_Pipeline.from_pretrained( | |
model_id, | |
vae=vae, | |
text_encoder=text_encoder, | |
transformer=transformer, | |
scheduler=scheduler, | |
torch_dtype=weight_dtype | |
) | |
pipeline = merge_lora(pipeline, lora_path, 1.00) | |
if low_gpu_memory_mode: | |
pipeline.enable_sequential_cpu_offload() | |
else: | |
pipeline.enable_model_cpu_offload() | |
def generate(input): | |
values = input["input"] | |
prompt = values["prompt"] | |
print("starting Generate function") | |
print(prompt) | |
negative_prompt = values.get("negative_prompt", "The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange motion trajectory. blurry, blurred, grainy, distortion, blurry face") | |
guidance_scale = values.get("guidance_scale", 6.0) | |
seed = values.get("seed", 42) | |
num_inference_steps = values.get("num_inference_steps", 18) | |
base_resolution = values.get("base_resolution", 512) | |
video_length = values.get("video_length", 49) | |
fps = values.get("fps", 10) | |
save_path = "samples" | |
partial_video_length = values.get("partial_video_length", None) | |
overlap_video_length = values.get("overlap_video_length", 4) | |
validation_image_start = values.get("validation_image_start", "asset/1.png") | |
print(validation_image_start) | |
downloaded_image_path = download_image(validation_image_start) | |
validation_image_end = values.get("validation_image_end", None) | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
print("Generator started") | |
aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} | |
start_img = Image.open(downloaded_image_path) | |
original_width, original_height = start_img.size | |
closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size) | |
height, width = [int(x / 16) * 16 for x in closest_size] | |
sample_size = [height, width] | |
print("Getting closest ratio") | |
print(closest_ratio) | |
video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 | |
input_video, input_video_mask, clip_image = get_image_to_video_latent(downloaded_image_path, validation_image_end, video_length=video_length, sample_size=sample_size) | |
with torch.no_grad(): | |
sample = pipeline(prompt=prompt,num_frames=video_length,negative_prompt=negative_prompt,height=sample_size[0],width=sample_size[1],generator=generator,guidance_scale=guidance_scale,num_inference_steps=num_inference_steps,video=input_video,mask_video=input_video_mask).videos | |
if not os.path.exists(save_path): | |
os.makedirs(save_path, exist_ok=True) | |
index = len([path for path in os.listdir(save_path)]) + 1 | |
prefix = str(index).zfill(8) | |
filename2 = f"{uuid.uuid4().hex}.mp4" | |
video_path = os.path.join(save_path, filename2) | |
save_videos_grid(sample, video_path, fps=fps) | |
print("Video saved to grid, uploading to huggingface") | |
hf_api = HfApi() | |
repo_id = "meepmoo/h4h4jejdf" # Set your HF repo | |
hf_api.upload_file( | |
path_or_fileobj=video_path, | |
path_in_repo=filename2, | |
repo_id=repo_id, | |
token=tokenxf, | |
repo_type="model" | |
) | |
print("Video uploaded to huggingface returing output") | |
result_url = f"https://huggingface.co/{repo_id}/resolve/main/{filename2}" | |
job_id = values.get("job_id", "default-job-id") # For RunPod job tracking | |
return {"jobId": job_id, "result": result_url, "status": "DONE"} | |
runpod.serverless.start({"handler": generate}) | |