BestWishYsh commited on
Commit
2a90ae7
·
verified ·
1 Parent(s): b82463b

Simplify the code

Browse files
Files changed (1) hide show
  1. app.py +23 -40
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import copy
 
3
  import torch
4
  import random
5
  import gradio as gr
@@ -7,7 +8,7 @@ from glob import glob
7
  from omegaconf import OmegaConf
8
  from safetensors import safe_open
9
  from diffusers import AutoencoderKL
10
- from diffusers import EulerDiscreteScheduler, DDIMScheduler
11
  from diffusers.utils.import_utils import is_xformers_available
12
  from transformers import CLIPTextModel, CLIPTokenizer
13
 
@@ -66,7 +67,6 @@ device = torch.device('cuda:0')
66
 
67
  class MagicTimeController:
68
  def __init__(self):
69
-
70
  # config dirs
71
  self.basedir = os.getcwd()
72
  self.stable_diffusion_dir = os.path.join(self.basedir, "ckpts", "Base_Model")
@@ -93,18 +93,11 @@ class MagicTimeController:
93
  self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).to(device)
94
  self.text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
95
  self.unet_model = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs))
96
-
97
- # self.tokenizer = tokenizer
98
- # self.text_encoder = text_encoder
99
- # self.vae = vae
100
- # self.unet = unet
101
- # self.text_model = text_model
102
 
103
  self.update_motion_module(self.motion_module_list[0])
104
  self.update_motion_module_2(self.motion_module_list[0])
105
  self.update_dreambooth(self.dreambooth_list[0])
106
 
107
-
108
  def refresh_motion_module(self):
109
  motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt"))
110
  self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
@@ -113,7 +106,7 @@ class MagicTimeController:
113
  dreambooth_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
114
  self.dreambooth_list = [os.path.basename(p) for p in dreambooth_list]
115
 
116
- def update_dreambooth(self, dreambooth_dropdown):
117
  self.selected_dreambooth = dreambooth_dropdown
118
 
119
  dreambooth_dropdown = os.path.join(self.personalized_model_dir, dreambooth_dropdown)
@@ -124,26 +117,18 @@ class MagicTimeController:
124
  converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, self.vae.config)
125
  self.vae.load_state_dict(converted_vae_checkpoint)
126
 
127
- if self.unet is not None:
128
- del self.unet
129
- torch.cuda.empty_cache()
130
- torch.cuda.empty_cache()
131
- torch.cuda.empty_cache()
132
- torch.cuda.empty_cache()
133
- torch.cuda.empty_cache()
134
- torch.cuda.empty_cache()
135
  converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, self.unet_model.config)
136
  self.unet = copy.deepcopy(self.unet_model)
137
  self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
138
 
139
- if self.text_encoder is not None:
140
- del self.text_encoder
141
- torch.cuda.empty_cache()
142
- torch.cuda.empty_cache()
143
- torch.cuda.empty_cache()
144
- torch.cuda.empty_cache()
145
- torch.cuda.empty_cache()
146
- torch.cuda.empty_cache()
147
  text_model = copy.deepcopy(self.text_model)
148
  self.text_encoder = convert_ldm_clip_text_model(text_model, dreambooth_state_dict)
149
 
@@ -182,17 +167,23 @@ class MagicTimeController:
182
  height_slider,
183
  seed_textbox,
184
  ):
 
 
 
185
  if self.selected_motion_module != motion_module_dropdown: self.update_motion_module(motion_module_dropdown)
186
  if self.selected_motion_module != motion_module_dropdown: self.update_motion_module_2(motion_module_dropdown)
187
  if self.selected_dreambooth != dreambooth_dropdown: self.update_dreambooth(dreambooth_dropdown)
188
 
 
 
 
189
  if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
190
 
191
  pipeline = MagicTimePipeline(
192
  vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
193
  scheduler=DDIMScheduler(**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
194
- ).to(device)
195
-
196
  if int(seed_textbox) > 0: seed = int(seed_textbox)
197
  else: seed = random.randint(1, 1e16)
198
  torch.manual_seed(int(seed))
@@ -225,16 +216,12 @@ class MagicTimeController:
225
  "seed": seed,
226
  "dreambooth": dreambooth_dropdown,
227
  }
 
 
 
228
  return gr.Video(value=save_sample_path), gr.Json(value=json_config)
229
 
230
- # inference_config = OmegaConf.load(inference_config_path)[1]
231
- # tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
232
- # text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").cuda()
233
- # vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").cuda()
234
- # unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)).cuda()
235
- # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
236
- # controller = MagicTimeController(tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, unet=unet, text_model=text_model)
237
- controller = MagicTimeController()
238
 
239
  def ui():
240
  with gr.Blocks(css=css) as demo:
@@ -255,9 +242,6 @@ def ui():
255
  dreambooth_dropdown = gr.Dropdown( label="DreamBooth Model", choices=controller.dreambooth_list, value=controller.dreambooth_list[0], interactive=True )
256
  motion_module_dropdown = gr.Dropdown( label="Motion Module", choices=controller.motion_module_list, value=controller.motion_module_list[0], interactive=True )
