BestWishYsh commited on
Commit
0739b68
·
verified ·
1 Parent(s): da1d89a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -49
app.py CHANGED
@@ -2,7 +2,6 @@ import os
2
  import copy
3
  import torch
4
  import random
5
- import spaces
6
  import gradio as gr
7
  from glob import glob
8
  from omegaconf import OmegaConf
@@ -15,6 +14,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
15
  from utils.unet import UNet3DConditionModel
16
  from utils.pipeline_magictime import MagicTimePipeline
17
  from utils.util import save_videos_grid, convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint, load_diffusers_lora_unet, convert_ldm_clip_text_model
 
18
 
19
  pretrained_model_path = "./ckpts/Base_Model/stable-diffusion-v1-5"
20
  inference_config_path = "./sample_configs/RealisticVision.yaml"
@@ -62,6 +62,7 @@ examples = [
62
  print(f"### Cleaning cached examples ...")
63
  os.system(f"rm -rf gradio_cached_examples/")
64
 
 
65
 
66
  class MagicTimeController:
67
  def __init__(self):
@@ -87,9 +88,9 @@ class MagicTimeController:
87
  self.inference_config = OmegaConf.load(inference_config_path)[1]
88
 
89
  self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
90
- self.text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").cuda()
91
- self.vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").cuda()
92
- self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
93
  self.text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
94
  self.unet_model = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs))
95
 
@@ -153,7 +154,8 @@ class MagicTimeController:
153
  _, unexpected = self.unet_model.load_state_dict(motion_module_state_dict, strict=False)
154
  assert len(unexpected) == 0
155
  return gr.Dropdown()
156
-
 
157
  def magictime(
158
  self,
159
  dreambooth_dropdown,
@@ -173,7 +175,7 @@ class MagicTimeController:
173
  pipeline = MagicTimePipeline(
174
  vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
175
  scheduler=DDIMScheduler(**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
176
- ).to("cuda")
177
 
178
  if int(seed_textbox) > 0: seed = int(seed_textbox)
179
  else: seed = random.randint(1, 1e16)
@@ -182,7 +184,7 @@ class MagicTimeController:
182
  assert seed == torch.initial_seed()
183
  print(f"### seed: {seed}")
184
 
185
- generator = torch.Generator(device="cuda")
186
  generator.manual_seed(seed)
187
 
188
  sample = pipeline(
@@ -208,51 +210,72 @@ class MagicTimeController:
208
  "dreambooth": dreambooth_dropdown,
209
  }
210
  return gr.Video(value=save_sample_path), gr.Json(value=json_config)
211
-
212
- controller = MagicTimeController()
213
 
214
- @spaces.GPU(duration=300)
215
- def magictime_interface(
216
- dreambooth_dropdown,
217
- motion_module_dropdown,
218
- prompt_textbox,
219
- negative_prompt_textbox,
220
- width_slider,
221
- height_slider,
222
- seed_textbox,
223
- ):
224
- return controller.magictime(
225
- dreambooth_dropdown,
226
- motion_module_dropdown,
227
- prompt_textbox,
228
- negative_prompt_textbox,
229
- width_slider,
230
- height_slider,
231
- seed_textbox,
232
- )
233
 
234
- inputs = [
235
- gr.Dropdown(label="DreamBooth Model", choices=controller.dreambooth_list, value=controller.dreambooth_list[0]),
236
- gr.Dropdown(label="Motion Module", choices=controller.motion_module_list, value=controller.motion_module_list[0]),
237
- gr.Textbox(label="Prompt", lines=3),
238
- gr.Textbox(label="Negative Prompt", lines=3, value="worst quality, low quality, nsfw, logo"),
239
- gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64),
240
- gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64),
241
- gr.Textbox(label="Seed", value="-1"),
242
- ]
 
 
 
 
 
 
 
 
 
243
 
244
- outputs = [
245
- gr.Video(label="Generated Animation"),
246
- gr.Json(label="Config")
247
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
- iface = gr.Interface(
250
- fn=magictime_interface,
251
- inputs=inputs,
252
- outputs=outputs,
253
- title="MagicTime Controller",
254
- examples=examples
255
- )
256
 
257
  if __name__ == "__main__":
258
- iface.launch()
 
 
 
2
  import copy
3
  import torch
4
  import random
 
5
  import gradio as gr
6
  from glob import glob
7
  from omegaconf import OmegaConf
 
14
  from utils.unet import UNet3DConditionModel
15
  from utils.pipeline_magictime import MagicTimePipeline
16
  from utils.util import save_videos_grid, convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint, load_diffusers_lora_unet, convert_ldm_clip_text_model
17
+ import spaces
18
 
19
  pretrained_model_path = "./ckpts/Base_Model/stable-diffusion-v1-5"
20
  inference_config_path = "./sample_configs/RealisticVision.yaml"
 
62
  print(f"### Cleaning cached examples ...")
63
  os.system(f"rm -rf gradio_cached_examples/")
64
 
65
+ device = torch.device('cuda:0')
66
 
67
  class MagicTimeController:
68
  def __init__(self):
 
88
  self.inference_config = OmegaConf.load(inference_config_path)[1]
89
 
90
  self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
91
+ self.text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device)
92
+ self.vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
93
+ self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).to(device)
94
  self.text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
