vilarin commited on
Commit
96fa82a
1 Parent(s): 2ba49a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -66
app.py CHANGED
@@ -10,7 +10,8 @@ from glob import glob
10
  from pathlib import Path
11
  from typing import Optional
12
 
13
- from diffusers import StableVideoDiffusionPipeline, UNetSpatioTemporalConditionControlNetModel
 
14
  from diffusers.utils import load_image, export_to_video
15
 
16
  import uuid
@@ -20,9 +21,6 @@ from huggingface_hub import hf_hub_download
20
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
21
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
22
  # Constants
23
- base = "stabilityai/stable-video-diffusion-img2vid-xt"
24
- model = "ECNU-CILab/ExVideo-SVD-128f-v1"
25
-
26
  MAX_SEED = np.iinfo(np.int32).max
27
 
28
  CSS = """
@@ -38,30 +36,15 @@ JS = """function () {
38
  }
39
  }"""
40
 
41
- downloaded_model_path = hf_hub_download(
42
- repo_id=model,
43
- filename=model.fp16.safetensors,
44
- local_dir="model"
45
- )
46
-
47
- MODEL_PATH = "./model/"
48
-
49
-
50
 
51
  # Ensure model and scheduler are initialized in GPU-enabled function
52
  if torch.cuda.is_available():
 
 
 
 
 
53
 
54
- unet = UNetSpatioTemporalConditionControlNetModel.from_pretrained(
55
- MODEL_PATH,
56
- low_cpu_mem_usage=True,
57
- variant="fp16",
58
- )
59
-
60
- pipe = StableVideoDiffusionPipeline.from_pretrained(
61
- base,
62
- unet=unet,
63
- torch_dtype=torch.float16,
64
- variant="fp16").to("cuda")
65
 
66
  # function source codes modified from multimodalart/stable-video-diffusion
67
  @spaces.GPU(duration=120)
@@ -69,11 +52,7 @@ def generate(
69
  image: Image,
70
  seed: Optional[int] = -1,
71
  motion_bucket_id: int = 127,
72
- fps_id: int = 6,
73
- version: str = "svd_xt",
74
- cond_aug: float = 0.02,
75
- decoding_t: int = 1,
76
- device: str = "cuda",
77
  output_folder: str = "outputs",
78
  progress=gr.Progress(track_tqdm=True)):
79
 
@@ -83,49 +62,29 @@ def generate(
83
  if image.mode == "RGBA":
84
  image = image.convert("RGB")
85
 
86
- generator = torch.manual_seed(seed)
87
 
88
  os.makedirs(output_folder, exist_ok=True)
89
  base_count = len(glob(os.path.join(output_folder, "*.mp4")))
90
  video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
91
 
92
- frames = pipe(image, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=0.1, num_frames=25).frames[0]
 
 
 
 
 
 
 
 
 
 
 
 
93
  export_to_video(frames, video_path, fps=fps_id)
94
- torch.manual_seed(seed)
95
 
96
  return video_path, seed
97
 
98
- def resize_image(image, output_size=(1024, 576)):
99
- # Calculate aspect ratios
100
- target_aspect = output_size[0] / output_size[1] # Aspect ratio of the desired size
101
- image_aspect = image.width / image.height # Aspect ratio of the original image
102
-
103
- # Resize then crop if the original image is larger
104
- if image_aspect > target_aspect:
105
- # Resize the image to match the target height, maintaining aspect ratio
106
- new_height = output_size[1]
107
- new_width = int(new_height * image_aspect)
108
- resized_image = image.resize((new_width, new_height), Image.LANCZOS)
109
- # Calculate coordinates for cropping
110
- left = (new_width - output_size[0]) / 2
111
- top = 0
112
- right = (new_width + output_size[0]) / 2
113
- bottom = output_size[1]
114
- else:
115
- # Resize the image to match the target width, maintaining aspect ratio
116
- new_width = output_size[0]
117
- new_height = int(new_width / image_aspect)
118
- resized_image = image.resize((new_width, new_height), Image.LANCZOS)
119
- # Calculate coordinates for cropping
120
- left = 0
121
- top = (new_height - output_size[1]) / 2
122
- right = output_size[0]
123
- bottom = (new_height + output_size[1]) / 2
124
-
125
- # Crop the image
126
- cropped_image = resized_image.crop((left, top, right, bottom))
127
- return cropped_image
128
-
129
 
130
  examples = [
131
  "./train.jpg",
@@ -162,7 +121,7 @@ with gr.Blocks(css=CSS, js=JS, theme="soft") as demo:
162
  fps_id = gr.Slider(
163
  label="Frames per second",
164
  info="The length of your video in seconds will be 25/fps",
165
- value=6,
166
  minimum=5,
167
  maximum=30
168
  )
@@ -178,8 +137,6 @@ with gr.Blocks(css=CSS, js=JS, theme="soft") as demo:
178
  examples_per_page=4,
179
  )
180
 
181
- image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
182
-
183
  generate_btn.click(fn=generate, inputs=[image, seed, motion_bucket_id, fps_id], outputs=[video, seed], api_name="video")
184
 
185
  demo.queue().launch()
 
10
  from pathlib import Path
11
  from typing import Optional
12
 
13
+ from diffsynth import ModelManager, SVDVideoPipeline, HunyuanDiTImagePipeline
14
+ from diffsynth import ModelManager
15
  from diffusers.utils import load_image, export_to_video
16
 
17
  import uuid
 
21
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
22
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
23
  # Constants
 
 
 
24
  MAX_SEED = np.iinfo(np.int32).max
25
 
26
  CSS = """
 