257
 
258
- dreambooth_dropdown.change(fn=controller.update_dreambooth, inputs=[dreambooth_dropdown], outputs=[dreambooth_dropdown])
259
- motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown], outputs=[motion_module_dropdown])
260
-
261
  prompt_textbox = gr.Textbox( label="Prompt", lines=3 )
262
  negative_prompt_textbox = gr.Textbox( label="Negative Prompt", lines=3, value="worst quality, low quality, nsfw, logo")
263
 
@@ -290,7 +274,6 @@ def ui():
290
 
291
  return demo
292
 
293
-
294
  if __name__ == "__main__":
295
  demo = ui()
296
  demo.queue(max_size=20)
 
1
  import os
2
  import copy
3
+ import time
4
  import torch
5
  import random
6
  import gradio as gr
 
8
  from omegaconf import OmegaConf
9
  from safetensors import safe_open
10
  from diffusers import AutoencoderKL
11
+ from diffusers import DDIMScheduler
12
  from diffusers.utils.import_utils import is_xformers_available
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
 
 
67
 
68
  class MagicTimeController:
69
  def __init__(self):
 
70
  # config dirs
71
  self.basedir = os.getcwd()
72
  self.stable_diffusion_dir = os.path.join(self.basedir, "ckpts", "Base_Model")
 
93
  self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).to(device)
94
  self.text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
95
  self.unet_model = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs))
 
 
 
 
 
 
96
 
97
  self.update_motion_module(self.motion_module_list[0])
98
  self.update_motion_module_2(self.motion_module_list[0])
99
  self.update_dreambooth(self.dreambooth_list[0])
100
 
 
101
  def refresh_motion_module(self):
102
  motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt"))
103
  self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
 
106
  dreambooth_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
107
  self.dreambooth_list = [os.path.basename(p) for p in dreambooth_list]
108
 
109
+ def update_dreambooth(self, dreambooth_dropdown, motion_module_dropdown=None):
110
  self.selected_dreambooth = dreambooth_dropdown
111
 
112
  dreambooth_dropdown = os.path.join(self.personalized_model_dir, dreambooth_dropdown)
 
117
  converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, self.vae.config)
118
  self.vae.load_state_dict(converted_vae_checkpoint)
119
 
120
+ del self.unet
121
+ self.unet = None
122
+ torch.cuda.empty_cache()
123
+ time.sleep(1)
 
 
 
 
124
  converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, self.unet_model.config)
125
  self.unet = copy.deepcopy(self.unet_model)
126
  self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
127
 
128
+ del self.text_encoder
129
+ self.text_encoder = None
130
+ torch.cuda.empty_cache()
131
+ time.sleep(1)
 
 
 
 
132
  text_model = copy.deepcopy(self.text_model)
133
  self.text_encoder = convert_ldm_clip_text_model(text_model, dreambooth_state_dict)
134
 
 
167
  height_slider,
168
  seed_textbox,
169
  ):
170
+ torch.cuda.empty_cache()
171
+ time.sleep(1)
172
+
173
  if self.selected_motion_module != motion_module_dropdown: self.update_motion_module(motion_module_dropdown)
174
  if self.selected_motion_module != motion_module_dropdown: self.update_motion_module_2(motion_module_dropdown)
175
  if self.selected_dreambooth != dreambooth_dropdown: self.update_dreambooth(dreambooth_dropdown)
176
 
177
+ while self.text_encoder is None or self.unet is None:
178
+ self.update_dreambooth(dreambooth_dropdown, motion_module_dropdown)
179
+
180
  if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
181
 
182
  pipeline = MagicTimePipeline(
183
  vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
184
  scheduler=DDIMScheduler(**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
185
+ ).to(device)
186
+
187
  if int(seed_textbox) > 0: seed = int(seed_textbox)
188
  else: seed = random.randint(1, 1e16)
189
  torch.manual_seed(int(seed))
 
216
  "seed": seed,
217
  "dreambooth": dreambooth_dropdown,
218
  }
219
+
220
+ torch.cuda.empty_cache()
221
+ time.sleep(1)
222
  return gr.Video(value=save_sample_path), gr.Json(value=json_config)
223
 
224
+ controller = MagicTimeController()
 
 
 
 
 
 
 
225
 
226
  def ui():
227
  with gr.Blocks(css=css) as demo:
 
242
  dreambooth_dropdown = gr.Dropdown( label="DreamBooth Model", choices=controller.dreambooth_list, value=controller.dreambooth_list[0], interactive=True )
243
  motion_module_dropdown = gr.Dropdown( label="Motion Module", choices=controller.motion_module_list, value=controller.motion_module_list[0], interactive=True )
244
 
 
 
 
245
  prompt_textbox = gr.Textbox( label="Prompt", lines=3 )
246
  negative_prompt_textbox = gr.Textbox( label="Negative Prompt", lines=3, value="worst quality, low quality, nsfw, logo")
247
 
 
274
 
275
  return demo
276
 
 
277
  if __name__ == "__main__":
278
  demo = ui()
279
  demo.queue(max_size=20)