95
  self.unet_model = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs))
96
 
 
154
  _, unexpected = self.unet_model.load_state_dict(motion_module_state_dict, strict=False)
155
  assert len(unexpected) == 0
156
  return gr.Dropdown()
157
+
158
+ @spaces.GPU(duration=300)
159
  def magictime(
160
  self,
161
  dreambooth_dropdown,
 
175
  pipeline = MagicTimePipeline(
176
  vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
177
  scheduler=DDIMScheduler(**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
178
+ ).to(device)
179
 
180
  if int(seed_textbox) > 0: seed = int(seed_textbox)
181
  else: seed = random.randint(1, 1e16)
 
184
  assert seed == torch.initial_seed()
185
  print(f"### seed: {seed}")
186
 
187
+ generator = torch.Generator(device=device)
188
  generator.manual_seed(seed)
189
 
190
  sample = pipeline(
 
210
  "dreambooth": dreambooth_dropdown,
211
  }
212
  return gr.Video(value=save_sample_path), gr.Json(value=json_config)
 
 
213
 
214
+ # inference_config = OmegaConf.load(inference_config_path)[1]
215
+ # tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
216
+ # text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").cuda()
217
+ # vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").cuda()
218
+ # unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)).cuda()
219
+ # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
220
+ # controller = MagicTimeController(tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, unet=unet, text_model=text_model)
221
+ controller = MagicTimeController()
 
 
 
 
 
 
 
 
 
 
 
222
 
223
+ def ui():
224
+ with gr.Blocks(css=css) as demo:
225
+ gr.Markdown(
226
+ """
227
+ <div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
228
+ <img src='https://www.pnglog.com/48rWnj.png' style='width: 300px; height: auto; margin-right: 10px;' />
229
+ </div>
230
+
231
+ <h2 align="center"> <a href="https://github.com/PKU-YuanGroup/MagicTime">MagicTime: Time-lapse Video Generation Models as Metamorphic Simulators</a></h2>
232
+ <h5 style="text-align:left;">If you like our project, please give us a star ⭐ on GitHub for the latest update.</h5>
233
+
234
+ [GitHub](https://github.com/PKU-YuanGroup/MagicTime) | [arXiv](https://arxiv.org/abs/2404.05014) | [Home Page](https://pku-yuangroup.github.io/MagicTime/) | [Dataset](https://drive.google.com/drive/folders/1WsomdkmSp3ql3ImcNsmzFuSQ9Qukuyr8?usp=sharing)
235
+ """
236
+ )
237
+ with gr.Row():
238
+ with gr.Column():
239
+ dreambooth_dropdown = gr.Dropdown( label="DreamBooth Model", choices=controller.dreambooth_list, value=controller.dreambooth_list[0], interactive=True )
240
+ motion_module_dropdown = gr.Dropdown( label="Motion Module", choices=controller.motion_module_list, value=controller.motion_module_list[0], interactive=True )
241
 
242
+ dreambooth_dropdown.change(fn=controller.update_dreambooth, inputs=[dreambooth_dropdown], outputs=[dreambooth_dropdown])
243
+ motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown], outputs=[motion_module_dropdown])
244
+
245
+ prompt_textbox = gr.Textbox( label="Prompt", lines=3 )
246
+ negative_prompt_textbox = gr.Textbox( label="Negative Prompt", lines=3, value="worst quality, low quality, nsfw, logo")
247
+
248
+ with gr.Accordion("Advance", open=False):
249
+ with gr.Row():
250
+ width_slider = gr.Slider( label="Width", value=512, minimum=256, maximum=1024, step=64 )
251
+ height_slider = gr.Slider( label="Height", value=512, minimum=256, maximum=1024, step=64 )
252
+ with gr.Row():
253
+ seed_textbox = gr.Textbox( label="Seed (-1 means random)", value=-1)
254
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
255
+ seed_button.click(fn=lambda: gr.Textbox(value=random.randint(1, 1e16)), inputs=[], outputs=[seed_textbox])
256
+
257
+ generate_button = gr.Button( value="Generate", variant='primary' )
258
+
259
+ with gr.Column():
260
+ result_video = gr.Video( label="Generated Animation", interactive=False )
261
+ json_config = gr.Json( label="Config", value=None )
262
+
263
+ inputs = [dreambooth_dropdown, motion_module_dropdown, prompt_textbox, negative_prompt_textbox, width_slider, height_slider, seed_textbox]
264
+ outputs = [result_video, json_config]
265
+
266
+ generate_button.click( fn=controller.magictime, inputs=inputs, outputs=outputs )
267
+
268
+ gr.Markdown(
269
+ """
270
+ <h5 style="text-align:left;">Warning: It is worth noting that even if we use the same seed and prompt but we change a machine, the results will be different. If you find a better seed and prompt, please tell me in a GitHub issue.</h5>
271
+ """
272
+ )
273
+ gr.Examples( fn=controller.magictime, examples=examples, inputs=inputs, outputs=outputs, cache_examples=True )
274
+
275
+ return demo
276
 
 
 
 
 
 
 
 
277
 
278
  if __name__ == "__main__":
279
+ demo = ui()
280
+ demo.queue(max_size=20)
281
+ demo.launch()