36
  }
37
  }"""
38
 
 
 
 
 
 
 
 
 
 
39
 
40
  # Ensure model and scheduler are initialized in GPU-enabled function
41
  if torch.cuda.is_available():
42
+ model_manager = ModelManager(
43
+ torch_dtype=torch.float16,
44
+ device="cuda",
45
+ model_id_list=["stable-video-diffusion-img2vid-xt", "ExVideo-SVD-128f-v1"])
46
+ pipe = SVDVideoPipeline.from_model_manager(model_manager)
47
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  # function source codes modified from multimodalart/stable-video-diffusion
50
  @spaces.GPU(duration=120)
 
52
  image: Image,
53
  seed: Optional[int] = -1,
54
  motion_bucket_id: int = 127,
55
+ fps_id: int = 25,
 
 
 
 
56
  output_folder: str = "outputs",
57
  progress=gr.Progress(track_tqdm=True)):
58
 
 
62
  if image.mode == "RGBA":
63
  image = image.convert("RGB")
64
 
65
+ torch.manual_seed(seed)
66
 
67
  os.makedirs(output_folder, exist_ok=True)
68
  base_count = len(glob(os.path.join(output_folder, "*.mp4")))
69
  video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
70
 
71
+ frames = pipe(
72
+ input_image=image.resize((512, 512)),
73
+ num_frames=128,
74
+ fps=fps_id,
75
+ height=512,
76
+ width=512,
77
+ motion_bucket_id=motion_bucket_id,
78
+ num_inference_steps=50,
79
+ min_cfg_scale=2,
80
+ max_cfg_scale=2,
81
+ contrast_enhance_scale=1.2
82
+ ).frames[0]
83
+
84
  export_to_video(frames, video_path, fps=fps_id)
 
85
 
86
  return video_path, seed
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  examples = [
90
  "./train.jpg",
 
121
  fps_id = gr.Slider(
122
  label="Frames per second",
123
  info="The length of your video in seconds will be 25/fps",
124
+ value=25,
125
  minimum=5,
126
  maximum=30
127
  )
 
137
  examples_per_page=4,
138
  )
139
 
 
 
140
  generate_btn.click(fn=generate, inputs=[image, seed, motion_bucket_id, fps_id], outputs=[video, seed], api_name="video")
141
 
142
  demo.queue().launch()