BestWishYsh commited on
Commit
ac6cff8
1 Parent(s): a5e282c

Update app.py

Browse files

Fix weight loading.

Files changed (1) hide show
  1. app.py +14 -3
app.py CHANGED
@@ -91,7 +91,8 @@ class MagicTimeController:
91
  self.vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").cuda()
92
  self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
93
  self.text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
94
-
 
95
  # self.tokenizer = tokenizer
96
  # self.text_encoder = text_encoder
97
  # self.vae = vae
@@ -99,6 +100,7 @@ class MagicTimeController:
99
  # self.text_model = text_model
100
 
101
  self.update_motion_module(self.motion_module_list[0])
 
102
  self.update_dreambooth(self.dreambooth_list[0])
103
 
104
 
@@ -121,9 +123,10 @@ class MagicTimeController:
121
  converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, self.vae.config)
122
  self.vae.load_state_dict(converted_vae_checkpoint)
123
 
124
- converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, self.unet.config)
 
125
  self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
126
-
127
  text_model = copy.deepcopy(self.text_model)
128
  self.text_encoder = convert_ldm_clip_text_model(text_model, dreambooth_state_dict)
129
 
@@ -143,6 +146,13 @@ class MagicTimeController:
143
  assert len(unexpected) == 0
144
  return gr.Dropdown()
145
 
 
 
 
 
 
 
 
146
 
147
  def magictime(
148
  self,
@@ -155,6 +165,7 @@ class MagicTimeController:
155
  seed_textbox,
156
  ):
157
  if self.selected_motion_module != motion_module_dropdown: self.update_motion_module(motion_module_dropdown)
 
158
  if self.selected_dreambooth != dreambooth_dropdown: self.update_dreambooth(dreambooth_dropdown)
159
 
160
  if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
 
91
  self.vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").cuda()
92
  self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
93
  self.text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
94
+ self.unet_model = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs))
95
+
96
  # self.tokenizer = tokenizer
97
  # self.text_encoder = text_encoder
98
  # self.vae = vae
 
100
  # self.text_model = text_model
101
 
102
  self.update_motion_module(self.motion_module_list[0])
103
+ self.update_motion_module_2(self.motion_module_list[0])
104
  self.update_dreambooth(self.dreambooth_list[0])
105
 
106
 
 
123
  converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, self.vae.config)
124
  self.vae.load_state_dict(converted_vae_checkpoint)
125
 
126
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, self.unet_model.config)
127
+ self.unet = copy.deepcopy(self.unet_model)
128
  self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
129
+
130
  text_model = copy.deepcopy(self.text_model)
131
  self.text_encoder = convert_ldm_clip_text_model(text_model, dreambooth_state_dict)
132
 
 
146
  assert len(unexpected) == 0
147
  return gr.Dropdown()
148
 
149
+ def update_motion_module_2(self, motion_module_dropdown):
150
+ self.selected_motion_module = motion_module_dropdown
151
+ motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
152
+ motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
153
+ _, unexpected = self.unet_model.load_state_dict(motion_module_state_dict, strict=False)
154
+ assert len(unexpected) == 0
155
+ return gr.Dropdown()
156
 
157
  def magictime(
158
  self,
 
165
  seed_textbox,
166
  ):
167
  if self.selected_motion_module != motion_module_dropdown: self.update_motion_module(motion_module_dropdown)
168
+ if self.selected_motion_module != motion_module_dropdown: self.update_motion_module_2(motion_module_dropdown)
169
  if self.selected_dreambooth != dreambooth_dropdown: self.update_dreambooth(dreambooth_dropdown)
170
 
171
  if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()