BestWishYsh commited on
Commit
62c169e
·
verified ·
1 Parent(s): f6d0306

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -11
app.py CHANGED
@@ -15,7 +15,7 @@ from utils.unet import UNet3DConditionModel
15
  from utils.pipeline_magictime import MagicTimePipeline
16
  from utils.util import save_videos_grid, convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint, load_diffusers_lora_unet, convert_ldm_clip_text_model
17
 
18
- pretrained_model_path = "./ckpts/Base_Model/stable-diffusion-v1-5"
19
  inference_config_path = "./sample_configs/RealisticVision.yaml"
20
  magic_adapter_s_path = "./ckpts/Magic_Weights/magic_adapter_s/magic_adapter_s.ckpt"
21
  magic_adapter_t_path = "./ckpts/Magic_Weights/magic_adapter_t"
@@ -91,16 +91,10 @@ class MagicTimeController:
91
  self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
92
 
93
  self.text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
94
-
95
- self.update_dreambooth(self.dreambooth_list[0])
96
  self.update_motion_module(self.motion_module_list[0])
97
-
98
- from swift import Swift
99
- magic_adapter_s_state_dict = torch.load(magic_adapter_s_path, map_location="cpu")
100
- self.unet = load_diffusers_lora_unet(self.unet, magic_adapter_s_state_dict, alpha=1.0)
101
- self.unet = Swift.from_pretrained(self.unet, magic_adapter_t_path)
102
- self.text_encoder = Swift.from_pretrained(self.text_encoder, magic_text_encoder_path)
103
-
104
 
105
  def refresh_motion_module(self):
106
  motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt"))
@@ -126,6 +120,13 @@ class MagicTimeController:
126
 
127
  text_model = copy.deepcopy(self.text_model)
128
  self.text_encoder = convert_ldm_clip_text_model(text_model, dreambooth_state_dict)
 
 
 
 
 
 
 
129
  return gr.Dropdown()
130
 
131
  def update_motion_module(self, motion_module_dropdown):
@@ -147,8 +148,8 @@ class MagicTimeController:
147
  height_slider,
148
  seed_textbox,
149
  ):
150
- if self.selected_dreambooth != dreambooth_dropdown: self.update_dreambooth(dreambooth_dropdown)
151
  if self.selected_motion_module != motion_module_dropdown: self.update_motion_module(motion_module_dropdown)
 
152
 
153
  if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
154
 
 
15
  from utils.pipeline_magictime import MagicTimePipeline
16
  from utils.util import save_videos_grid, convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint, load_diffusers_lora_unet, convert_ldm_clip_text_model
17
 
18
+ pretrained_model_path = "runwayml/stable-diffusion-v1-5"
19
  inference_config_path = "./sample_configs/RealisticVision.yaml"
20
  magic_adapter_s_path = "./ckpts/Magic_Weights/magic_adapter_s/magic_adapter_s.ckpt"
21
  magic_adapter_t_path = "./ckpts/Magic_Weights/magic_adapter_t"
 
91
  self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
92
 
93
  self.text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
94
+
 
95
  self.update_motion_module(self.motion_module_list[0])
96
+ self.update_dreambooth(self.dreambooth_list[0])
97
+
 
 
 
 
 
98
 
99
  def refresh_motion_module(self):
100
  motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt"))
 
120
 
121
  text_model = copy.deepcopy(self.text_model)
122
  self.text_encoder = convert_ldm_clip_text_model(text_model, dreambooth_state_dict)
123
+
124
+ from swift import Swift
125
+ magic_adapter_s_state_dict = torch.load(magic_adapter_s_path, map_location="cpu")
126
+ self.unet = load_diffusers_lora_unet(self.unet, magic_adapter_s_state_dict, alpha=1.0)
127
+ self.unet = Swift.from_pretrained(self.unet, magic_adapter_t_path)
128
+ self.text_encoder = Swift.from_pretrained(self.text_encoder, magic_text_encoder_path)
129
+
130
  return gr.Dropdown()
131
 
132
  def update_motion_module(self, motion_module_dropdown):
 
148
  height_slider,
149
  seed_textbox,
150
  ):
 
151
  if self.selected_motion_module != motion_module_dropdown: self.update_motion_module(motion_module_dropdown)
152
+ if self.selected_dreambooth != dreambooth_dropdown: self.update_dreambooth(dreambooth_dropdown)
153
 
154
  if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
155